SkillAgentSearch skills...

DiGIT

[NeurIPS 2024] Stabilize the Latent Space for Image Autoregressive Modeling: A Unified Perspective

Install / Use

/learn @DAMO-NLP-SG/DiGIT
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

<h1 align="center"> Stabilize the Latent Space for Image Autoregressive Modeling: A Unified Perspective (NeurIPS 2024) </h1> <div align="center">

arXiv  benchmark

</div>

FID_IS

Overview

The overview of DiGIT

We present DiGIT, an auto-regressive generative model performing next-token prediction in an abstract latent space derived from self-supervised learning (SSL) models. By employing K-Means clustering on the hidden states of the DINOv2 model, we effectively create a novel discrete tokenizer. This method significantly boosts image generation performance on ImageNet dataset, achieving an FID score of 4.59 for class-unconditional tasks and 3.39 for class-conditional tasks. Additionally, the model enhances image understanding, achieving a linear-probe accuracy of 80.3.

Experimental Results

Linear-Probe Accuracy on ImageNet

| Methods | # Tokens | Features | # Params | Top-1 Acc. $\uparrow$ | |-----------------------------------|-------------|----------|------------|-----------------------| | iGPT-L | 32 $\times$ 32 | 1536 | 1362M | 60.3 | | iGPT-XL | 64 $\times$ 64 | 3072 | 6801M | 68.7 | | VIM+VQGAN | 32 $\times$ 32 | 1024 | 650M | 61.8 | | VIM+dVAE | 32 $\times$ 32 | 1024 | 650M | 63.8 | | VIM+ViT-VQGAN | 32 $\times$ 32 | 1024 | 650M | 65.1 | | VIM+ViT-VQGAN | 32 $\times$ 32 | 2048 | 1697M | 73.2 | | AIM | 16 $\times$ 16 | 1536 | 0.6B | 70.5 | | DiGIT (Ours) | 16 $\times$ 16 | 1024 | 219M | 71.7 | | DiGIT (Ours) | 16 $\times$ 16 | 1536 | 732M | 80.3 |

Class-Unconditional Image Generation on ImageNet (Resolution: 256 $\times$ 256)

| Type | Methods | # Param | # Epoch | FID $\downarrow$ | IS $\uparrow$ | |-------|-------------------------------------|----------|----------|------------------|----------------| | GAN | BigGAN | 70M | - | 38.6 | 24.70 | | Diff. | LDM | 395M | - | 39.1 | 22.83 | | Diff. | ADM | 554M | - | 26.2 | 39.70 | | MIM | MAGE | 200M | 1600 | 11.1 | 81.17 | | MIM | MAGE | 463M | 1600 | 9.10 | 105.1 | | MIM | MaskGIT | 227M | 300 | 20.7 | 42.08 | | MIM | DiGIT (+MaskGIT) | 219M | 200 | 9.04 | 75.04 | | AR | VQGAN | 214M | 200 | 24.38 | 30.93 | | AR | DiGIT (+VQGAN) | 219M | 400 | 9.13 | 73.85 | | AR | DiGIT (+VQGAN) | 732M | 200 | 4.59 | 141.29 |

Class-Conditional Image Generation on ImageNet (Resolution: 256 $\times$ 256)

