Ankh
Ankh: Optimized Protein Language Model
Install / Use
/learn @agemagician/AnkhREADME
Ankh is the first general-purpose protein language model trained on Google's TPU-V4 surpassing the state-of-the-art performance with dramatically less parameters, promoting accessibility to research innovation via attainable resources.
<div align="center"><img width=500 height=350 src="https://github.com/agemagician/Ankh/blob/main/images/AnkhGIF.gif?raw=true"></div>This repository will be updated regulary with new pre-trained models for proteins in part of supporting the biotech community in revolutinizing protein engineering using AI.
Table of Contents
- Installation
- Models Availability
- Dataset Availability
- Usage
- Original downstream Predictions
- Followup use-cases
- Comparisons to other tools
- Community and Contributions
- Have a question?
- Found a bug?
- Requirements
- Sponsors
- Team
- License
- Citation
<a name="install"></a>
Installation
python -m pip install ankh
<a name="models"></a>
Models Availability
| Model | Ankh | Huggingface |
|------------------------------------|-----------------------------------|-------------------------------------------------------------|
| Ankh Large | ankh.load_large_model() |Ankh Large |
| Ankh Base | ankh.load_base_model() |Ankh Base |
| Ankh3 Large | ankh.load_ankh3_large() |Ankh3 Large|
| Ankh3 XL | ankh.load_ankh3_xl() |Ankh3 XL |
<a name="datasets"></a>
Datasets Availability
| Dataset | Huggingface |
| ----------------------------- |---------------------------------------------------------------------------------------------------|
| Remote Homology | load_dataset("proteinea/remote_homology") |
| CASP12 | load_dataset("proteinea/secondary_structure_prediction", data_files={'test': ['CASP12.csv']})|
| CASP14 | load_dataset("proteinea/secondary_structure_prediction", data_files={'test': ['CASP14.csv']})|
| CB513 | load_dataset("proteinea/secondary_structure_prediction", data_files={'test': ['CB513.csv']}) |
| TS115 | load_dataset("proteinea/secondary_structure_prediction", data_files={'test': ['TS115.csv']}) |
| DeepLoc | load_dataset("proteinea/deeploc") |
| Fluorescence | load_dataset("proteinea/fluorescence") |
| Solubility | load_dataset("proteinea/solubility") |
| Nearest Neighbor Search | load_dataset("proteinea/nearest_neighbor_search") |
<a name="usage"></a>
Usage
- Loading pre-trained models:
import ankh
# Load Ankh base.
model, tokenizer = ankh.load_ankh_base()
model.eval()
# Load Ankh large.
model, tokenizer = ankh.load_ankh_large()
model.eval()
# Load Ankh3 Large
model, tokenizer = ankh.load_ankh3_large()
model.eval()
# Load Ankh3 XL
model, tokenizer = ankh.load_ankh3_xl()
model.eval()
- Feature extraction using ankh large example:
model, tokenizer = ankh.load_ankh_large()
model.eval()
protein_sequences = [
'MKALCLLLLPVLGLLVSSKTLCSMEEAINERIQEVAGSLIFRAISSIGLECQSVTSRGDLATCPRGFAVTGCTCGSACGSWDVRAETTCHCQCAGMDWTGARCCRVQPLEHHHHHH',
'GSHMSLFDFFKNKGSAATATDRLKLILAKERTLNLPYMEEMRKEIIAVIQKYTKSSDIHFKTLDSNQSVETIEVEIILPR',
]
protein_sequences = [list(seq) for seq in protein_sequences]
outputs = tokenizer(
protein_sequences,
add_special_tokens=True,
padding=True,
is_split_into_words=True,
return_tensors="pt",
)
with torch.no_grad():
embeddings = model(input_ids=outputs['input_ids'], attention_mask=outputs['attention_mask'])
- Loading downstream models example:
# To use downstream model for binary classification:
binary_classification_model = ankh.ConvBERTForBinaryClassification(
input_dim=768,
nhead=4,
hidden_dim=384,
num_hidden_layers=1,
num_layers=1,
kernel_size=7,
dropout=0.2,
pooling='max',
)
# To use downstream model for multiclass classification:
multiclass_classification_model = ankh.ConvBERTForMultiClassClassification(
num_tokens=2,
input_dim=768,
nhead=4,
hidden_dim=384,
num_hidden_layers=1,
num_layers=1,
kernel_size=7,
dropout=0.2,
)
# To use downstream model for regression:
# training_labels_mean is optional parameter and it's used to fill the output layer's bias with it,
# it's useful for faster convergence.
regression_model = ankh.ConvBERTForRegression(
input_dim=768,
nhead=4,
hidden_dim=384,
num_hidden_layers=1,
num_layers=1,
kernel_size=7,
dropout=0,
pooling='max',
training_labels_mean=0.38145,
)
- Calculating Likelihood
import ankh
seq = "MDDADPEERNYDNMLKMLSDLNKDLEKLLEEMEKISVQATWMAYDMVVMRTNPTLAESMRRLEDAFVNCKEEMEKNWQELLHETKQRL"
likelihood = ankh.compute_pseudo_likelihood(
"ankh_base",
sequence,
device="cpu",
shard_input=True,
shard_batch_size=32,
verbose=True,
)
<a name="results"></a>
Original downstream Predictions
<a name="q3"></a>
- <b> Secondary Structure Prediction (Q3):</b><br/>
| Model | CASP12 | CASP14 | TS115 | CB513 | |--------------------------|:----------------:|:-------------:|:-------------:|:------------:| |Ankh3 XLarge (NLU) | 84.40% | 82.19% | - | - | |Ankh3 XLarge (S2S) | 83.76% | 82.30% | - | - | |Ankh3 Large (NLU) | 78.03% | 79.28% | - | - | |Ankh3 Large (S2S) | 75.49% | 77.96% | - | - | |Ankh 2 Large | 84.18% | 76.82% | 88.59% | 88.78% | |Ankh Large | 83.59% | 77.48% | 88.22% | 88.48% | |Ankh Base | 80.81% | 76.67% | 86.92% | 86.94% | |ProtT5-XL-UniRef50 | 83.34% | 75.09% | 86.82% | 86.64% | |ESM2-15B | 83.16% | 76.56% | 87.50% | 87.35% | |ESM2-3B | 83.14% | 76.75% | 87.50% | 87.44% | |ESM2-650M | 82.43% | 76.97% | 87.22% | 87.18% | |ESM-1b | 79.45% | 75.39% | 85.02% | 84.31% |
<a name="q8"></a>
- <b> Secondary Structure Prediction (Q8):</b><br/>
| Model | CASP12 | CASP14 | TS115 | CB513 | |--------------------------|:----------------:|:-------------:|:-------------:|:------------:| |Ankh3 XLarge (NLU) | 72.53% | 69.85% | - | - | |Ankh3 XLarge (S2S) | 72.25% | 69.51% | - | - | |Ankh3 Large (NLU) | 65.29% | 65.50% | - | - | |Ankh3 Large (S2S) | 62.74% | 65.88% | - | - | |Ankh 2 Large | 72.90% | 62.84% | 79.88% | 79.01% | |Ankh Large | 71.69% | 63.17% | 79.10% | 78.45% | |Ankh Base | 68.85% | 62.33% | 77.08% | 75.83% | |ProtT5-XL-UniRef50 | 70.47% | 59.71% | 76.91% | 74.81% | |ESM2-15B | 71.17% | 61.81% | 77.67% | 75.88% | |ESM2-3B | 71.69% | 61.52% | 77.62% | 75.95% | |ESM2-650M | 70.50% | 62.10% | 77.68% | 75.89% | |ESM-1b | 66.02% | 60.34% | 73.82% | 71.55% |
<a name="CP"></a>
- <b> Contact Prediction Long Precision Using Embeddings:</b><br/>
| Model | ProteinNet (L/1) | ProteinNet (L/5) | CASP14 (L/1) | CASP14 (L/5) | |--------------------------|:----------------:|:----------------:|:-------------:|:------------:| |Ankh 2 Large | In Progress | In Progress | In Progress | In Progress | |Ankh Large | 48.93% | 73.49% | 16.01% | 29.91% | |Ankh Base | 43.21% | 66.63% | 13.50% | 28.65% | |ProtT5-XL-UniRef50 | 44.74% | 68.95% | 11.95% | 24.45% | |ESM2-15B | 31.62% | 52.97% | 14.44% | 26.61% | |ESM2-3B | 30.24%
