SkillAgentSearch skills...

Mup

maximal update parametrization (µP)

Install / Use

/learn @microsoft/Mup
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

Maximal Update Parametrization (μP) and Hyperparameter Transfer (μTransfer)

Paper link | Blog link | YouTube link

In Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer, we show that optimal hyperparameters become stable across neural network sizes when we parametrize the model in maximal update parametrization (μP). This can be used to tune extremely large neural networks such as large pretrained transformers, as we have done in our work. More generally, μP reduces the fragility and uncertainty when transitioning from exploration to scaling up, which are not often talked about explicitly in the deep learning literature.

<font size="1"> Figure above: Training loss against learning rate on Transformers of varying d_model trained with Adam.</font>

μP turns out to be the unique "natural" parametrization that has this hyperparameter stability property across width, as empirically verified in the gif below on MLPs trained with SGD. Here, across time, we interpolate between PyTorch default and μP's learning rate and initialization scalings (right), and we scale up the width-256 model (log2(width)=8) to width 2^13 = 8192 using this interpolated scaling rule (left).

This repo contains the source code for the mup package, our tool that makes the implementation of μP in Pytorch models effortless and less error-prone.

Table of Contents

Installation

pip install mup

Install From Source

Clone this repo, change to its directory, and do

pip install -r requirements.txt
pip install -e .

Basic Usage

from mup import MuReadout, make_base_shapes, set_base_shapes, MuSGD, MuAdam

class MyModel(nn.Module):
    def __init__(self, width, ...):
        ...
        ### In model definition, replace output layer with MuReadout
        # readout = nn.Linear(width, d_out)
        readout = MuReadout(width, d_out)
        ### If tying weights with an input nn.Embedding layer, do
        # readout = MuSharedReadout(input_layer.weight)
        ...
    def forward(self, ...):
        ...
        ### If using a transformer, make sure to use
        ###   1/d instead of 1/sqrt(d) attention scaling
        # attention_scores = query @ key.T / d**0.5
        attention_scores = query @ key.T * 8 / d
        ### We use 8/d instead of 1/d here to be backward compatible
        ###   with 1/d**0.5 when d=64, a common head dimension.
        ...

### Instantiate a base model
base_model = MyModel(width=1)
### Optionally, use `torchdistx.deferred_init.deferred_init` to avoid instantiating the parameters
### Simply install `torchdistx` and use
# base_model = torchdistx.deferred_init.deferred_init(MyModel, width=1)
### Instantiate a "delta" model that differs from the base model
###   in all dimensions ("widths") that one wishes to scale.
### Here it's simple, but e.g., in a Transformer, you may want to scale
###   both nhead and dhead, so the delta model should differ in both.
delta_model = MyModel(width=2) # Optionally use `torchdistx` to avoid instantiating

### Instantiate the target model (the model you actually want to train).
### This should be the same as the base model except 
###   the widths could be potentially different.
### In particular, base_model and model should have the same depth.
model = MyModel(width=100)

### Set base shapes
### When `model` has same parameter shapes as `base_model`,
###   `model` behaves exactly the same as `base_model`
###   (which is in PyTorch's default parametrization).
###   This provides backward compatibility at this particular model size.
###   Otherwise, `model`'s init and LR are scaled by μP.
### IMPORTANT: this should be called as soon as possible,
###   before re-initialization and optimizer definition.
set_base_shapes(model, base_model, delta=delta_model)

### Alternatively, one can save the base model shapes in a file
# make_base_shapes(base_model, delta_model, filename)
### and later set base shapes directly from the filename
# set_base_shapes(model, filename)
### This is useful when one cannot fit both 
###   base_model and model in memory at the same time

### Replace your custom init, if any
for param in model.parameters():
    ### If initializing manually with fixed std or bounds,
    ### then replace with same function from mup.init
    # torch.nn.init.uniform_(param, -0.1, 0.1)
    mup.init.uniform_(param, -0.1, 0.1)
    ### Likewise, if using
    ###   `xavier_uniform_, xavier_normal_, kaiming_uniform_, kaiming_normal_`
    ### from `torch.nn.init`, replace with the same functions from `mup.init`

