SkillAgentSearch skills...

Gemlite

Fast low-bit matmul kernels in Triton

Install / Use

/learn @dropbox/Gemlite
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

GemLite

<div align="center" style="margin-bottom: 1em;"> <h2>Triton Kernels for Efficient Low-Bit Matrix Multiplication</h2> <img src="images/gemlite%20banner.png" alt="GemLite Logo" height="150">

[![Twitter][mobius-twitter-badge]][mobius-twitter]

Made with ❤ by the team at Mobius Labs for 'Aana' (ആന : Elephant) suite of multimodal product.

</div>

GemLite is a collection of Triton kernels designed for efficient low-bit matrix multiplication, emphasizing simplicity and reusability. It provides a practical solution for achieving significant performance gains, delivering up to 7-8x faster prefill and 3-6x faster decoding compared to default Torch AO kernels. For more detailed benchmarks, check the Performance section.

GemLite strikes the perfect balance between flexibility and performance, allowing users to easily use and modify the codebase to develop high-performance kernels optimized for their specific hardware. We have included multiple versions of the kernels to maximize performance across different matrix shapes.

The project started with CUDA kernels, but we have switched to <a href="https://github.com/triton-lang/triton/">Triton</a> for enhanced flexibility. For the old CUDA version, please refer to <a href="https://github.com/dropbox/gemlite/tree/stable_cuda_only">this branch.</a>

Result Teaser

| End-to-end Performance (Llama3 8-bit) | Matmul Performance (A16W8) | | --------------------------------------------------- | ---------------------------------------- | | End to End Performance | Matmul Performance |

Extensive performance results across different bitwidths, batch sizes, and devices are available in the Performance section below.

Table of Contents

Recent Highlights

  • GemLite now supports MXFP for Blackwell!
  • GemLite now supports vLLM V1 (torch.compile compatible)!
  • GemLite now supports bfloat16!
  • GemLite is now available in <a href="https://github.com/vllm-project/vllm/">vllm</a> via the <a href="https://github.com/dropbox/hqq/">hqq</a> lib!
  • GemLite is now integrated with <a href="https://github.com/pytorch/ao">TorchAO</a>/<a href="https://github.com/sgl-project/sglang">SGLang</a> for 4-bit quantization. Check-out the <a href="https://pytorch.org/blog/accelerating-llm-inference/">blogpost</a>!
  • Major performance improvement: especially on the A100 and H100.
  • Flexible bitpacking: use 8-bit packing for improved batched performance on the A100 and H100 with packed data.
  • Autotune caching: save/load the best autotune configs across all the kernels with a single line of code.
  • Helper functions: helper functions make it easier to get started, especially useful for dynamic quantization.
  • New GEMV RevSplitK algorithm: outperforms GEMM Split-K and GEMV for batch-size=1 with packed data.
  • Channel-wise scaling: Added support for channel-wise scaling for weights, activations, and both.
  • Precision support: Includes FP16 x Wn, FP8 x FP8, FP8 x Wn, INT8 x INT8, INT8 x Wn, MXFPn x MXFPn.
  • torch.compile() support.

Getting Started

Installation

Latest Stable Version

pip install gemlite

Latest (Recommended)

pip install git+https://github.com/dropbox/gemlite/

Usage

import gemlite
from gemlite import DType, GemLiteLinear

#Reset the default cache to get the best perf but warm-up will be slow. 
#gemlite.reset_cache()

#Set autotune mode: fast:faste start-up (default), max: long startt-up but best perf, default/False: no autotune
#gemlite.set_autotune("fast")

#Enable kernel caching: makes some kernels faster, but might break with some torch.compile settings
#gemlite.set_kernel_caching(True)

#Main constructor
gemlite_linear = GemLiteLinear(
    W_nbits, #weight quantization bitwidth. supported: [8, 4, 2, 1]
    group_size=group_size, # any group_size divisible by 32 - enable autotune for group_size < 128 (!)
    in_features=in_features, # input size
    out_features=out_features, #ouput size
    input_dtype=DType.FP16, #FP16, BF16, FP8, INT8
    output_dtype=DType.FP16, #FP16, BF16, FP32, FP8, INT32
    scaled_activations=False, #If the activations are scaled or not
)

#Packing: we follow the hqq format (W_q - zeros) * scales ~ W (https://github.com/dropbox/hqq/)
gemlite_linear.pack(W_q, scales, zeros, bias)

#Forward
out = gemlite_linear(x)

#Save cache if want to re-use the same autotune config
#gemlite.cache_config('gemlite_config.json')

Helper Functions

Additionally, we offer helper functions that operate as follows:

from gemlite.helper import *
device, dtype = 'cuda:0', torch.float16

#AxWy: x: activation precision in bits, y: weight precision in bits.

#Weight-only
gemlite_linear = A16W8_INT8(device=device, dtype=dtype).from_linear(layer)
gemlite_linear = A16W8_FP8(device=device, dtype=dtype).from_linear(layer)
gemlite_linear = A16W8_HQQ_INT(device=device, dtype=dtype).from_hqqlinear(hqq_layer)
gemlite_linear = A16W4_HQQ_INT(device=device, dtype=dtype).from_hqqlinear(hqq_layer)
gemlite_linear = A16W2_HQQ_INT(device=device, dtype=dtype).from_hqqlinear(hqq_layer)
gemlite_linear = A16W158_INT(device=device, dtype=dtype).from_bitlinear(bitlinear_layer)

