Feather
Feather is a module that enables effective sparsification of neural networks during training. This repository accompanies the paper "Feather: An Elegant Solution to Effective DNN Sparsification" (BMVC2023).
Install / Use
/learn @athglentis/FeatherREADME
Feather: An Elegant Solution to Effective DNN Sparsification
This repository contains the source code accompanying the accepted paper at the 2023 British Machine Vision Conference (BMVC), titled Feather: An Elegant Solution to Effective DNN Sparsification by Athanasios Glentis Georgoulakis, George Retsinas and Petros Maragos.
Overview

Figure: (a) The Feather module, utilizing the new thresholding operator and the gradient scaling mask (b) the proposed family of thresholding operators for varying values of $p$. We adopt $p=3$, resulting to a fine balance between the two extremes, hard and soft thresholding respectively.
Feather is a module that enables effective sparsification of neural networks during the standard course of training. The pruning process relies on an enhanced version of the Straight-Through Estimator (STE), utilizing a new thresholding operator and a gradient scaling technique, resulting into sparse yet highly accurate models, suitable for compact applications.
Feather is versatile and not bound to a particular pruning framework. For the case of using a backbone based on global magnitude thresholding (i.e. a single threshold selected for all layers) and an incrementally increasing sparsity ratio over the training process, Feather(-Global) results to sparse models with the exact requested sparsity at the end of training which are more accurate than the current state-of-the-art, by a considerable margin.
Library Usage
We provide a sketch of how the library that performs DNN pruning using Feather with the Global pruning backbone is used:
import torch
from sparse_utils import Pruner
train_loader = ...
epochs = ...
model = ...
# create a Pruner class instance
pruner = Pruner(model, device=..., final_rate=..., nbatches=..., epochs=...)
optimizer = ...
loss_fn = ...
for epoch in range(epochs):
for data, target in train_loader:
# update the pruning threshold based on the iteration number and the scheduler used
pruner.update_thresh()
output = model(data)
loss = loss_fn(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# update the pruning threshold after last step of the optimizer
pruner.update_thresh(end_of_batch=True)
# finalize sparse model
pruner.desparsify()
Provided files
-
sparse_utils.py: contains an implementation of Feather and the global magnitude pruning framework (Feather-Global) as described in our paper (including all necessary utility functions).
-
main.py: contains an example use case of Feather-Global on CIFAR-100 (using model architectures ResNet-20, MobileNetV1 and DenseNet40-24 provided in archs directory).
-
args.py: contains the user-defined arguments regarding training and sparsification choices.
Tested on PyTorch 1.13 (numpy, torch, torchvision, tensorboard & argparse packages are required)
Main options: (--gpu, --batch-size, --lr, --wd, --epochs, --model, --ptarget, --sname)
- gpu: select GPU device id
- batch-size: batch size for training (default: 128)
- lr: initial learning rate (default: 0.1, existing scheduler is Cosine Annealing with no warm restarts)
- wd: weight decay (default: 5e-4)
- epochs: number of overall epochs (default: 160)
- model: model architecture to train
- ptarget: final target pruning ratio
- sname: folder name for tensorboard log file (final name will be in the form: datetime_sname)
Examples
- python main.py --gpu=0 --wd=5e-4 --epochs=160 --model=resnet20 --ptarget=0.90 --sname='resnet20_ep=160_wd=5e-4_pt=0.90'
- python main.py --gpu=0 --wd=5e-4 --epochs=160 --model=resnet20 --ptarget=0.99 --sname='resnet20_ep=160_wd=5e-4_pt=0.99'
Citation
If you find this work useful for your research, please cite our paper:
@inproceedings{glentis2023feather,
author = {Glentis Georgoulakis, Athanasios and Retsinas, George and Maragos, Petros},
title = {Feather: An Elegant Solution to Effective DNN Sparsification},
booktitle = {Proceedings of the British Machine Vision Conference ({BMVC})},
year = {2023}
}
Related Skills
proje
Interactive vocabulary learning platform with smart flashcards and spaced repetition for effective language acquisition.
YC-Killer
2.7kA library of enterprise-grade AI agents designed to democratize artificial intelligence and provide free, open-source alternatives to overvalued Y Combinator startups. If you are excited about democratizing AI access & AI agents, please star ⭐️ this repository and use the link in the readme to join our open source AI research team.
groundhog
400Groundhog's primary purpose is to teach people how Cursor and all these other coding agents work under the hood. If you understand how these coding assistants work from first principles, then you can drive these tools harder (or perhaps make your own!).
workshop-rules
Materials used to teach the summer camp <Data Science for Kids>
