ThunderKittens
Tile primitives for speedy kernels
Install / Use
/learn @HazyResearch/ThunderKittensREADME
ThunderKittens
<div align="center" > <img src="assets/thunderkittens.png" height=350 alt="ThunderKittens logo" style="margin-bottom:px"/><br/> <em>ThunderKittens: Tile primitives for speedy kernels</em><br/><br/> </div>ThunderKittens is a framework to make it easy to write fast deep learning kernels in CUDA. It is built around three key principles:
- Simplicity. ThunderKittens is stupidly simple to write.
- Extensibility. ThunderKittens is natively embedded into CUDA, so that if you need more than ThunderKittens can offer, it won’t get in your way of building it yourself.
- Speed. Kernels written in ThunderKittens should be at least as fast as those written from scratch -- especially because ThunderKittens can do things the “right” way under the hood. We think our FlashAttention-3 implementation speaks for this point.
ThunderKittens began as an internal art project and is maintained by graduate students at the Hazy Research Lab. Nonetheless, many AI companies use it for production-scale training and inference (e.g., Together AI, Jump Trading, and Cursor).
ThunderKittens is built for NVIDIA GPUs. For AMD GPUs, check out HipKittens.
Recent Updates
Jan 11, 2026: ThunderKittens 2.0 is out!
- This release brings full support for Blackwell GPUs along with MXFP8 and NVFP4 precision, and merges major contributions from across the industry.
- The repository structure has changed. We no longer support the repo as a Python package (i.e., a top-level
setup.py). Kernels under the/kernelsdirectory must now be compiled individually. Makefiles, tests, and benchmarks reside alongside their corresponding kernel source files. - We no longer actively support Ampere GPUs. While ThunderKittens should still work on Ampere, we do not plan to bring further support to it.
Overview
ThunderKittens is built from the hardware up; we do what the silicon tells us. And modern GPUs tell us that they want to work with fairly small tiles of data. A GPU is not really a 1000x1000 matrix multiply machine (even if it is often used as such); it’s a manycore processor where each core can efficiently run ~16x16 matrix multiplies. Consequently, ThunderKittens is built around manipulating tiles of data no smaller than 16x16 values.
ThunderKittens makes a few tricky things easy that enable high utilization on modern hardware.
- Tensor cores. ThunderKittens can call fast tensor core functions, including asynchronous WGMMA calls on H100 GPUs and TCGEN05 calls on B200 GPUs.
- Shared Memory. I got ninety-nine problems but a bank conflict ain’t one.
- Loads and stores. Hide latencies with asynchronous copies and address generation with TMA.
- Distributed Shared Memory. L2 is so last year.
- Worker overlapping. Use our Load-Store-Compute-Finish template to overlap work and I/O.
- GPU networking. ThunderKittens lets you transfer data over NVLink and utilize NVSwitch acceleration for fast multi-GPU operations.
Example: A Simple Matrix Multiplication Kernel
For example, here’s an example of what a simple matrix multiplication kernel for an H100 looks like written in ThunderKittens.
#include "kittens.cuh"
#include "prototype.cuh"
using namespace kittens;
using namespace kittens::prototype;
using namespace kittens::prototype::lcf;
template<int M_BLOCK, int N_BLOCK>
struct matmul_layout {
using base_tile = st_bf<64, 64>;
using global_layout = gl<bf16, 1, 1, -1, -1, base_tile>;
struct globals { global_layout A, B, C; };
struct input_block { base_tile a[M_BLOCK], b[N_BLOCK]; };
struct finish_block { base_tile c[M_BLOCK][N_BLOCK]; };
struct common_state { int2 coord; };
struct consumer_state { rt_fl<16, N_BLOCK*base_tile::cols> accum; };
};
template<int _M_BLOCK=2, int _N_BLOCK=4, int _SUPER_M=12>
struct matmul_template {
static constexpr int M_BLOCK = _M_BLOCK, N_BLOCK = _N_BLOCK, SUPER_M = _SUPER_M;
using layout = matmul_layout<M_BLOCK, N_BLOCK>;
using wide_tile = st_bf<64, 64*N_BLOCK>;
static constexpr int NUM_CONSUMER_WARPS=M_BLOCK*4, INPUT_PIPE_STAGES=4, PRODUCER_BARRIER_ARRIVALS=1;
// Helper functions
template<bool PERISISTENT_GRID=true> __host__ static inline dim3 grid(int M, int N, int K) {
return dim3(PERISISTENT_GRID ? 132 : M*N/(M_BLOCK*N_BLOCK*layout::base_tile::num_elements));
}
// ThunderKittens template functions
__device__ static inline void common_setup(common_setup_args<layout> args) {
int Rblocks = args.globals.C.rows() / (M_BLOCK*64), Cblocks = args.globals.C.cols() / (N_BLOCK*64);
int super_rows = (Rblocks/SUPER_M)*SUPER_M,
final_rows = Rblocks - super_rows,
super_repeat = SUPER_M*Cblocks;
int task_id = args.task_iter*gridDim.x + blockIdx.x;
if (task_id < super_rows * Cblocks)
args.common.coord = { SUPER_M*(task_id/super_repeat) + task_id%SUPER_M,
(task_id%super_repeat)/SUPER_M };
else if (task_id < Rblocks*Cblocks) {
int remainder_id = task_id - super_rows*Cblocks;
args.common.coord = { super_rows + (remainder_id%final_rows), remainder_id/final_rows };
}
else { // Id is too high, no more work to do
args.num_iters = -1;
return;
}
args.num_iters = args.globals.A.cols()/64;
int id = warpgroup::groupid() == NUM_CONSUMER_WARPS/4 ? 0 : warpgroup::groupid(); // producer sets as 0
args.common.coord = { args.common.coord.x*M_BLOCK + id, args.common.coord.y*N_BLOCK };
}
struct producer {
__device__ static void setup(producer_setup_args<layout> args) {
warpgroup::decrease_registers<40>(); // decrease registers for producers
}
__device__ static void load(producer_load_args<layout> args) {
if (warpgroup::laneid() == 0) {
tma::expect(args.inputs_arrived, args.input);
for(int i = 0; i < M_BLOCK; i++)
tma::load_async(args.input.a[i], args.globals.A,
{args.common.coord.x+i, args.iter}, args.inputs_arrived);
for(int i = 0; i < N_BLOCK; i++)
tma::load_async(args.input.b[i], args.globals.B,
{args.iter, args.common.coord.y+i}, args.inputs_arrived);
}
}
};
struct consumer {
__device__ static void setup(consumer_setup_args<layout> args) {
warpgroup::increase_registers<232>(); // increase registers for consumers
kittens::warp::zero(args.state.accum);
}
__device__ static void compute(consumer_compute_args<layout> args) {
warpgroup::mma_AB(
args.state.accum, // dest registers
args.input.a[warpgroup::groupid()], // A matrix
reinterpret_cast<wide_tile&>(args.input.b) // B matrix
);
warpgroup::mma_async_wait();
if (warp::laneid() == 0) arrive(args.inputs_finished);
}
__device__ static void finish(consumer_finish_args<layout> args) {
warpgroup::store(reinterpret_cast<wide_tile&>(args.finish.c[warpgroup::groupid()]), args.state.accum);
warpgroup::sync(warpgroup::groupid()+4);
if (warpgroup::laneid() == 0) for(int i = 0; i < N_BLOCK; i++) {
tma::store_async(args.globals.C, args.finish.c[warpgroup::groupid()][i],
{args.common.coord.x, args.common.coord.y+i});
tma::store_async_read_wait(); // wait that store is finished before reusing finish memory
}
kittens::warp::zero(args.state.accum);
if (warp::laneid() == 0) arrive(args.finish_finished);
}
};
};
Altogether, this is less than 100 lines of code, and achieves about 855 TFLOPs on an H100 (86% of theoretical max). We’ll go through some of these primitives more carefully later, in the ThunderKittens Manual section.
Installation
ThunderKittens itself is a header-only library. The library itself does not require any installation; just clone the repo, and include kittens.cuh. Easy money.
Hardware requirements
- ThunderKittens is mainly built and tested for Hopper and Blackwell GPUs.
- We no longer actively support Ampere GPUs. However, contributions are welcomed!
Build requirements
ThunderKittens does use a bunch of modern stuff, so it has fairly aggressive requirements.
-
CUDA 12.8+. We want our kittens to play in the nicest, most modern environment possible. Make sure you run the following to set up your CUDA environment properly:
export CUDA_HOME=/usr/local/cuda-<YOUR-CUDA-VERSION> # ex. cuda-12.8 export PATH=${CUDA_HOME}/bin:${PATH} export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:$LD_LIBRARY_PATH -
C++20. TK runs on
concepts. If you get weird compilation errors, chances are yourgccis out of date. Update your compiler with:sudo apt update sudo apt install gcc-11 g++-11 sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 100 --slave /usr/bin/g++ g++ /usr/bin/g++-11 sudo apt update sudo apt install clang-11
Sometimes, there's a libc10.so error, which you can fix with:
# Take the <PRINTED_PATH> from below
python -c "import torch; print(torch.file)"
# And run the command below
export LD_LIBRARY_PATH=<PRINTED_PATH>/lib:$LD_LIBRARY_PATH
ThunderKittens Manual
ThunderKittens is actually a pretty small library, in terms of what it gives you.
- Data
