SkillAgentSearch skills...

EnergyMatching

[NeurIPS 2025] Official repository for "Energy Matching: Unifying Flow Matching and Energy-Based Models for Generative Modeling"

Install / Use

/learn @m1balcerak/EnergyMatching
About this skill

Quality Score

0/100

Supported Platforms

Universal

README

Energy Matching

<img align="right" src="media/EM_2D.png" width="30%" alt="Energy Matching Illustration" /> Energy Matching unifies flow matching and energy-based models in a single time-independent scalar field, enabling efficient transport between the source and target distributions while retaining explicit likelihood information for flexible, high-quality generation. [NeurIPS 2025]

Version 0.9 – This is the official repository for the paper Energy Matching.

Checkpoints

  • CIFAR-10 (Image → Scalar, 50M parameters): warm-up and main-training checkpoints on Hugging Face reach FID ≈ 3.3 around T=3.25.
  • ImageNet32 (Image → Scalar, 50M parameters): warm-up and main-training checkpoints on Hugging Face reach FID ≈ 6.6 around T=2.50.

Setup (CUDA)

  1. Create and activate a Python environment (conda example):
    conda create -n energy-matching python=3.10 -y
    conda activate energy-matching
    
  2. Install PyTorch with CUDA support and the project requirements:
    pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
    pip install -r requirements.txt
    

Running the examples

2D Playground – Eight Gaussians to Two Moons

Experiment with the core idea in a lightweight setting using the notebook at experiments/toy2d/tutorial_2D.ipynb. It visualizes how the potential energy field transports particles in 2D from eight Gaussians to two moons.

CIFAR‑10 Training and Evaluation

<p align="center"> <strong>Langevin MCMC (unconditional) <br> Trajectory from T = 0 to T = 4 (FID = 3.3)</strong><br> <img src="media/cifar10_FID_3_3.gif" width="60%" alt="Trajectory Animation"> </p>

Initial training (warm-up, Algorithm 1):

torchrun --nproc_per_node=4 experiments/cifar10/train_cifar_multigpu.py \
    --total_steps 145000 \
    --lr 1.2e-3 \
    --batch_size 128 \
    --time_cutoff 1.0 \
    --epsilon_max 0.0 \
    --lambda_cd 0. \
    --n_gibbs 0 \
    --ema_decay 0.9999 \
    --save_step 5000

Main training with contrastive divergance (Algorithm 2):

torchrun --nproc_per_node=4 experiments/cifar10/train_cifar_multigpu.py \
    --resume_ckpt /PATH/TO/warm_up_checkpoint.pt \
    --total_steps 147000 \
    --lr 1.2e-3 \
    --batch_size 128 \
    --time_cutoff 1.0 \
    --epsilon_max 0.01 \
    --lambda_cd 1e-3 \
    --n_gibbs 200 \
    --ema_decay 0.99 \
    --save_step 100 \
    --dt_gibbs 0.01 \
    --cd_neg_clamp 0.02  \
    --split_negative True \
    --same_temperature_scheduler True

Evaluation FID across trajectories at times T=1.0 to T=5.0 (Heun solver):

python experiments/cifar10/fid_cifar_heun_1gpu.py \
    --resume_ckpt=/PATH/TO/main_training_checkpoint.pt \
    --output_dir=./sampling_results \
    --use_ema True \
    --time_cutoff 1.0 \
    --epsilon_max 0.01 \
    --batch_size 64 \
    --dt_gibbs 0.01

Pretrained CIFAR-10 checkpoints are available at Hugging Face. Use cifar10_warm_up_145000.pt for the warm-up phase and cifar10_main_training_147000.pt after the main training. The latter obtains an FID of 3.3 at around T=3.25.

To generate CIFAR-10 images using unconditional Langevin Monte Carlo sampling from the trained Energy Matching model, run:

python experiments/cifar10/sample_cifar_heun_1gpu.py \
    --resume_ckpt=/PATH/TO/main_training_checkpoint.pt \
    --batch_size 128 \
    --time_cutoff 1.0 \
    --epsilon_max 0.01 \
    --dt_gibbs 0.01 \
    --use_ema True \
    --t_end=3.25

Here, t_end corresponds to the sampling time $\tau_s$.

ImageNet32 Training and Evaluation

Download the downsampled ImageNet32 training batches (train_data_batch_1 ... train_data_batch_10) from the official release and place them under experiments/imagenet/data/Imagenet32_train/ (or point the IMAGENET32_PATH environment variable to that folder) before launching training.

Initial training (Algorithm 1):

torchrun --nproc_per_node=7 experiments/imagenet/train_imagenet_multigpu.py \
    --total_steps 640000 \
    --lr 6e-4 \
    --batch_size 128 \
    --time_cutoff 1.0 \
    --epsilon_max 0. \
    --lambda_cd 0. \
    --ema_decay 0.9999 \
    --save_step 80000

Main training with contrastive divergence (Algorithm 2):

torchrun --nproc_per_node=7 experiments/imagenet/train_imagenet_multigpu.py \
    --resume_ckpt=/PATH/TO/warm_up_imagenet_checkpoint.pt \
    --total_steps 641000 \
    --lr 6e-4 \
    --batch_size 128 \
    --time_cutoff 1.0 \
    --epsilon_max 0.01 \
    --lambda_cd 0.001 \
    --n_gibbs 200 \
    --ema_decay 0.99 \
    --save_step 100 \
    --cd_neg_clamp 0.02 \
    --split_negative True \
    --same_temperature_scheduler True

FID evaluation across trajectory times T=0.75 to T=4.0 (Heun solver):

torchrun --nproc_per_node=1 experiments/imagenet/fid_imagenet_heun_multigpu.py \
    --resume_ckpt=/PATH/TO/main_training_imagenet_checkpoint.pt \
    --output_dir=./sampling_results \
    --use_ema True \
    --time_cutoff 1.0 \
    --epsilon_max 0.01 \
    --batch_size 128 \
    --dt_gibbs 0.01

Pretrained ImageNet32 checkpoints (warm-up and main training) are hosted on Hugging Face.

Protein inverse design

Train the model with:

python experiments/proteins/train_proteins.py \
    --epsilon_max 0.1 \
    --time_cutoff 0.9 \
    --n_gibbs 200 \
    --dt_gibbs 0.01

Pretrained AAV medium/hard checkpoints are available at Hugging Face. Run conditional sampling with:

python experiments/proteins/sampling.py

The VAE used for the continuous latent space and the dataset is already provided.

Citation

If you find our work useful, please consider citing:

@article{balcerak2025energy,
  title={Energy Matching: Unifying Flow Matching and Energy-Based Models for Generative Modeling},
  author={Balcerak, Michal and Amiranashvili, Tamaz and Terpin, Antonio and Shit, Suprosanna and Bogensperger, Lea and Kaltenbach, Sebastian and Koumoutsakos, Petros and Menze, Bjoern},
  journal={arXiv preprint arXiv:2504.10612},
  year={2025}
}
View on GitHub
GitHub Stars211
CategoryDevelopment
Updated4d ago
Forks10

Languages

Jupyter Notebook

Security Score

95/100

Audited on Apr 4, 2026

No findings