SkillAgentSearch skills...

PMRF

[ICLR 2025] Official implementation of Posterior-Mean Rectified Flow: Towards Minimum MSE Photo-Realistic Image Restoration

Install / Use

/learn @ohayonguy/PMRF

README

<div align="center">

Posterior-Mean Rectified Flow:<br />Towards Minimum MSE Photo-Realistic Image Restoration<br />(ICLR 2025)

[Paper] [Project Page] [Demo]

Guy Ohayon, Tomer Michaeli, Michael Elad<br /> Technion—Israel Institute of Technology

</div>

PMRF is a novel photo-realistic image restoration algorithm. It (provably) approximates the optimal estimator that minimizes the Mean Squared Error (MSE) under a perfect perceptual quality constraint.

<div align="center"> <img src="assets/flow.png" width="2000"> </div>
<div align="center">

license torch lightning Hugging Face Hits

</div>

📈 Some results from our paper

CelebA-Test quantitative comparison

Red, blue and green indicate the best, the second best and the third best scores, respectively. <img src="assets/celeba-test-table.png"/>

WIDER-Test visual comparison

<img src="assets/wider.png"/>

WebPhoto-Test visual comparison

<img src="assets/webphoto.png"/>

⚙️ Installation

Note for Windows users: It appears that several Windows users have been unable to install the natten package, which is required in order to use the HDiT model architecture in PMRF. A solution that worked for several people is suggested here. If you couldn't solve this issue, you may train PMRF using a different architecture (e.g. UNet) and avoid using the natten package.

We created a conda environment by running the following commands, exactly in the given order (these are given in the install.sh file):

conda create -n pmrf python=3.10
conda activate pmrf
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=11.8 -c pytorch -c nvidia
conda install lightning==2.3.3 -c conda-forge
pip install opencv-python==4.10.0.84 timm==1.0.8 wandb==0.17.5 lovely-tensors==0.1.16 torch-fidelity==0.3.0 einops==0.8.0 dctorch==0.1.2 torch-ema==0.3
pip install natten==0.17.1+torch230cu118 -f https://shi-labs.com/natten/wheels
pip install nvidia-cuda-nvcc-cu11
pip install basicsr==1.4.2
pip install git+https://github.com/toshas/torch-fidelity.git
pip install lpips==0.1.4
pip install piq==0.8.0
pip install huggingface_hub==0.24.5
  1. Note that the package natten is required for the HDiT architecture used by PMRF. Make sure to replace natten==0.17.1+torch230cu118 with the correct CUDA version installed on your system. Check out https://shi-labs.com/natten/ for the available versions.
  2. We installed nvidia-cuda-nvcc-cu11 because otherwise torch.compile got hanging for some reason. torch.compile may work in your system without this package. In any case, if you wish to do so, you may simply skip this package and/or remove all the torch.compile lines from our code.
  3. Due to a compatibility issue in basicsr, you will need to modify one of the files in this package. Open /path/to/env/pmrf/lib/python3.10/site-packages/basicsr/data/degradations.py, where /path/to/env is the path where your conda installed the pmrf environment. Then, change the line
from torchvision.transforms.functional_tensor import rgb_to_grayscale

to

from torchvision.transforms.functional import rgb_to_grayscale

⬇️ Downloads

🌐 Model checkpoints

We provide our blind face image restoration model checkpoint in Hugging Face and in Google Drive. The checkpoints for section 5.2 in the paper (the controlled experiments) can be downloaded from Google Drive. Please keep the same folder structure as provided in Google Drive:

checkpoints/
├── blind_face_restoration_pmrf.ckpt    # Checkpoint of our blind face image restoration model.
├── swinir_restoration512_L1.pth    # Checkpoint of the SwinIR model trained by DifFace
├── controlled_experiments/     # Checkpoints for the controlled experiments
│   ├── colorization_gaussian_noise_025/
│   │   ├── pmrf/
│   │   │   └── epoch=999-step=273000.ckpt
│   │   ├── mmse/
│   │   │   └── epoch=999-step=273000.ckpt
.   .   .
.   .   .
.   .   .

To evaluate the landmark distance (LMD in the paper) and the identity metric (Deg in the paper), you will also need to download the resnet18_110.pth and alignment_WFLW_4HG.pth checkpoints from the Google Drive of VQFR. Place these checkpoints in the evaluation/metrics_ckpt/ folder.

