Cubecl
Multi-platform high-performance compute language extension for Rust.
Install / Use
/learn @tracel-ai/CubeclREADME
Multi-platform high-performance compute language extension for Rust. <br/>
</div>TL;DR
With CubeCL, you can program your GPU using Rust, taking advantage of zero-cost abstractions to develop maintainable, flexible, and efficient compute kernels. CubeCL also comes with optimized runtimes managing memory management and lazy execution for any platform.
Supported Platforms
| Platform | Runtime | Compiler | Hardware | | -------- | ------- | ----------- | ----------------------------- | | WebGPU | wgpu | WGSL | Most GPUs | | CUDA | CUDA | C++ (CUDA) | NVIDIA GPUs | | ROCm | HIP | C++ (HIP) | AMD GPUs | | Metal | wgpu | C++ (Metal) | Apple GPUs | | Vulkan | wgpu | SPIR-V | Most GPUs on Linux & Windows | | CPU | cpu | Rust | All Cpus, SIMD with most CPUs |
Not all platforms support the same features. For instance Tensor Cores acceleration isn't supported on WebGPU yet. Using an instruction that isn't available on a platform will result with a compilation error at runtime. The launch function is normally responsible to dispatch the right kernel based on device properties.
Example
Simply annotate functions with the cube attribute to indicate that they should run on the GPU.
use cubecl::prelude::*;
#[cube(launch_unchecked)]
/// A [Vector] represents a contiguous series of elements where SIMD operations may be available.
/// The runtime will automatically use SIMD instructions when possible for improved performance.
fn gelu_array<F: Float, N: Size>(input: &Array<Vector<F, N>>, output: &mut Array<Vector<F, N>>) {
if ABSOLUTE_POS < input.len() {
output[ABSOLUTE_POS] = gelu_scalar(input[ABSOLUTE_POS]);
}
}
#[cube]
fn gelu_scalar<F: Float, N: Size>(x: Vector<F, N>) -> Vector<F, N> {
// Execute the sqrt function at comptime.
let sqrt2 = F::new(comptime!(2.0f32.sqrt()));
let tmp = x / Vector::new(sqrt2);
x * (Vector::erf(tmp) + 1.0) / 2.0
}
You can then launch the kernel using the autogenerated gelu_array::launch_unchecked function.
pub fn launch<R: Runtime>(device: &R::Device) {
let client = R::client(device);
let input = &[-1., 0., 1., 5.];
let vectorization = 4;
let output_handle = client.empty(input.len() * core::mem::size_of::<f32>());
let input_handle = client.create(f32::as_bytes(input));
unsafe {
gelu_array::launch_unchecked::<f32, R>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_1d(input.len() as u32 / vectorization),
vectorization,
ArrayArg::from_raw_parts(&input_handle, input.len()),
ArrayArg::from_raw_parts(&output_handle, input.len()),
)
};
let bytes = client.read_one(output_handle);
let output = f32::from_bytes(&bytes);
// Should be [-0.1587, 0.0000, 0.8413, 5.0000]
println!("Executed gelu with runtime {:?} => {output:?}", R::name(&client));
}
To see it in action, run the working GELU example with the following command:
cargo run --example gelu --features cpu # cpu/simd runtime
cargo run --example gelu --features cuda # cuda runtime
cargo run --example gelu --features wgpu # wgpu runtime
Motivation
The goal of CubeCL is to ease the pain of writing highly optimized compute kernels that are portable across hardware. There is currently no adequate solution when you want optimal performance while still being multi-platform. You either have to write custom kernels for different hardware, often with different languages such as CUDA, Metal, or ROCm. To fix this, we created a Just-in-Time compiler with three core features: automatic vectorization, comptime, and autotune!
These features are extremely useful for anyone writing high-performance kernels, even when portability is not a concern. They improve code composability, reusability, testability, and maintainability, all while staying optimal. CubeCL also ships with a memory management strategy optimized for throughput with heavy buffer reuse to avoid allocations.
Our goal extends beyond providing an optimized compute language; we aim to develop an ecosystem of high-performance and scientific computing in Rust. To achieve this, we're developing linear algebra components that you can integrate into your own kernels. We currently have an highly optimized matrix multiplication module, leveraging Tensor Cores on NVIDIA hardware where available, while gracefully falling back to basic instructions on other platforms. While there's room for improvement, particularly in using custom instructions from newer NVIDIA GPUs, our implementation already delivers impressive performance.
We are a small team also building Burn, so don't hesitate to contribute and port algorithms; it can help more than you would imagine!
How it works
CubeCL leverages Rust's proc macro system in a unique two-step process:
- Parsing: The proc macro parses the GPU kernel code using the syn crate.
- Expansion: Instead of immediately generating an Intermediate Representation (IR), the macro generates a new Rust function.
The generated function, semantically similar to the original, is responsible for creating the IR when called. This approach differs from traditional compilers, which typically generate IR directly after parsing. Our method enables several key features:
- Comptime: By not transforming the original code, it becomes remarkably easy to integrate compile-time optimizations.
- Automatic Vectorization: By simply vectorizing the inputs of a CubeCL function, we can determine the vectorization factor of each intermediate variable during the expansion.
- Rust Integration: The generated code remains valid Rust code, allowing it to be bundled without any dependency on the specific runtime.
Design
CubeCL is designed around - you guessed it - Cubes! More specifically, it's based on cuboids, because not all axes are the same size. Since all compute APIs need to map to the hardware, which are tiles that can be accessed using a 3D representation, our topology can easily be mapped to concepts from other APIs.
<div align="center">CubeCL - Topology
<img src="./assets/cubecl.drawio.svg" width="100%"/> <br /> </div> <br />A cube is composed of units, so a 3x3x3 cube has 27 units that can be accessed by their positions along the x, y, and z axes. Similarly, a hyper-cube is composed of cubes, just as a cube is composed of units. Each cube in the hyper-cube can be accessed by its position relative to the hyper-cube along the x, y, and z axes. Hence, a hyper-cube of 3x3x3 will have 27 cubes. In this example, the total number of working units would be 27 x 27 = 729.
<details> <summary>Topology Equivalence 👇</summary> <br />Since all topology variables are constant within the kernel entry point, we chose to use the Rust constant syntax with capital letters. Often when creating kernels, we don't always care about the relative position of a unit within a cube along each axis, but often we only care about its position in general. Therefore, each kind of variable also has its own axis-independent variable, which is often not present in other languages.
<br />| CubeCL | CUDA | WebGPU | Metal | | -------------- | ----------- | ---------------------- | -------------------------------- | | CUBE_COUNT | N/A | N/A | N/A | | CUBE_COUNT_X | gridDim.x | num_workgroups.x | threadgroups_per_grid.x | | CUBE_COUNT_Y | gridDim.y | num_workgroups.y | threadgroups_per_grid.y | | CUBE_COUNT_Z | gridDim.z | num_workgroups.z | threadgroups_per_grid.z | | CUBE_POS | N/A | N/A | N/A | | CUBE_POS_X | blockIdx.x | workgroup_id.x | threadgroup_position_in_grid.x | | CUBE_POS_Y | blockIdx.y | workgroup_id.y | threadgroup_position_in_grid.y | | CUBE_POS_Z | blockIdx.z | workgroup_id.z | threadgroup_position_in_grid.z | | CUBE_DIM | N/A | N/A | N/A | | CUBE_DIM_X | blockDim.x | workgroup_size.x | threads_per_threadgroup.x | | CUBE_DIM_Y | blockDim.y | workgroup_size.y | threads_per_threadgroup.y | | CUBE_DIM_Z | blockDim.z | workgroup_size.z | threads_per_threadgroup.z | | UNIT_POS | N/A | local_invocation_index | thread_index_in_threadgroup | | UNIT_POS_X | threadIdx.x | local_invocation_id.x | thread_position_in_threadgroup.x | | UNIT_POS_Y | threadIdx.y | local_invocation_id.y | thread_position_in_threadgroup.y |
