SkillAgentSearch skills...

Gangealing

Official PyTorch Implementation of "GAN-Supervised Dense Visual Alignment" (CVPR 2022 Oral, Best Paper Finalist)

Install / Use

/learn @wpeebles/Gangealing

README

GAN-Supervised Dense Visual Alignment (GANgealing)<br><sub>Official PyTorch Implementation of the CVPR 2022 Paper (Oral, Best Paper Finalist)</sub>

Paper | Project Page | Video | Two Minute Papers | Mixed Reality Playground Open in Colab

Teaser image Teaser image Teaser image

This repo contains training, evaluation, and visualization code for the GANgealing algorithm from our GAN-Supervised Dense Visual Alignment paper. Please see our project page for high quality results.

GAN-Supervised Dense Visual Alignment<br> William Peebles, Jun-Yan Zhu, Richard Zhang, Antonio Torralba, Alexei Efros, Eli Shechtman<br> UC Berkeley, Carnegie Mellon University, Adobe Research, MIT CSAIL<br> CVPR 2022 - Oral, Best Paper Finalist

GAN-Supervised Learning is a method for learning discriminative models and their GAN-generated training data jointly end-to-end. We apply our framework to the dense visual alignment problem. Inspired by the classic Congealing method, our GANgealing algorithm trains a Spatial Transformer to warp random samples from a GAN trained on unaligned data to a common, jointly-learned target mode. The target mode is updated to make the Spatial Transformer's job "as easy as possible." The Spatial Transformer is trained exclusively on GAN images and generalizes to real images at test time automatically.

Watch the video

Once trained, the average aligned image is a template from which you can propagate anything. For example, by drawing cartoon eyes on our average congealed cat image, you can propagate them realistically to any video or image of a cat.

This repository contains:

  • 🎱 Pre-trained GANgealing models for eight datasets, including both the Spatial Transformers and generators
  • 💥 Training code which fully supports Distributed Data Parallel and the torchrun API
  • 🎥 Scripts and a self-contained Colab notebook for running mixed reality with our Spatial Transformers
  • ⚡ A lightning-fast CUDA implementation of splatting to generate high-quality warping visualizations
  • 🚀 An implementation of anti-aliased grid sampling useful for Spatial Transformers (thanks Tim Brooks!)
  • 🎆 Several additional evaluation and visualization scripts to reproduce results from our paper and website

Setup

First, download the repo and add it to your PYTHONPATH:

git clone https://github.com/wpeebles/gangealing.git
cd gangealing
export PYTHONPATH="${PYTHONPATH}:${PWD}"

We provide an environment.yml file that can be used to create a Conda environment:

conda env create -f environment.yml
conda activate gg

This will install PyTorch with a recent version of CUDA/cuDNN. To install CUDA 10.2/cuDNN 7.6.5 specifically, you can use environment_cu102.yml in the above command. See below for details on performance differences between CUDA/cuDNN versions.

If you use your own environment, you need a recent version of PyTorch (1.10.1+). Older versions of PyTorch will likely have problems building the StyleGAN2 extensions.

Running Pre-trained Models

The applications directory contains several files for evaluating and visualizing pre-trained GANgealing models.

Using our Pre-trained Models: We provide several pre-trained GANgealing models: bicycle, cat, celeba, cub, dog and tvmonitor. We also have pre-trained checkpoints for our car and horse clustering models. You can use any of these models by specifying them with the --ckpt argument; this will automatically download and cache the weights. The relevant hyperparameters for running the model (most importantly, the --iters argument) will be automatically loaded as well. If you want to use your own test time hyperparameters, add --override to the command; see an example here.

The --output_resolution argument controls the size of congealed images output by the Spatial Transformer. For the highest quality results, we recommend setting this equal to the value you provide to --real_size (default value is 128).

Preparing Real Data

We use LMDBs for storing data. You can use prepare_data.py to pre-process input datasets. Note that setting-up real data is not required for training.

