Triton
Github mirror of trition-lang/triton repo.
Install / Use
/learn @facebookexperimental/TritonREADME
TLX - Triton Low-level Language Extensions
Introduction
TLX (Triton Low-level Language Extensions) is a low-level, warp-aware, hardware-near extension of the Triton DSL. It offers intrinsics and warp-specialized operations for fine-grained GPU control, hardware-oriented primitives for advanced kernel development, and explicit constructs for GPU memory, computation, and asynchronous control flow. TLX is designed for expert users pushing Triton closer to the metal.
Primarily targeting NVIDIA GPUs (for now), TLX extends Triton to support:
- Hardware-specific intrinsics (e.g., wgmma, async_copy, barrier)
- Shared and local memory allocation
- Instruction-level scheduling and control
- Cross-warpgroup synchronization
While this approach places more responsibility on the user, it reduces the compiler's role as a performance bottleneck. Although it may introduce divergence across hardware platforms, it empowers users to perform deeper, architecture-specific optimizations without relying solely on compiler heuristics.
The DSL Extension
Local buffer operations
-
buffers = tlx.local_alloc(shape, dtype, NUM_BUFFERS)Allocate
NUM_BUFFERSbuffers in local memory per thread block, each of size size. The memory layout is inferred from its consumers. -
buffers = tlx.local_alloc(shape, dtype, NUM_BUFFERS, tlx.storage_kind.tmem)Allocate
NUM_BUFFERSof buffers in the tensor memory per thread block, each with size size. The memory layout is inferred from its consumers. -
buffers = tlx.local_alloc(shape, dtype, NUM_BUFFERS, reuse=other_buffers)Alias this allocation to an existing
buffered_tensorso multiple logical buffers reuse the same underlying local storage (SMEM or TMEM) without reallocation. -
buffer = tlx.local_view(buffers, buffer_idx)orbuffer = buffers[buffer_idx]Return a subview of the buffer indexed by
buffer_idxfrombuffers. Both the explicitlocal_view()call and the indexing syntax[]are supported. -
distributed_tensor = tlx.local_load(buffer, optional_token)Loads the buffer from local memory or tensor memory into a distributed tensor.
-
tlx.local_store(buffer, distributed_tensor)Store a distributed tensor into a buffer in local memory or tensor memory.
-
buffer = tlx.local_trans(buffer, dims)Permutes the dimensions of a tensor.
-
buffer = tlx.local_slice(buffer, offsets=[m, n], shapes=[M, N])Slice a
M x Ntensor at am x noffset.
Buffer Reuse
TLX provides you the ability to reuse the same allocated buffer across multiple disjoint steps in your kernel. This is useful to allow additional pipelining when you may not have enough isolated SMEM or TMEM.
-
tlx.storage_alias_spec(storage=storage_kind)Defines a buffer that you will want to share across multiple aliases. The storage can be either SMEM or TMEM. To use this in an allocation you the spec in the
reuseargument forlocal_alloc. Here is the example from the FA kernel.
# Create the storage alias spec for all shared buffers. Cannot be directly
# indexed.
qk_storage_alias = tlx.storage_alias_spec(storage=tlx.storage_kind.tmem)
# Allocate all buffers referencing the same spec
qk_tiles = tlx.local_alloc(
(BLOCK_M_SPLIT, BLOCK_N), qk_dtype, NUM_MMA_GROUPS,
tlx.storage_kind.tmem, reuse=qk_storage_alias,
)
p_tiles = tlx.local_alloc(
(BLOCK_M_SPLIT, BLOCK_N // NUM_MMA_SLICES), tlx.dtype_of(desc_v),
NUM_MMA_GROUPS * NUM_MMA_SLICES, tlx.storage_kind.tmem,
reuse=qk_storage_alias,
)
alpha_tiles = tlx.local_alloc(
(BLOCK_M_SPLIT, 1), tl.float32, NUM_MMA_GROUPS * NUM_BUFFERS_QK,
tlx.storage_kind.tmem, reuse=qk_storage_alias,
)
l_tiles = tlx.local_alloc(
(BLOCK_M_SPLIT, 1), tl.float32, NUM_MMA_GROUPS * NUM_BUFFERS_QK,
tlx.storage_kind.tmem, reuse=qk_storage_alias,
)
m_tiles = tlx.local_alloc(
(BLOCK_M_SPLIT, 1), tl.float32, NUM_MMA_GROUPS * NUM_BUFFERS_QK,
tlx.storage_kind.tmem, reuse=qk_storage_alias,
)
-
tlx.reuse_group(*tensors, group_type=REUSE_TYPE, group_size=SUBTILE_SIZE)A reuse group expresses how you intend to access the shared buffer. There are two types: Shared or Distinct. A shared buffer wants to occupy the same memory and each index should not be accessed at the same time. A distinct buffer will be accessible at the same index at the same time. The compiler will isolate buffer locations and potentially expand the buffer allocation to enforce this guarantee, which is helpful with buffers of unequal sizes.
The group_size is used to enable subtiling a buffer. This creates ensures that for every 1 index of a buffer that SUBTILE_SIZE indices of this other buffer/group can be accessed. Reuse groups can be nested to allow expressing more complex relationships. Currently a reuse group is not applied unless you assign it to a buffer with
spec.set_buffer_overlap.Here is the example implementation for Flash Attention. In this kernel as the comment suggests, QK is shared with P, l, m, and alpha, and P is potentially subtiling.
# Define the buffer overlap strategy:
# QK : | BLK_M/2 * BLOCK_N * fp32 |
# P: | BLK_M/(2*SLICES) * fp16| BLK_M/(2*SLICES) * fp16|...
# Alpha: |BLK_M/2*1*fp32|
# l : |BLK_M/2*1*fp32|
# m : |BLK_M/2*1*fp32|
qk_storage_alias.set_buffer_overlap(
tlx.reuse_group(
qk_tiles,
tlx.reuse_group(
tlx.reuse_group(p_tiles, group_size=NUM_MMA_SLICES),
alpha_tiles, l_tiles, m_tiles,
group_type=tlx.reuse_group_type.distinct,
),
group_type=tlx.reuse_group_type.shared,
)
)
Compiler Pipeline Inspection Steps
To introspect the pipeline add_stages, before running your kernels, simply set
the add_stages_inspection_hook like so:
def inspect_stages(_self, stages, options, language, capability):
# inspect or modify add_stages here
triton.knobs.runtime.add_stages_inspection_hook = inspect_stages
Binary wheels are available for CPython 3.10-3.14.
Remote buffer operations
-
buffer = tlx.remote_view(buffer, remote_cta_rank)Return a remote view of the
bufferliving in another CTA in the same cluster with IDremote_cta_rank. NOTE: for now we only support barrier asbuffer, not general SMEM. -
tlx.remote_shmem_store(dst, src, remote_cta_rank)Store a distributed tensor into a buffer in the remote shared memory of a cluster (synchronous).
Parameters:
dst: The destination buffer in local shared memory (will be internally mapped to the remote CTA)src: The source distributed tensor to storeremote_cta_rank: The rank (unique ID) of the remote CTA within the cluster
Example:
# Allocate shared memory buffer buffer = tlx.local_alloc((BLOCK_M, BLOCK_N), tl.float16, 1) # Store to remote CTA's shared memory (synchronous) tlx.remote_shmem_store(buffer[0], src_tensor, remote_cta_rank=1)
Async memory access
-
tlx.async_descriptor_load(desc, buffer, offsets, barrier, pred=None, cache_modifier="", eviction_policy="", multicast_targets=[])Load a chunk of data from global memory into a local memory buffer using TMA. The global address, strides, and buffer size are defined by the tensor descriptor. A barrier object is provided and signaled upon completion of the operation.
Parameters:
desc: Tensor descriptor for the sourcebuffer: Destination buffer in shared memoryoffsets: List of offsets for each dimensionbarrier: mbarrier to signal upon completionpred: Optional predicate to guard the loadcache_modifier: Cache modifier hint (e.g.,"","evict_first")eviction_policy: L2 cache eviction policy ("","evict_first","evict_last")multicast_targets: Optional list of multicast targets for cluster-wide loads
-
tlx.async_descriptor_prefetch_tensor(memdesc, [offsets], pred, eviction_policy)Hint hardware to load a chunk of data from global memory into a L2 cache to prepare for upcoming
async_descriptor_loadoperations. -
tlx.async_descriptor_store(desc, source, offsets, eviction_policy="", store_reduce="")Store a chunk of data from shared memory into global memory using TMA. The global address, strides, and buffer size are defined by the tensor descriptor.
Supports optional atomic reduction (
store_reduce) and L2 cache eviction hints (eviction_policy). Both regular stores and atomic reduce stores support cache eviction policies.Parameters:
desc: Tensor descriptor for the destinationsource: Source buffer in shared memoryoffsets: List of offsets for each dimensioneviction_policy: L2 cache eviction policy ("","evict_first","evict_last")store_reduce: Atomic reduction kind ("","add","min","max","and","or","xor")
Example:
# Regular TMA store with L2 evict_first hint tlx.async_descriptor_store(desc_c, c_buf[0], [offs_m, offs_n], eviction_policy="evict_first") # TMA atomic reduce-add with L2 evict_first hint tlx.async_descriptor_store(desc_c, c_buf[0], [offs_m, offs_n], eviction_policy="evict_first", store_reduce="add") -
tlx.async_remote_shmem_store(dst, src, remote_cta_rank, barrier)Store a distributed tensor into a buffer in the remote shared memory of a cluster asynchronously. Signals the provided mbarrier when the store completes.
Parameters:
dst: The destination buffer in local shared memory (will be internally mapped to the remote CTA)src: The source distributed tensor to store
