SkillAgentSearch skills...

Ddae

[ICCV 2023 Oral] Official Implementation of "Denoising Diffusion Autoencoders are Unified Self-supervised Learners"

Install / Use

/learn @FutureXiang/Ddae

README

🆕 [2025] Please check out the more recent study DDAE++ continuing this line of work.

Denoising Diffusion Autoencoders (DDAE)

<p align="center"> <img src="https://github.com/FutureXiang/ddae/assets/33350017/b0825947-e58f-4c5e-b672-ec59465ac14d" width="480"> </p>

This is a multi-gpu PyTorch implementation of the paper Denoising Diffusion Autoencoders are Unified Self-supervised Learners:

@inproceedings{ddae2023,
  title={Denoising Diffusion Autoencoders are Unified Self-supervised Learners},
  author={Xiang, Weilai and Yang, Hongyu and Huang, Di and Wang, Yunhong},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  year={2023}
}

:star: (News) Our paper is cited by Kaiming He's new paper Deconstructing Denoising Diffusion Models for Self-Supervised Learning, check it out! :fire:

Overview

This repo contains:

  • [x] Pre-training, sampling and FID evaluation code for diffusion models, including
    • Frameworks:
      • [x] DDPM & DDIM
      • [x] EDM (w/ or w/o data augmentation)
    • Networks:
      • [x] The basic 35.7M DDPM UNet
      • [x] A larger 56M DDPM++ UNet
    • Datasets:
      • [x] CIFAR-10
      • [ ] Tiny-ImageNet
  • [x] Feature quality evaluation code, including
    • [x] Linear probing and grid searching
    • [x] Contrastive metrics, i.e., alignment and uniformity
    • [ ] Fine-tuning
  • [x] Noise-conditional classifier training and evaluation, including
    • [x] MLP classifier based on DDPM/EDM features
    • [x] WideResNet with VP/VE perturbation
  • [x] Evaluation code for ImageNet-256 pre-trained DiT-XL/2 checkpoint

Requirements

  • In addition to PyTorch environments, please install:
    conda install pyyaml
    pip install pytorch-fid ema-pytorch
    
  • We use 4 or 8 3080ti GPUs to conduct all the experiments presented in the paper. With automatic mixed precision enabled and 4 GPUs, training a basic 35.7M UNet on CIFAR-10 takes ~14 hours.
  • The pytorch-fid requires image files to calculate the FID metric. Please refer to extract_cifar10_pngs.ipynb to unpack the CIFAR-10 training dataset into 50000 .png image files.

Main results

We present the generative and discriminative evaluation results that can be obtained by this codebase. The EDM_ddpmpp_aug.yaml training is performed on 8 GPUs, while other models are trained on 4 GPUs.

Please note that this is a over-simplified DDPM / EDM implementation, and some network details, initialization, and hyper-parameters may differ from official ones. Please refer to their respective official codebases to reproduce the exact results reported in the paper.

