DisBack
The official implementation of Distribution Backtracking Distillation for One-step Diffusion Models
Install / Use
/learn @SYZhang0805/DisBackREADME
🔥DisBack (ICLR 2025)🔥
Distribution Backtracking Builds A Faster Convergence Trajectory for One-step Diffusion Distillation
by Shengyuan Zhang<sup>1</sup>, Ling Yang<sup>2</sup>, Zejian Li*<sup>1</sup>, An Zhao<sup>1</sup>, Chenye Meng<sup>1</sup>, Changyuan Yang<sup>3</sup>, Guang Yang<sup>3</sup>, Zhiyuan Yang<sup>3</sup>, Lingyun Sun<sup>1</sup>
<sup>1</sup>Zhejiang University <sup>2</sup>Peking University <sup>3</sup>Alibaba Group
</div>Abstract
Accelerating the sampling speed of diffusion models remains a significant challenge. Recent score distillation methods distill a heavy teacher model into an efficient student generator, which is optimized by calculating the difference in scores for the samples generated by the student model between the two score functions. However, there is a score mismatch issue in the early stage of the distillation process, because existing methods mainly focus on using the endpoint of pre-trained diffusion models as teacher models, overlooking the importance of the convergence trajectory between the one-step generator and the teacher model. To address this issue, we extend the score distillation process with the entire convergence trajectory of teacher models and propose \textbf{Dis}tribution \textbf{Back}tracking Distillation (\textbf{DisBack}) for distilling one-step generators. DisBask is composed of two stages: \textit{Degradation Recording} and \textit{Distribution Backtracking}. \textit{Degradation Recording} is designed for obtaining the convergence trajectory of teacher models, which obtains the degradation path from the trained teacher model to the untrained initial student. The degradation path implicitly represents the intermediate distributions of teacher models. Then \textit{Distribution Backtracking} trains a student generator to backtrack the intermediate distributions for approximating the convergence trajectory of teacher models. Extensive experiments show that the DisBack achieves faster and better convergence than the existing distillation method and accomplishes comparable generation performance. Notably, DisBack is easy to implement and can be generalized to existing distillation methods to boost performance.
The structure of DisBack

Samples of DisBack

Using DisBack
Environment setup
conda create -n disback python=3.8 -y
conda activate disback
pip install --upgrade anyio
pip install -r requirements.txt
python setup.py develop
Inference
The distilled SDXL model is already uploaded on HuggingFace
One-step text-to-image generation
import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
from huggingface_hub import hf_hub_download
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
repo_name = "SYZhang0805/DisBack"
ckpt_name = "SDXL_DisBack.bin"
unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name), map_location="cuda"))
pipe = DiffusionPipeline.from_pretrained(base_model_id, unet=unet, torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to("cuda")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
prompt="A photo of a dog."
image=pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[399], height=1024, width=1024).images[0]
image.save('output.png', 'PNG')
Training and Evaluation
In text-to-image scenario, DisBack is trained based on the DMD2. The pre-trained DMD2 model can be downloaded from here.
Download Base Diffusion Models and Training Data
export CHECKPOINT_PATH="" # change this to your own checkpoint folder (this should be a central directory shared across nodes)
export WANDB_ENTITY="" # change this to your own wandb entity
export WANDB_PROJECT="" # change this to your own wandb project
export MASTER_IP="" # change this to your own master ip
# Not sure why but we found the following line necessary to work with the accelerate package in our system.
# Change YOUR_MASTER_IP/YOUR_MASTER_NODE_NAME to the correct value
echo "YOUR_MASTER_IP YOUR_MASTER_NODE_NAME" | sudo tee -a /etc/hosts
# create a fsdp configs for accelerate launch. change the EXP_NAME to your own experiment name
python main/sdxl/create_sdxl_fsdp_configs.py --folder fsdp_configs/EXP_NAME --master_ip $MASTER_IP --num_machines 8 --sharding_strategy 4
mkdir $CHECKPOINT_PATH
mkdir $CHECKPOINT_PATH/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode_checkpoint_model_024000/
bash scripts/download_sdxl.sh $CHECKPOINT_PATH
Degradation recording stage
bash experiments/sdxl/degradation_sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode.sh
The degradation path is saved as follows:
path_directory/
│
├── checkpoint_model_024100/
│ └── pytorch_model_1.bin
├── checkpoint_model_024200/
│ └── pytorch_model_1.bin
...
Distribution backtracking stage
bash experiments/sdxl/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode.sh
Compatibility of other models
DisBack can be applied to the score distillation process using the following pseudocode.
# Degradation
s_theta = UNet() # Pre-trained Diffusion Model
s_theta_prime, G_stu = s_theta.clone(), s_theta.clone() # initialize generator and the beginning of the degradation path.
path_degradation = []
for idx in range(num_iter_1st_stage):
x_0 = one_step_sample(G_stu)
x_t, t, epsilon = addnoise(x_0)
ckpt = train_score_model(s_theta_prime, x_t, t, epsilon) # Training strategy depends on the type of pre-trained model used. Eq.(7) in the paper.
if idx // interval_1st == 0:
path_degradation.append(ckpt) # Add intermediate checkpoint to the degradation path.
else:
path_degradation.append(ckpt)
# Backtracking
path_backtracking = path_degradation[::-1] # The reverse of the degradation path is viewed as the convergence trajectory.
s_phi = path_backtracking[0].clone() # Use the first checkpoint of the convergence trajectory as the initial s_phi.
target = 1
for idx in range(num_iter_2nd_stage):
s_target = path_backtracking[target]
x_0 = one_step_sample(G_stu) # One step generation.
x_t, t, epsilon = addnoise(x_0)
x_t.bachward( s_phi(x_t,t) - s_target(x_t,t) ) # VSD loss. Eq.(8) in the paper.
update(G_stu) # Optimize G by gradient descent.
train_score_model(s_phi, x_t, t, epsilon) # Eq.(5) in the paper.
if idx // interval_2nd == 0 and idx>1: # Switch the target.
target += 1
Citation
If you find our paper useful or relevant to your research, please kindly cite our papers:
@article{zhang2024distributionbacktrackingbuildsfaster,
title={Distribution Backtracking Builds A Faster Convergence Trajectory for One-step Diffusion Distillation},
author={Shengyuan Zhang and Ling Yang and Zejian Li and An Zhao and Chenye Meng and Changyuan Yang and Guang Yang and Zhiyuan Yang and Lingyun Sun},
journal={arXiv 2408.15991},
year={2024}
}
Credits
DisBack is highly built on the following amazing open-source projects:
DMD2: Improved Distribution Matching Distillation for Fast Image Synthesis
Diff-Instruct: Diff-Instruct: A Universal Approach for Transferring Knowledge From Pre-trained Diffusion Models
ScoreGAN: Unifying GANs and Score-Based Diffusion as Generative Particle Models
Thanks to the maintainers of these projects for their contribution to this project!
