LayerSkip
Code for "LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding", ACL 2024
Install / Use
/learn @facebookresearch/LayerSkipREADME
LayerSkip
<a href='https://huggingface.co/collections/facebook/layerskip-666b25c50c8ae90e1965727a'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
This code base is the implementation of LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding.
<div align="center"> <img src="https://github.com/user-attachments/assets/1fdd91d9-37ea-4b42-b5be-579fb5e1f2f2" width="500"> </div>News
- [2024/11] 🤗 LayerSkip inference has been integrated into Hugging Face
transformers. - [2024/12] 🔥 LayerSkip training recipe has been integrated into PyTorch
torchtune. - [2025/03] 🤗 LayerSkip training recipe implemented in Hugging Face
trl
Getting Started
- Clone repo:
$ git clone git@github.com:facebookresearch/LayerSkip.git
$ cd LayerSkip
- Setup environment:
$ conda create --name layer_skip python=3.10
$ conda activate layer_skip
$ pip install -r requirements.txt
-
Access models: In order to observe speedup, you need to access LLMs that have been trained using the LayerSkip recipe. We provide 6 checkpoints on HuggingFace of different Llama models continually pretrained using the LayerSkip recipe:
In order to access each model:
- Visit the model's corresponding link above, make sure you are logged on the HuggingFace website with your account.
- Fill the request form and submit it. Approval may take a while and you should receive an email notification to notify you that permission to the model is granted.
- Follow the steps here to obtain a user access token.
- In the command-line run
huggingface-cli login, and you will be prompted to provide the token you have obtained in Step 3.
Once you run those steps, the commands below to run the LayerSkip checkpoints should work.
Generate
To run one of our models in interactive mode using regular autoregressive decoding:
$ torchrun generate.py --model facebook/layerskip-llama2-7B \
--sample True \
--max_steps 512
In order to observe speedup, you need to use self-speculative decoding to generate tokens, and specify --exit_layer, the layer the draft stage to exit at, and --num_speculations, the number of draft tokens:
$ torchrun generate.py --model facebook/layerskip-llama2-7B \
--sample True \
--max_steps 512 \
--generation_strategy self_speculative \
--exit_layer 8 \
--num_speculations 6
Tips:
- You may change
--modelto any HuggingFace model but in order to observe speedup with self-speculative decoding, use a model trained using the LayerSkip recipe, such as those we have open sourced on HuggingFace. - By default we enable sampling. You may change the sampling behaviour using the
--sample,--temperature,--top_p, and--top_karguments. - You may run
python generate.py --helpfor details on different command-line arguments.
Benchmark
To benchmark on a dataset:
$ torchrun benchmark.py --model facebook/layerskip-llama2-7B \
--dataset cnn_dm_summarization \
--num_samples 100 \
--generation_strategy self_speculative \
--exit_layer 8 \
--num_speculations 6 \
--output_dir ./logs
Tips:
- You can specify different tasks by modifying the
--datasetargument:cnn_dm_summarization: CNN/DM Summarizationxsum_summarization: XSUM Summarizationcnn_dm_lm: CNN/DM Language Modeling (given the first few words of an article, generate the remaining article)human_eval: HumanEval Coding
- By default, the tasks run as 0-shot. You can change to any specified
n-shot by specifying the--n_shotargument. - By default we enable sampling, while the results reported in the paper were greedy decoding without sampling. You may change the sampling behaviour using the
--sample,--temperature,--top_p, and--top_karguments. - You may run
python benchmark.py --helpfor details on different command-line arguments.
Evaluate
We have integrated our generation scripts with Eleuther Language Model Evaluation Harness to enable a large number of tasks and properly post-process generated text.
$ torchrun eval.py --model facebook/layerskip-llama2-7B \
--tasks gsm8k \
--limit 10 \
--generation_strategy self_speculative \
--exit_layer 8 \
--num_speculations 6 \
--output_dir ./logs
Tips:
- Note that with speculative decoding we can only obtain speedups from generation tasks (e.g.,
gsm8korcnn_dailymail), while classificaton tasks, i.e., multiple choice question tasks (e.g.,piqa,social_iqa) or True/False question tasks (e.g.,boolq) will not lead to speedup. - You can specify arbitrary number of tasks supported by Eleuther Evaluation Harness using the
--tasksargument. To get a list of all of possible tasks, check this link. - Similar to the
generate.pyandbenchmark.pyscripts, you may specify different models, datasets, and sampling parameters - You may run
python benchmark.py --helpfor details on different command-line arguments.
Sweep
Our inference hyperparameters, exit_layer and num_speculations determine the speedup during inference:
exit_layer:- smaller means a faster but less accurate draft stage
- larger means a more accurate but slower draft stage
num_speculations:- smaller means higher acceptance rate but verification stage will amortize less the draft stage
- learger means verification stage will better amortize the draft stage but acceptance rate decreases
The optimal combination of exit_layer and num_speculations may change with the model, dataset and sampling parameters. Hence, we provided a script to sweep over a grid of different exit_layer and num_speculations:
$ torchrun sweep.py --model facebook/layerskip-llama2-7B \
--dataset human_eval \
--generation_strategy self_speculative \
--num_samples 150 \
--max_steps 256 \
--output_dir ./logs/ \
--sample False
This will create a CSV file in the directory specified in the --outpu_dir argument.
Tips:
- Similar to the
generate.pyandbenchmark.pyscripts, you may specify different models, datasets, and sampling parameters - You may run
python sweep.py --helpfor details on different command-line arguments.
Correctness
In order to verify that the generated tokens of our self-speculative decoding algorithm are correct, we have created a script to compare the outputs of autoregressive decoding with self-speculative decoding. Note that the outputs we can only guarantee equivalence when there is no sampling (i.e., --sample False):
$ torchrun correctness.py --model facebook/layerskip-llama2-7B \
--dataset human_eval \
--generation_strategy self_speculative \
--num_speculations 6 \
--exit_layer 4 \
--num_samples 10 \
--sample False \
--output_dir ./logs
Using Docker
Kindy check DOCKER.md to setup the project using docker
Other Implementations
We also have other implementations of LayerSkip inference:
- gpt-fast: gpt-fast is a simple and efficient pytorch-native transformer text generation. We have implemented LayerSkip in the gpt-fast codebase to enable compouding it with other optimizations such as
torch.compile(), quantization, and tensor parallelism. - Native HuggingFace: in the model card of each of our HuggingFace models, we have provided simple code snippets that leverages HuggingFace speculative decoding capabilities using a simple trick to clone the earlier layers of the main model without cloning its weights. Although this implementation is simple and does not require implementing other functions or importing other libraries, it does not share the KV cache or execution between the draft and verification stages.
Training
Our training implementation is work-in-progress. You can check this pull request for details and discussions.
Tests
To run unit/integration tests:
$ pytest ./tests/
`
