SkillAgentSearch skills...

Whisfusion

Whisfusion: Parallel ASR Decoding via a Diffusion Transformer

Install / Use

/learn @taeyoun811/Whisfusion
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

Whisfusion: Parallel ASR Decoding via a Diffusion Transformer

arXiv Hugging Face

Official implementation of Whisfusion - the first Diffusion Transformer ASR framework that fuses a Whisper encoder with a diffusion decoder for faster, non-autoregressive transcription.

<p align="center"> <img src="assets/inference.gif" width="80%"> </p>

🏗️ Architecture

Whisfusion combines three key components:

  • Whisper-small Encoder (88.2M params): Pre-trained acoustic feature extractor from OpenAI
  • SMDM-170M Decoder: Masked diffusion model for parallel text generation
  • Cross-Attention Adapter (42.5M params): Lightweight bridge between modalities

🎯 Key Features

  • Parallel decoding architecture with 14× higher throughput than autoregressive models (3,180 vs 230 tokens/s)
  • Superior accuracy - lower WER than Whisper-tiny (8.3% vs 9.7%) with comparable latency
  • Scalable inference - constant time regardless of sequence length (up to 2.6× faster on long audio >20s)

🚀 Quick Start

Installation

# Clone the repository
git clone https://github.com/taeyoun811/Whisfusion.git
cd Whisfusion

# Install PyTorch (CUDA 12.1)
pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121

# Install FlashAttention (required for efficient attention)
git clone --recurse-submodules --branch v2.6.3 https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
pip install .
cd csrc/rotary && pip install .
cd ../layer_norm && pip install .
cd ../xentropy && pip install .
cd ../../.. && rm -rf flash-attention

# Install other dependencies
pip install -r requirements.txt

Download Pre-trained Models

from huggingface_hub import hf_hub_download

# Download Stage 2 model (full model with decoder fine-tuning)
hf_hub_download(
    repo_id="taeyoun811/whisfusion",
    filename="whisfusion_stage2_decoder.pt",
    local_dir="out/",
)

# Download Stage 1 adapter (optional, adapter-only checkpoint)
hf_hub_download(
    repo_id="taeyoun811/whisfusion",
    filename="whisfusion_stage1_adapter.pt",
    local_dir="out/",
)

Then use the downloaded models with the evaluation scripts in src/evaluation/.

💻 Training Requirements

  • GPUs: 4× NVIDIA A100 (40GB)
  • Storage: ~700GB for full LibriSpeech 960h dataset and preprocessed features

📚 Training & Evaluation

Data Preparation

# 1. Download LibriSpeech (~60GB)
bash scripts/00_download_librispeech.sh

# 2. Preprocess audio to Whisper features
bash scripts/01_preprocess_librispeech.sh

# 3. Download pre-trained SMDM weights
bash scripts/02_download_pretrained_model.sh

Two-Stage Training

# Stage 1: Train adapter only (frozen encoder/decoder)
bash scripts/03_train_stage1_adapter.sh

# Stage 2: Fine-tune full decoder (update --pretrain_path to Stage 1 checkpoint)
bash scripts/04_train_stage2_decoder_high_ratio.sh

Evaluation

# Evaluate Whisfusion (update --adapter_path to Stage 2 checkpoint)
bash scripts/05_evaluate_whisfusion.sh

# Compare with Whisper baselines
bash scripts/06_evaluate_whisper.sh

Results will be saved in evaluation_results/ directory as JSON files.

📁 Project Structure

Whisfusion/
├── scripts/                   # Training and evaluation scripts
├── src/
│   ├── data/                  # Data preprocessing and dataset
│   ├── training/              # Training scripts for both stages
│   ├── evaluation/            # Evaluation and inference code
│   └── lit_gpt/               # Model architecture
├── data/                      # Dataset storage
│   ├── raw/                   # Original LibriSpeech files
│   └── processed/             # Preprocessed Whisper features
├── out/                       # Training outputs and downloaded models
└── pretrained_models/         # Pre-trained (SMDM-170M) checkpoints

📖 Citation

If you find this work useful, please cite:

@article{kwon2025whisfusion,
  title={Whisfusion: Parallel ASR Decoding via a Diffusion Transformer},
  author={Kwon, Taeyoun and Ahn, Junhyuk and Yun, Taegeun and Jwa, Heeju and Choi, Yoonchae and Park, Siwon and Kim, Nam-Joon and Kim, Jangchan and Ryu, Hyun Gon and Lee, Hyuk-Jae},
  journal={arXiv preprint arXiv:2508.07048},
  year={2025}
}

🙏 Acknowledgments

This project builds upon the following excellent works:

SMDM

Whisper

  • Repository: openai/whisper
  • We use the pre-trained Whisper encoder for acoustic feature extraction

📧 Contact

For questions, discussions, or contributions:

View on GitHub
GitHub Stars23
CategoryDevelopment
Updated4d ago
Forks3

Languages

Python

Security Score

95/100

Audited on Apr 4, 2026

No findings