LSGM
The Official PyTorch Implementation of "LSGM: Score-based Generative Modeling in Latent Space" (NeurIPS 2021)
Install / Use
/learn @NVlabs/LSGMREADME
The Official PyTorch Implementation of "LSGM: Score-based Generative Modeling in Latent Space" (NeurIPS 2021)
<div align="center"> <a href="http://latentspace.cc/arash_vahdat/" target="_blank">Arash Vahdat*</a>   <b>·</b>   <a href="https://karstenkreis.github.io/" target="_blank">Karsten Kreis*</a>   <b>·</b>   <a href="http://jankautz.com/" target="_blank">Jan Kautz</a> <br> <br> (*equal contribution) <br> <br> <a href="https://nvlabs.github.io/LSGM/" target="_blank">Project Page</a> </div> <br> <br>LSGM trains a score-based generative model (a.k.a. a denoising diffusion model) in the latent space of a variational autoencoder. It currently achieves state-of-the-art generative performance on several image datasets.
<p align="center"> <img src="img/LSGM.png" width="1200"> </p>Requirements
LSGM is built in Python 3.8 using PyTorch 1.8.0. Please use the following command to install the requirements:
pip install -r requirements.txt
Optionally, you can also install NVIDIA Apex. When apex is installed, our training scripts use the Adam optimizer from this library, which is faster than Pytorch's native Adam.
Set up file paths and data
This work builds on top of our previous work NVAE. Please follow the instructions in the
NVAE repository to prepare your data for training and evaluation. Small datasets such as CIFAR-10, MNIST, and OMNIGLOT
do not require any data preparation as they will be downloaded automatically. Below, $DATA_DIR indicates
the path to a data directory that will contain all the datasets.
$CHECKPOINT_DIR is a directory used for storing checkpoints, and $EXPR_ID is a unique ID for the experiment.
$IP_ADDR is the IP address of the machine that will host the process with rank 0 during training
(see here).
$NODE_RANK is the index of each node among all the nodes that are running the job
(setting $IP_ADDR and $NODE_RANK is only required for multi-node training).
$FID_STATS_DIR is a directory containing the FID statistics computed on each dataset (see below).
Precomputing feature statistics on each dataset for FID evaluation
You can use the following command to compute FID statistics on the CIFAR-10 dataset as an example:
python scripts/precompute_fid_statistics.py --data $DATA_DIR/cifar10 --dataset cifar10 --fid_dir $FID_STATS_DIR
which will save the FID related statistics in a directory under $FID_STATS_DIR. For other datasets, simply change
--data and --dataset accordingly.
Training and evaluation
Training LSGM is often done in two stages. In the first stage, we train our VAE backbone assuming that the prior is
a standard Normal distribution. In the second stage, we swap the standard Normal prior with a score-based prior and
we jointly train both the VAE backbone and the score-based prior in an end-to-end fashion. Please check Appendix G in
our paper for implementation details. Below, we provide commands used for both stages. If for any reason your training
is stopped, use the exact same commend with the addition of --cont_training to continue training from the last
saved checkpoint. If you observe NaN, continuing the training using this flag will usually not fix the NaN issue.
We train 3 different VAEs with the following commands (see Table 7 in the paper).
- 20 group NVAE with full KL annealing for the "balanced" model (using 8 16GB V100 GPUs):
python train_vae.py --data $DATA_DIR/cifar10 --root $CHECKPOINT_DIR --save $EXPR_ID/vae1 --dataset cifar10 \
--num_channels_enc 128 --num_channels_dec 128 --num_postprocess_cells 2 --num_preprocess_cells 2 \
--num_latent_scales 1 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 --num_preprocess_blocks 1 \
--num_postprocess_blocks 1 --num_latent_per_group 9 --num_groups_per_scale 20 --epochs 600 --batch_size 32 \
--weight_decay_norm 1e-2 --num_nf 0 --kl_anneal_portion 0.5 --kl_max_coeff 1.0 --channel_mult 1 2 --seed 1 \
--arch_instance res_bnswish --num_process_per_node 8 --use_se
- 20 group NVAE with partial KL annealing for the model with best FID (using 8 16GB V100 GPUs):
python train_vae.py --data $DATA_DIR/cifar10 --root $CHECKPOINT_DIR --save $EXPR_ID/vae2 --dataset cifar10 \
--num_channels_enc 128 --num_channels_dec 128 --num_postprocess_cells 2 --num_preprocess_cells 2 \
--num_latent_scales 1 --num_cell_per_cond_enc 2 --num_cell_per_cond_dec 2 --num_preprocess_blocks 1 \
--num_postprocess_blocks 1 --num_latent_per_group 9 --num_groups_per_scale 20 --epochs 400 --batch_size 32 \
--weight_decay_norm 1e-2 --num_nf 0 --kl_anneal_portion 1.0 --kl_max_coeff 0.7 --channel_mult 1 2 --seed 1 \
--arch_instance res_bnswish --num_process_per_node 8 --use_se
- 4 group NVAE with partial KL annealing for the model with best NLL (using 4 16GB V100 GPUs):
python train_vae.py --data $DATA_DIR/cifar10 --root $CHECKPOINT_DIR --save $EXPR_ID/vae3 --dataset cifar10 \
--num_channels_enc 256 --num_channels_dec 256 --num_postprocess_cells 3 --num_preprocess_cells 3 \
--num_latent_scales 1 --num_cell_per_cond_enc 3 --num_cell_per_cond_dec 3 --num_preprocess_blocks 1 \
--num_postprocess_blocks 1 --num_latent_per_group 45 --num_groups_per_scale 4 --epochs 400 --batch_size 64 \
--weight_decay_norm 1e-2 --num_nf 2 --kl_anneal_portion 1.0 --kl_max_coeff 0.7 --channel_mult 1 2 --seed 1 \
--arch_instance res_bnswish --num_process_per_node 4 --use_se
With the resulting VAE checkpoints, we can train the three different LSGMs. The models are trained with the following commands on 2 nodes with 8 32GB V100 GPUs each.
- LSGM (balanced):
mpirun --allow-run-as-root -np 2 -npernode 1 bash -c
'python train_vada.py --fid_dir $FID_STATS_DIR --data $DATA_DIR/cifar10 --root $CHECKPOINT_DIR \
--save $EXPR_ID/lsgm1 --vae_checkpoint $EXPR_ID/vae1/checkpoint.pt --train_vae --custom_conv_dae --apply_sqrt2_res \
--fir --dae_arch ncsnpp --embedding_scale 1000 --dataset cifar10 --learning_rate_dae 1e-4 \
--learning_rate_min_dae 1e-4 --epochs 1875 --dropout 0.2 --batch_size 16 --num_channels_dae 512 --num_scales_dae 3 \
--num_cell_per_scale_dae 8 --sde_type vpsde --beta_start 0.1 --beta_end 20.0 --sigma2_0 0.0 \
--weight_decay_norm_dae 1e-2 --weight_decay_norm_vae 1e-2 --time_eps 0.01 --train_ode_eps 1e-6 --eval_ode_eps 1e-6 \
--train_ode_solver_tol 1e-5 --eval_ode_solver_tol 1e-5 --iw_sample_p drop_all_iw --iw_sample_q reweight_p_samples \
--arch_instance_dae res_ho_attn --num_process_per_node 8 --use_se --node_rank $NODE_RANK --num_proc_node 2 \
--master_address $IP_ADDR '
- LSGM (best FID):
mpirun --allow-run-as-root -np 2 -npernode 1 bash -c
'python train_vada.py --fid_dir $FID_STATS_DIR --data $DATA_DIR/cifar10 --root $CHECKPOINT_DIR \
--save $EXPR_ID/lsgm2 --vae_checkpoint $EXPR_ID/vae2/checkpoint.pt --train_vae --custom_conv_dae --apply_sqrt2_res \
--fir --cont_kl_anneal --dae_arch ncsnpp --embedding_scale 1000 --dataset cifar10 --learning_rate_dae 1e-4 \
--learning_rate_min_dae 1e-4 --epochs 1875 --dropout 0.2 --batch_size 16 --num_channels_dae 512 --num_scales_dae 3 \
--num_cell_per_scale_dae 8 --sde_type vpsde --beta_start 0.1 --beta_end 20.0 --sigma2_0 0.0 \
--weight_decay_norm_dae 1e-2 --weight_decay_norm_vae 1e-2 --time_eps 0.01 --train_ode_eps 1e-6 --eval_ode_eps 1e-6 \
--train_ode_solver_tol 1e-5 --eval_ode_solver_tol 1e-5 --iw_sample_p drop_all_iw --iw_sample_q reweight_p_samples \
--arch_instance_dae res_ho_attn --num_process_per_node 8 --use_se --node_rank $NODE_RANK --num_proc_node 2 \
--master_address $IP_ADDR '
- LSGM (best NLL):
mpirun --allow-run-as-root -np 2 -npernode 1 bash -c
'python train_vada.py --fid_dir $FID_STATS_DIR --data $DATA_DIR/cifar10 --root $CHECKPOINT_DIR \
--save $EXPR_ID/lsgm3 --vae_checkpoint $EXPR_ID/vae3/checkpoint.pt --train_vae --apply_sqrt2_res --fir \
--cont_kl_anneal --dae_arch ncsnpp --embedding_scale 1000 --dataset cifar10 --learning_rate_dae 1e-4 \
--learning_rate_min_dae 1e-4 --epochs 1875 --dropout 0.2 --batch_size 16 --num_channels_dae 512 --num_scales_dae 3 \
--num_cell_per_scale_dae 8 --sde_type geometric_sde --sigma2_min 3e-5 --sigma2_max 0.999 --sigma2_0 3e-5 \
--weight_decay_norm_dae 1e-2 --weight_decay_norm_vae 1e-2 --time_eps 0.0 --train_ode_eps 1e-6 --eval_ode_eps 1e-6 \
--train_ode_solver_tol 1e-5 --eval_ode_solver_tol 1e-5 --iw_sample_p ll_uniform --iw_sample_q reweight_p_samples \
--arch_instance_dae res_ho_attn --num_process_per_node 8 --use_se --node_rank $NODE_RANK --num_proc_node 2 \
--master_address \${NGC_MASTER_ADDR} '
The following command can be used to evaluate the negative variational bound on the data log-likelihood as well as the FID score for any of the LSGMs trained on CIFAR-10 (on 2 nodes with 8 32GB V100 GPUs each):
mpirun --allow-run-as-root -np 2 -npernode 1 bash -c
'python evaluate_vada.py --data $DATA_DIR/cifar10 --root $CHECKPOINT_DIR --save $EXPR_ID/eval -