| Type | Methods | # Param | # Epoch | FID $\downarrow$ | IS $\uparrow$ | |-------|----------------------|----------|----------|------------------|----------------| | GAN | BigGAN | 160M | - | 6.95 | 198.2 | | Diff. | ADM | 554M | - | 10.94 | 101.0 | | Diff. | LDM-4 | 400M | - | 10.56 | 103.5 | | Diff. | DiT-XL/2 | 675M | - | 9.62 | 121.50 | | Diff. | L-DiT-7B | 7B | - | 6.09 | 153.32 | | MIM | CQR-Trans | 371M | 300 | 5.45 | 172.6 | | MIM+AR | VAR | 310M | 200 | 4.64 | - | | MIM+AR | VAR | 310M | 200 | 3.60* | 257.5* | | MIM+AR | VAR | 600M | 250 | 2.95* | 306.1* | | MIM | MAGVIT-v2 | 307M | 1080 | 3.65 | 200.5 | | AR | VQVAE-2 | 13.5B | - | 31.11 | 45 | | AR | RQ-Trans | 480M | - | 15.72 | 86.8 | | AR | RQ-Trans | 3.8B | - | 7.55 | 134.0 | | AR | ViTVQGAN | 650M | 360 | 11.20 | 97.2 | | AR | ViTVQGAN | 1.7B | 360 | 5.3 | 149.9 | | MIM | MaskGIT | 227M | 300 | 6.18 | 182.1 | | MIM | DiGIT (+MaskGIT) | 219M | 200 | 4.62 | 146.19 | | AR | VQGAN | 227M | 300 | 18.65 | 80.4 | | AR | DiGIT (+VQGAN) | 219M | 400 | 4.79 | 142.87 | | AR | DiGIT (+VQGAN) | 732M | 200 | 3.39 | 205.96 |

*: VAR is trained with classifier-free guidance while all the other models are not.

Checkpoints

The K-Means npy file and model checkpoints can be downloaded from:

| Model | Link |
|:----------:|:-----:| | HF weights🤗 | Huggingface |

For the base model we use DINOv2-base and DINOv2-large for large size model. The VQGAN we use is the same as MAGE.

DiGIT
└── data/
    ├── ILSVRC2012
        ├── dinov2_base_short_224_l3
            ├── km_8k.npy
        ├── dinov2_large_short_224_l3
            ├── km_16k.npy
└── outputs/
    ├── base_8k_stage1
    ├── ...
└── models/
    ├── vqgan_jax_strongaug.ckpt
    ├── dinov2_vitb14_reg4_pretrain.pth
    ├── dinov2_vitl14_reg4_pretrain.pth

Preparation

Installation

  1. Download the code
git clone https://github.com/DAMO-NLP-SG/DiGIT.git
cd DiGIT
  1. Install fairseq via pip install fairseq.

Dataset Preparation

Download ImageNet dataset, and place it in your dataset dir $PATH_TO_YOUR_WORKSPACE/dataset/ILSVRC2012.

Tokenizer

Extract SSL features and save them as .npy files. Use the K-Means algorithm with faiss to compute the centroids. You can also utilize our pre-trained centroids available on Huggingface.

bash preprocess/run.sh

Training Scripts

Step1

Train a GPT model with a discriminative tokenizer. You can find the training scripts in scripts/train_stage1_ar.sh and the hyper-params are in config/stage1/dino_base.yaml. For class conditional generation configuration, see scripts/train_stage1_classcond.sh.

Step2

Train a pixel decoder (either AR model or NAR model) conditioned on the discriminative tokens. You can find the autoregressive training scripts in scripts/train_stage2_ar.sh and NAR training scripts in scripts/train_stage2_nar.sh.

A folder named outputs/EXP_NAME/checkpoints will be created to save the checkpoints. TensorBoard log files are saved at outputs/EXP_NAME/tb. Logs will be recorded in outputs/EXP_NAME/train.log.

You can monitor the training process using tensorboard --logdir=outputs/EXP_NAME/tb.

Sampling Scripts

First sampling discriminative tokens with scripts/infer_stage1_ar.sh. For the base model size, we recommend setting topk=200, and for a large model size, use topk=400.

Then run scripts/infer_stage2_ar.sh to sample VQ tokens based on the previously sampled discriminative tokens.

Generated tokens and synthesized images will be stored in a directory named outputs/EXP_NAME/results.

FID and IS evaluation

Prepare the ImageNet validation set for FID evaluation:

python prepare_imgnet_val.py --data_path $PATH_TO_YOUR_WORKSPACE/dataset/ILSVRC2012 --output_dir imagenet-val

Install the evaluation tool by run

View on GitHub
GitHub Stars79
CategoryProduct
Updated1mo ago
Forks1

Languages

Python

Security Score

100/100

Audited on Jan 31, 2026

No findings