Tutel
Tutel MoE: Optimized Mixture-of-Experts Library, Support GptOss/DeepSeek/Kimi-K2/Qwen3 using FP8/NVFP4/MXFP4
Install / Use
/learn @microsoft/TutelREADME
Tutel
Tutel MoE: An Optimized Mixture-of-Experts Implementation, also the first parallel solution proposing "No-penalty Parallism/Sparsity/Capacity/.. Switching" for modern training and inference that have dynamic behaviors.
- Supported Framework: Pytorch (recommend: >= 2.0)
- Supported GPUs: CUDA(fp64/fp32/fp16/bf16), ROCm(fp64/fp32/fp16/bf16)
- Supported CPU: fp64/fp32
- Support direct NVFP4/MXFP4/BlockwiseFP8 Inference for MoE-based DeepSeek / Kimi / Qwen3 / GptOSS using A100/A800/H100/MI300/..
[!TIP]
Steps for DeepSeek V3.2 (Long-Context Mode):
[Model Downloads] pip3 install -U "huggingface_hub[cli]" --upgrade hf download nvidia/Kimi-K2.5-NVFP4 --local-dir nvidia/Kimi-K2.5-NVFP4 hf download nvidia/Kimi-K2-Thinking-NVFP4 --local-dir nvidia/Kimi-K2-Thinking-NVFP4 hf download nvidia/DeepSeek-V3.2-NVFP4 --local-dir nvidia/DeepSeek-V3.2-NVFP4 [DeepSeek V3.2 Long-Context (A100/H100/B200 only)] docker run -e LOCAL_SIZE=8 -e WORKER=1 -it --rm --ipc=host --net=host --shm-size=8g \ --ulimit memlock=-1 --ulimit stack=67108864 -v /:/host -w /host$(pwd) -v /tmp:/tmp \ -v /usr/lib/x86_64-linux-gnu/libcuda.so.1:/usr/lib/x86_64-linux-gnu/libcuda.so.1 --privileged \ tutelgroup/deepseek-671b:a100x8-chat-20260327 --serve=webui --listen_port 8000 \ --try_path nvidia/Kimi-K2.5-NVFP4 \ --try_path nvidia/Kimi-K2-Thinking-NVFP4 \ --try_path nvidia/DeepSeek-V3.2-NVFP4 \ --try_path nvidia/DeepSeek-R1-NVFP4 \ --max_seq_len 32768 [DeepSeek V3.2 Long-Context (MI300 only)] docker run -e LOCAL_SIZE=8 -e WORKER=1 -it --rm --ipc=host --net=host --shm-size=8g \ --ulimit memlock=-1 --ulimit stack=67108864 --device=/dev/kfd --device=/dev/dri --group-add=video \ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /:/host -w /host$(pwd) -v /tmp:/tmp \ tutelgroup/deepseek-671b:mi300x8-chat-20260327 --serve=webui --listen_port 8000 \ --try_path nvidia/Kimi-K2.5-NVFP4 \ --try_path nvidia/Kimi-K2-Thinking-NVFP4 \ --try_path nvidia/DeepSeek-V3.2-NVFP4 \ --try_path nvidia/DeepSeek-R1-NVFP4 \ --max_seq_len 1000000 [OpenAI/Ollama/Direct Request] curl -N -X POST http://0.0.0.0:8000/chat -d '{"text": "Write a Python code of the Quicksort algorithm."}' python3 -m tutel.examples.oai_request_stream --url '0.0.0.0:8000' --prompt 'Write a Python code of the Quicksort algorithm.' [Open-WebUI URL for Web browsers] xdg-open http://0.0.0.0:8000
[!TIP]
Steps for Microsoft VibeVoice (Multimodality Mode):
[Model Downloads] pip3 install -U "huggingface_hub[cli]" --upgrade hf download microsoft/VibeVoice-1.5B --local-dir microsoft/VibeVoice-1.5B hf download Qwen/Qwen2.5-1.5B --local-dir Qwen/Qwen2.5-1.5B hf download microsoft/VibeVoice-Large --local-dir aoi-ot/VibeVoice-Large hf download Qwen/Qwen2.5-7B --local-dir Qwen/Qwen2.5-7B [Microsoft VibeVoice (A100/H100/B200 only)] docker run -e LOCAL_SIZE=1 -it --rm -p 8001:8000 --shm-size=8g \ --ulimit memlock=-1 --ulimit stack=67108864 -v /:/host -w /host$(pwd) -v /tmp:/tmp \ -v /usr/lib/x86_64-linux-gnu/libcuda.so.1:/usr/lib/x86_64-linux-gnu/libcuda.so.1 --privileged \ -e VOICES="https://homepages.inf.ed.ac.uk/htang2/notes/speech-samples/103-1240-0000.wav" \ tutelgroup/deepseek-671b:a100x8-chat-20251222 --serve=core \ --try_path ./microsoft/VibeVoice-1.5B \ --try_path ./microsoft/VibeVoice-Large [Microsoft VibeVoice (MI300 only)] docker run -e LOCAL_SIZE=1 -it --rm -p 8001:8000 --shm-size=8g \ --ulimit memlock=-1 --ulimit stack=67108864 --device=/dev/kfd --device=/dev/dri --group-add=video \ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /:/host -w /host$(pwd) -v /tmp:/tmp \ -e VOICES="https://homepages.inf.ed.ac.uk/htang2/notes/speech-samples/103-1240-0000.wav" \ tutelgroup/deepseek-671b:mi300x8-chat-20251222 --serve=core \ --try_path ./microsoft/VibeVoice-1.5B \ --try_path ./microsoft/VibeVoice-Large [Audio Generation Request] curl -X POST http://0.0.0.0:8001/chat -d '{"text": "VibeVoice is a novel framework designed for generating expressive, long-form, multi-speaker conversational audio, such as podcasts, from text."}' > sound_output.mp3
Inference TPOS for DeepSeek-MoE/Qwen3-MoE/KimiK2-MoE/GptOSS-MoE/..:
| Model & Machine Type | Precision | SGL | Tutel | | ---- | ---- | ---- | ---- | | $deepseek-ai/DeepSeek-V3.2\ (671B,\ A100 \times 8)$ | nvfp4 | - | 102 | | $deepseek-ai/DeepSeek-V3.2\ (671B,\ MI300 \times 8)$ | nvfp4 | - | 151 | | $moonshotai/Kimi-K2-Instruct\ (1T,\ A100 \times 8)$ | nvfp4 | - | 104 | | $moonshotai/Kimi-K2-Instruct\ (1T,\ MI300 \times 8)$ | fp8b128 | 49 | 153 | | $NVFP4/Qwen3-235B-A22B-Instruct-2507-FP4(A100\times8)$ | nvfp4 | - | 114 | | $NVFP4/Qwen3-235B-A22B-Instruct-2507-FP4(MI300\times8)$ | nvfp4 | - | 122 | | $openai/gpt-oss-120b\ (120B,\ A100 \times 1)$ | mxfp4 | 127 | 212 | | $openai/gpt-oss-120b\ (120B,\ MI300 \times 1)$ | mxfp4 | 191 | 311 | | $microsoft/VibeVoice-1.5B (A100 \times 1)$ | bf16 | - | rtf=0.07 | | $microsoft/VibeVoice-1.5B (MI300 \times 1)$ | bf16 | - | rtf=0.06 |
What's New:
Image-20260327: Add support for Kimi-K2.5.
Image-20260306: Support DeepSeek V3.2 Long-context mode for A100/H100/MI300/B200.
Image-20251222: Fine-tune A100 performance for most models.
Image-20251111: Integrate Tutel LLM module into VibeVoice for accelerated inference (rtf = 0.07 for single A100).
Image-20251006: Resolve compatibility with DeepSeek-V3.2-Exp
Image-20250827: Add distributed support for OpenAI GPT-OSS 20B/120B with MXFP4 inference
Image-20250801: Support Qwen3 MoE series, integrate OpenWebUI
Image-20250712: Support Kimi K2 1TB MoE inference with NVFP4 for NVIDIA/AMD GPUs
Image-20250601: Improved decoding performance for DeepSeek 671B on MI300x to 140-150 TPS
More image versions can be found here
Tutel v0.4.2: Add R1-FP4/Qwen3MoE-FP8 Support for NVIDIA and AMD GPUs & Fast Gating APIs:
>> Example:
import torch
from tutel import ops
# Qwen3 Fast MoE Gating for 128 Experts, with Routed Weights normalized to 1.0
logits_fp32 = torch.softmax(torch.randn([32, 128]), -1, dtype=torch.float32).cuda()
topk_weights, topk_ids = ops.qwen3_moe_scaled_topk(logits_fp32)
print(topk_weights, topk_ids, topk_weights.sum(-1))
# DeepSeek V3/R1 Fast MoE Gating for 256 Experts, with Routed Weights normalized to 2.5
logits_bf16 = torch.randn([32, 256], dtype=torch.bfloat16).cuda()
correction_bias_bf16 = torch.randn([logits_bf16.size(-1)], dtype=torch.bfloat16).cuda()
topk_weights, topk_ids = ops.deepseek_moe_sigmoid_scaled_topk(logits_bf16, correction_bias_bf16, None, None)
print(topk_weights, topk_ids, topk_weights.sum(-1))
Tutel v0.4.1: Support Deepseek R1 FP8 with NVIDIA GPUs (A100 / A800)
Tutel v0.4.0: Accelerating Deepseek R1 Full-precision-Chat for AMD MI300x8:
>> Example:
# Step-1: Download Deepseek R1 671B Model
huggingface-cli download deepseek-ai/DeepSeek-R1 --local-dir ./deepseek-ai/DeepSeek-R1
# Step-2: Using 8 MI300 GPUs to Serve Deepseek R1 Chat on Local Port :8000
docker run -it --rm --ipc=host --privileged -p 8000:8000 \
-v /:/host -w /host$(pwd) tutelgroup/deepseek-671b:mi300x8-chat-20250224 \
--model_path ./deepseek-ai/DeepSeek-R1
# Step-3: Issue a Prompt Request with curl
curl -X POST http://0.0.0.0:8000/chat -d '{"text": "Calculate the result of: 1 / (sqrt(5) - sqrt(3))"}'
Tutel v0.3.3: Add all-to-all benchmark:
>> Example:
python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.bandwidth_test --size_mb=256
Tutel v0.3.2: Add tensorcore option for extra benchmarks / Extend the example for custom experts / Allow NCCL timeout settings:
>> Example of using tensorcore:
python3 -m tutel.examples.helloworld --dtype=float32
python3 -m tutel.examples.helloworld --dtype=float32 --use_tensorcore
python3 -m tutel.examples.helloworld --dtype=float16
python3 -m tutel.examples.helloworld --dtype=float16 --use_tensorcore
>> Example of custom gates/experts:
python3 -m tutel.examples.helloworld_custom_gate_expert --batch_size=16
>> Example of NCCL timeout settings:
TUTEL_GLOBAL_TIMEOUT_SEC=60 python3 -m torch.distributed.run --nproc_per_node=8 -m tutel.examples.helloworld --use_tensorcore
Tutel v0.3.1: Add NCCL all_to_all_v and all_gather_v for arbitrary-length message transfers:
>> Example:
# All_to_All_v:
python3 -m torch.distributed.run --nproc_per_node=2 --master_port=7340 -m tutel.examples.nccl_all_to_all_v
# All_Gather_v:
python3 -m torch.distributed.run --nproc_per_node=2 --master_port=7340 -m tutel.examples.nccl_all_gather_v
>> How to:
net.batch_all_to_all_v([t_x_cuda, t_y_cuda, ..], common_send_counts)
net.batch_all_gather_v([t_x_cuda, t_y_cuda, ..])
Tutel v0.3: Add Megablocks solution to improve decoder inference on single-GPU with num_local_expert >= 2:
>> Example (capacity_factor=0 required by dropless-MoE):
# Using BatchMatmul:
python3 -m tutel.examples.helloworld --megablocks_size=0 --batch_size=1 --num_tokens=32 --top=1 --eval --num_local_experts=128 --capacity_factor=0
# Using Megablocks with block_size = 1:
python3 -m tutel.examples.helloworld --megablocks_size=1 --batch_size=1 --num_tokens=32 --top=1 -
