Bd3lms
[ICLR 2025 Oral] Block Diffusion: Interpolating Between Autoregressive and Diffusion Language Models
Install / Use
/learn @kuleshov-group/Bd3lmsREADME
Block Diffusion: Interpolating Between Autoregressive and Diffusion Language Models (ICLR 2025 Oral)
By Marianne Arriola, Aaron Gokaslan, Justin T Chiu, Zhihan Yang, Zhixuan Qi, Jiaqi Han, Subham Sekhar Sahoo, Volodymyr Kuleshov
<!-- [](https://colab.research.google.com/drive/18nC6q7dWq154fI1BXPLwmtnS7Zvbrv6p?usp=sharing/) -->
We introduce BD3-LMs, a family of Block Discrete Denoising Diffusion Language Models that achieve SOTA likelihoods among diffusion models and enable generation of arbitrary-length sequences. BD3-LMs combine the strengths of autoregressive and diffusion language models by decomposing a token sequence into blocks and performing discrete diffusion within each block. By tuning the block size, we interpolate between autoregressive and diffusion models which introduces a trade-off between quality and sample efficiency. We propose a recipe for building effective BD3-LMs that includes an efficient training algorithm, estimators of gradient variance, and data-driven noise schedules to minimize the variance.
<!-- We provide a demo in this [](https://colab.research.google.com/drive/18nC6q7dWq154fI1BXPLwmtnS7Zvbrv6p?usp=sharing/) notebook. -->In this repo, we provide:
- The BD3-LM framework
- Block-autoregressive likelihood parameterization
- Data-driven noise schedules to reduce training variance
- Arbitrary-length discrete diffusion samplers
- Baseline implementations
<a name="code-organization"></a>
Code Organization
main.py: Routines for training and evaluationnoise_schedule.py: Noise schedulesdiffusion.py: Forward/reverse diffusiondataloader.py: Dataloadersutils.py: LR scheduler, logging,fsspechandlingmodels/: Network architectures. Supports DiT and AR transformerconfigs/: Config files for datasets/models/noise schedules/LR schedulesscripts/: Shell scripts for training/evaluationtrain/: Training scripts (LM1B, OWT)ppl/: Likelihood evaluation on the pretraining set (LM1B, OWT)zs_ppl/: Zero-shot likelihood evaluation on GPT2 benchmark datasetsgen_ppl/: Sample quality (generative perplexity under GPT2)var_len/: Arbitrary-length sequence generation
ssd-lm/: SSD-LM codebaserun_generate_text_batch.sh: Generates SSD-LM samplesreport_genppl.py: Reports generative perplexity of SSD-LM samples
<a name="getting_started"></a>
Getting Started
To get started, create a conda environment containing the required dependencies.
conda create --name bd3lm python=3.9
conda activate bd3lm
pip install -r requirements.txt
While BD3-LMs don't require FlashAttention, evaluating baselines from MDLM require flash-attn==2.5.6
Create the following directories to store saved models and slurm logs:
mkdir outputs watch_folder logs sample_logs
and run the training as a batch job:
sbatch scripts/train/train_owt_bd3lm.sh
Checkpoints
We have uploaded BD3-LMs trained on OpenWebText using block sizes 4, 8, 16 for 1M training steps to HuggingFace 🤗: kuleshov-group/bd3-lms BD3-LMs are finetuned from an MDLM checkpoint trained on OpenWebText for 850K gradient updates. We release the pretraining checkpoint on HuggingFace: kuleshov-group/bd3lm-owt-block_size1024-pretrain
The MDLM baseline is also found on the HuggingFace: kuleshov-group/mdlm-owt. The AR and SEDD baselines trained on OpenWebText in this Google Drive folder.
For arbitrary-length sequence generation, we compare with AR, SEDD, and MDLM (supported as an inference-only technique and does not feature a training objective), and SSD-LM. In order to generate sequences longer than the training context size (fixed at 1024 tokens for OWT), we retrained AR and MDLM from Sahoo et. al without artificially injecting BOS/EOS tokens in the context. We also provide these checkpoints on HuggingFace: kuleshov-group/mdlm-noeos-owt, kuleshov-group/sedd-noeos-owt, kuleshov-group/ar-noeos-owt.
Reproducing Experiments
Below, we describe the steps required for reproducing the experiments in the paper.
Throughout, the main entry point for running experiments is the main.py script.
We also provide sample slurm scripts for launching pre-training and downstream fine-tuning experiments in the scripts/ directory.
Generate Arbitrary-Length Sequences
To generate arbitrary-length sequences, set mode=sample_eval. Example scripts are provided in scripts/var_len/var_len*.sh. Here's an example script using BD3-LM:
HuggingFace model
BLOCK_SIZE=4 # 4, 8, 16
LENGTH=2048 # arbitrary; needs to be a multiple of the block size
python -u main.py \
loader.eval_batch_size=1 \
model=small \
algo=bd3lm \
algo.T=5000 \
algo.backbone=hf_dit \
data=openwebtext-split \
model.length=$LENGTH \
block_size=$BLOCK_SIZE \
wandb=null \
mode=sample_eval \
eval.checkpoint_path=kuleshov-group/bd3lm-owt-block_size${BLOCK_SIZE} \
model.attn_backend=sdpa \
sampling.nucleus_p=0.9 \
sampling.kv_cache=true \
sampling.logdir=$PWD/sample_logs/samples_genlen_bd3lm_blocksize${BLOCK_SIZE}
Local checkpoint
BLOCK_SIZE=4 # 4, 8, 16
LENGTH=2048 # arbitrary; needs to be a multiple of the block size
python -u main.py \
loader.eval_batch_size=1 \
model=small \
algo=bd3lm \
algo.T=5000 \
data=openwebtext-split \
model.length=$LENGTH \
block_size=$BLOCK_SIZE \
wandb=null \
mode=sample_eval \
eval.checkpoint_path=/path/to/checkpoint/bd3lm-owt-block_size${BLOCK_SIZE} \
model.attn_backend=sdpa \
sampling.nucleus_p=0.9 \
sampling.kv_cache=true \
sampling.logdir=$PWD/sample_logs/samples_genlen_bd3lm_blocksize${BLOCK_SIZE}
Likelihood Evaluation
To compute test perplexity, use mode=ppl_eval. Example scripts are provided in scripts/ppl/eval_owt_*.sh. Here's an example evaluation script on OpenWebText:
BLOCK_SIZE=4 # 4, 8, 16
python -u main.py \
loader.eval_batch_size=16 \
model=small \
algo=bd3lm \
algo.backbone=hf_dit \
data=openwebtext-split \
data.insert_valid_special=False \
model.length=1024 \
model.attn_backend=flex \
block_size=${BLOCK_SIZE} \
eval.checkpoint_path=kuleshov-group/bd3lm-owt-block_size${BLOCK_SIZE} \
wandb=null \
mode=ppl_eval > logs/bd3lm_owt_block_size${BLOCK_SIZE}.log
Training Pipeline
To train BD3-LMs, use mode=train (default mode). Example scripts are provided in scripts/train/train_owt*.sh. Here's an example training script on OpenWebText:
BLOCK_SIZE=4 # we recommend 4, 8, or 16. must be a factor of the context length
PRETRAIN_CKPT=kuleshov-group/bd3lm-owt-block_size1024-pretrain # to train from scratch, set to null
python -u main.py \
loader.global_batch_size=512 \
loader.eval_global_batch_size=512 \
loader.batch_size=16 \
loader.eval_batch_size=16 \
model=small \
algo=bd3lm \
algo.clip_search_widths=[0.5,0.6,0.7,0.8,0.9] \
data=openwebtext-split \
model.length=1024 \
block_size=$BLOCK_SIZE \
wandb.name=bd3lm-owt-block_size${BLOCK_SIZE} \
mode=train \
model.attn_backend=flex \
training.resample=True \
training.from_pretrained=$PRETRAIN_CKPT
The arguments loader.batch_size and loader.eval_batch_size allow you to control the batch size per GPU. If loader.batch_size * num_gpus is less than the global_batch_size, PyTorch Lightning will resort to gradient accumulation. You can also launch a training job on Slurm using the command: sbatch scripts/train/train_owt_bd3lm.sh.
Acknowledgements
This repository was built off of MDLM and SEDD.
Citation
@inproceedings{
arriola2025block,
title={Block Diffusion: Interpolating Between Autoregressive and Diffusion Language Models},
author={Marianne Arriola and Aaron Gokaslan and Justin T Chiu and Zhihan Yang and Zhixuan Qi and Jiaqi Han and Subham Sekhar Sahoo and Volodymyr Kuleshov},
booktitle={The Thirteenth International Conference on Learning Representation
Related Skills
node-connect
348.0kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
108.8kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
348.0kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
348.0kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
