SkillAgentSearch skills...

HGEMM

⚡️Write HGEMM from scratch using Tensor Cores with WMMA, MMA and CuTe API, Achieve Peak⚡️ Performance.

Install / Use

/learn @xlite-dev/HGEMM
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

⚡️⚡️Toy-HGEMM: Achieve the 98%~100% TFLOPS of cuBLAS 🎉🎉

toy-hgemm-library

📖Toy-HGEMM Library⚡️⚡️ is a library that write many HGEMM kernels from scratch using Tensor Cores with WMMA, MMA PTX and CuTe API, thus, can achieve 98%~100% performance of cuBLAS. The codes here are source from 📖LeetCUDA and exported as a standalone library, please checkout LeetCUDA for latest updates. Welcome to 🌟👆🏻star this repo to support me, many thanks ~ 🎉🎉

<div id="hgemm-sgemm"></div> <div align='center'> <img src='https://github.com/user-attachments/assets/71927ac9-72b3-4ce9-b0e2-788b5885bc99' height="170px" width="270px"> <img src='https://github.com/user-attachments/assets/05ef4f5e-d999-48ea-b58e-782cffb24e85' height="170px" width="270px"> <img src='https://github.com/user-attachments/assets/9472e970-c083-4b31-9252-3eeecc761078' height="170px" width="270px"> </div>

Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's default Tensor Cores math algorithm CUBLAS_GEMM_DEFAULT_TENSOR_OP, the HGEMM (WMMA/MMA/CuTe) implemented in this repo (blue🔵) can achieve 98%~100% of its (orange🟠) performance. Please check toy-hgemm library⚡️⚡️ for more details.

|📚Feature |📚Feature |📚Feature |📚Feature| |:---:|:---:|:---:|:---:| |✔️CUDA/Tensor Cores|✔️Loop over K|✔️Tile Block(BMxBK)|✔️Tile Threads(T 8x8)| |✔️WMMA(m16n16k16)|✔️MMA(m16n8k16)|✔️Pack LDST(128 bits)|✔️SMEM Padding| |✔️Copy Async|✔️Tile MMAs|✔️Tile Warps|✔️Multi Stages(2~4)|
|✔️Register Double Buffers|✔️Block Swizzle|✔️Warp Swizzle|✔️SMEM Swizzle(CuTe/MMA)| |✔️Collective Store(Shfl)|✔️Layout NN|✔️Layout TN|✔️SGEMM FP32/TF32|

©️Citations🎉🎉

@misc{HGEMM@2024,
  title={HGEMM: Write HGEMM from scratch using Tensor Cores with WMMA, MMA PTX and CuTe API.},
  url={https://github.com/xlite-dev/HGEMM},
  note={Open-source software available at https://github.com/xlite-dev/HGEMM},
  author={xlite-dev etc},
  year={2024}
}

📖 HGEMM CUDA Kernels in Toy-HGEMM Library 🎉🎉

<div id="kernels"></div>
void hgemm_naive_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_sliced_k_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4_pack(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4_pack_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x8_pack_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_cublas_tensor_op_nn(torch::Tensor a, torch::Tensor b, torch::Tensor c); 
void hgemm_cublas_tensor_op_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);                                                        
void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_mma_m16n8k16_mma2x4_warp4x4(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_stages_block_swizzle_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_tn_swizzle_x4(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);

📖 Contents

📖 Prerequisites

<div id="prerequisites"></div>
  • PyTorch >= 2.0, CUDA >= 12.0
  • Recommended: PyTorch 2.5.1, CUDA 12.5

📖 Installation

<div id="install"></div>

The HGEMM implemented in this repo can be install as a python library, namely, toy-hgemm library (optional).

cd kernels/hgemm
git submodule update --init --recursive --force # Fetch `CUTLASS` submodule, needed
python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl # pip uninstall toy-hgemm -y 

📖 Python Testing

<div id="test"></div>

CUTLASS: Fetch CUTLASS submodule. Currently, I use v3.5.1 for HGEMM CuTe kernel.

git submodule update --init --recursive --force

You can test many custom HGEMM kernel via Python script and figure out the difference in their performance.

# You can test Ada or Ampere only, also, Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada # for Ada only
export TORCH_CUDA_ARCH_LIST=Ampere # for Ampere only
python3 hgemm.py --wmma # test defalut wmma kernels for all MNK
python3 hgemm.py --mma  # test defalut mma kernels for all MNK
python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --wmma # test default wmma kernels for specific MNK
python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --mma # test default mma kernels for specific MNK
python3 hgemm.py --wmma-all # test all wmma kernels for all MNK
python3 hgemm.py --mma-all # test all mma kernels for all MNK
python3 hgemm.py --cuda-all --wmma-all --mma-all # test all kernels for all MNK
python3 hgemm.py --cute-tn --no-default # test cute hgemm kernels with smem swizzle for all MNK

If you want to draw a TFLOPS curve, you need to install matplotlib first and set the --plot-flops (or --plot) option.

python3 -m pip install matplotlib
# Specify topk to plot only the top k kernels with the best performance.
python3 hgemm.py --mma-all --plot --topk 8
# test default mma kernels & cute hgemm kernels with smem swizzle for all MNK
python3 hgemm.py --cute-tn --mma --plot 

📖 C++ Testing

<div id="test-cpp"></div>

The HGEMM benchmark also supports C++ testing. Currently, it supports comparisons between the following implementations:

  • MMA HGEMM NN implemented in this repository
  • CuTe HGEMM TN implemented in this repository
  • cuBLAS HGEMM TN use default Tensor Cores math algorithm

Performance data obtained from C++ binary tests tend to be slightly better than those from Python tests. This difference may be attributed to additional overhead introduced by the PyTorch Python bindings.

make
./hgemm_mma_stage.bin
# NVIDIA L20
ALGO = MMA16816 HGEMM NN MMA=2x4 WARP=4x4x2 STAGES=2 BLOCK SWIZZLE=2048
M N K =  12544  12544  12544, Time =   0.03445555   0.03446098   0.03447399 s, AVG Performance =   114.5541 Tflops
M N K =  15360  15360  15360, Time =   0.06307226   0.06307789   0.06308864 s, AVG Performance =   114.9017 Tflops
M N K =  15616  15616  15616, Time =   0.06612480   0.06612798   0.06613094 s, AVG Performance =   115.1739 Tflops
M N K =  15872  15872  15872, Time =   0.06969549   0.06970215   0.06971290 s, AVG Performance =   114.7305 Tflops
M N K =  16128  16128  16128, Time =   0.072950
View on GitHub
GitHub Stars150
CategoryDevelopment
Updated12d ago
Forks9

Languages

Cuda

Security Score

100/100

Audited on Mar 20, 2026

No findings