LSUN: The following command will automatically download and pre-process the first 10,000 images from LSUN Cats (you can change --lsun_category and --max_images):

python prepare_data.py --input_is_lmdb --lsun_category cat --out data/lsun_cats --size 512 --max_images 10000

If you previously downloaded an LSUN LMDB yourself (e.g., at path_to_lsun_cats_download), you can instead use the following command:

python prepare_data.py --input_is_lmdb --path path_to_lsun_cats_download --out data/lsun_cats --size 512 --max_images 10000

Image Folders: For any dataset where you have all images in a single folder, you can pre-process them with:

python prepare_data.py --path folder_of_images --out data/my_new_dataset --pad [center/border/zero] --size S

where S is the square resolution of the resized images.

SPair-71K: You can download and prepare SPair for PCK evaluation (e.g., for Cats) with:

python prepare_data.py --spair_category cat --spair_split test --out data/spair_cats_test --size 256

CUB: We closely follow the pre-processing steps used by ACSM for CUB PCK evaluation. You can download and prepare the CUB validation split with:

python prepare_data.py --cub_acsm --out data/cub_val --size 256

Congealing and Dense Correspondence Visualization

Teaser image Teaser image

vis_correspondence.py produces a video depicting real images being gradually aligned with our Spatial Transformer network. It also can be used to visualize label/object propagation:

python applications/vis_correspondence.py --ckpt cat --real_data_path data/lsun_cats --vis_in_stages --real_size 512 --output_resolution 512 --resolution 256 --label_path assets/masks/cat_mask.png --dset_indices 2363 9750 7432 1946

Mixed Reality (Object Lenses) Open in Colab

Teaser image Teaser image Teaser image

<table cellpadding="0" cellspacing="0" > <tr> <td align="center">Dense Tracking<br> <img src="images/snowpuppy_track.gif" width=240px></td> <td align="center">Object Propagation<br> <img src="images/snowpuppy_object.gif" width=240px></td> <td align="center">Congealed Video<br> <img src="images/snowpuppy_congealed.gif" width=240px></td> </tr> </table>

mixed_reality.py applies a pre-trained Spatial Transformer per-frame to an input video. We include several objects and masks you can propagate in the assets folder.

The first step is to prepare the video dataset. If you have the video saved as an image folder (with filenames in order based on timestamp), you can run:

python prepare_data.py --path folder_of_frames --out data/my_video_dataset --pad center --size 1024

This command will pre-process the images to square with center-cropping and resize them to 1024x1024 resolution. You can specify --pad border to perform border padding instead of cropping or --pad resize_small_side to preserve aspect ratio. No matter what you choose for --pad, the value you use for --size needs to be a multiple of 128.

If your video is saved in mp4, mov, etc. format, we provide a script that will convert it into frames via FFmpeg:

./process_video.sh path_to_video

This will save a folder of frames in the data/video_frames folder, which you can then run prepare_data.py on as described above.

Now we can run GANgealing on the video. For example, this will propagate a cartoon face via our LSUN Cats model:

torchrun --nproc_per_node=NUM_GPUS applications/mixed_reality.py --ckpt cat --objects --label_path assets/objects/cat/cat_cartoon.png --sigma 0.3 --opacity 1 --real_size 1024 --resolution 8192 --real_data_path path_to_my_video --no_flip_inference

This will efficiently parallelize the evaluation of the video over NUM_GPUS. Here is a quick overview of some arguments you can use with this script (see mixed_reality.py for all options):

  • --save_frames can be specified to significantly reduce GPU memory usage (at the cost of speed)
  • --label_path points to the RGBA png file containing the object/mask you are propagating
  • --objects will propagate RGB values from your label_path image. If you omit this argument, only the alpha channel of the label_path image will be used, and an RGB colorscale will be create
View on GitHub
GitHub Stars1.0k
CategoryEducation
Updated8d ago
Forks121

Languages

Python

Security Score

100/100

Audited on Mar 14, 2026

No findings