SkillAgentSearch skills...

MARS

The official implementation of MARS: Unleashing the Power of Variance Reduction for Training Large Models

Install / Use

/learn @AGI-Arena/MARS
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

MARS: Unleashing the Power of Variance Reduction for Training Large Models

This repository contains the official code for the paper MARS: Unleashing the Power of Variance Reduction for Training Large Models.

Authors: Huizhuo Yuan*, Yifeng Liu*, Shuang Wu, Xun Zhou, Quanquan Gu

🔔 NEWS

  • [03/04/2025] MARS-AdamW CUDA implementation is available.
  • [05/01/2025] Our paper is accepted by ICML 2025 🎉🎉.
  • [02/10/2025] Our paper is updated on ArXiv: https://arxiv.org/pdf/2411.10438v2.
  • [01/12/2025] Update scripts for reproducing GPT-2 XL results and FineWeb-Edu results.
  • [01/12/2025] Our pretraining results on FineWeb-Edu are available. GPT-2 XL reaches a Hellaswag accuracy of 56.52 in 50B tokens.
  • [11/26/2024] Vision tasks added.
  • [11/18/2024] Our code is open-sourced!
  • [11/15/2024] Our paper is released on arXiv: https://arxiv.org/abs/2411.10438.

About MARS

MARS (Make vAriance Reduction Shine) is a unified optimization framework designed to address the inherent challenges of training large models. Traditional adaptive gradient methods like Adam and AdamW often suffer from high stochastic gradient variance, while variance reduction techniques have struggled to gain practical impact in deep learning. At its core, MARS comprises two major components: (1) a scaled stochastic recursive momentum, which provides a variance-reduced estimator of the full gradient for better gradient complexity; and (2) the preconditioned update, which approximates the second-order Newton's method for better per-iteration complexity. By combining preconditioned gradient methods with variance reduction, MARS achieves the best of both worlds, accelerating the search for critical points in optimization.

The MARS framework is built on the following preconditioned variance-reduced updates

$$ \mathbf{c}_t = \nabla f(\mathbf{x}_t, \mathbf{\xi}_t)+\underbrace{{\color{red}\gamma_t} \frac{\beta_{1}}{1-\beta_{1}} \left(\nabla f(\mathbf{x}_t, \mathbf{\xi}_t)-\nabla f(\mathbf{x}_{t-1}, \mathbf{\xi}_t)\right)}_{\text{scaled gradient correction}} $$

$$ \tilde{\mathbf{c}}_t = \text{Clip}(\mathbf{c}_t,1) = \begin{cases} \frac{\mathbf{c}_t}{\|\mathbf{c}_t\|_2} & \text{if } \|\mathbf{c}_t\|_2 > 1,\ \mathbf{c}_t & \text{otherwise}. \end{cases} $$

$$ \mathbf{m}_t = \beta_1 \mathbf{m}_{t-1} + (1-\beta_{1})\tilde{\mathbf{c}}_t $$

$$ \mathbf{x}_{t+1} = \arg\min_{\mathbf{x} \in \mathbb{R}^d} \left\{\eta_t \left\langle \mathbf{m}_t, \mathbf{x} \right\rangle + \frac{1}{2} \|\mathbf{x} - \mathbf{x}_t \|_{\mathbf{H}_t}^2\right\} $$

Here ${\color{red}\gamma_t}$ is a scaling parameter that controls the strength of gradient correction.

Instantiations of MARS

Under the MARS framework, we provide three instantiations based on different Hessian matrix approximations: MARS-AdamW, MARS-Lion, and MARS-Shampoo. Please note that the hyperparameters in this framework are tuned on MARS-AdamW. When using other instantiations, it is essential to tune the hyperparameters—particularly the learning rates—for optimal performance.

MARS-AdamW

(Enable with mars_type="mars-adamw" in mars.py)

The Hessian matrix approximation is defined as:

$$ \mathbf{v}_t =\beta_2 \mathbf{v}_{t-1}+(1-\beta_2) \big(\nabla f(\mathbf{x}_t, \mathbf{\xi}_t)\big)^2 $$

$$ \mathbf{H}_t := \sqrt{\text{diag}\Big(\mathbf{v}_t\Big)}\cdot \frac{1 - \beta_1^t}{\sqrt{1 - \beta_2^t}}. $$

MARS-Lion

(Enable with mars_type="mars-lion" in mars.py)

The Hessian matrix approximation is defined as:

$$ \mathbf{H}_t := \sqrt{\text{diag}(\mathbf{m}_t^2)}. $$

MARS-Shampoo

(Enable with mars_type="mars-shampoo" in mars.py)

The preconditioner can be seen as an orthogonal mapping operator:

$$ \mathbf{U}_t, \mathbf{\Sigma}_t, \mathbf{V}_t = \text{SVD}(\mathbf{G}_t),\qquad \mathbf{x}_{t+1} =\mathbf{x}_t-\eta_t\mathbf{U}_t\mathbf{V}_t^\top. $$