### Use the optimizers from `mup.optim` instead of `torch.optim`
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer = MuSGD(model.parameters(), lr=0.1)

### Then just train normally

Note the base and delta models do not need to be trained --- we are only extracting parameter shape information from them. Therefore, optionally, we can avoid instantiating these potentially large models by using the deferred_init function in torchdistx. After installing torchdistx, use torchdistx.deferred_init.deferred_init(MyModel, **args) instead of MyModel(**args). See this page for more detail. In the MLP and Transformer examples (not mutransformers) we provided, you can activate this feature by passing --deferred_init.

How mup Works Under the Hood

By invoking set_base_shapes(model, ...), each parameter tensor p of model gets a p.infshape attribute that stores, for each of its dimensions, the corresponding base dimension and whether that dimension should be considered infinite (i.e. will be scaled up/down, e.g., d_model of a Transformer) or finite (i.e. will be fixed, e.g., vocabulary size). This information is used in the initializers and optimizers to automatically scale the parameters or learning rates to be compliant with μP. For example, the Adam learning rate of hidden weights p is calculated as globalLR / p.infshape.width_mult(), where p.infshape.width_mult() essentially calculates fan_in / base_fan_in.

Current Limitations

  • set_base_shapes(model, ...) assumes that model has just been randomly initialized in the standard way and rescales its parameters using the base shape information so the model is in μP.
  • If you want data parallelism, please use torch.nn.parallel.DistributedDataParallel instead of torch.nn.DataParallel. This is because the latter removes the attributes the mup package adds to each parameter tensor of the model. Also, for performance, pytorch recommends the former anyway.
  • We scale the learning rate according to μP explicitly by creating refined parameter groups from what is passed to the mup optimizer and by manipulating the lr attribute in those groups. This is compatible with PyTorch's learning rate schedulers. However, if you roll your own, make sure the scheduler sets the learning rate relative to what is currently in the refined parameter groups. The following is an example of what not to do and what is OK:
optimizer = mup.MuAdam(model.parameters(), lr=1e-3)
for pg in optimizer.param_groups:
  # what NOT to do: setting learning rate absolutely
  # pg['lr'] = 1e-3 * 2
  # what is an OK alternative: setting it relatively
  pg['lr'] *= 2
  • By default, any parameter matrix that has 2 "infinite" dimensions (i.e. dimensions that are different from base dimensions) are considered by mup to have shape (fan_out, fan_in), i.e., in the forward pass, this matrix multiplies its input on the right. This is the case with all nn.Linear weights from pytorch. If you have a custom parameter, say W, that violates this convention, you can manually set W.infshape.main_idx = 0; W.infshape.main = W.infshape[0] to let mup know that its shape corresponds to (fan_in, fan_out). A similar discussion applies if you have a parameter tensor with many dimensions but exactly 2 "infinite" dimensions, for which the first is fan_in and the second is fan_out.
  • Currently, torch.save does not save the infshape objects attached to each parameter tensor. Before this is fixed, you would have to set base shape manually after loading a model checkpoint like so:
model = torch.load('my/model/path.pt')
# Important: note the flag `rescale_params=False`!
set_base_shapes(model, 'my/base/shape/path.bsh', rescale_params=False)

(set_base_shapes by default rescales the parameters of model, assuming it's freshly initialized by PyTorch, to be consistent with μP. The rescale_params=False flag turns off this behavior.)

Checking Correctness of Parametrization

Coord Check

Just like gradient checking is a simple way of verifying the correctness of an autograd implementation, *coordinate c

Related Skills

View on GitHub
GitHub Stars1.7k
CategoryEducation
Updated1d ago
Forks105

Languages

Jupyter Notebook

Security Score

100/100

Audited on Mar 31, 2026

No findings