#8-bit activation dynamic quant
gemlite_linear = A8W8_INT8_dynamic(device=device, dtype=dtype).from_linear(layer)
gemlite_linear = A8W8_FP8_dynamic(device=device, dtype=dtype).from_linear(layer)
gemlite_linear = A8W4_HQQ_INT_dynamic(device=device, dtype=dtype).from_hqqlinear(hqq_layer)
gemlite_linear = A8W158_INT_dynamic(device=device, dtype=dtype).from_bitlinear(bitlinear_layer)

#MXFP weight-only
gemlite_linear = A16W8_MXFP(device=device, dtype=dtype).from_linear(layer)
gemlite_linear = A16W4_MXFP(device=device, dtype=dtype).from_linear(layer)

#MXFP/NVFP dynamic quant - if post_scale=True, uses channel-wise activation quant.
#Support depends on triton's ability to support native mxfp/nvfp mma.
gemlite_linear = A8W8_MXFP_dynamic(device=device, dtype=dtype, post_scale=False).from_linear(layer)
gemlite_linear = A8W8_MXFP_dynamic(device=device, dtype=dtype, post_scale=True).from_linear(layer)
gemlite_linear = A8W4_MXFP_dynamic(device=device, dtype=dtype, post_scale=False).from_linear(layer)
gemlite_linear = A8W4_MXFP_dynamic(device=device, dtype=dtype, post_scale=True).from_linear(layer)
gemlite_linear = A4W4_MXFP_dynamic(device=device, dtype=dtype).from_linear(layer)
gemlite_linear = A4W4_NVFP_dynamic(device=device, dtype=dtype).from_linear(layer)

You can also patch the whole model (even from cpu) as follows:

from gemlite.helper import *
patch_model(model, device=device, processor=A8W8_INT8_dynamic())

Config Caching

Triton autotuning can be time-consuming. To accelerate this process, we provide tools to automatically cache and load the optimal autotuning configurations for all kernels:

import gemlite
gemlite.reset_config() #resets cache config for all kernels
gemlite.cache_config('gemlite_config.json') #Cache
gemlite.load_config('gemlite_config.json') #Load

Ensure that you have one JSON cache file per GPU model. When the cache is loaded, the kernels will skip autotuning, leading to a faster startup time.

You can warm-up with specific shapes via the following helper function:

import gemlite

#Ignore pre-loaded configs - if you want to start from scratch (Optional)
#gemlite.reset_config() 

#Set autotune mode: fast or max
#gemlite.set_autotune("max")

#Autotune with the default batch-sizes
warmup(A8W8_INT8_dynamic(), shapes=[(4096, 4096), (2048, 4096)])

#You can specify the batch-sizes too
warmup(A8W8_INT8_dynamic(), shapes=[(4096, 4096), (2048, 4096)], batch_sizes=[1, 8, 64, 128])

#If you want to specify the group-size for HQQ-style quantization
warmup(A16W4_HQQ_INT(), shapes=[(4096, 4096), (2048, 4096)], group_size=64)

#Cache your new config
gemlite.cache_config('new_config.json')

VLLM

You can use GemLite with vLLM via <a href="https://github.com/pytorch/ao/">torchao</a> or <a href="https://github.com/dropbox/hqq/">hqq</a> as follows:

from hqq.utils.vllm import set_vllm_onthefly_hqq_quant
skip_modules = ['lm_head', 'visual', 'vision']

#Select one of the following modes:

#INT/FP format
set_vllm_onthefly_hqq_quant(weight_bits=8, group_size=None, quant_mode='int8_weightonly', skip_modules=skip_modules) #A16W8 - INT8 weight only
set_vllm_onthefly_hqq_quant(weight_bits=4, group_size=128, quant_mode='int4_weightonly', skip_modules=skip_modules) #A16W4 - HQQ weight only
set_vllm_onthefly_hqq_quant(weight_bits=8, quant_mode='int8_dynamic', skip_modules=skip_modules) #A8W8 - INT8 x INT8 dynamic
set_vllm_onthefly_hqq_quant(weight_bits=8, quant_mode='fp8_dynamic', skip_modules=skip_modules) #A8W8 - FP8 x FP8 dynamic

#MXFP format
set_vllm_onthefly_hqq_quant(weight_bits=8, group_size=None, quant_mode='mxfp8_dynamic', skip_modules=skip_modules) #A8W8 - MXFP8 x MXPF8 - post_scale=True
set_vllm_onthefly_hqq_quant(weight_bits=8, group_size=32, quant_mode='mxfp8_dynamic', skip_modules=skip_modules) #A8W8 - MXFP8 x MXPF8- post_scale=False
set_vllm_onthefly_hqq_quant(weight_bits=4, quant_mode='mxfp4_weightonly', skip_modules=skip_modules) #A16W4 - MXFP4 weight-only
set_vllm_onthefly_hqq_quant(weight_bits=4, quant_mode='mxfp8_dynamic', skip_modules=skip_modules) #A8W4 - MXFP8 x MXFP4 dynamic
set_vllm_onthefly_hqq_quant(weight_bits=4, quant_mode='mxfp4_dynamic', skip_modules=skip_modules) #A4W4 - MXPF4 x MXPF4 dynamic
set_vllm_onthefly_hqq_quant(weight_bits=4, quant_mode='nvfp4_dyna

Related Skills

View on GitHub
GitHub Stars438
CategoryDevelopment
Updated1d ago
Forks33

Languages

Python

Security Score

95/100

Audited on Mar 27, 2026

No findings