SkillAgentSearch skills...

Triton

Github mirror of trition-lang/triton repo.

Install / Use

/learn @facebookexperimental/Triton
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

TLX - Triton Low-level Language Extensions

Introduction

TLX (Triton Low-level Language Extensions) is a low-level, warp-aware, hardware-near extension of the Triton DSL. It offers intrinsics and warp-specialized operations for fine-grained GPU control, hardware-oriented primitives for advanced kernel development, and explicit constructs for GPU memory, computation, and asynchronous control flow. TLX is designed for expert users pushing Triton closer to the metal.

Primarily targeting NVIDIA GPUs (for now), TLX extends Triton to support:

  • Hardware-specific intrinsics (e.g., wgmma, async_copy, barrier)
  • Shared and local memory allocation
  • Instruction-level scheduling and control
  • Cross-warpgroup synchronization

While this approach places more responsibility on the user, it reduces the compiler's role as a performance bottleneck. Although it may introduce divergence across hardware platforms, it empowers users to perform deeper, architecture-specific optimizations without relying solely on compiler heuristics.

The DSL Extension

Local buffer operations

  • buffers = tlx.local_alloc(shape, dtype, NUM_BUFFERS)

    Allocate NUM_BUFFERS buffers in local memory per thread block, each of size size. The memory layout is inferred from its consumers.

  • buffers = tlx.local_alloc(shape, dtype, NUM_BUFFERS, tlx.storage_kind.tmem)

    Allocate NUM_BUFFERS of buffers in the tensor memory per thread block, each with size size. The memory layout is inferred from its consumers.

  • buffers = tlx.local_alloc(shape, dtype, NUM_BUFFERS, reuse=other_buffers)

    Alias this allocation to an existing buffered_tensor so multiple logical buffers reuse the same underlying local storage (SMEM or TMEM) without reallocation.

  • buffer = tlx.local_view(buffers, buffer_idx) or buffer = buffers[buffer_idx]

    Return a subview of the buffer indexed by buffer_idx from buffers. Both the explicit local_view() call and the indexing syntax [] are supported.

  • distributed_tensor = tlx.local_load(buffer, optional_token)

    Loads the buffer from local memory or tensor memory into a distributed tensor.

  • tlx.local_store(buffer, distributed_tensor)

    Store a distributed tensor into a buffer in local memory or tensor memory.

  • buffer = tlx.local_trans(buffer, dims)

    Permutes the dimensions of a tensor.

  • buffer = tlx.local_slice(buffer, offsets=[m, n], shapes=[M, N])

    Slice a M x N tensor at a m x n offset.

Buffer Reuse

TLX provides you the ability to reuse the same allocated buffer across multiple disjoint steps in your kernel. This is useful to allow additional pipelining when you may not have enough isolated SMEM or TMEM.

  • tlx.storage_alias_spec(storage=storage_kind)

    Defines a buffer that you will want to share across multiple aliases. The storage can be either SMEM or TMEM. To use this in an allocation you the spec in the reuse argument for local_alloc. Here is the example from the FA kernel.

# Create the storage alias spec for all shared buffers. Cannot be directly
# indexed.
qk_storage_alias = tlx.storage_alias_spec(storage=tlx.storage_kind.tmem)