In practice, we use the Newton-Schulz iteration to accelerate and approximate the solution of SVD problem.

Performance of MARS Compared to Baselines

Experiments on OpenWebText

Experimental results for MARS are based on the MARS-AdamW instantiation, unless otherwise stated. In our experiments, gradients are calculated once per sample and per update (MARS-approx in our paper). Performing exact gradient computation with two evaluations per update, as in the exact form of MARS, can slightly enhance performance but at the cost of doubling the computational expense. For more details, refer to our paper.

MARS consistently outperforms AdamW and the Muon optimizers across GPT-2 models:

| GPT-2 small | GPT-2 medium | GPT-2 large | | ------------------------------------------------ | ------------------------------------------------- | ------------------------------------------------ | | <img src="assets/val_small.png" width="350"> | <img src="assets/val_medium.png" width="350"> | <img src="assets/val_large.png" width="350"> |

| Best Val Loss | GPT-2 Small (5B tokens) | GPT-2 Medium (5B tokens) | GPT-2 Large (5B tokens) | GPT-2 Small (20B tokens) | GPT-2 Medium (20B tokens) | GPT-2 Large (20B tokens) | GPT-2 Small (50B tokens) | GPT-2 Medium (50B tokens) | GPT-2 Large (50B tokens) | | --------------------- | ----------------------- | ------------------------ | ----------------------- | ------------------------ | ------------------------- | ------------------------ | ------------------------ | ------------------------- | ------------------------ | | AdamW | 3.193 | 3.084 | 3.013 | 3.024 | 2.821 | 2.741 | 2.885 | 2.691 | 2.561 | | Muon | 3.165 | 3.009 | 2.915 | 3.006 | 2.813 | 2.691 | 2.901 | 2.688 | 2.573 | | MARS-exact | 3.107 | - | - | 2.980 | - | - | 2.847 | - | - | | MARS-approx | 3.108 | 2.969 | 2.876 | 2.981 | 2.763 | 2.647 | 2.849 | 2.636 | 2.518 |

Efficiency of MARS

The MARS algorithm can achieve better performance not only within the same number of training steps, but also within the same training time:

| GPT-2 small | GPT-2 medium | GPT-2 large | | ------------------------------------------------- | -------------------------------------------------- | ------------------------------------------------- | | <img src="assets/time_small.png" width="350"> | <img src="assets/time_medium.png" width="350"> | <img src="assets/time_large.png" width="350"> |


Experiments on FineWeb-Edu

Below are the training and validation loss curves for both GPT‑2 Small and GPT‑2 XL when using our MARS approach versus AdamW. As you can see, MARS often yields faster convergence and consistently lower losses across different training steps.

| Model | GPT-2 small | GPT-2 XL | | ----------------------- | -------------------------------------------------------- | --------------------------------------------------------- | | Train Loss | <img src="assets/small_train.png" width="350"> | <img src="assets/xl_train.png" width="350"> | | Validation Loss | <img src="assets/small_val.png" width="350"> | <img src="assets/xl_val.png" width="350"> |

Evaluation Metrics

Below, we present the evaluation metrics on the FineWeb-Edu dataset for both GPT‑2 Small and GPT‑2 XL, comparing OpenAI GPT2 baseline, AdamW, and our MARS-AdamW optimizer.

<img src="assets/fineweb_hella.png" width="350">

Results on GPT-2 small

MARS-AdamW shows a clear improvement over AdamW and the OpenAI baseline across multiple tasks, with the highest average score of 45.93 on GPT‑2 Small. | Method/Task | ARC-E | ARC-C | BoolQ | HellaSwag | OBQA | PIQA | WG | MMLU | SciQ | Avg. | |--------------|-------|-------|-------|-----------|-------|-------|-------|-------|-------|-------| | OpenAI-Comm. | 39.48 | 22.70 | 48.72 | 31.14 | 27.20 | 62.51 | 51.62 | 22.92 | 64.40 | 41.19 | | AdamW | 51.43 | 26.54 | 55.78 | 36.26 | 30.60 | 64.53 | 50.36 | 24.49 | 71.50 | 45.72 | | MARS-AdamW | 52.23 | 27.39 | 55.84 | 36.91 | 32.20 | 64.80 | 49.96 | 22.95 | 71.10 | 45.93 |

Results on GPT-2 XL

On GPT‑2 XL, MARS-AdamW continues to outperform AdamW across most tasks, delivering an impressive HellaSwag accuracy of 56.52.

| Method/Task | ARC-E | ARC-C | BoolQ | HellaSwag | OBQA | PIQA | WG | MMLU | SciQ | Avg. | |--------------|-------|-------|-------|-----------|-------|-------|-------|-------|-------|-------| | OpenAI-Comm. | 51.05 | 28.50 | 61.77 | 50.89 | 32.00 |

View on GitHub
GitHub Stars718
CategoryDevelopment
Updated3d ago
Forks49

Languages

Python

Security Score

100/100

Audited on Apr 3, 2026

No findings