Rigl
End-to-end training of sparse deep neural networks with little-to-no performance loss.
Install / Use
/learn @google-research/RiglREADME
Rigging the Lottery: Making All Tickets Winners
<img src="https://github.com/google-research/rigl/blob/master/imgs/flops8.jpg" alt="80% Sparse Resnet-50" width="45%" align="middle">Paper: https://arxiv.org/abs/1911.11134
15min Presentation [pml4dc] [icml]
ML Reproducibility Challenge 2020 report
Colabs for Calculating FLOPs of Sparse Models
Best Sparse Models
Parameters are float, so each parameter is represented with 4 bytes. Uniform sparsity distribution keeps first layer dense therefore have slightly larger size and parameters. ERK applies to all layers except for 99% sparse model, in which we set the first layer to be dense, since otherwise we observe much worse performance.
Extended Training Results
Performance of RigL increases significantly with extended training iterations. In this section we extend the training of sparse models by 5x. Note that sparse models require much less FLOPs per training iteration and therefore most of the extended trainings cost less FLOPs than baseline dense training.
Observing improving performance we wanted to understand where the performance of sparse networks saturates. Longest training we ran had 100x training length of the original 100 epoch ImageNet training. This training costs 5.8x of the original dense training FLOPS and the resulting 99% sparse Resnet-50 achieves an impressive 68.15% test accuracy (vs 5x training accuracy of 61.86%).
| S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt | |-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------| | - (DENSE) | 0 | 3.2e18 | 8.2e9 | 102.122 | 76.8 | - | | ERK | 0.8 | 2.09x | 0.42x | 23.683 | 77.17 | link | | Uniform | 0.8 | 1.14x | 0.23x | 23.685 | 76.71 | link | | ERK | 0.9 | 1.23x | 0.24x | 13.499 | 76.42 | link | | Uniform | 0.9 | 0.66x | 0.13x | 13.532 | 75.73 | link | | ERK | 0.95 | 0.63x | 0.12x | 8.399 | 74.63 | link | | Uniform | 0.95 | 0.42x | 0.08x | 8.433 | 73.22 | link | | ERK | 0.965 | 0.45x | 0.09x | 6.904 | 72.77 | link | | Uniform | 0.965 | 0.34x | 0.07x | 6.904 | 71.31 | link | | ERK | 0.99 | 0.29x | 0.05x | 4.354 | 61.86 | link | | ERK | 0.99 | 0.58x | 0.05x | 4.354 | 63.89 | link | | ERK | 0.99 | 2.32x | 0.05x | 4.354 | 66.94 | link | | ERK | 0.99 | 5.8x | 0.05x | 4.354 | 68.15 | link |
We also ran extended training runs with MobileNet-v1. Again training 100x more, we were not able saturate the performance. Training longer consistently achieved better results.
| S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt | |-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------| | - (DENSE) | 0 | 4.5e17 | 1.14e9 | 16.864 | 72.1 | - | | ERK | 0.89 | 1.39x | 0.21x | 2.392 | 69.31 | link | | ERK | 0.89 | 2.79x | 0.21x | 2.392 | 70.63 | link | | Uniform | 0.89 | 1.25x | 0.09x | 2.392 | 69.28 | link | | Uniform | 0.89 | 6.25x | 0.09x | 2.392 | 70.25 | link | | Uniform | 0.89 | 12.5x | 0.09x | 2.392 | 70.59 | link |
1x Training Results
| S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt | |-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------| | ERK | 0.8 | 0.42x | 0.42x | 23.683 | 75.12 | link | | Uniform | 0.8 | 0.23x | 0.23x | 23.685 | 74.60 | link | | ERK | 0.9 | 0.24x | 0.24x | 13.499 | 73.07 | link | | Uniform | 0.9 | 0.13x | 0.13x | 13.532 | 72.02 | link |
Results w/o label smoothing
| S. Distribution | Sparsity | Training FLOPs | Inference FLOPs | Model Size (Bytes) | Top-1 Acc | Ckpt | |-----------------|-----------|----------------|-----------------|-------------------------------------|-----------|--------------| | ERK | 0.8 | 0.42x | 0.42x | 23.683 | 75.02 | link | | ERK | 0.8 | 2.09x | 0.42x | 23.683 | 76.17 | link | | ERK | 0.9 | 0.24x | 0.24x | 13.499 | 73.4 | link | | ERK | 0.9 | 1.23x | 0.24x | 13.499 | 75.9 | link | | ERK | 0.95 | 0.13x | 0.12x | 8.399 | 70.39 | link | | ERK | 0.95 | 0.63x | 0.12x | 8.399 | 74.36 | link |
Evaluating checkpoints
Download the checkpoints and run the evaluation on ERK checkpoints with the following:
python imagenet_train_eval.py --mode=eval_once --output_dir=path/to/ckpt/folder \
--eval_once_ckpt_prefix=model.ckpt-3200000 --use_folder_stub=False \
--training_method=rigl --mask_init_method=erdos_renyi_kernel \
--first_layer_sparsity=-1
When running checkpoints with uniform sparsity distribution use --mask_init_method=random and --first_layer_sparsity=0. Set
--model_architecture=mobilenet_v1 when evaluating mobilenet checkpoints.
Sparse Training Algorithms
In this repository we implement following dynamic sparsity strategies:
-
SET: Implements Sparse Evalutionary Training (SET) which corresponds to replacing low magnitude connections randomly with new ones.
-
SNFS: Implements momentum based training without sparsity re-distribution:
-
RigL: Our method, RigL, removes a fraction of connections based on weight magnitudes and activates new ones using instantaneous gradient information.
And the following one-shot pruning algorithm:
- SNIP: Single-shot Network Pruning based on connection sensitivity prunes the least salient connections b
Related Skills
YC-Killer
2.7kA library of enterprise-grade AI agents designed to democratize artificial intelligence and provide free, open-source alternatives to overvalued Y Combinator startups. If you are excited about democratizing AI access & AI agents, please star ⭐️ this repository and use the link in the readme to join our open source AI research team.
best-practices-researcher
The most comprehensive Claude Code skills registry | Web Search: https://skills-registry-web.vercel.app
groundhog
400Groundhog's primary purpose is to teach people how Cursor and all these other coding agents work under the hood. If you understand how these coding assistants work from first principles, then you can drive these tools harder (or perhaps make your own!).
last30days-skill
20.0kAI agent skill that researches any topic across Reddit, X, YouTube, HN, Polymarket, and the web - then synthesizes a grounded summary