🌐 Test data sets for blind face image restoration

  1. Download WebPhoto-Test, LFW-Test, and CelebA-Test (HQ and LQ) from https://xinntao.github.io/projects/gfpgan.
  2. Download WIDER-Test from https://shangchenzhou.com/projects/CodeFormer/.
  3. Put these data sets wherever you want in your system.

🧑 Blind face image restoration (section 5.1 in the paper)

⚡ Quick inference

To quickly use our model, we provide a Hugging Face checkpoint which is automatically downloaded. Simply run

python inference.py \
--ckpt_path ohayonguy/PMRF_blind_face_image_restoration \
--ckpt_path_is_huggingface \
--lq_data_path /path/to/lq/images \
--output_dir /path/to/results/dir \
--batch_size 64 \
--num_flow_steps 25

Please alter --num_flow_steps as you wish (this is the hyper-parameter K in our paper)

You may also provide a local model checkpoint (e.g., if you train your own PMRF model, or if you wish to use our Google Drive checkpoint instead of the Hugging Face one). Simply run

python inference.py \
--ckpt_path ./checkpoints/blind_face_restoration_pmrf.ckpt \
--lq_data_path /path/to/lq/images \
--output_dir /path/to/results/dir \
--batch_size 64 \
--num_flow_steps 25

Importantly, note that our blind face image restoration model is trained to handle square and aligned face images. To restore general content face images (e.g., where there is more than one face in the image), you may use our Hugging Face demo.

🔬 Evaluation

  1. We downloaded the resnet18_110.pth and alignment_WFLW_4HG.pth checkpoints from the Google Drive of VQFR, and put these in the folder evaluation/metrics_ckpt/. To evaluate the results on CelebA-Test, run:
cd evaluation
python compute_metrics_blind.py \
--parent_ffhq_512_path /path/to/parent/of/ffhq512 \
--rec_path /path/to/celeba-512-test/restored/images \
--gt_path /path/to/celeba-512-test/ground-truth/images

To evaluate the results on the real-world data sets, run:

cd evaluation
python compute_metrics_blind.py \
--parent_ffhq_512_path /path/to/parent/of/ffhq512 \
--rec_path /path/to/real-world/restored/images \
--mmse_rec_path /path/to/mmse/restored/images

The --mmse_rec_path argument is optional, and allows you to compute IndRMSE, as an indicator of the true RMSE for real-world degraded images. Note that the MMSE reconstructions are saved automatically when you run inference.py, since the MMSE model is also in the PMRF checkpoint.

💻 Training

In the folder scripts/ we provide the training scripts we used for blind face image restoration and for training the baseline models as well. If you want to run a script, you need to execute it in the root folder (where train.py is located). To train the model, you will need the FFHQ data set. We downloaded the original FFHQ 1024x1024 data set and down-sampled the images to size 512x512 using bi-cubic down-sampling.

  1. Copy the train_pmrf.sh file (located in scripts/train/blind_face_restoration) to the root folder.
  2. Adjust the arguments --train_data_root and --val_data_root according to the location of the training and validation data in your system.
  3. The SwinIR model which was trained by DifFace is provided in the checkpoints/ folder. We downloaded it via
wget https://github.com/zsyOAOA/DifFace/releases/download/V1.0/swinir_restoration512_L1.pth
  1. Adjust the argument --mmse_model_ckpt_path to the path of the SwinIR model.
  2. Adjust the arguments --num_gpus and --num_workers according to your system.
  3. Run the script train_pmrf.sh to train our model.

👩‍🔬 Controlled experiments (section 5.2 in the paper)

We provide training and evaluation codes for the controlled experiments in our paper, where we compare PMRF with the following baseline methods:

  1. Flow conditioned on Y: A rectified flow model which is conditioned on the input measurement, and learns to flow from pure noise to the ground-truth data distribution.
  2. Flow conditioned on the posterior mean predictor: A rectified flow model which is conditioned on the posterior mean prediction, and learns to flow from pure noise to the ground-truth data distribution.
  3. Flow from Y: A rectified flow model which flows
View on GitHub
GitHub Stars742
CategoryDevelopment
Updated21d ago
Forks42

Languages

Python

Security Score

100/100

Audited on Mar 3, 2026

No findings