# Allocate all buffers referencing the same spec
qk_tiles = tlx.local_alloc(
    (BLOCK_M_SPLIT, BLOCK_N), qk_dtype, NUM_MMA_GROUPS,
    tlx.storage_kind.tmem, reuse=qk_storage_alias,
)
p_tiles = tlx.local_alloc(
    (BLOCK_M_SPLIT, BLOCK_N // NUM_MMA_SLICES), tlx.dtype_of(desc_v),
    NUM_MMA_GROUPS * NUM_MMA_SLICES, tlx.storage_kind.tmem,
    reuse=qk_storage_alias,
)
alpha_tiles = tlx.local_alloc(
    (BLOCK_M_SPLIT, 1), tl.float32, NUM_MMA_GROUPS * NUM_BUFFERS_QK,
    tlx.storage_kind.tmem, reuse=qk_storage_alias,
)
l_tiles = tlx.local_alloc(
    (BLOCK_M_SPLIT, 1), tl.float32, NUM_MMA_GROUPS * NUM_BUFFERS_QK,
    tlx.storage_kind.tmem, reuse=qk_storage_alias,
)
m_tiles = tlx.local_alloc(
    (BLOCK_M_SPLIT, 1), tl.float32, NUM_MMA_GROUPS * NUM_BUFFERS_QK,
    tlx.storage_kind.tmem, reuse=qk_storage_alias,
)
  • tlx.reuse_group(*tensors, group_type=REUSE_TYPE, group_size=SUBTILE_SIZE)

    A reuse group expresses how you intend to access the shared buffer. There are two types: Shared or Distinct. A shared buffer wants to occupy the same memory and each index should not be accessed at the same time. A distinct buffer will be accessible at the same index at the same time. The compiler will isolate buffer locations and potentially expand the buffer allocation to enforce this guarantee, which is helpful with buffers of unequal sizes.

    The group_size is used to enable subtiling a buffer. This creates ensures that for every 1 index of a buffer that SUBTILE_SIZE indices of this other buffer/group can be accessed. Reuse groups can be nested to allow expressing more complex relationships. Currently a reuse group is not applied unless you assign it to a buffer with spec.set_buffer_overlap.

    Here is the example implementation for Flash Attention. In this kernel as the comment suggests, QK is shared with P, l, m, and alpha, and P is potentially subtiling.

# Define the buffer overlap strategy:
#   QK : |                                                   BLK_M/2 * BLOCK_N * fp32                         |
#   P:   |  BLK_M/(2*SLICES) * fp16| BLK_M/(2*SLICES) * fp16|...
# Alpha:                                                        |BLK_M/2*1*fp32|
#   l  :                                                                        |BLK_M/2*1*fp32|
#   m  :                                                                                       |BLK_M/2*1*fp32|
qk_storage_alias.set_buffer_overlap(
    tlx.reuse_group(
        qk_tiles,
        tlx.reuse_group(
            tlx.reuse_group(p_tiles, group_size=NUM_MMA_SLICES),
            alpha_tiles, l_tiles, m_tiles,
            group_type=tlx.reuse_group_type.distinct,
        ),
        group_type=tlx.reuse_group_type.shared,
    )
)

Compiler Pipeline Inspection Steps To introspect the pipeline add_stages, before running your kernels, simply set the add_stages_inspection_hook like so:

def inspect_stages(_self, stages, options, language, capability):
    # inspect or modify add_stages here
triton.knobs.runtime.add_stages_inspection_hook = inspect_stages

Binary wheels are available for CPython 3.10-3.14.

Remote buffer operations

  • buffer = tlx.remote_view(buffer, remote_cta_rank)

    Return a remote view of the buffer living in another CTA in the same cluster with ID remote_cta_rank. NOTE: for now we only support barrier as buffer, not general SMEM.

  • tlx.remote_shmem_store(dst, src, remote_cta_rank)

    Store a distributed tensor into a buffer in the remote shared memory of a cluster (synchronous).

    Parameters:

    • dst: The destination buffer in local shared memory (will be internally mapped to the remote CTA)
    • src: The source distributed tensor to store
    • remote_cta_rank: The rank (unique ID) of the remote CTA within the cluster

    Example:

    # Allocate shared memory buffer
    buffer = tlx.local_alloc((BLOCK_M, BLOCK_N), tl.float16, 1)
    
    # Store to remote CTA's shared memory (synchronous)
    tlx.remote_shmem_store(buffer[0], src_tensor, remote_cta_rank=1)
    

Async memory access

  • tlx.async_descriptor_load(desc, buffer, offsets, barrier, pred=None, cache_modifier="", eviction_policy="", multicast_targets=[])

    Load a chunk of data from global memory into a local memory buffer using TMA. The global address, strides, and buffer size are defined by the tensor descriptor. A barrier object is provided and signaled upon completion of the operation.

    Parameters:

    • desc: Tensor descriptor for the source
    • buffer: Destination buffer in shared memory
    • offsets: List of offsets for each dimension
    • barrier: mbarrier to signal upon completion
    • pred: Optional predicate to guard the load
    • cache_modifier: Cache modifier hint (e.g., "", "evict_first")
    • eviction_policy: L2 cache eviction policy ("", "evict_first", "evict_last")
    • multicast_targets: Optional list of multicast targets for cluster-wide loads
  • tlx.async_descriptor_prefetch_tensor(memdesc, [offsets], pred, eviction_policy)

    Hint hardware to load a chunk of data from global memory into a L2 cache to prepare for upcoming async_descriptor_load operations.

  • tlx.async_descriptor_store(desc, source, offsets, eviction_policy="", store_reduce="")

    Store a chunk of data from shared memory into global memory using TMA. The global address, strides, and buffer size are defined by the tensor descriptor.

    Supports optional atomic reduction (store_reduce) and L2 cache eviction hints (eviction_policy). Both regular stores and atomic reduce stores support cache eviction policies.

    Parameters:

    • desc: Tensor descriptor for the destination
    • source: Source buffer in shared memory
    • offsets: List of offsets for each dimension
    • eviction_policy: L2 cache eviction policy ("", "evict_first", "evict_last")
    • store_reduce: Atomic reduction kind ("", "add", "min", "max", "and", "or", "xor")

    Example:

    # Regular TMA store with L2 evict_first hint
    tlx.async_descriptor_store(desc_c, c_buf[0], [offs_m, offs_n], eviction_policy="evict_first")
    
    # TMA atomic reduce-add with L2 evict_first hint
    tlx.async_descriptor_store(desc_c, c_buf[0], [offs_m, offs_n],
                               eviction_policy="evict_first", store_reduce="add")
    
  • tlx.async_remote_shmem_store(dst, src, remote_cta_rank, barrier)

    Store a distributed tensor into a buffer in the remote shared memory of a cluster asynchronously. Signals the provided mbarrier when the store completes.

    Parameters:

    • dst: The destination buffer in local shared memory (will be internally mapped to the remote CTA)
    • src: The source distributed tensor to store
View on GitHub
GitHub Stars153
CategoryDevelopment
Updated1d ago
Forks42

Languages

MLIR

Security Score

95/100

Audited on Apr 6, 2026

No findings