GDDIM
[ICLR'23 Spotlight] gDDIM: analyze and accelerate general diffusion models, isotropic and non-isotropic
Install / Use
/learn @qsh-zh/GDDIMREADME
<p align="center">gDDIM: Generalized denoising diffusion implicit models</p>
<div align="center"> <a href="https://qsh-zh.github.io/" target="_blank">Qinsheng Zhang</a>   <b>·</b>   <a href="https://mtao8.math.gatech.edu/" target="_blank">Molei Tao</a>   <b>·</b>   <a href="https://yongxin.ae.gatech.edu/" target="_blank">Yongxin Chen</a> <br> <br> <a href="https://arxiv.org/abs/2206.05564" target="_blank">Paper</a>   </div> <br><br>TLDR: We unbox the accelerating secret of DDIMs based on Dirac approximation and generalize it to general diffusion models, isotropic and non-isotropic.
<!-- When applied to the critically-damped Langevin diffusion model, it achieves an FID score of 2.26 on CIFAR10 with 50 steps. -->

Setup
The codebase is only tested in docker environment.
Docker
- Dockerfile lists necessary steps and packages to setup training / testing environments.
- We provide a Docker Image in DockerHub
Reproduce results
gDDIM on CLD
Training on cifar10
cd ${gDDIM_PROJECT_FOLDER}/cld_jax
wandb login ${WANDB_KEY}
python main.py --config configs/accr_dcifar10_config.py --mode train --workdir logs/accr_dcifar_nomixed --wandb --config.seed=8
- I have randomly try seed=
1,8,123. Andseed=8(checkpoint 15) gives the best FID while the lowest FIDs from other two are slightly high (around 2.30) in CIFAR10.
Eval on cifar10
-
Download CIFAR stats to
${gDDIM_PROJECT_FOLDER}/cld_jax/assets/stats/. -
We provide pretrain model checkpoint.
the checkpoint has 2.2565 FID in my machine with 50 NFE
- User can evaluate FID via
cd ${gDDIM_PROJECT_FOLDER}/cld_jax
python main.py --config configs/accr_dcifar10_config.py --mode check --result_folder logs/fid --ckpt ${CLD_BEST_PATH} --config.sampling.deis_order=2 --config.sampling.nfe=50
Blur diffusion model
Training on cifar10
cd ${gDDIM_PROJECT_FOLDER}/blur_jax
wandb login ${WANDB_KEY}
python main.py --config configs/ddpm_deep_cifar10_config.py --mode train --workdir logs/ddpm_deep_sigma${sigma}_seed${seed} --wandb --config.model.sigma_blur_max=${sigma} --config.seed=${seed}"
Eval on cifar10
-
Download CIFAR stats to
${gDDIM_PROJECT_FOLDER}/blur_jax/assets/stats/. -
We provide pretrain model checkpoint.
-
User can evaluate FID via
cd ${gDDIM_PROJECT_FOLDER}/blur_jax
python main.py --config configs/accr_dcifar10_config.py --mode check --result_folder logs/fid --ckpt ${CLD_BEST_PATH} --config.sampling.deis_order=2 --config.sampling.nfe=50
Reference
@misc{zhang2022gddim,
title={gDDIM: Generalized denoising diffusion implicit models},
author={Qinsheng Zhang and Molei Tao and Yongxin Chen},
year={2022},
eprint={2206.05564},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Related works
@inproceedings{song2020denoising,
title={Denoising diffusion implicit models},
author={Song, Jiaming and Meng, Chenlin and Ermon, Stefano},
booktitle={International Conference on Learning Representations (ICLR)},
year={2021}
}
@inproceedings{dockhorn2022score,
title={Score-Based Generative Modeling with Critically-Damped Langevin Diffusion},
author={Tim Dockhorn and Arash Vahdat and Karsten Kreis},
booktitle={International Conference on Learning Representations (ICLR)},
year={2022}
}
@article{hoogeboom2022blurring,
title={Blurring diffusion models},
author={Hoogeboom, Emiel and Salimans, Tim},
journal={arXiv preprint arXiv:2209.05557},
year={2022}
}
Miscellaneous
The project is built upon score-sde developed by Yang Song. Additionally, the sampling code has been adopted from DEIS.
Related Skills
node-connect
345.9kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
106.4kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
345.9kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
345.9kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
