Medusa
Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads
Install / Use
/learn @FasterDecoding/MedusaREADME
<img src="assets/logo.png" alt="Medusa" width="100" align="left"><div align="center"><h1> Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads</h1></div>
<p align="center"> | <a href="https://sites.google.com/view/ medusa-llm"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2401.10774"><b>Report</b></a> | <a href="ROADMAP.md"><b>Roadmap</b></a> | </p>News 🔥
- [2024/1] Medusa technical report is now available on arXiv. We've added multiple new features, including Medusa-2 recipe for full-model training, self-distillation for adding Medusa to any fine-tuned LLM, etc. The new results show a 2.2-3.6x speedup over the original model on a range of LLMs.
Introduction
Medusa is a simple framework that democratizes the acceleration techniques for LLM generation with multiple decoding heads.
<div align="center"> <picture> <img src="assets/medusa_demo.gif" width="80%"> </picture> <br> <div align="center" width="80%"> <em>Medusa-1 on Vicuna-7b.</em> </div> <br> </div>We aim to tackle the three pain points of popular acceleration techniques like speculative decoding:
- Requirement of a good draft model.
- System complexity.
- Inefficiency when using sampling-based generation.
We aim to solve the challenges associated with speculative decoding by implementing the following ideas:
- Instead of introducing a new model, we train multiple decoding heads on the same model.
- The training is parameter-efficient so that even the "GPU-Poor" can do it. And since there is no additional model, there is no need to adjust the distributed computing setup.
- Relaxing the requirement of matching the distribution of the original model makes the non-greedy generation even faster than greedy decoding.
In the initial release, our primary focus is on optimizing Medusa for a batch size of 1—a setting commonly utilized for local model hosting. In this configuration, Medusa delivers approximately a 2x speed increase across a range of Vicuna models. We are actively working to extend Medusa's capabilities by integrating it into additional inference frameworks, with the aim of achieving even greater performance gains and extending Medusa to broader settings.
<p align="center"> <picture> <img src="assets/medusa_speedup_cmp.jpg" width="45%"> </picture> </p>In the updated version, we add support for full-model training, called Medusa-2 (compared to Medusa-1, which only trains the new heads), which requires a special recipe that adds the speculative prediction ability while keeping the original model's performance.
We also add support for self-distillation, which allows us to add Medusa to any fine-tuned LLM without requiring the availability of the original training data.
Contents
- Introduction
- Contents
- Installation
- Citation
- Codebase Guide
- Community Adoption
- Contributing
- Acknowledgements
Installation
Method 1: With pip (may not be the latest version)
pip install medusa-llm
Method 2: From the source (recommended)
git clone https://github.com/FasterDecoding/Medusa.git
cd Medusa
pip install -e .
Model Weights
Medusa-1
| Size | Chat Command | Hugging Face Repo |
| ---- | --------------------------------------------- | --------------------------------------------------------------------- |
| 7B | python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-7b-v1.3 | FasterDecoding/medusa-vicuna-7b-v1.3 |
| 13B | python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-13b-v1.3 | FasterDecoding/medusa-vicuna-13b-v1.3 |
| 33B | python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-33b-v1.3 | FasterDecoding/medusa-vicuna-33b-v1.3 |
Medusa-2
| Size | Chat Command | Hugging Face Repo |
| ---- | --------------------------------------------- | --------------------------------------------------------------------- |
| Zephyr-7B-Beta | python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-zephyr-7b-beta | FasterDecoding/medusa-1.0-zephyr-7b-beta |
| Vicuna-7B-v1.5 | python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-vicuna-7b-v1.5 | FasterDecoding/medusa-1.0-vicuna-7b-v1.5 |
| Vicuna-13B-v1.5 | python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-vicuna-13b-v1.5 | FasterDecoding/medusa-1.0-vicuna-13b-v1.5 |
| Vicuna-33B-v1.5 | python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-vicuna-33b-v1.5 | FasterDecoding/medusa-1.0-vicuna-33b-v1.5 |
Inference
We currently support single-GPU inference with a batch size of 1, which is the most common setup for local model hosting. We are actively working to extend Medusa's capabilities by integrating it into other inference frameworks; please don't hesitate to reach out if you are interested in contributing to this effort.
You can use the following command to launch a CLI interface:
CUDA_VISIBLE_DEVICES=0 python -m medusa.inference.cli --model [path of medusa model]
You can also pass --load-in-8bit or --load-in-4bit to load the base model in quantized format. If you download the base model elsewhere, you may override base model name or path with --base-model [path of base model].
Training
In the updated version, we use the amazing axolotl library to manage the training process. Please refer to our fork for the training code. The major code modifications are in src/axolotl/utils/models.py. The training configs can be found in examples/medusa. A typical training command is as follows:
accelerate launch -m axolotl.cli.train examples/medusa/your_config.yml
The data preparation code for self-distillation can be found in data_generation folder of the current repo. For other datasets, you can directly download the data from the corresponding Hugging Face dataset repo.
Training on various architectures
The following instructions are for the initial release of Medusa, it provides a minimal example of how to train a Medusa-1 model. For the updated version, please refer to the previous section.
For training, please install:
pip install -e ".[train]"
Prepare the data
We take a public version of the ShareGPT dataset, which is a subset of the Vicuna training data. For other models, you can use the corresponding training dataset.
git clone https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered
Remark: If you haven't installed git-lfs, please install it before cloning:
git lfs install
Adapt the data to the model you want to enable medusa on.
Start by launch an inference server you like that will run the model you want to train on. Let's use mistralai/Mistral-7B-Instruct-v0.2 as an example.
For instance you can use text-generation-inference, which you can also use after you've trained the medusa heads.
model=mistralai/Mistral-7B-Instruct-v0.2
volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --input-length 4000 --max-total-tokens 4096 --max-batch-prefill-tokens 4000
The sequences in shareGPT are relatively long for some, so make sure you can infer on those. If you do not have enough room, the script will simply ignore those long conversation. It shouldn't impact too much downstream performance, but more data is always better. You can use various tradeoffs to speed up inference but the defaults show be good enough in most cases.
python create_data.py --input-filename ShareGPT_Vicuna_unfilte
Related Skills
node-connect
342.5kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
85.3kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
342.5kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
342.5kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
