Dion
Dion optimizer algorithm
Install / Use
/learn @microsoft/DionREADME
Welcome to the Microsoft/Dion Codebase
This repository provides efficient implementations of orthonormal optimizers for distributed ML training. You can find the following optimizers:
Table of Contents
<details> <summary>Show/Hide</summary>- Requirements
- Quick Start
- Introduction
- Optimizers
- Building Parameter Groups
- Distributed Training Configuration
- Compressed Data-Parallel Gradient Sync
- Best Practices
- Experimental Features
- Citation
Requirements
This code is written for modern PyTorch (version 2.7 or newer) using DTensor-based parallelism. This includes FSDP2 with fully_shard and tensor parallelism (TP) with parallelize_module. Support for other distributed training APIs is not implemented.
Quick Start
Our implementations are available as a pip package! Install to use in your project:
pip install git+https://github.com/microsoft/dion.git
Then in your code, you can use:
from dion import Dion2, Muon, NorMuon, Dion
Please carefully go through this readme for detailed instructions on using our optimizers. There are major differences compared to PyTorch built-in optimizers, such as Adam/AdamW.
Running Our Sample Training Script
First clone this repo, then install dependencies for both Dion and training code:
git clone https://github.com/microsoft/dion.git
cd dion
pip install -e .[train]
Download pretokenized FineWeb dataset:
python data/cached_fineweb10B.py 30
Distributed Data Parallel (DDP) Training
To train a GPT-small model using Dion2 with 4 GPUs (adjust as needed for your setup):
torchrun --standalone --nproc_per_node=4 train.py --config configs/dion2_160m.yaml
This will launch Distributed Data Parallel (DDP) training.
Distributed Training: FSDP / TP / Hybrid Sharding
Fully Sharded Data Parallel (FSDP)
To enable FSDP, specify the FSDP group size using --fs_size:
torchrun --standalone --nproc_per_node=4 train.py \
--config configs/dion2_160m.yaml \
--fs_size 4
This configuration trains a GPT-small model using Dion2 with FSDP sharding across all 4 GPUs (a single FSDP group of size 4).
Hybrid Sharded Data Parallel (HSDP)
To use Hybrid Sharded Data Parallel, where multiple FSDP groups are replicated using Data Parallel (DP), set --fs_size smaller than the total number of GPUs and specify the data parallel dimension via --dp_size:
torchrun --standalone --nproc_per_node=4 train.py \
--config configs/dion2_160m.yaml \
--fs_size 2 \
--dp_size 2
This configuration creates:
- 2 FSDP groups, each spanning 2 GPUs
- 2-way data parallelism across the FSDP groups
- Total: 4 GPUs with 2-way FSDP × 2-way DP
The product dp_size × fs_size must equal world_size. Any unspecified dimension defaults to 1.
Tensor Parallelism (TP)
Note: Currently, only Dion (our legacy implementation) supports Tensor Parallelism.
You can combine all three parallelism strategies (DP × FSDP × TP). For example, a 2 × 2 × 2 configuration across 8 GPUs:
torchrun --standalone --nproc_per_node=8 train.py \
--config configs/dion_160m.yaml \
--dp_size 2 \
--fs_size 2 \
--tp_size 2
This configuration creates:
- 2-way data parallelism (outer replication)
- 2-way FSDP
- 2-way tensor parallelism
- Total: 8 GPUs with 2-way DP × 2-way FSDP × 2-way TP
The product dp_size × fs_size × tp_size must equal world_size. Any unspecified dimension defaults to 1.
Introduction
Optimization algorithms are essential to training neural networks, converting gradients into model weight updates to minimize loss. For many years, the method of choice has been Adam/AdamW. However, recent work has shown that orthonormal optimizers can significantly accelerate model convergence. Check out blog posts by Jeremy Bernstein and Laker Newhouse for more details.
The practical effectiveness of orthonormal optimizers was first demonstrated by Muon in the NanoGPT speedrun, and has since been validated at scale by models such as Kimi K2 and GLM-4.5. Muon implements orthonormalization via Newton-Schulz iterations, which relies on repeated matrix-matrix multiplications. However, large-scale training relies on model sharding, where weight matrices and optimizer states are distributed across multiple processes. As discussed by Essential AI, orthonormalizing a sharded matrix with Newton-Schulz iterations involves the communication-intensive procedure of reconstructing the full matrices from their individual shards.
Dion/Dion2 are our methods for building a scalable, communication-efficient optimizer. Like Muon, they compute matrix weight updates based on matrix orthonormalization and share similar practical benefits. The key difference is that Dion and Dion2 shirnk the matrix before orthonormalization, reducing both computational and communication costs. Dion uses power iteration to compute a low-rank approximation, while Dion2 applies a simple submatrix-selection procedure. To reduce information loss, both methods include an error-feedback mechanism that tracks the discrepancy between the original matrix and its compressed approximation.
Optimizers
Our current implementations support the following parallelization techniques:
| Parallelization | Dion | Dion2 | Muon | NorMuon | |--------------------|------|-------|------|---------| | Single device | Yes | Yes | Yes | Yes | | PyTorch DDP | Yes | Yes | Yes | Yes | | PyTorch FSDP2 | Yes | Yes | Yes | Yes | | PyTorch FSDP2 + TP | Yes | No | No | No |
For faster performance, these optimizers will process parameters in batches and interleave multiple batches to overlap compute with communication.
We include optimizer implementations in the dion/ directory of this repo.
dion.py: High-performance version of Dion. Depending on how each batch of matrices is sharded, we select the best communication patterns to compute Dion's orthonormal update. All-reduce operations may be split into reduce-scatter and all-gather across the batch dimension to more efficiently distribute work and avoid redundant computation.muon.py: High-performance version of Muon. For sharded matrices, all-to-all communication is used to simultaneously unshard and distribute a batch of matrices. For replicated matrices, Muon will distribute work across all devices and all-gather final results.dion2.py: High-performance implementation of Dion2, using a similar all-to-all communication pattern for distributed orthonormalization. Only an α-fraction of the momentum matrix is communicated and orthonormalized, significantly reducing both communication overhead and computation cost.normuon.py: A variant of the Muon optimizer that introduces neuron-wise normalization to improve stability and convergence efficiency, modified to take similar arguments asmuon.py. See the paper for more details.
We also provide some reference implementations:
dion_reference.py: An implementation without batching, communication overlapping, or split all-reduce. This version of Dion is intended to closely follow the algorithms as described in our Dion paper.dion_simple.py: A simplified illustration of the Dion update rule in a single Python function, provided for educational value.muon_reference.py: A version of Muon by Moonshot AI, modified to take similar arguments asmuon.py.
Building Parameter Groups
Unlike typical PyTorch optimizers (e.g. Adam/AdamW), Dion and Muon require separating your model's parameters into different groups (same in spirit as Modula). These orthonormal optimization algorithms are only applicable to two-dimensional matrix weights. Non-matrix parameters require a different scalar optimizer algorithm (element-wise updates) and may also use a different learning rate. We currently support Lion and AdamW.
The details of parameter grouping are dependent on model architecture and implementation. Therefore, we leave it up to you to categorize your model's parameters and create the necessary parameter groups.
- In transformer models and many other neural networks, most parameters are
nn.Linearlayers with two-dimensional weight matrices. These parameters should use Dion or Muon. A shape-dependent learning rate scale facto