<table class="tg"> <thead> <tr> <th class="tg-uzvj" rowspan="2">Config</th> <th class="tg-uzvj" rowspan="2">Model</th> <th class="tg-uzvj" rowspan="2">Network</th> <th class="tg-7btt" colspan="3">Best linear probe checkpoint</th> <th class="tg-amwm" colspan="3">Best FID checkpoint</th> </tr> <tr> <th class="tg-7btt">epoch</th> <th class="tg-7btt">FID</th> <th class="tg-7btt">acc</th> <th class="tg-amwm">epoch</th> <th class="tg-amwm">FID</th> <th class="tg-amwm">acc</th> </tr> </thead> <tbody> <tr> <td class="tg-0pky">DDPM_ddpm.yaml</td> <td class="tg-0pky">DDPM</td> <td class="tg-0pky">35.7M UNet</td> <td class="tg-0pky">800</td> <td class="tg-0pky">4.09</td> <td class="tg-0pky">90.05</td> <td class="tg-0lax">1999</td> <td class="tg-0lax">3.62</td> <td class="tg-0lax">88.23</td> </tr> <tr> <td class="tg-0pky">EDM_ddpm.yaml</td> <td class="tg-0pky">EDM</td> <td class="tg-0pky">35.7M UNet</td> <td class="tg-0pky">1200</td> <td class="tg-0pky">3.97</td> <td class="tg-0pky">90.44</td> <td class="tg-0lax">1999</td> <td class="tg-0lax">3.56</td> <td class="tg-0lax">89.71</td> </tr> <tr> <td class="tg-0lax">DDPM_ddpmpp.yaml</td> <td class="tg-0lax">DDPM</td> <td class="tg-0lax">56.5M DDPM++</td> <td class="tg-0lax">1200</td> <td class="tg-0lax">3.08</td> <td class="tg-0lax">93.97</td> <td class="tg-0lax">1999</td> <td class="tg-0lax">2.98</td> <td class="tg-0lax">93.03</td> </tr> <tr> <td class="tg-0lax">EDM_ddpmpp.yaml</td> <td class="tg-0lax">EDM</td> <td class="tg-0lax">56.5M DDPM++</td> <td class="tg-0lax">1200</td> <td class="tg-0lax">2.23</td> <td class="tg-0lax">94.50</td> <td class="tg-baqh" colspan="3">(same)</td> </tr> <tr> <td class="tg-0lax">EDM_ddpmpp_aug.yaml</td> <td class="tg-0lax">EDM + data aug</td> <td class="tg-0lax">56.5M DDPM++</td> <td class="tg-0lax">2000</td> <td class="tg-0lax">2.34</td> <td class="tg-1wig">95.49</td> <td class="tg-0lax">3200</td> <td class="tg-1wig">2.12</td> <td class="tg-0lax">95.19</td> </tr> </tbody> </table>

FIDs are calculated using 50000 images generated by the deterministic fast sampler (DDIM 100 steps or EDM 18 steps).

Latent-space DiT

We evaluate pre-trained Transformer-based diffusion networks, DiT, from the perspective of transfer learning. Please refer to the ddae/DiT subfolder.

Usage

Diffusion pre-training

To train a DDAE model and generate 50000 image samples with 4 GPUs, for example, run:

python -m torch.distributed.launch --nproc_per_node=4
  # diffusion pre-training with AMP enabled
  train.py       --config config/DDPM_ddpm.yaml --use_amp
  
  # deterministic fast sampling (i.e. DDIM 100 steps / EDM 18 steps)
  sample.py      --config config/DDPM_ddpm.yaml --use_amp --epoch 400

  # stochastic sampling (i.e. DDPM 1000 steps)
  sample.py      --config config/DDPM_ddpm.yaml --use_amp --epoch 400 --mode DDPM

To calculate the FID metric on the training set, for example, run:

python -m pytorch_fid   data/cifar10-pngs/  output_DDPM_ddpm/EMAgenerated_ep400_ddim_steps100_eta0.0/pngs/

Features produced by DDAE

To evaluate the features produced by pre-trained DDAE, for example, run:

python -m torch.distributed.launch --nproc_per_node=4
  # grid searching for proper layer-noise combination
  linear.py      --config config/DDPM_ddpm.yaml --use_amp --epoch 400 --grid

  # linear probing, using the layer-noise combination specified by config.yaml
  linear.py      --config config/DDPM_ddpm.yaml --use_amp --epoch 400

  # showing the alignment-uniformity metrics with respect to different checkpoints
  contrastive.py --config config/DDPM_ddpm.yaml --use_amp

Noise-conditional classifier

To train WideResNet-based classifiers from scratch:

python -m torch.distributed.launch --nproc_per_node=4
  # VP (DDPM) perturbation
  noisy_classifier_WRN.py --mode DDPM
  # VE (EDM) perturbation
  noisy_classifier_WRN.py --mode EDM

and compare their noise-conditional recognition rates with DDAE-based MLP classifier heads:

python -m torch.distributed.launch --nproc_per_node=4
  # using DDPM DDAE encoder
  noisy_classifier_DDAE.py --config config/DDPM_ddpm.yaml  --use_amp --epoch 1999
  # using EDM DDAE encoder
  noisy_classifier_DDAE.py --config config/EDM_ddpmpp.yaml --use_amp --epoch 1200

Acknowledgments

This repository is built on numerous open-source codebases such as DDPM, DDPM-pytorch, DDIM, EDM, Score-based SDE, DiT, and align_uniform.

View on GitHub
GitHub Stars183
CategoryEducation
Updated16d ago
Forks8

Languages

Python

Security Score

85/100

Audited on Mar 8, 2026

No findings