Tokamax
Tokamax: A GPU and TPU kernel library.
Install / Use
/learn @openxla/TokamaxREADME
Tokamax
Tokamax is a library of custom accelerator kernels, supporting both NVIDIA GPUs and Google TPUs. Tokamax provides state-of-the-art custom kernel implementations built on top of JAX and Pallas.
Tokamax also provides tooling for users to build and autotune their own custom accelerator kernels.
Status
Tokamax is still heavily under development. Incomplete features and API changes are to be expected.
We currently support the following GPU kernels:
tokamax.dot_product_attention(FlashAttention).tokamax.gated_linear_unit(Gated linear units (SwiGLU etc)).tokamax.layer_norm(Layer normalization and Root Mean Squared normalization).
And the following for both GPU and TPU:
tokamax.ragged_dot(Mixture of Experts).
And the following TPU kernels:
tokamax.linear_softmax_cross_entropy_loss(Memory Efficient Linear Cross Entropy Loss Kernel)
Installation
The latest Tokamax PyPI release:
pip install -U tokamax
The latest bleeding edge version from Github, with no stability guarantees:
pip install git+https://github.com/openxla/tokamax.git
Using Tokamax
Consider a function containing Tokamax functions running on an H100 GPU:
import jax
import jax.numpy as jnp
import tokamax
def loss(x, scale):
x = tokamax.layer_norm(
x, scale=scale, offset=None, implementation="triton"
)
x = tokamax.dot_product_attention(x, x, x, implementation="xla_chunked")
x = tokamax.layer_norm(x, scale=scale, offset=None, implementation=None)
x = tokamax.dot_product_attention(x, x, x, implementation="mosaic")
return jnp.sum(x)
f_grad = jax.jit(jax.grad(loss))
With implementation=None, Tokamax is allowed to select the best implementation
for each kernel shape. It is even allowed to choose different implementations
for the forward pass and gradient. It will also always be supported, as it can
fall back to an XLA implementation implementation='xla'.
However, you may want to choose a specific implementation of the kernel, and
fail if it is unsupported. For instance, implementation="mosaic" will try to
use a Pallas:Mosaic GPU
kernel if possible, and throw an exception if this is unsupported for any
reason. For example, using FP64 inputs are unsupported, or older GPUs.
Evaluate the Gradient
channels, seq_len, batch_size, num_heads = (64, 2048, 32, 16)
scale = jax.random.normal(jax.random.key(0), (channels,), dtype=jnp.float32)
x = jax.random.normal(
jax.random.key(1),
(batch_size, seq_len, num_heads, channels),
dtype=jnp.bfloat16,
)
out = f_grad(x, scale)
Autotuning
To get the best performance, autotune all Tokamax kernels in f_grad:
autotune_result: tokamax.AutotuningResult = tokamax.autotune(f, x, scale)
autotune_result can be used as a context-manager, using the autotuned configs
for all Tokamax kernels in f_grad:
with autotune_result:
out_autotuned = f_grad(x, scale)
To serialize and reuse the result of a potentially expensive
tokamax.autotuning call:
autotune_result_json: str = autotune_result.dumps()
autotune_result = tokamax.AutotuningResult.loads(autotune_result_json)
Users can autotune their own kernels with tokamax.autotune by inheriting from
the tokamax.Op class and overriding the tokamax.Op._get_autotuning_configs
method to define the autotuning search-space.
Note that autotuning is fundamentally non-deterministic: measuring kernel execution times is noisy. As different configs chosen during autotuning can lead to different numerics, this is a potential source of numerical non-determinism. Serializing and reusing fixed autotuning results is a way to ensure the same numerics across sessions.
Serialization
Kernels can be serialized to StableHLO. Kernel
calls are JAX custom calls, which are by default banned in jax.export,
requiring the use of
tokamax.DISABLE_JAX_EXPORT_CHECKS to allow all Tokamax kernels to be exported:
from jax import export
f_grad_exported = export.export(f_grad, disabled_checks=tokamax.DISABLE_JAX_EXPORT_CHECKS)(
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(scale.shape, scale.dtype),
)
Note that functions serialized with Tokamax kernels lose the device-independence of standard StableHLO. Tokamax makes two serialization guarantees:
- A deserialized function serialized on a specific device will be guaranteed to run on the exact device it was serialized for.
- Tokamax gives the same compatibility guarantees as JAX: 6 month backward compatibility.
Benchmarking
JAX Python overhead is often much larger than the actual accelerator kernel
execution time. This means the usual approach of timing
jax.block_until_ready(f_grad(x, scale)) won't be useful. Tokamax has utilities
for only measuring accelerator execution time:
f_std, args = tokamax.standardize_function(f, kwargs={'x': x, 'scale': scale})
bench: tokamax.BenchmarkData = tokamax.benchmark(f_std, args)
There are different measurement techniques: for example, on GPU, there is the
CUPTI profiler that can be specified via
tokamax.benchmark(f_std, args, method='cupti'). This instruments the
kernel and adds some a small overhead. The default method=None allows Tokamax
to choose the method, and works for both TPU and GPU. Benchmark noise can be
reduced by increasing the number of iterations:
tokamax.benchmark(f_std, args, iterations=10)
Disclaimer
This is not an official Google product.
Related Skills
node-connect
339.5kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
83.9kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
339.5kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
commit-push-pr
83.9kCommit, push, and open a PR
