HGEMM
⚡️Write HGEMM from scratch using Tensor Cores with WMMA, MMA and CuTe API, Achieve Peak⚡️ Performance.
Install / Use
/learn @xlite-dev/HGEMMREADME
⚡️⚡️Toy-HGEMM: Achieve the 98%~100% TFLOPS of cuBLAS 🎉🎉
📖Toy-HGEMM Library⚡️⚡️ is a library that write many HGEMM kernels from scratch using Tensor Cores with WMMA, MMA PTX and CuTe API, thus, can achieve 98%~100% performance of cuBLAS. The codes here are source from 📖LeetCUDA and exported as a standalone library, please checkout LeetCUDA for latest updates. Welcome to 🌟👆🏻star this repo to support me, many thanks ~ 🎉🎉
Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's default Tensor Cores math algorithm CUBLAS_GEMM_DEFAULT_TENSOR_OP, the HGEMM (WMMA/MMA/CuTe) implemented in this repo (blue🔵) can achieve 98%~100% of its (orange🟠) performance. Please check toy-hgemm library⚡️⚡️ for more details.
|📚Feature |📚Feature |📚Feature |📚Feature|
|:---:|:---:|:---:|:---:|
|✔️CUDA/Tensor Cores|✔️Loop over K|✔️Tile Block(BMxBK)|✔️Tile Threads(T 8x8)|
|✔️WMMA(m16n16k16)|✔️MMA(m16n8k16)|✔️Pack LDST(128 bits)|✔️SMEM Padding|
|✔️Copy Async|✔️Tile MMAs|✔️Tile Warps|✔️Multi Stages(2~4)|
|✔️Register Double Buffers|✔️Block Swizzle|✔️Warp Swizzle|✔️SMEM Swizzle(CuTe/MMA)|
|✔️Collective Store(Shfl)|✔️Layout NN|✔️Layout TN|✔️SGEMM FP32/TF32|
©️Citations🎉🎉
@misc{HGEMM@2024,
title={HGEMM: Write HGEMM from scratch using Tensor Cores with WMMA, MMA PTX and CuTe API.},
url={https://github.com/xlite-dev/HGEMM},
note={Open-source software available at https://github.com/xlite-dev/HGEMM},
author={xlite-dev etc},
year={2024}
}
📖 HGEMM CUDA Kernels in Toy-HGEMM Library 🎉🎉
<div id="kernels"></div>void hgemm_naive_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_sliced_k_f16(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4_pack(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x4_pack_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x8_pack_bcf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k_f16x8_pack_bcf_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k16_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_8x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_t_16x8_sliced_k32_f16x8_pack_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_cublas_tensor_op_nn(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_cublas_tensor_op_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m32n8k16_mma2x4_warp2x4_dbuf_async(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_wmma_m16n16k16_mma4x2_warp2x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_wmma_m16n16k16_mma4x2_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_wmma_m16n16k16_mma4x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_naive(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_mma_m16n8k16_mma2x4_warp4x4(torch::Tensor a, torch::Tensor b, torch::Tensor c);
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_stages_block_swizzle_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_tn_swizzle_x4(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
📖 Contents
- 📖 Prerequisites
- 📖 Installation
- 📖 Python Testing
- 📖 C++ Testing
- 📖 NVIDIA L20 bench
- 📖 NVIDIA RTX 4090 bench
- 📖 NVIDIA RTX 3080 Laptop bench
- 📖 Docs
- 📖 References
📖 Prerequisites
<div id="prerequisites"></div>- PyTorch >= 2.0, CUDA >= 12.0
- Recommended: PyTorch 2.5.1, CUDA 12.5
📖 Installation
<div id="install"></div>The HGEMM implemented in this repo can be install as a python library, namely, toy-hgemm library (optional).
cd kernels/hgemm
git submodule update --init --recursive --force # Fetch `CUTLASS` submodule, needed
python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl # pip uninstall toy-hgemm -y
📖 Python Testing
<div id="test"></div>CUTLASS: Fetch CUTLASS submodule. Currently, I use v3.5.1 for HGEMM CuTe kernel.
git submodule update --init --recursive --force
You can test many custom HGEMM kernel via Python script and figure out the difference in their performance.
# You can test Ada or Ampere only, also, Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada # for Ada only
export TORCH_CUDA_ARCH_LIST=Ampere # for Ampere only
python3 hgemm.py --wmma # test defalut wmma kernels for all MNK
python3 hgemm.py --mma # test defalut mma kernels for all MNK
python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --wmma # test default wmma kernels for specific MNK
python3 hgemm.py --M 16384 --N 16384 --K 8192 --i 10 --mma # test default mma kernels for specific MNK
python3 hgemm.py --wmma-all # test all wmma kernels for all MNK
python3 hgemm.py --mma-all # test all mma kernels for all MNK
python3 hgemm.py --cuda-all --wmma-all --mma-all # test all kernels for all MNK
python3 hgemm.py --cute-tn --no-default # test cute hgemm kernels with smem swizzle for all MNK
If you want to draw a TFLOPS curve, you need to install matplotlib first and set the --plot-flops (or --plot) option.
python3 -m pip install matplotlib
# Specify topk to plot only the top k kernels with the best performance.
python3 hgemm.py --mma-all --plot --topk 8
# test default mma kernels & cute hgemm kernels with smem swizzle for all MNK
python3 hgemm.py --cute-tn --mma --plot
📖 C++ Testing
<div id="test-cpp"></div>The HGEMM benchmark also supports C++ testing. Currently, it supports comparisons between the following implementations:
- MMA HGEMM NN implemented in this repository
- CuTe HGEMM TN implemented in this repository
- cuBLAS HGEMM TN use default Tensor Cores math algorithm
Performance data obtained from C++ binary tests tend to be slightly better than those from Python tests. This difference may be attributed to additional overhead introduced by the PyTorch Python bindings.
make
./hgemm_mma_stage.bin
# NVIDIA L20
ALGO = MMA16816 HGEMM NN MMA=2x4 WARP=4x4x2 STAGES=2 BLOCK SWIZZLE=2048
M N K = 12544 12544 12544, Time = 0.03445555 0.03446098 0.03447399 s, AVG Performance = 114.5541 Tflops
M N K = 15360 15360 15360, Time = 0.06307226 0.06307789 0.06308864 s, AVG Performance = 114.9017 Tflops
M N K = 15616 15616 15616, Time = 0.06612480 0.06612798 0.06613094 s, AVG Performance = 115.1739 Tflops
M N K = 15872 15872 15872, Time = 0.06969549 0.06970215 0.06971290 s, AVG Performance = 114.7305 Tflops
M N K = 16128 16128 16128, Time = 0.072950
