Triton
Development repository for the Triton language and compiler
Install / Use
/learn @triton-lang/TritonREADME
| Documentation | Nightly Wheels |
|-------------------- | -------------------- |
| |
|
Triton Conference 2025
The 3rd Triton Developer Conference took place on October 21, 2025 at the Microsoft Silicon Valley Campus in Mountain View, California.
Conference Materials
Conference recordings and materials are now available online:
- Conference Videos: YouTube Playlist
- Conference Slides: Google Drive Folder
For previous conference materials, see:
Triton
This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.
The foundations of this project are described in the following MAPL2019 publication: Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations. Please consider citing this work if you use Triton!
The official documentation contains installation instructions and tutorials. See also these third-party Triton puzzles, which can all be run using the Triton interpreter -- no GPU required.
Quick Installation
You can install the latest stable release of Triton from pip:
pip install triton
Binary wheels are available for CPython 3.10-3.14.
Install from source
git clone https://github.com/triton-lang/triton.git
cd triton
pip install -r python/requirements.txt # build-time dependencies
pip install -e .
Or with a virtualenv:
git clone https://github.com/triton-lang/triton.git
cd triton
python -m venv .venv --prompt triton
source .venv/bin/activate
pip install -r python/requirements.txt # build-time dependencies
pip install -e .
Building with a custom LLVM
Triton uses LLVM to generate code for GPUs and CPUs. Normally, the Triton build downloads a prebuilt LLVM, but you can also build and use LLVM from source.
LLVM does not have a stable API, so the Triton build will not work at an arbitrary LLVM version.
For convenience, use the following command to build LLVM and install Triton with the custom LLVM:
make dev-install-llvm
<details>
<summary>
Alternatively, follow these steps to build LLVM from source manually.
</summary>
-
Find the version of LLVM that Triton builds against. Check
cmake/llvm-hash.txtto see the current version. For example, if it says: 49af6502c6dcb4a7f7520178bd14df396f78240c.This means that the version of Triton you have builds against LLVM 49af6502.
-
git checkoutLLVM at this revision. Optionally, make additional modifications to LLVM. -
Build LLVM. For example, you might run:
$ cd $HOME/llvm-project # your clone of LLVM. $ mkdir build $ cd build $ cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm;lld;clang" -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU" $ ninja -
Grab a snack, this will take a while.
-
Build Triton as above, but set the following environment variables:
# Modify as appropriate to point to your LLVM build. $ export LLVM_BUILD_DIR=$HOME/llvm-project/build $ cd <triton install> $ LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \ LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib \ LLVM_SYSPATH=$LLVM_BUILD_DIR \ pip install -e .
Tips for building
-
Set
TRITON_BUILD_WITH_CLANG_LLD=trueas an environment variable to use clang and lld. lld in particular results in faster builds. -
Set
TRITON_BUILD_WITH_CCACHE=trueto build with ccache. -
Set
TRITON_HOME=/some/pathto change the location of the.tritondirectory where Triton's cache is located and downloads are stored during the build. By default, this is the user's home directory. It can be changed anytime. -
If you're running out of memory when building Triton, specify the
MAX_JOBSenvironment variable (to thepip install -e .command) to limit the number of jobs. -
Pass
--no-build-isolationtopip installto make nop builds faster. Without this, every invocation ofpip installuses a different symlink to cmake, and this forces ninja to rebuild most of the.afiles. -
The build system creates a
compile_commands.jsonfile under the Triton repo directory. This file is used by VSCode IntelliSense and clangd to provide code completion and other features for C++ code.If IntelliSense does not work, you can try the following steps:
- Do a local build. Run command
pip install -e .. - Get the full path to the
compile_commands.jsonfile produced by the build:find ./build -name 'compile_commands.json' | xargs readlink -f. You might get a full path similar to/Users/{username}/triton/build/cmake.macosx-11.1-arm64-cpython-3.12/compile_commands.json. - In VSCode, install the
C/C++
extension,
then open the command palette (
Shift + Command + Pon Mac, orShift + Ctrl + Pon Windows/Linux) and openC/C++: Edit Configurations (UI). - Open "Advanced Settings" and paste the full path to
compile_commands.jsoninto the "Compile Commands" textbox.
- Do a local build. Run command
Running tests
There currently isn't a turnkey way to run all the Triton tests, but you can follow the following recipe:
# One-time setup. Note this will reinstall local Triton because torch
# overwrites it with the public version.
$ make dev-install
# To run all tests (requires a GPU)
$ make test
# Or, to run tests without a gpu
$ make test-nogpu
Tips for hacking
For detailed instructions on how to debug Triton's frontend, please refer to this tutorial. The following includes additional tips for hacking on Triton's backend.
Configuration knobs
See python/triton/knobs.py for the full list of configuration knobs. You can set those knobs directly in python or use environment variables to control them. Below are some of the environment variables you can specify (see knobs.py for the full list):
-
MLIR_ENABLE_DUMP=1dumps the IR before every MLIR pass Triton runs, for all kernels. UseMLIR_ENABLE_DUMP=kernelNameto dump for a specific kernel only.- Triton cache can interfere with the dump. In cases where
MLIR_ENABLE_DUMP=1does not work, try cleaning your triton cache:rm -r ~/.triton/cache/*.
- Triton cache can interfere with the dump. In cases where
-
MLIR_DUMP_PATHspecifies whereMLIR_ENABLE_DUMPwill dump to. If unset will dump to stderr. -
LLVM_IR_ENABLE_DUMP=1dumps the IR before every pass run over the LLVM IR. -
TRITON_REPRODUCER_PATH=<reproducer_path>will generate an MLIR reproducer file at<reproducer_path>before each MLIR compiler stage. If any of the stages fail,<reproducer_path>will be a local MLIR reproducer captured right before the failing pass. -
TRITON_INTERPRET=1uses the Triton interpreter instead of running on the GPU. You can insert Python breakpoints in your kernel code! -
TRITON_ENABLE_LLVM_DEBUG=1passes-debugto LLVM, printing a lot of debugging information to stdout. If this is too noisy, run with justTRITON_LLVM_DEBUG_ONLYinstead to limit the output.- An alternative way to reduce output noisiness is running with
LLVM_IR_ENABLE_DUMP=1, extract the IR before the LLVM pass of interest, and then run LLVM'soptstandalone, perhaps passing-debug-only=fooon the command line.
- An alternative way to reduce output noisiness is running with
-
TRITON_LLVM_DEBUG_ONLY=<comma-separated>is the equivalent of LLVM's-debug-onlycommand-line option. This limits the LLVM debug output to specific pass or component names (which are specified using#define DEBUG_TYPEthroughout LLVM and Triton) in order to allow the debug output to be less noisy.TRITON_LLVM_DEBUG_ONLYallows for one or more comma separated values to be specified (egTRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions"orTRITON_LLVM_DEBUG_ONLY="tritongpu-remove-layout-conversions,regalloc"). -
TRITON_ENABLE_ASAN=1invokes the LLVM address sanitizer for memory leak and out of bounds access detection. Currently only supported on the AMD backend. This must be run using the ASAN libraries documented here.- When enabling the address sanitizer it is recommended to disable various memory caching strategies both within the ROCm stack and PyTorch. This will give the address sanitizer the best chance at finding the memory fault where it originates. See this test for more details.
-
USE_IR_LOC={ttir,ttgir}reparses the IR such that the location information will be the line number of the IR file with that particular extension, instead of line number of the python file. This can provide a di
