TextPruner
A PyTorch-based model pruning toolkit for pre-trained language models
Install / Use
/learn @airaria/TextPrunerREADME
TextPruner is a model pruning toolkit for pre-trained language models. It provides low-cost and training-free methods to reduce your model size and speed up your model inference speed by removing redundant neurons.
You may also be interested in,
- Knowledge Distillation Toolkit - TextBrewer: https://github.com/airaria/TextBrewer
- Chinese MacBERT: https://github.com/ymcui/MacBERT
- Chinese ELECTRA: https://github.com/ymcui/Chinese-ELECTRA
- Chinese XLNet: https://github.com/ymcui/Chinese-XLNet
- CINO: https://github.com/ymcui/Chinese-Minority-PLM
News
-
[Mar 21, 2022] (new functionality in v1.1) Added vocabulary pruning for XLM, BART, T5 and mT5 models.
-
[Mar 4, 2022] We are delighted to announce that TextPruner paper TextPruner: A Model Pruning Toolkit for Pre-Trained Language Models has been accepted to ACL 2022 demo.
-
[Jan 26, 2022] (new functionality in v1.0.1) Added support for self-supervised pruning via
use_logitsoption inTransformerPruningConfig.
Table of Contents
<!-- TOC -->| Section | Contents | |-|-| | Introduction | Introduction to TextPruner | | Installation | Requirements and how to install | | Pruning Modes | A brief introduction to the three pruning modes | | Usage | A quick guide on how to use TextPruner | | Experiments | Pruning experiments on typical tasks | | FAQ | Frequently asked questions | | Follow Us | - |
Introduction
TextPruner is a toolkit for pruning pre-trained transformer-based language models written in PyTorch. It offers structured training-free pruning methods and a user-friendly interface.
The main features of TexPruner include:
- Compatibility: TextPruner is compatible with different NLU pre-trained models. You can use it to prune your own models for various NLP tasks as long as they are built on the standard pre-trained models.
- Usability: TextPruner can be used as a package or a CLI tool. They are both easy to use.
- Efficiency: TextPruner reduces the model size in a simple and fast way. TextPruner uses structured training-free methods to prune models. It is much faster than distillation and other pruning methods that involve training.
TextPruner currently supports vocabulary pruning and transformer pruning. For the explanation of the pruning modes, see Pruning Modes.
To use TextPruner, users can either import TextPruner into the python scripts or run the TextPruner command line tool. See the examples in Usage.
For the performance of the pruned model on typical tasks, see Experiments.
Paper: TextPruner: A Model Pruning Toolkit for Pre-Trained Language Models
Supporting Models
TextPruner currently supports the following pre-trained models in transformers:
| Model | Vocabualry Pruning | Transformer Pruning |
|-|:-:|:-:|
| BERT | :heavy_check_mark: | :heavy_check_mark: |
| ALBERT | :heavy_check_mark: | :heavy_check_mark: |
| RoBERTa | :heavy_check_mark: | :heavy_check_mark: |
| ELECTRA| :heavy_check_mark: | :heavy_check_mark: |
| XLM-RoBERTa | :heavy_check_mark: | :heavy_check_mark: |
|XLM | :heavy_check_mark: | :x: |
|BART | :heavy_check_mark: | :x: |
|T5 | :heavy_check_mark: | :x: |
|mT5 | :heavy_check_mark: | :x: |
See the online documentation for the API reference.
Installation
-
Requirements
- Python >= 3.7
- torch >= 1.7
- transformers >= 4.0
- sentencepiece
- protobuf
-
Install with pip
pip install textpruner -
Install from the source
git clone https://github.com/airaria/TextPruner.git pip install ./textpruner
Pruning Modes
In TextPruner, there are three pruning modes: vocabulary pruning, transformer pruning and pipeline pruning.

Vocabulary Pruning
The pre-trained models usually have a large vocabulary, but some tokens rarely appear in the datasets of the downstream tasks. These tokens can be removed to reduce the model size and accelerate MLM pre-training.
Transformer Pruning
AP
Another approach is pruning the transformer blocks. Some studies have shown that not all attention heads are equally important in the transformers. TextPruner reduces the model size and keeps the model performance as high as possible by locating and removing the unimportant attention heads and the feed-forward networks' neurons.
Pipeline Pruning
In pipeline pruning, TextPruner performs transformer pruning and vocabulary pruning successively to fully reduce the model size.
Usage
The pruners perform the pruning process. The configurations set their behaviors. There names are self-explained:
- Pruners
textpruner.VocabularyPrunertextpruner.TransformerPrunertextpruner.PipelinePruner
- Configurations
textpruner.GeneralConfigtextpruner.VocabularyPruningConfigtextpruner.TransformerPruningConfig
See the online documentation for the API reference.
The Configurations are explained in Configurations.
We demonstrate the basic usage below.
Vocabulary Pruning
To perform vocabulary pruning, users should provide a text file or a list of strings. The tokens that do not appear in the texts are removed from the model and the tokenizer.
See the examples at examples/vocabulary_pruning and examples/vocabulary_pruning_xnli.
Use TextPruner as a package
Pruning the vocabulary in 3 lines of code:
from textpruner import VocabularyPruner
pruner = VocabularyPruner(model, tokenizer)
pruner.prune(dataiter=texts)
modelis the pre-trained model for the MLM task or other NLP tasks.tokenizeris the corresponding tokenizer.textsis a list of strings. The tokens that do not appear in the texts are removed from the model and the tokenizer.
VocabularyPruner accepts GeneralConfig and VocabularyPruningConfig for fine control. By default we could omit them. See the API reference for details.
Use TextPruner-CLI tool
textpruner-cli \
--pruning_mode vocabulary \
--configurations gc.json vc.json \
--model_class XLMRobertaForSequenceClassification \
--tokenizer_class XLMRobertaTokenizer \
--model_path /path/to/model/and/config/directory \
--vocabulary /path/to/a/text/file
configurations: configuration files in the JSON format. See Configurations for details.model_class: The classname of the model. It must be accessible from the current directory. For example, ifmodel_classismodeling.ModelClassName, there should be amodeling.pyin the current directory. If there is no module name inmodel_class, TextPruner will try to import themodel_classfrom the transformers library, as shown above.tokenizer_class: The classname of the tokenizer. It must be accessible from the current directory. If there is no module name intokenizer_class, TextPruner will try to import thetokenizer_classfrom the transformers library.model_path: The directory that contains weight and the configurations for the model and the tokenizer.vocabulary: A text file that is used for generating new vocabulary. The tokens that do not appear in the vocabulary are removed from the model and the tokenizer.
Transformer Pruning
-
To perform transformer pruning on a dataset, a
dataloaderof the dataset should be provided. Thedataloadershould return both the inputs and the labels. -
TextPruner needs the loss returned by the model to calculate neuron importance scores. TextPruner will try to guess which element in the model output is the loss. If none of the following is true:
- the model returns a single element, which is the loss;
- the model output is a list or a tuple. Loss is its first element;
- the loss of can be accessed by
output['loss']oroutput.losswhereoutputis the model output
users should provide an
adaptorfunction (which takes the output of the model and return the loss) to theTransformerPruner.- If running in self-supervised mode, TextPruner needs the logits returned by the model to calculate importance scores. In this case, the
adaptorshould return the logits. Check theuse_logitsoption inTransformerPruningConfigfor details.
See the examples at examples/transformer_pruning.
For self-supervised pruning, see the examples examples/transformer_pruning_xnli.
Use TextPruner as a package
from textpruner import TransformerPruner, TransformerPruningConfig
transformer_pruning_config = TransformerPruningCon
