AGE
Code for the paper "Adversarial Generator-Encoder Networks"
Install / Use
/learn @DmitryUlyanov/AGEREADME
This repository contains code for the paper
"Adversarial Generator-Encoder Networks" (AAAI'18) by Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky.

Pretrained models
This is how you can access the models used to generate figures in the paper.
-
First install dev version of
pytorch 0.2and make sure you havejupyter notebookready. -
Then download the models with the script:
bash download_pretrained.sh
- Run
jupyter notebookand go throughevaluate.ipynb.
Here is an example of samples and reconstructions for imagenet, celeba and cifar10 datasets generated with evaluate.ipynb.
Celeba
|Samples |Reconstructions|
|:---------:|:-------------:|
|
|
|
Cifar10
|Samples |Reconstructions|
|:---------:|:-------------:|
|
|
|
Tiny ImageNet
|Samples |Reconstructions|
|:---------:|:-------------:|
|
|
|
Training
Use age.py script to train a model. Here are the most important parameters:
--dataset: one of [celeba, cifar10, imagenet, svhn, mnist]--dataroot: for datasets included intorchvisionit is a directory where everything will be downloaded to; for imagenet, celeba datasets it is a path to a directory with folderstrainandvalinside.--image_size:--save_dir: path to a folder, where checkpoints will be stored--nz: dimensionality of latent space-- batch_size: Batch size. Default 64.--netG:.pyfile with generator definition. Searched inmodelsdirectory--netE:.pyfile with generator definition. Searched inmodelsdirectory--netG_chp: path to a generator checkpoint to load from--netE_chp: path to an encoder checkpoint to load from--nepoch: number of epoch to run--start_epoch: epoch number to start from. Useful for finetuning.--e_updates: Update plan for encoder.<num steps>;KL_fake:<weight>,KL_real:<weight>,match_z:<weight>,match_x:<weight>.--g_updates: Update plan for generator.<num steps>;KL_fake:<weight>,match_z:<weight>,match_x:<weight>.
And misc arguments:
--workers: number of dataloader workers.--ngf: controlles number of channels in generator--ndf: controlles number of channels in encoder--beta1: parameter for ADAM optimizer--cpu: do not use GPU--criterion: Parametricparamor non-parametricnonparamway to compute KL. Parametric fits Gaussian into data, non-parametric is based on nearest neighbors. Default:param.--KL: What KL to compute:qporpq. Default isqp.--noise:spherefor uniform on sphere orgaussian. Defaultsphere.--match_z: loss to use as reconstruction loss in latent space.L1|L2|cos. Defaultcos.--match_x: loss to use as reconstruction loss in data space.L1|L2|cos. DefaultL1.--drop_lr: eachdrop_lrepochs a learning rate is dropped.--save_every: controls how often intermediate results are stored. Default50.--manual_seed: random seed. Default123.
Here is cmd you can start with:
Celeba
Let data_root to be a directory with two folders train, val, each with the images for corresponding split.
python age.py --dataset celeba --dataroot <data_root> --image_size 64 --save_dir <save_dir> --lr 0.0002 --nz 64 --batch_size 64 --netG dcgan64px --netE dcgan64px --nepoch 5 --drop_lr 5 --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' --g_updates '3;KL_fake:1,match_z:1000,match_x:0'
It is beneficial to finetune the model with larger batch_size and stronger matching weight then:
python age.py --dataset celeba --dataroot <data_root> --image_size 64 --save_dir <save_dir> --start_epoch 5 --lr 0.0002 --nz 64 --batch_size 256 --netG dcgan64px --netE dcgan64px --nepoch 6 --drop_lr 5 --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:15' --g_updates '3;KL_fake:1,match_z:1000,match_x:0' --netE_chp <save_dir>/netE_epoch_5.pth --netG_chp <save_dir>/netG_epoch_5.pth
Imagenet
python age.py --dataset imagenet --dataroot /path/to/imagenet_dir/ --save_dir <save_dir> --image_size 32 --save_dir ${pdir} --lr 0.0002 --nz 128 --netG dcgan32px --netE dcgan32px --nepoch 6 --drop_lr 3 --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' --g_updates '2;KL_fake:1,match_z:2000,match_x:0' --workers 12
It can be beneficial to switch to 256 batch size after several epochs.
Cifar10
python age.py --dataset cifar10 --image_size 32 --save_dir <save_dir> --lr 0.0002 --nz 128 --netG dcgan32px --netE dcgan32px --nepoch 150 --drop_lr 40 --e_updates '1;KL_fake:1,KL_real:1,match_z:0,match_x:10' --g_updates '2;KL_fake:1,match_z:1000,match_x:0'
Tested with python 2.7.
Implementation is based on pyTorch DCGAN code.
Citation
If you found this code useful please cite our paper
@inproceedings{DBLP:conf/aaai/UlyanovVL18,
author = {Dmitry Ulyanov and
Andrea Vedaldi and
Victor S. Lempitsky},
title = {It Takes (Only) Two: Adversarial Generator-Encoder Networks},
booktitle = {{AAAI}},
publisher = {{AAAI} Press},
year = {2018}
}
Related Skills
node-connect
350.8kDiagnose OpenClaw node connection and pairing failures for Android, iOS, and macOS companion apps
frontend-design
110.4kCreate distinctive, production-grade frontend interfaces with high design quality. Use this skill when the user asks to build web components, pages, or applications. Generates creative, polished code that avoids generic AI aesthetics.
openai-whisper-api
350.8kTranscribe audio via OpenAI Audio Transcriptions API (Whisper).
qqbot-media
350.8kQQBot 富媒体收发能力。使用 <qqmedia> 标签,系统根据文件扩展名自动识别类型(图片/语音/视频/文件)。
