ConfDiff
Official Implemetation of ConfDiff (ICML'24) - Protein Conformation Generation via Force-Guided SE(3) Diffusion Models
Install / Use
/learn @bytedance/ConfDiffREADME
Official Implemetation of ConfDiff (ICML'24) - Protein Conformation Generation via Force-Guided SE(3) Diffusion Models
<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a> <a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a> <a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-89b8cd"></a> <a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a><br>
</div> <!-- <div align="center"> <img src="assets/bpti.gif" alt="Image description" style="width:600px;"> </div> -->The repository is the official implementation of the ICML24 paper Protein Conformation Generation via Force-Guided SE(3) Diffusion Models, which introduces ConfDiff, a force-guided SE(3) diffusion model for protein conformation generation. ConfDiff can generate protein conformations with rich diversity while preserving high fidelity. Physics-based energy and force guidance strategies effectively guide the diffusion sampler to generate low-energy conformations that better align with the underlying Boltzmann distribution.
<!-- Based on the protein backbone model from ICML, we extend the model to a **full-atom model** by incorporating side-chain predictions (no need for packing side-chains when calculating energy). Additionally, we also include the test results on ATLAS dataset. -->With recent progress in protein conformation prediction, we extend ConfDiff to ConfDiff-FullAtom, diffusion models for full-atom protein conformation prediction. Current models include the following updates:
- Integrated a regression module to predict atomic coordinates for side-chain heavy atoms
- Provided models options with four folding model representations (ESMFold or OpenFold with recycling number of 0 and 3)
- Used all feature outputs from the folding model (node + edge) for diffusion model training.
- Released a version of sequence-conditional models fine-tuned on the Atlas MD dataset.
Installation
# clone project
git clone https://url/to/this/repo/ConfDiff.git
cd ConfDiff
# create conda virtual environment
conda env create -f env.yml
conda activate confdiff
# install openfold
git clone https://github.com/aqlaboratory/openfold.git
pip install -e openfold
Pretrained Representation
We precompute ESMFold and OpenFold representations as inputs to the model. The detailed generation pipline can be referenced in the README of the pretrained_repr/ folder.
ConfDiff-BASE
ConfDiff-BASE employs a sequence-based conditional score network to guide an unconditional score model using classifier-free guidance, enabling diverse conformation sampling while ensuring structural fidelity to the input sequence.
Prepare datasets
We train Confdiff-BASE using the protein structures from Protein Data Bank, and evaluate on various datasets including fast-folding, bpti, apo-holo and atlas.
Details on dataset prepration and evaluation can be found in the dataset folder.
The following datasets and pre-computed representations are required to train Confdiff-BASE:
- RCSB PDB dataset: See
dataset/rcsbfor details. Once prepared, specify thecsv_pathandpdb_dirin the configuration fileconfigs/paths/default.yaml. - ESMFold or OpenFold representations: See
pretrained_reprfor details. Once prepared, specify thedata_rootofesmfold_repr/openfold_reprin the configuration fileconfigs/paths/default.yaml.
Training
ConfDiff-BASE consists of a sequence-conditional model and an unconditional model.
To train the conditional model:
python3 src/train.py \
task_name=cond \
experiment=full_atom \
data/dataset=rcsb \
data/repr_loader=openfold \
data.repr_loader.num_recycles=3 \
data.train_batch_size=4
The detailed training configuration can be found in configs/experiment/full_atom.yaml.
To train the unconditonal model:
python3 src/train.py \
task_name=uncond \
experiment=uncond_model \
data/dataset=rcsb \
data.train_batch_size=4
The detailed training configuration can be found in configs/experiment/uncond.yaml.
Model Checkpoints
Access pretrained models with different pretrained representations:
| Model name | Repr type | Num of Recycles | |--------------------|---------------|---------------------| | ConfDiff-ESM-r0-COND | ESMFold | 0 | | ConfDiff-ESM-r3-COND | ESMFold | 3 | | ConfDiff-OF-r0-COND | OpenFold | 0 | | ConfDiff-OF-r3-COND | OpenFold | 3 | | ConfDiff-UNCOND | / | / |
<!-- The results of ConfDiff-BASE with different pretrained representations are shown in the table below. -->Inference
To sample conformations using the ConfDiff-BASE model:
#Please note that the model and representation need to be compatible.
python3 src/eval.py \
task_name=eval_base_bpti \
experiment=clsfree_guide \
data/repr_loader=openfold \
data.repr_loader.num_recycles=3 \
paths.guidance.cond_ckpt=/path/to/your/cond_model \
paths.guidance.uncond_ckpt=/path/to/your/uncond_model \
data.dataset.test_gen_dataset.csv_path=/path/to/your/testset_csv \
data.dataset.test_gen_dataset.num_samples=1000 \
data.gen_batch_size=20 \
model.score_network.cfg.clsfree_guidance_strength=0.8
ConfDiff-FORCE/ENERGY
By utilizing prior information from the MD force field, our model effectively reweights the generated conformations to ensure they better adhere to the equilibrium distribution
Data
Protein conformations with force or energy labels are required to train the corresponding ConfDiff-FORCE or ConfDiff-ENERGY. We use OpenMM for energy and force evaluation of the conformation samples generated by ConfDiff-BASE
To evaluate force and energy labels using OpenMM and prepare the training data:
python3 src/utils/protein/openmm_energy.py \
--input-root /path/to/your/generated/samples \
--output-root /path/to/your/output_dir \
The output directory /path/to/your/output_dir contains force annotation files (with the suffix *force.npy), optimized energy PDB files, and train and validation CSV files with energy labels.
Training
Before training, please ensure that the pretrained representations for the training proteins have been prepared.
To train ConfDiff-FORCE:
# case for training ConfDiff-FORCE
python3 src/train.py \
experiment=force_guide \
data/repr_loader=esmfold \
data.repr_loader.num_recycles=3 \
paths.guidance.cond_ckpt=/path/to/your/cond_model \
paths.guidance.uncond_ckpt=/path/to/your/uncond_model \
paths.guidance.train_csv=/path/to/your/output_dir/train.csv \
paths.guidance.val_csv=/path/to/your/output_dir/val.csv \
paths.guidance.pdb_dir=/path/to/your/output_dir/ \
data.train_batch_size=4
Similarly, the ConfDiff-ENERGY model can be trained by setting experiment=energy_guide.
Detailed training configurations can be found in the file configs/experiment/force_guide(energy_guide).yaml.
Model Checkpoints
Access pretrained ConfDiff-FORCE/ENERGY with ESMFold representations on different datasets.
| Model name| dataset | Repr type | Num of Recycles | | --------- | --------------- | --------------- | --------------- | | ConfDiff-ESM-r0-FORCE | fast-folding | ESMFold| 0 | | ConfDiff-ESM-r0-ENERGY | fast-folding | ESMFold| 0 | | ConfDiff-ESM-r3-FORCE | bpti |ESMFold| 3 | | ConfDiff-ESM-r3-ENERGY | bpti | ESMFold| 3 |
We found that using only the node representation's on the fast-folding dataset yields better results. To train models with only node representation, set data.repr_loader.edge_size=0
Inference
To sample conformations using the ConfDiff-FORCE/ENERGY model:
# case for generating samples by ConfDiff-FORCE
python3 src/eval.py \
task_name=eval_force \
experiment=force_guide \
data/repr_loader=esmfold \
data.repr_loader.num_recycles=3 \
ckpt_path=/path/to/your/model/ckpt/ \
data.dataset.test_gen_dataset.csv_path=/path/to/your/testset_csv \
data.dataset.test_gen_dataset.num_samples=1000 \
data.gen_batch_size=20 \
model.score_network.cfg.clsfree_guidance_strength=0.8 \
model.score_network.cfg.force_guidance_strength=1.0
# data.repr_loader.edge_size=0 for pretrained checkpoints on fast-folding
Fine-tuning On ATLAS
See `datasets/
