Gangealing
Official PyTorch Implementation of "GAN-Supervised Dense Visual Alignment" (CVPR 2022 Oral, Best Paper Finalist)
Install / Use
/learn @wpeebles/GangealingREADME
GAN-Supervised Dense Visual Alignment (GANgealing)<br><sub>Official PyTorch Implementation of the CVPR 2022 Paper (Oral, Best Paper Finalist)</sub>
Paper | Project Page | Video | Two Minute Papers | Mixed Reality Playground 

This repo contains training, evaluation, and visualization code for the GANgealing algorithm from our GAN-Supervised Dense Visual Alignment paper. Please see our project page for high quality results.
GAN-Supervised Dense Visual Alignment<br> William Peebles, Jun-Yan Zhu, Richard Zhang, Antonio Torralba, Alexei Efros, Eli Shechtman<br> UC Berkeley, Carnegie Mellon University, Adobe Research, MIT CSAIL<br> CVPR 2022 - Oral, Best Paper Finalist
GAN-Supervised Learning is a method for learning discriminative models and their GAN-generated training data jointly end-to-end. We apply our framework to the dense visual alignment problem. Inspired by the classic Congealing method, our GANgealing algorithm trains a Spatial Transformer to warp random samples from a GAN trained on unaligned data to a common, jointly-learned target mode. The target mode is updated to make the Spatial Transformer's job "as easy as possible." The Spatial Transformer is trained exclusively on GAN images and generalizes to real images at test time automatically.
Once trained, the average aligned image is a template from which you can propagate anything. For example, by drawing cartoon eyes on our average congealed cat image, you can propagate them realistically to any video or image of a cat.
This repository contains:
- 🎱 Pre-trained GANgealing models for eight datasets, including both the Spatial Transformers and generators
- 💥 Training code which fully supports Distributed Data Parallel and the torchrun API
- 🎥 Scripts and a self-contained Colab notebook for running mixed reality with our Spatial Transformers
- ⚡ A lightning-fast CUDA implementation of splatting to generate high-quality warping visualizations
- 🚀 An implementation of anti-aliased grid sampling useful for Spatial Transformers (thanks Tim Brooks!)
- 🎆 Several additional evaluation and visualization scripts to reproduce results from our paper and website
Setup
First, download the repo and add it to your PYTHONPATH:
git clone https://github.com/wpeebles/gangealing.git
cd gangealing
export PYTHONPATH="${PYTHONPATH}:${PWD}"
We provide an environment.yml file that can be used to create a Conda environment:
conda env create -f environment.yml
conda activate gg
This will install PyTorch with a recent version of CUDA/cuDNN. To install CUDA 10.2/cuDNN 7.6.5 specifically, you can use environment_cu102.yml in the above command. See below for details on performance differences between CUDA/cuDNN versions.
If you use your own environment, you need a recent version of PyTorch (1.10.1+). Older versions of PyTorch will likely have problems building the StyleGAN2 extensions.
Running Pre-trained Models
The applications directory contains several files for evaluating and visualizing pre-trained GANgealing models.
Using our Pre-trained Models: We provide several pre-trained GANgealing models: bicycle, cat, celeba, cub, dog and tvmonitor. We also have pre-trained checkpoints
for our car and horse clustering models. You can use any of these models by specifying them with the --ckpt argument; this will automatically download and cache
the weights. The relevant hyperparameters for running the model (most importantly, the --iters argument) will be automatically loaded as well. If you want to use your own test time hyperparameters, add --override to the command; see an example here.
The --output_resolution argument controls the size of congealed images output by the Spatial Transformer. For the highest quality results, we recommend setting this equal to the value you provide to --real_size (default value is 128).
Preparing Real Data
We use LMDBs for storing data. You can use prepare_data.py to pre-process input datasets. Note that setting-up real data is not
required for training.
LSUN: The following command will automatically download and pre-process the first 10,000 images from LSUN Cats (you can change --lsun_category and --max_images):
python prepare_data.py --input_is_lmdb --lsun_category cat --out data/lsun_cats --size 512 --max_images 10000
If you previously downloaded an LSUN LMDB yourself (e.g., at path_to_lsun_cats_download), you can instead use the following command:
python prepare_data.py --input_is_lmdb --path path_to_lsun_cats_download --out data/lsun_cats --size 512 --max_images 10000
Image Folders: For any dataset where you have all images in a single folder, you can pre-process them with:
python prepare_data.py --path folder_of_images --out data/my_new_dataset --pad [center/border/zero] --size S
where S is the square resolution of the resized images.
SPair-71K: You can download and prepare SPair for PCK evaluation (e.g., for Cats) with:
python prepare_data.py --spair_category cat --spair_split test --out data/spair_cats_test --size 256
CUB: We closely follow the pre-processing steps used by ACSM for CUB PCK evaluation. You can download and prepare the CUB validation split with:
python prepare_data.py --cub_acsm --out data/cub_val --size 256
Congealing and Dense Correspondence Visualization

vis_correspondence.py produces a video depicting real images being gradually aligned with our Spatial Transformer network.
It also can be used to visualize label/object propagation:
python applications/vis_correspondence.py --ckpt cat --real_data_path data/lsun_cats --vis_in_stages --real_size 512 --output_resolution 512 --resolution 256 --label_path assets/masks/cat_mask.png --dset_indices 2363 9750 7432 1946
Mixed Reality (Object Lenses) 

mixed_reality.py applies a pre-trained Spatial Transformer per-frame to an input video. We include several objects
and masks you can propagate in the assets folder.
The first step is to prepare the video dataset. If you have the video saved as an image folder (with filenames in order based on timestamp), you can run:
python prepare_data.py --path folder_of_frames --out data/my_video_dataset --pad center --size 1024
This command will pre-process the images to square with center-cropping and resize them to 1024x1024 resolution.
You can specify --pad border to perform border padding instead of cropping or --pad resize_small_side to preserve aspect ratio. No matter what you choose for --pad, the value you use for --size needs to be a multiple of 128.
If your video is saved in mp4, mov, etc. format, we provide a script that will convert it into frames via FFmpeg:
./process_video.sh path_to_video
This will save a folder of frames in the data/video_frames folder, which you can then run prepare_data.py on as described above.
Now we can run GANgealing on the video. For example, this will propagate a cartoon face via our LSUN Cats model:
torchrun --nproc_per_node=NUM_GPUS applications/mixed_reality.py --ckpt cat --objects --label_path assets/objects/cat/cat_cartoon.png --sigma 0.3 --opacity 1 --real_size 1024 --resolution 8192 --real_data_path path_to_my_video --no_flip_inference
This will efficiently parallelize the evaluation of the video over NUM_GPUS. Here is a quick overview of some arguments you can use with this script (see mixed_reality.py for all options):
--save_framescan be specified to significantly reduce GPU memory usage (at the cost of speed)--label_pathpoints to the RGBApngfile containing the object/mask you are propagating--objectswill propagate RGB values from yourlabel_pathimage. If you omit this argument, only the alpha channel of thelabel_pathimage will be used, and an RGB colorscale will be create

