Rai
RAI: Rust ML framework with composable transformations like JAX.
Install / Use
/learn @cksac/RaiREADME
RAI
ML framework with ergonomic APIs in Rust. Lazy computation and composable transformations like JAX.
Installation
cargo add rai
Code snippets
Function transformations (jvp, vjp, grad, value_and_grad)
use rai::{grad, Cpu, Tensor, F32};
fn f(x: &Tensor) -> Tensor {
x.sin()
}
fn main() {
let grad_fn = grad(grad(f));
let x = &Tensor::ones([1], F32, &Cpu);
let grad = grad_fn(x);
println!("{}", grad.dot_graph());
println!("{}", grad);
}
NN Modules, Optimizer and loss functions
fn loss_fn<M: TrainableModule<Input = Tensor, Output = Tensor>>(
model: &M,
input: &Tensor,
labels: &Tensor,
) -> (Tensor, Aux<Tensor>) {
let logits = model.forward(input);
let loss = softmax_cross_entropy(&logits, labels).mean(..);
(loss, Aux(logits))
}
fn train_step<M: TrainableModule<Input = Tensor, Output = Tensor>, O: Optimizer>(
optimizer: &mut O,
model: &M,
input: &Tensor,
labels: &Tensor,
) {
let vg_fn = value_and_grad(loss_fn);
let ((_loss, Aux(_logits)), (grads, ..)) = vg_fn((model, input, labels));
let mut params = optimizer.step(&grads);
eval(¶ms);
model.update_params(&mut params);
}
Examples
- linear_regression
cargo run --bin linear_regression --release
- mnist
cargo run --bin mnist --releasecargo run --bin mnist --release --features=cuda
- mnist-cnn
cargo run --bin mnist-cnn --releasecargo run --bin mnist-cnn --release --features=cuda
- phi2
cargo run --bin phi2 --releasecargo run --bin phi2 --release --features=cuda
- phi3
cargo run --bin phi3 --releasecargo run --bin phi3 --release --features=cuda
- qwen2
cargo run --bin qwen2 --releasecargo run --bin qwen2 --release --features=cuda
- gemma
- accept license agreement in https://huggingface.co/google/gemma-2b
pip install huggingface_hub- login to hf
huggingface-cli login cargo run --bin gemma --releasecargo run --bin gemma --release --features=cuda
- vit
cargo run --bin vit --releasecargo run --bin vit --release --features=cuda
LICENSE
This project is licensed under either of
- Apache License, Version 2.0, (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
- MIT license (LICENSE-MIT or http://opensource.org/licenses/MIT)
at your option.
Related Skills
node-connect
351.4kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
110.7kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
351.4kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
351.4kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
