├── output ├── clinvar │ └── .gitkeep ├── rocklin │ └── .gitkeep ├── mave_val │ └── .gitkeep ├── proteingym │ └── .gitkeep ├── scannet │ ├── metrics │ │ └── .gitkeep │ └── models │ │ ├── optimizer │ │ └── .gitkeep │ │ └── transformer │ │ └── .gitkeep └── train │ ├── metrics │ └── .gitkeep │ └── models │ ├── gvp │ └── .gitkeep │ ├── optimizer │ └── .gitkeep │ └── msa_transformer │ └── .gitkeep ├── data ├── train │ └── cath │ │ ├── msa │ │ └── .gitkeep │ │ ├── getCATH.sh │ │ └── CATH_get_pdbids.py └── test │ ├── clinvar │ ├── msa │ │ └── .gitkeep │ ├── raw │ │ └── .gitkeep │ └── structure │ │ ├── raw │ │ └── .gitkeep │ │ └── cleaned │ │ └── .gitkeep │ ├── mave_val │ ├── msa │ │ └── .gitkeep │ └── structure │ │ └── cleaned │ │ └── .gitkeep │ ├── proteingym │ ├── exp │ │ └── .gitkeep │ ├── msa │ │ └── .gitkeep │ ├── raw │ │ └── .gitkeep │ ├── structure │ │ ├── raw │ │ │ └── .gitkeep │ │ └── cleaned │ │ │ └── .gitkeep │ ├── acknowledgements.txt │ └── all_models_substitutions_Spearman_Uniprot_level.csv │ ├── rocklin │ ├── exp │ │ └── .gitkeep │ ├── msa │ │ └── .gitkeep │ ├── raw │ │ └── .gitkeep │ └── structure │ │ ├── raw │ │ └── .gitkeep │ │ └── cleaned │ │ └── .gitkeep │ └── scannet │ ├── labels │ └── .gitkeep │ ├── msa │ └── .gitkeep │ ├── raw │ └── .gitkeep │ └── structure │ ├── raw │ └── .gitkeep │ └── cleaned │ └── .gitkeep ├── src ├── models │ ├── msa_transformer │ │ ├── version.py │ │ ├── constants.py │ │ ├── acknowledgements.txt │ │ ├── __init__.py │ │ ├── model.py │ │ ├── axial_attention.py │ │ ├── data.py │ │ ├── modules.py │ │ └── multihead_attention.py │ ├── gvp │ │ ├── acknowledgements.txt │ │ ├── models.py │ │ ├── data.py │ │ └── __init__.py │ └── scannet │ │ └── model.py ├── pdb_parser_scripts │ ├── clean_pdbs.sh │ ├── parse_pdbs.py │ └── clean_pdb.py ├── merge_and_sort_msas.py ├── run_test_proteingym.py ├── run_test_mave.py ├── run_test_rocklin.py ├── run_test_clinvar.py ├── run_pipeline.py └── run_pipeline_scannet.py ├── LICENSE └── README.md /output/clinvar/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/rocklin/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/train/cath/msa/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/mave_val/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/proteingym/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/clinvar/msa/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/clinvar/raw/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/mave_val/msa/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/proteingym/exp/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/proteingym/msa/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/proteingym/raw/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/rocklin/exp/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/rocklin/msa/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/rocklin/raw/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/scannet/labels/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/scannet/msa/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/scannet/raw/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/scannet/metrics/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/train/metrics/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/train/models/gvp/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/train/models/optimizer/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/clinvar/structure/raw/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/proteingym/structure/raw/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/rocklin/structure/raw/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/scannet/structure/raw/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/scannet/models/optimizer/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/scannet/models/transformer/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/clinvar/structure/cleaned/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/mave_val/structure/cleaned/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/proteingym/structure/cleaned/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/rocklin/structure/cleaned/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/scannet/structure/cleaned/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /output/train/models/msa_transformer/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/train/cath/getCATH.sh: -------------------------------------------------------------------------------- 1 | wget http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/chain_set.jsonl 2 | wget http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/chain_set_splits.json 3 | -------------------------------------------------------------------------------- /src/models/msa_transformer/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | version = "2.0.1" 7 | -------------------------------------------------------------------------------- /src/models/msa_transformer/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # fmt: off 7 | proteinseq_toks = { 8 | 'toks': ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-'] 9 | } 10 | # fmt: on 11 | -------------------------------------------------------------------------------- /src/models/msa_transformer/acknowledgements.txt: -------------------------------------------------------------------------------- 1 | Large parts of the code presented in this directory has been taken from: 2 | https://github.com/facebookresearch/esm/tree/main/esm 3 | 4 | @article{rao2021msa, 5 | author = {Rao, Roshan and Liu, Jason and Verkuil, Robert and Meier, Joshua and Canny, John F. and Abbeel, Pieter and Sercu, Tom and Rives, Alexander}, 6 | title={MSA Transformer}, 7 | year={2021}, 8 | doi={10.1101/2021.02.12.430858}, 9 | url={https://www.biorxiv.org/content/10.1101/2021.02.12.430858v1}, 10 | journal={bioRxiv} 11 | } 12 | -------------------------------------------------------------------------------- /data/test/proteingym/acknowledgements.txt: -------------------------------------------------------------------------------- 1 | The ProteinGym data files have been taken from: 2 | https://github.com/OATML-Markslab/ProteinGym 3 | 4 | @article{notin2023a, 5 | title={Proteingym: Large-scale benchmarks for protein design and fitness prediction}, 6 | author={Notin, Pascal and Kollasch, Aaron W and Ritter, Daniel and van Niekerk, Lood and Paul, Steffanie and Spinner, Hansen and Rollins, Nathan and Shaw, Ada and Weitzman, Ruben and Frazer, Jonathan and others}, 7 | journal={bioRxiv}, 8 | pages={2023--12}, 9 | year={2023}, 10 | publisher={Cold Spring Harbor Laboratory} 11 | } 12 | -------------------------------------------------------------------------------- /src/pdb_parser_scripts/clean_pdbs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Settings 4 | counter=1 5 | dir=$(pwd)/pdb_parser_scripts/ 6 | pdb_dir=$1 7 | pdbs=$pdb_dir/raw/*.pdb 8 | n_pdbs=$(echo $pdbs | wc -w) 9 | 10 | # Create data directories 11 | mkdir -p $pdb_dir/cleaned 12 | 13 | # Clean pdbs 14 | for pdb in $pdbs; 15 | do 16 | python $dir/clean_pdb.py --pdb_file_in $pdb \ 17 | --out_dir $pdb_dir/cleaned/ \ 18 | #&> /dev/null 19 | 20 | # Check for exit code 0 and skip file if not 0. 21 | if [ $? -eq 0 ] 22 | then 23 | echo "Successfully cleaned $pdb. $counter/$n_pdbs." 24 | else 25 | echo "Error when cleaning $pdb. Skipping.." >&2 26 | fi 27 | counter=$((counter+1)) 28 | done 29 | -------------------------------------------------------------------------------- /data/train/cath/CATH_get_pdbids.py: -------------------------------------------------------------------------------- 1 | import json # Opening JSON file f = open('chain_set_splits.json',) data = json.load(f) # Split and process #pdbids_train = [x[:-2].upper() for x in data["train"]] pdbids_train = [x[:-2].upper() for x in data["train"]] pdbids_val = [x[:-2].upper() for x in data["validation"]] pdbids_test = [x[:-2].upper() for x in data["test"]] # Write # tf = open("CATH_pdbids_train.txt", "w") # for element in pdbids_train: # tf. write(element + "\n") # tf.close() tf = open("CATH_pdbids_train.txt", "w") for element in pdbids_train: tf. write(element + "\n") tf.close() tf = open("CATH_pdbids_val.txt", "w") for element in pdbids_val: tf. write(element + "\n") tf.close() tf = open("CATH_pdbids_test.txt", "w") for element in pdbids_test: tf. write(element + "\n") tf.close() -------------------------------------------------------------------------------- /src/models/gvp/acknowledgements.txt: -------------------------------------------------------------------------------- 1 | Large parts of the code presented in this directory has been taken from: 2 | https://github.com/drorlab/gvp-pytorch/tree/main 3 | 4 | Citations: 5 | @inproceedings{ 6 | jing2021learning, 7 | title={Learning from Protein Structure with Geometric Vector Perceptrons}, 8 | author={Bowen Jing and Stephan Eismann and Patricia Suriana and Raphael John Lamarre Townshend and Ron Dror}, 9 | booktitle={International Conference on Learning Representations}, 10 | year={2021}, 11 | url={https://openreview.net/forum?id=1YLJDvSx6J4} 12 | } 13 | 14 | @article{jing2021equivariant, 15 | title={Equivariant Graph Neural Networks for 3D Macromolecular Structure}, 16 | author={Jing, Bowen and Eismann, Stephan and Soni, Pratham N and Dror, Ron O}, 17 | journal={arXiv preprint arXiv:2106.03843}, 18 | year={2021} 19 | } 20 | 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Linderstrøm-Lang Centre for Protein Science, University of Copenhagen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/models/scannet/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from tempfile import TemporaryDirectory 4 | from typing import Tuple 5 | 6 | import torch 7 | from torch import nn, Tensor 8 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 9 | from torch.utils.data import dataset 10 | 11 | class TransformerModel(nn.Module): 12 | def __init__(self, ntoken: int, nhead: int, d_hid: int, 13 | nlayers: int, dropout: float = 0.5): 14 | super().__init__() 15 | self.model_type = 'Transformer' 16 | self.pos_encoder = PositionalEncoding(d_hid, dropout) 17 | encoder_layers = TransformerEncoderLayer(d_hid, nhead, d_hid, dropout) 18 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 19 | self.linear = nn.Linear(d_hid, 1) 20 | 21 | self.init_weights() 22 | 23 | def init_weights(self) -> None: 24 | initrange = 0.1 25 | self.linear.bias.data.zero_() 26 | self.linear.weight.data.uniform_(-initrange, initrange) 27 | 28 | def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor: 29 | src = self.pos_encoder(src) 30 | output = self.transformer_encoder(src, src_mask) 31 | output = self.linear(output) 32 | return output 33 | 34 | class PositionalEncoding(nn.Module): 35 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1024): 36 | super().__init__() 37 | self.dropout = nn.Dropout(p=dropout) 38 | 39 | position = torch.arange(max_len).unsqueeze(1) 40 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 41 | pe = torch.zeros(max_len, 1, d_model) 42 | pe[:, 0, 0::2] = torch.sin(position * div_term) 43 | pe[:, 0, 1::2] = torch.cos(position * div_term) 44 | self.register_buffer('pe', pe) 45 | 46 | def forward(self, x: Tensor) -> Tensor: 47 | x = x + self.pe[:x.size(0)] 48 | return self.dropout(x) 49 | 50 | -------------------------------------------------------------------------------- /src/merge_and_sort_msas.py: -------------------------------------------------------------------------------- 1 | import string 2 | from Bio import SeqIO 3 | from typing import List, Tuple 4 | import numpy as np 5 | import glob 6 | import sys 7 | import subprocess 8 | 9 | # Initialize 10 | deletekeys = dict.fromkeys(string.ascii_lowercase) 11 | deletekeys["."] = None 12 | deletekeys["*"] = None 13 | translation = str.maketrans(deletekeys) 14 | 15 | 16 | def remove_insertions(sequence: str): 17 | """Removes any insertions into the sequence. Needed to load aligned sequences in an MSA.""" 18 | return sequence.translate(translation) 19 | 20 | 21 | def read_msa(filename: str) -> List[Tuple[str, str]]: 22 | """Reads the first nseq sequences from an MSA file, automatically removes insertions.""" 23 | return [ 24 | (record.description, remove_insertions(str(record.seq))) 25 | for record in SeqIO.parse(filename, "fasta") 26 | ] 27 | 28 | 29 | def hamming_distance(string1, string2): 30 | return sum(c1 != c2 for c1, c2 in zip(string1, string2)) 31 | 32 | 33 | # Initialize 34 | msa_dir = sys.argv[1] 35 | subprocess.run(["mkdir", "-p", f"{msa_dir}_tmp"]) 36 | 37 | # Load MSA files 38 | msa_files = sorted(glob.glob(f"{msa_dir}/*.a3m")) 39 | 40 | # Loop 41 | for i, _file in enumerate(msa_files): 42 | print(f"Processing MSA: {i+1}/{len(msa_files)}") 43 | 44 | msa = read_msa(_file) 45 | 46 | seqs = [x for x in msa] 47 | 48 | query = seqs[0] 49 | seqs = seqs[1:] 50 | ham_dists = np.zeros(len(seqs)) 51 | 52 | for j, seq in enumerate(seqs): 53 | assert len(query) == len(seq) 54 | ham_dists[j] = hamming_distance(query[1], seq[1]) 55 | 56 | # Rank indices 57 | rank_indices = np.argsort(ham_dists) 58 | 59 | # Remove query duplicates 60 | if 0 in ham_dists: 61 | query_idx = np.argwhere(ham_dists == 0)[0] 62 | rank_indices = np.delete( 63 | rank_indices, np.argwhere(np.isin(rank_indices, query_idx)) 64 | ) 65 | 66 | # Construct new sorted MSA 67 | seqs_new = [] 68 | for idx in rank_indices: 69 | seqs_new.append(seqs[idx]) 70 | 71 | # Write to new file 72 | outfile = open(f"{msa_dir}_tmp/{query[0]}.a3m", "w") 73 | outfile.write(f">{query[0]}\n") 74 | outfile.write(f"{query[1]}\n") 75 | 76 | for seq in seqs_new: 77 | outfile.write(f">{seq[0]}\n") 78 | outfile.write(f"{seq[1]}\n") 79 | outfile.close() 80 | 81 | # Delete tmp directory 82 | subprocess.run(["rm", "-r", f"{msa_dir}"]) 83 | subprocess.run(["mv", f"{msa_dir}_tmp", msa_dir]) 84 | -------------------------------------------------------------------------------- /src/pdb_parser_scripts/parse_pdbs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import enum 3 | import os 4 | import sys 5 | import json 6 | import glob 7 | import Bio 8 | import Bio.PDB 9 | from Bio import SeqIO 10 | from Bio.SeqRecord import SeqRecord 11 | import pickle 12 | 13 | def parse(pdb_dir): 14 | # Load PDBS 15 | pdb_filenames = sorted(glob.glob(f"{pdb_dir}/cleaned/*.pdb")) 16 | 17 | # Create fasta file 18 | fh = open(f"{pdb_dir}/seqs.fasta","w") 19 | 20 | # Initialize list of pdb dicts 21 | pdb_dict_list = [] 22 | 23 | # Loop over proteins 24 | for pdb_filename in pdb_filenames: 25 | 26 | # Parse structure with Biopython 27 | pdb_parser = Bio.PDB.PDBParser() 28 | pdb_id = os.path.basename(pdb_filename).split("/")[-1][:-4] 29 | structure = pdb_parser.get_structure(pdb_id, pdb_filename) 30 | first_model = structure.get_list()[0] 31 | first_model.child_list = sorted(first_model.child_list) # Sort chains alphabetically 32 | 33 | # Iterate over chain,residue,atoms and extract features 34 | for chain in first_model: # Loop over chains even though there is only 1 35 | 36 | # Initialize 37 | chain_id = chain.id 38 | seq = [] 39 | coords = [] 40 | pdb_dict = {} 41 | 42 | for j, residue in enumerate(chain): 43 | atom_names = [] 44 | backbone_coords = [] 45 | 46 | for atom in residue: 47 | # Extract atom features 48 | if atom.name in ["N","CA","C","O"]: 49 | atom_names.append(atom.name) 50 | #backbone_coords.append(list(atom.coord)) 51 | backbone_coords.append([str(x) for x in atom.coord]) 52 | 53 | # Check that all backbone atoms are present 54 | if atom_names == ["N","CA","C","O"] and len(backbone_coords)==4 and residue._id[0].startswith("H_") == False: # HETATM check 55 | 56 | # Add coordinates 57 | coords.append(backbone_coords) 58 | 59 | # Add residue to sequence 60 | seq.append(Bio.PDB.Polypeptide.three_to_one(residue.resname)) 61 | 62 | # Save coords+seq to dict 63 | pdb_dict["name"] = pdb_id 64 | pdb_dict["coords"] = coords 65 | pdb_dict["seq"] = "".join(seq) 66 | pdb_dict_list.append(pdb_dict) 67 | 68 | # Output seq to fasta 69 | fh.write(f">{pdb_dict['name']}\n") 70 | fh.write("".join(seq)) 71 | fh.write("\n") 72 | fh.close() 73 | 74 | # Save total coord dict 75 | with open(f'{pdb_dir}/coords.json', 'w') as fp: 76 | json.dump(pdb_dict_list, fp) 77 | fp.close() 78 | 79 | if __name__ == '__main__': 80 | parse_pdbs() 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A joint embedding of protein sequence and structure enables robust variant effect predictions 2 | 3 | ## Introduction 4 | This repository contains scripts and data to repeat the analyses in Blaabjerg et al.:
5 | [*"A joint embedding of protein sequence and structure enables robust variant effect predictions"*](https://www.biorxiv.org/content/10.1101/2023.12.14.571755v1). 6 | 7 | ## Execution 8 | Execute the pipeline using `src/run_pipeline.py`.
9 | This main script will call other scripts in the `src` directory to train, validate and test the SSEmb model as described in the paper. 10 | 11 | ## Requirements 12 | The code has been developed and tested in a Unix environment using the following packages:
13 | * `python==3.7.16` 14 | * `pytorch==1.13.1` 15 | * `pyg==2.2.0` 16 | * `pytorch-scatter==2.1.0` 17 | * `pytorch-cluster==1.6.0` 18 | * `fair-esm==2.0.0` 19 | * `numpy==1.21.6` 20 | * `pandas==1.3.5` 21 | * `biopython==1.79` 22 | * `openmm==7.6.0` 23 | * `pdbfixer==1.8.1` 24 | * `scipy==1.7.3` 25 | * `scikit-learn==1.0.2` 26 | * `tqdm==4.64.1` 27 | * `pytz==2022.7` 28 | * `matplotlib==3.2.2` 29 | * `mpl-scatter-density==0.7` 30 | 31 | ## Downloads 32 | Data related to the paper can be download here: [https://zenodo.org/records/12798019](https://zenodo.org/records/12798019).
33 | The `data` directory contains the folding subdirectories:
34 | * `train` 35 | * `model_weights`: Final weights for the SSEmb-MSATransformer and SSEmb-GVPGNN modules. 36 | * `optimizer_weights`: Parameters for the optimizer at time of early-stopping. 37 | * `msa`: MSAs for the proteins in the training set. 38 | * `mave_val`: 39 | * `msa`: MSAs for the proteins in the MAVE validation set. 40 | * `rocklin`: 41 | * `msa`: MSAs for the proteins in the mega-scale stability change test set. 42 | * `proteingym`: 43 | * `structure`: AlphaFold-2 generated structures used for the ProteinGym test set. 44 | * `msa`: MSAs for the proteins in the ProteinGym test set. 45 | * `scannet`: 46 | * `model_weights`: Final weights for the SSEmb downstream model trained on the ScanNet data set. 47 | * `optimizer_weights`: Parameters for the optimizer at time of early-stopping. 48 | * `msa`: MSAs for the proteins in the ScanNet data set. 49 | * `clinvar`: 50 | * `structure`: AlphaFold-2 generated structures used for the ClinVar test set. 51 | * `msa`: MSAs for the proteins in the ClinVar test set. 52 | 53 | A copy of this repository can be found on Zenodo here: [https://zenodo.org/doi/10.5281/zenodo.13765792](https://zenodo.org/doi/10.5281/zenodo.13765792).
54 | 55 | ## SSEmbLab webserver 56 | We have created an online Colab-based webserver for making SSEmb predictions called SSEmbLab. The webserver can be accessed [here](https://colab.research.google.com/github/KULL-Centre/_2023_Blaabjerg_SSEmb/blob/main/SSEmbLab.ipynb). 57 | 58 | ## License 59 | Source code and model weights are licensed under the MIT License. 60 | 61 | ## Acknowledgements 62 | We thank Milot Mirdita and the rest of the ColabFold Search team for help in setting up the Colab SSEmb webserver.
63 |
64 | Code for the original MSA Transformer was developed by the ESM team at Meta Research:
65 | [https://github.com/facebookresearch/esm](https://github.com/facebookresearch/esm). 66 |

67 | Code for the original GVP-GNN was developed by Jing et al:
68 | [https://github.com/drorlab/gvp-pytorch](https://github.com/drorlab/gvp-pytorch). 69 | 70 | ## Citation 71 | Please cite: 72 | 73 | *Lasse M. Blaabjerg, Nicolas Jonsson, Wouter Boomsma, Amelie Stein, Kresten Lindorff-Larsen (2023). A joint embedding of protein sequence and structure enables robust variant effect predictions. bioRxiv, 2023.12.* 74 | 75 | ``` 76 | @article {Blaabjerg2023.12.14.571755, 77 | author = {Lasse M. Blaabjerg and Nicolas Jonsson and Wouter Boomsma and Amelie Stein and Kresten Lindorff-Larsen}, 78 | title = {A joint embedding of protein sequence and structure enables robust variant effect predictions}, 79 | elocation-id = {2023.12.14.571755}, 80 | year = {2023}, 81 | doi = {10.1101/2023.12.14.571755}, 82 | URL = {https://www.biorxiv.org/content/early/2023/12/16/2023.12.14.571755}, 83 | eprint = {https://www.biorxiv.org/content/early/2023/12/16/2023.12.14.571755.full.pdf}, 84 | journal = {bioRxiv} 85 | } 86 | ``` 87 | -------------------------------------------------------------------------------- /src/models/msa_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | import torch, functools 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import MessagePassing 5 | from torch_scatter import scatter_add 6 | 7 | def tuple_sum(*args): 8 | ''' 9 | Sums any number of tuples (s, V) elementwise. 10 | ''' 11 | return tuple(map(sum, zip(*args))) 12 | 13 | def tuple_cat(*args, dim=-1): 14 | ''' 15 | Concatenates any number of tuples (s, V) elementwise. 16 | 17 | :param dim: dimension along which to concatenate when viewed 18 | as the `dim` index for the scalar-channel tensors. 19 | This means that `dim=-1` will be applied as 20 | `dim=-2` for the vector-channel tensors. 21 | ''' 22 | dim %= len(args[0][0].shape) 23 | s_args, v_args = list(zip(*args)) 24 | return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim) 25 | 26 | def tuple_index(x, idx): 27 | ''' 28 | Indexes into a tuple (s, V) along the first dimension. 29 | 30 | :param idx: any object which can be used to index into a `torch.Tensor` 31 | ''' 32 | return x[0][idx], x[1][idx] 33 | 34 | def randn(n, dims, device="cpu"): 35 | ''' 36 | Returns random tuples (s, V) drawn elementwise from a normal distribution. 37 | 38 | :param n: number of data points 39 | :param dims: tuple of dimensions (n_scalar, n_vector) 40 | 41 | :return: (s, V) with s.shape = (n, n_scalar) and 42 | V.shape = (n, n_vector, 3) 43 | ''' 44 | return torch.randn(n, dims[0], device=device), \ 45 | torch.randn(n, dims[1], 3, device=device) 46 | 47 | def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True): 48 | ''' 49 | L2 norm of tensor clamped above a minimum value `eps`. 50 | 51 | :param sqrt: if `False`, returns the square of the L2 norm 52 | ''' 53 | out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps) 54 | return torch.sqrt(out) if sqrt else out 55 | 56 | def _split(x, nv): 57 | ''' 58 | Splits a merged representation of (s, V) back into a tuple. 59 | Should be used only with `_merge(s, V)` and only if the tuple 60 | representation cannot be used. 61 | 62 | :param x: the `torch.Tensor` returned from `_merge` 63 | :param nv: the number of vector channels in the input to `_merge` 64 | ''' 65 | v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3)) 66 | s = x[..., :-3*nv] 67 | return s, v 68 | 69 | def _merge(s, v): 70 | ''' 71 | Merges a tuple (s, V) into a single `torch.Tensor`, where the 72 | vector channels are flattened and appended to the scalar channels. 73 | Should be used only if the tuple representation cannot be used. 74 | Use `_split(x, nv)` to reverse. 75 | ''' 76 | v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],)) 77 | return torch.cat([s, v], -1) 78 | 79 | class Dropout(nn.Module): 80 | ''' 81 | Combined dropout for tuples (s, V). 82 | Takes tuples (s, V) as input and as output. 83 | ''' 84 | def __init__(self, drop_rate): 85 | super(Dropout, self).__init__() 86 | self.sdropout = nn.Dropout(drop_rate) 87 | self.vdropout = _VDropout(drop_rate) 88 | 89 | def forward(self, x): 90 | ''' 91 | :param x: tuple (s, V) of `torch.Tensor`, 92 | or single `torch.Tensor` 93 | (will be assumed to be scalar channels) 94 | ''' 95 | if type(x) is torch.Tensor: 96 | return self.sdropout(x) 97 | s, v = x 98 | return self.sdropout(s), self.vdropout(v) 99 | 100 | class LayerNorm(nn.Module): 101 | ''' 102 | Combined LayerNorm for tuples (s, V). 103 | Takes tuples (s, V) as input and as output. 104 | ''' 105 | def __init__(self, dims): 106 | super(LayerNorm, self).__init__() 107 | self.s, self.v = dims 108 | self.scalar_norm = nn.LayerNorm(self.s) 109 | 110 | def forward(self, x): 111 | ''' 112 | :param x: tuple (s, V) of `torch.Tensor`, 113 | or single `torch.Tensor` 114 | (will be assumed to be scalar channels) 115 | ''' 116 | if not self.v: 117 | return self.scalar_norm(x) 118 | s, v = x 119 | vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False) 120 | vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True)) 121 | return self.scalar_norm(s), v / vn 122 | -------------------------------------------------------------------------------- /src/models/gvp/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from . import GVP, GVPConvLayer, LayerNorm, tuple_index 6 | from torch.distributions import Categorical 7 | from torch_scatter import scatter_mean 8 | import torch.utils.checkpoint as checkpoint 9 | import sys 10 | from torch_geometric.nn import aggr 11 | 12 | class SSEmbGNN(torch.nn.Module): 13 | ''' 14 | GVP-GNN for structure-conditioned autoregressive 15 | protein design as described in manuscript. 16 | 17 | Takes in protein structure graphs of type `torch_geometric.data.Data` 18 | or `torch_geometric.data.Batch` and returns a categorical distribution 19 | over 20 amino acids at each position in a `torch.Tensor` of 20 | shape [n_nodes, 20]. 21 | 22 | Should be used with `gvp.data.ProteinGraphDataset`, or with generators 23 | of `torch_geometric.data.Batch` objects with the same attributes. 24 | 25 | The standard forward pass requires sequence information as input 26 | and should be used for training or evaluating likelihood. 27 | For sampling or design, use `self.sample`. 28 | 29 | :param node_in_dim: node dimensions in input graph, should be 30 | (6, 3) if using original features 31 | :param node_h_dim: node dimensions to use in GVP-GNN layers 32 | :param edge_in_dim: edge dimensions in input graph, should be 33 | (32, 1) if using original features 34 | :param edge_h_dim: edge dimensions to embed to before use 35 | in GVP-GNN layers 36 | :param num_layers: number of GVP-GNN layers in each of the encoder 37 | and decoder modules 38 | :param drop_rate: rate to use in all dropout layers 39 | ''' 40 | def __init__(self, node_in_dim, node_h_dim, 41 | edge_in_dim, edge_h_dim, 42 | num_layers=4, drop_rate=0.0, vector_gate=True): 43 | 44 | super(SSEmbGNN, self).__init__() 45 | 46 | # Get correct dimensions 47 | self.W_v = nn.Sequential( 48 | GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=vector_gate), 49 | LayerNorm(node_h_dim) 50 | ) 51 | self.W_e = nn.Sequential( 52 | GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=vector_gate), 53 | LayerNorm(edge_h_dim) 54 | ) 55 | 56 | # Encode 57 | self.encoder_layers = nn.ModuleList( 58 | GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate, vector_gate=vector_gate) 59 | for _ in range(num_layers)) 60 | 61 | self.W_S = nn.Embedding(21, 21) 62 | 63 | self.W_M = nn.Sequential( 64 | nn.Linear(768, node_h_dim[0]), 65 | ) 66 | 67 | self.W_decoder_in = nn.Sequential( 68 | nn.Linear(node_h_dim[0]*2, node_h_dim[0]*2), 69 | nn.LayerNorm(node_h_dim[0]*2), 70 | nn.ReLU(), 71 | nn.Dropout(drop_rate), 72 | nn.Linear(node_h_dim[0]*2, node_h_dim[0]), 73 | nn.LayerNorm(node_h_dim[0]), 74 | ) 75 | 76 | # Decode 77 | node_h_dim = (node_h_dim[0], node_h_dim[1]) 78 | edge_h_dim = (edge_h_dim[0] + 21, edge_h_dim[1]) 79 | 80 | self.decoder_layers = nn.ModuleList( 81 | GVPConvLayer(node_h_dim, edge_h_dim, 82 | drop_rate=drop_rate, vector_gate=vector_gate) 83 | for _ in range(num_layers)) 84 | 85 | # Out 86 | self.W_out = GVP(node_h_dim, (20, 0), activations=(None, None), vector_gate=vector_gate) 87 | 88 | def forward(self, h_V, edge_index, h_E, msa_emb, seq, get_emb=False): 89 | ''' 90 | Forward pass to be used at train-time, or evaluating likelihood. 91 | 92 | :param h_V: tuple (s, V) of node embeddings 93 | :param edge_index: `torch.Tensor` of shape [2, num_edges] 94 | :param h_E: tuple (s, V) of edge embeddings 95 | :param seq: int `torch.Tensor` of shape [num_nodes] 96 | ''' 97 | # Run through GVP to get correct hidden dimensions 98 | h_V = self.W_v(h_V) 99 | h_E = self.W_e(h_E) 100 | 101 | # Message passing 102 | # Encoding 103 | for layer in self.encoder_layers: 104 | h_V = layer(h_V, edge_index, h_E) 105 | 106 | # Add sequence info 107 | h_S = self.W_S(seq) 108 | h_S = h_S[edge_index[0]] 109 | h_E = (torch.cat([h_E[0], h_S], dim=-1), h_E[1]) 110 | 111 | h_M = self.W_M(msa_emb) 112 | h_V = (self.W_decoder_in(torch.cat([h_V[0], h_M], dim=-1)), h_V[1]) 113 | 114 | # Decoding 115 | for layer in self.decoder_layers: 116 | h_V = layer(h_V, edge_index, h_E) 117 | 118 | # Out 119 | if get_emb == True: 120 | return h_V[0] 121 | else: 122 | logits = self.W_out(h_V) 123 | return logits 124 | -------------------------------------------------------------------------------- /src/run_test_proteingym.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import subprocess 3 | import glob 4 | import re 5 | import torch 6 | import models.gvp.data, models.gvp.models 7 | import json 8 | import torch_geometric 9 | import esm 10 | import pandas as pd 11 | import random 12 | import torch.multiprocessing 13 | 14 | torch.multiprocessing.set_sharing_strategy("file_system") 15 | from models.msa_transformer.model import MSATransformer 16 | from models.gvp.models import SSEmbGNN 17 | from helpers import ( 18 | read_msa, 19 | loop_pred, 20 | ) 21 | from visualization import plot_proteingym 22 | import pdb_parser_scripts.parse_pdbs as parse_pdbs 23 | import torch.utils.data 24 | from collections import OrderedDict 25 | from ast import literal_eval 26 | import subprocess 27 | import pickle 28 | 29 | def test(run_name, epoch, msa_row_attn_mask=True, device=None): 30 | # Load data and dict of variant positions 31 | with open(f"../data/test/proteingym/data_with_msas.pkl", "rb") as fp: 32 | data = pickle.load(fp) 33 | 34 | with open(f"../data/test/proteingym/variant_pos_dict.pkl", "rb") as fp: 35 | variant_pos_dict = pickle.load(fp) 36 | 37 | # Load DMS data 38 | df_dms = pd.read_csv("../data/test/proteingym/exp/dms.csv") 39 | 40 | # Convert to graph data sets 41 | testset = models.gvp.data.ProteinGraphData(data) 42 | letter_to_num = testset.letter_to_num 43 | 44 | # Load MSA Transformer 45 | _, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S() 46 | model_msa = MSATransformer(msa_alphabet) 47 | model_msa = model_msa.to(device) 48 | msa_batch_converter = msa_alphabet.get_batch_converter() 49 | 50 | model_dict = OrderedDict() 51 | state_dict_msa = torch.load( 52 | f"../output/train/models/msa_transformer/{run_name}_msa_transformer_{epoch}.pt" 53 | ) 54 | pattern = re.compile("module.") 55 | for k, v in state_dict_msa.items(): 56 | if re.search("module", k): 57 | model_dict[re.sub(pattern, "", k)] = v 58 | else: 59 | model_dict = state_dict_msa 60 | model_msa.load_state_dict(model_dict) 61 | 62 | # Load GVP 63 | node_dim = (256, 64) 64 | edge_dim = (32, 1) 65 | model_gvp = SSEmbGNN((6, 3), node_dim, (32, 1), edge_dim) 66 | model_gvp = model_gvp.to(device) 67 | 68 | model_dict = OrderedDict() 69 | state_dict_gvp = torch.load(f"../output/train/models/gvp/{run_name}_gvp_{epoch}.pt") 70 | pattern = re.compile("module.") 71 | for k, v in state_dict_gvp.items(): 72 | if k.startswith("module"): 73 | model_dict[k[7:]] = v 74 | else: 75 | model_dict = state_dict_gvp 76 | model_gvp.load_state_dict(model_dict) 77 | 78 | # Initialize data loader 79 | test_loader = torch_geometric.loader.DataLoader( 80 | testset, batch_size=1, shuffle=False 81 | ) 82 | 83 | # Call test 84 | model_msa.eval() 85 | model_gvp.eval() 86 | 87 | with torch.no_grad(): 88 | pred_list, acc_mean = loop_pred( 89 | model_msa, 90 | model_gvp, 91 | msa_batch_converter, 92 | test_loader, 93 | variant_pos_dict, 94 | data, 95 | letter_to_num, 96 | msa_row_attn_mask=msa_row_attn_mask, 97 | device=device, 98 | ) 99 | 100 | # Transform results into df 101 | df_ml = pd.DataFrame(pred_list, columns=["dms_id", "variant_pos", "score_ml_pos"]) 102 | 103 | # Save 104 | df_ml.to_csv(f"../output/proteingym/df_ml_{run_name}_{epoch}.csv", index=False) 105 | 106 | # Load 107 | df_ml = pd.read_csv( 108 | f"../output/proteingym/df_ml_{run_name}_{epoch}.csv", 109 | converters=dict(score_ml_pos=literal_eval), 110 | ) 111 | 112 | # Compute score_ml from nlls 113 | dms_variant_list = df_dms.values.tolist() 114 | for i, row in enumerate(dms_variant_list): 115 | dms_id = row[0] 116 | print( 117 | f"Computing score for assay {dms_id} variant: {i+1}/{len(dms_variant_list)}" 118 | ) 119 | variant_set = row[1].split(":") 120 | score_ml = 0.0 121 | 122 | for variant in variant_set: 123 | wt = letter_to_num[variant[0]] 124 | pos = int(re.findall(r"\d+", variant)[0]) 125 | mt = letter_to_num[variant[-1]] 126 | score_ml_pos = df_ml[ 127 | (df_ml["dms_id"] == dms_id) & (df_ml["variant_pos"] == pos) 128 | ]["score_ml_pos"].values[0] 129 | score_ml += float(score_ml_pos[mt]) 130 | dms_variant_list[i].append(score_ml) 131 | df_total = pd.DataFrame( 132 | dms_variant_list, columns=["dms_id", "variant_set", "score_dms", "score_ml"] 133 | ) 134 | 135 | # Save 136 | df_total.to_csv(f"../output/proteingym/df_total_{run_name}_{epoch}.csv", index=False) 137 | 138 | # Load 139 | df_total = pd.read_csv(f"../output/proteingym/df_total_{run_name}_{epoch}.csv") 140 | 141 | # Compute correlations 142 | plot_proteingym(df_total, run_name, epoch) 143 | -------------------------------------------------------------------------------- /src/run_test_mave.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import subprocess 3 | import glob 4 | import re 5 | import torch 6 | import models.gvp.data, models.gvp.models 7 | import json 8 | import torch_geometric 9 | import esm 10 | import pandas as pd 11 | import random 12 | import torch.multiprocessing 13 | import pickle 14 | 15 | torch.multiprocessing.set_sharing_strategy("file_system") 16 | from models.msa_transformer.model import MSATransformer 17 | from models.gvp.models import SSEmbGNN 18 | from statistics import mean 19 | from helpers import ( 20 | read_msa, 21 | mave_val_pdb_to_prot, 22 | loop_pred, 23 | save_df_to_prism, 24 | get_prism_corr, 25 | get_prism_corr_all, 26 | ) 27 | import pdb_parser_scripts.parse_pdbs as parse_pdbs 28 | import torch.utils.data 29 | from collections import OrderedDict 30 | from ast import literal_eval 31 | import subprocess 32 | import shutil 33 | 34 | def test(run_name, epoch, msa_row_attn_mask=True, get_only_ssemb_metrics=True, device=None): 35 | # Load data and dict of variant positions 36 | with open(f"../data/test/mave_val/data_with_msas.pkl", "rb") as fp: 37 | data = pickle.load(fp) 38 | 39 | with open(f"../data/test/mave_val/variant_pos_dict.pkl", "rb") as fp: 40 | variant_pos_dict = pickle.load(fp) 41 | 42 | # Convert to graph data sets 43 | testset = models.gvp.data.ProteinGraphData(data) 44 | letter_to_num = testset.letter_to_num 45 | 46 | # Load MSA Transformer 47 | _, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S() 48 | model_msa = MSATransformer(msa_alphabet) 49 | model_msa = model_msa.to(device) 50 | msa_batch_converter = msa_alphabet.get_batch_converter() 51 | 52 | model_dict = OrderedDict() 53 | state_dict_msa = torch.load( 54 | f"../output/train/models/msa_transformer/{run_name}_msa_transformer_{epoch}.pt" 55 | ) 56 | pattern = re.compile("module.") 57 | for k, v in state_dict_msa.items(): 58 | if re.search("module", k): 59 | model_dict[re.sub(pattern, "", k)] = v 60 | else: 61 | model_dict = state_dict_msa 62 | model_msa.load_state_dict(model_dict) 63 | 64 | # Load GVP 65 | node_dim = (256, 64) 66 | edge_dim = (32, 1) 67 | model_gvp = SSEmbGNN((6, 3), node_dim, (32, 1), edge_dim) 68 | model_gvp = model_gvp.to(device) 69 | 70 | model_dict = OrderedDict() 71 | state_dict_gvp = torch.load(f"../output/train/models/gvp/{run_name}_gvp_{epoch}.pt") 72 | pattern = re.compile("module.") 73 | for k, v in state_dict_gvp.items(): 74 | if k.startswith("module"): 75 | model_dict[k[7:]] = v 76 | else: 77 | model_dict = state_dict_gvp 78 | model_gvp.load_state_dict(model_dict) 79 | 80 | # Initialize data loader 81 | test_loader = torch_geometric.loader.DataLoader( 82 | testset, batch_size=1, shuffle=False 83 | ) 84 | 85 | # Call test 86 | model_msa.eval() 87 | model_gvp.eval() 88 | 89 | with torch.no_grad(): 90 | pred_list, acc_mean = loop_pred( 91 | model_msa, 92 | model_gvp, 93 | msa_batch_converter, 94 | test_loader, 95 | variant_pos_dict, 96 | data, 97 | letter_to_num, 98 | msa_row_attn_mask=msa_row_attn_mask, 99 | device=device, 100 | ) 101 | 102 | # Transform results into df 103 | df_ml = pd.DataFrame(pred_list, columns=["dms_id", "variant_pos", "score_ml_pos"]) 104 | 105 | # Save 106 | df_ml.to_csv(f"../output/mave_val/df_ml_{run_name}.csv", index=False) 107 | 108 | # Load 109 | df_ml = pd.read_csv( 110 | f"../output/mave_val/df_ml_{run_name}.csv", 111 | converters=dict(score_ml_pos=literal_eval), 112 | ) 113 | 114 | # Compute score_ml from nlls 115 | pred_list_scores = [] 116 | mt_list = [x for x in sorted(letter_to_num, key=letter_to_num.get)][:-1] 117 | 118 | for entry in data: 119 | dms_id = entry["name"] 120 | df_dms_id = df_ml[df_ml["dms_id"] == dms_id] 121 | 122 | wt = [[wt] * 20 for wt in entry["seq"]] 123 | pos = [[pos] * 20 for pos in list(df_dms_id["variant_pos"])] 124 | pos = [item for sublist in pos for item in sublist] 125 | mt = mt_list * len(wt) 126 | wt = [item for sublist in wt for item in sublist] 127 | score_ml = [ 128 | item for sublist in list(df_dms_id["score_ml_pos"]) for item in sublist 129 | ] 130 | 131 | rows = [ 132 | [dms_id, wt[i] + str(pos[i]) + mt[i], score_ml[i]] for i in range(len(mt)) 133 | ] 134 | pred_list_scores += rows 135 | 136 | # Transform results into df 137 | df_ml_scores = pd.DataFrame( 138 | pred_list_scores, columns=["dms_id", "variant", "score_ml"] 139 | ) 140 | 141 | # Save 142 | df_ml_scores.to_csv(f"../output/mave_val/df_ml_scores_{run_name}.csv", index=False) 143 | 144 | # Load 145 | df_ml_scores = pd.read_csv(f"../output/mave_val/df_ml_scores_{run_name}.csv") 146 | 147 | # Save results to PRISM format 148 | for dms_id in df_ml_scores["dms_id"].unique(): 149 | df_dms = df_ml_scores[df_ml_scores["dms_id"] == dms_id] 150 | save_df_to_prism(df_dms, run_name, dms_id) 151 | 152 | # Compute metrics 153 | if get_only_ssemb_metrics == True: 154 | corrs = [] 155 | for dms_id in df_ml_scores["dms_id"].unique(): 156 | corr = get_prism_corr(dms_id, run_name) 157 | corrs.append(corr) 158 | return mean(corrs), acc_mean 159 | else: 160 | corrs_ssemb = [] 161 | corrs_gemme = [] 162 | corrs_rosetta = [] 163 | for dms_id in df_ml_scores["dms_id"].unique(): 164 | corrs = get_prism_corr_all(dms_id, run_name) 165 | corrs_ssemb.append(corrs[0]) 166 | corrs_gemme.append(corrs[1]) 167 | corrs_rosetta.append(corrs[2]) 168 | print(f"SSEmb: Mean MAVE spearman correlation: {mean(corrs_ssemb):.3f}") 169 | print(f"GEMME: Mean MAVE spearman correlation: {mean(corrs_gemme):.3f}") 170 | print(f"Rosetta: Mean MAVE spearman correlation: {mean(corrs_rosetta):.3f}") 171 | -------------------------------------------------------------------------------- /src/models/msa_transformer/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .modules import ( 10 | AxialTransformerLayer, 11 | LearnedPositionalEmbedding, 12 | RobertaLMHead, 13 | ESM1bLayerNorm, 14 | ContactPredictionHead, 15 | ) 16 | 17 | from .axial_attention import RowSelfAttention, ColumnSelfAttention 18 | from .__init__ import LayerNorm, tuple_index 19 | 20 | class MSATransformer(nn.Module): 21 | def __init__(self): 22 | super().__init__() 23 | self.num_layers = 12 24 | self.embed_dim = 768 25 | self.logit_bias = True 26 | self.ffn_embed_dim = 3072 27 | self.attention_heads = 12 28 | self.dropout = 0.1 29 | self.attention_dropout = 0.1 30 | self.activation_dropout = 0.1 31 | self.max_tokens_per_msa = 2 ** 22 32 | self.max_positions = 1024 33 | self.embed_positions_msa = False 34 | self.alphabet_size = 33 35 | self.padding_idx = 1 36 | self.mask_idx = 32 37 | self.cls_idx = 0 38 | self.eos_idx = 2 39 | self.prepend_bos = True 40 | self.append_eos = False 41 | 42 | self.embed_tokens = nn.Embedding( 43 | self.alphabet_size, self.embed_dim, padding_idx=self.padding_idx 44 | ) 45 | 46 | self.msa_position_embedding = nn.Parameter( 47 | 0.01 * torch.randn(1, 1024, 1, self.embed_dim), 48 | requires_grad=True, 49 | ) 50 | 51 | self.dropout_module = nn.Dropout(self.dropout) 52 | self.layers = nn.ModuleList( 53 | [ 54 | AxialTransformerLayer( 55 | self.embed_dim, 56 | self.ffn_embed_dim, 57 | self.attention_heads, 58 | self.dropout, 59 | self.attention_dropout, 60 | self.activation_dropout, 61 | self.max_tokens_per_msa, 62 | ) 63 | for _ in range(self.num_layers) 64 | ] 65 | ) 66 | 67 | self.contact_head = ContactPredictionHead( 68 | self.num_layers * self.attention_heads, 69 | self.prepend_bos, 70 | self.append_eos, 71 | eos_idx=self.eos_idx, 72 | ) 73 | self.embed_positions = LearnedPositionalEmbedding( 74 | self.max_positions, 75 | self.embed_dim, 76 | self.padding_idx, 77 | ) 78 | self.emb_layer_norm_before = ESM1bLayerNorm(self.embed_dim) 79 | self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim) 80 | self.lm_head = RobertaLMHead( 81 | embed_dim=self.embed_dim, 82 | output_dim=self.alphabet_size, 83 | weight=self.embed_tokens.weight, 84 | ) 85 | 86 | def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False, self_row_attn_mask=None): 87 | if return_contacts: 88 | need_head_weights = True 89 | 90 | assert tokens.ndim == 3 91 | batch_size, num_alignments, seqlen = tokens.size() 92 | padding_mask = tokens.eq(self.padding_idx) # B, R, C 93 | if not padding_mask.any(): 94 | padding_mask = None 95 | 96 | x = self.embed_tokens(tokens) 97 | x += self.embed_positions(tokens.view(batch_size * num_alignments, seqlen)).view(x.size()) 98 | if self.msa_position_embedding is not None: 99 | if x.size(1) > 1024: 100 | raise RuntimeError( 101 | "Using model with MSA position embedding trained on maximum MSA " 102 | f"depth of 1024, but received {x.size(1)} alignments." 103 | ) 104 | x += self.msa_position_embedding[:, :num_alignments] 105 | 106 | x = self.emb_layer_norm_before(x) 107 | 108 | x = self.dropout_module(x) 109 | 110 | if padding_mask is not None: 111 | x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) 112 | 113 | repr_layers = set(repr_layers) 114 | hidden_representations = {} 115 | if 0 in repr_layers: 116 | hidden_representations[0] = x 117 | 118 | if need_head_weights: 119 | row_attn_weights = [] 120 | col_attn_weights = [] 121 | 122 | # Mask row attention from tokens 123 | if self_row_attn_mask is not None: 124 | self_row_attn_mask = nn.functional.pad(self_row_attn_mask, (1, 0, 1, 0), 125 | value=True) 126 | 127 | # B x R x C x D -> R x C x B x D 128 | x = x.permute(1, 2, 0, 3) 129 | 130 | for layer_idx, layer in enumerate(self.layers): 131 | x = layer( 132 | x, 133 | self_attn_padding_mask=padding_mask, 134 | self_row_attn_mask=self_row_attn_mask, 135 | need_head_weights=need_head_weights, 136 | ) 137 | if need_head_weights: 138 | x, col_attn, row_attn = x 139 | # H x C x B x R x R -> B x H x C x R x R 140 | col_attn_weights.append(col_attn.permute(2, 0, 1, 3, 4)) 141 | # H x B x C x C -> B x H x C x C 142 | row_attn_weights.append(row_attn.permute(1, 0, 2, 3)) 143 | if (layer_idx + 1) in repr_layers: 144 | hidden_representations[layer_idx + 1] = x.permute(2, 0, 1, 3) 145 | 146 | x = self.emb_layer_norm_after(x) 147 | x = x.permute(2, 0, 1, 3) # R x C x B x D -> B x R x C x D 148 | 149 | # last hidden representation should have layer norm applied 150 | if (layer_idx + 1) in repr_layers: 151 | hidden_representations[layer_idx + 1] = x 152 | x = self.lm_head(x) 153 | 154 | result = {"logits": x, "representations": hidden_representations} 155 | if need_head_weights: 156 | # col_attentions: B x L x H x C x R x R 157 | col_attentions = torch.stack(col_attn_weights, 1) 158 | # row_attentions: B x L x H x C x C 159 | row_attentions = torch.stack(row_attn_weights, 1) 160 | result["col_attentions"] = col_attentions 161 | result["row_attentions"] = row_attentions 162 | if return_contacts: 163 | contacts = self.contact_head(tokens, row_attentions) 164 | result["contacts"] = contacts 165 | 166 | return result 167 | 168 | def predict_contacts(self, tokens): 169 | return self(tokens, return_contacts=True)["contacts"] 170 | 171 | def max_tokens_per_msa_(self, value: int) -> None: 172 | """The MSA Transformer automatically batches attention computations when 173 | gradients are disabled to allow you to pass in larger MSAs at test time than 174 | you can fit in GPU memory. By default this occurs when more than 2^14 tokens 175 | are passed in the input MSA. You can set this value to infinity to disable 176 | this behavior. 177 | """ 178 | for module in self.modules(): 179 | if isinstance(module, (RowSelfAttention, ColumnSelfAttention)): 180 | module.max_tokens_per_msa = value 181 | -------------------------------------------------------------------------------- /src/pdb_parser_scripts/clean_pdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | import sys 5 | import tempfile 6 | from io import BytesIO, StringIO 7 | 8 | import Bio.PDB 9 | import Bio.PDB.Polypeptide 10 | import Bio.SeqIO 11 | import pdbfixer 12 | import simtk 13 | import openmm 14 | import openmm.app 15 | 16 | 17 | PDBIO = Bio.PDB.PDBIO() 18 | PDB_PARSER = Bio.PDB.PDBParser(PERMISSIVE=0) 19 | 20 | 21 | class NonHetSelector(Bio.PDB.Select): 22 | """ Remove HET atoms and choose first conformation of disordered atoms""" 23 | 24 | def accept_residue(self, residue): 25 | norm_res_bool = residue.get_resname() in [ 26 | pdbfixer.pdbfixer.substitutions[key] 27 | for key in pdbfixer.pdbfixer.substitutions 28 | ] 29 | abnorm_res_bool = residue.get_resname() in [ 30 | key for key in pdbfixer.pdbfixer.substitutions 31 | ] 32 | return ( 33 | norm_res_bool or abnorm_res_bool 34 | ) # Accept abnorm since they are converted later 35 | 36 | def accept_atom(self, atom): 37 | return ( 38 | not atom.is_disordered() 39 | or atom.get_altloc() == "A" 40 | or atom.get_altloc() == "1" 41 | ) and atom.id[0] in ["C", "H", "N", "O", "S", "P"] 42 | 43 | 44 | class PDBFixerResIdentifiabilityIssue(Exception): 45 | pass 46 | 47 | 48 | def _step_3_pdbfixer(first_model, temp3): 49 | for chain in first_model: 50 | for res in chain: 51 | for atom in res: 52 | atom.set_altloc(" ") 53 | PDBIO.set_structure(first_model) 54 | PDBIO.save(temp3) 55 | temp3.flush() 56 | 57 | # Use PDBFixer to fix common PDB errors 58 | fixer = pdbfixer.PDBFixer(temp3.name) 59 | fixer.findMissingResidues() 60 | fixer.findNonstandardResidues() 61 | fixer.replaceNonstandardResidues() 62 | fixer.findMissingAtoms() 63 | fixer.addMissingAtoms() 64 | fixer.addMissingHydrogens(7.0) 65 | return temp3, fixer 66 | 67 | 68 | def _step_4_fix_numbering(fixer, temp3, temp4): 69 | simtk.openmm.app.PDBFile.writeFile( 70 | fixer.topology, fixer.positions, temp4, keepIds=False 71 | ) 72 | temp4.flush() 73 | # Fix IDs manually since pdbfixer does not preserve insertion codes 74 | structure_before = PDB_PARSER.get_structure(temp3.name, temp3.name) 75 | structure_after = PDB_PARSER.get_structure(temp4.name, temp4.name) 76 | residues_before = [] 77 | for chain in structure_before[0]: 78 | residues_before.append(chain.get_list()) 79 | residues_after = [] 80 | for chain in structure_after[0]: 81 | residues_after.append(chain.get_list()) 82 | chain_counter = "" 83 | for i, chain in enumerate(structure_before[0]): 84 | try: 85 | if ( 86 | structure_after[0].get_list()[i].id 87 | != structure_before[0].get_list()[i].id 88 | ): 89 | try: 90 | # HACK BECAUSE OF https://github.com/biopython/biopython/issues/1551 91 | # Essentially, a new change in biopython prevents you from changing the 92 | # id to an already existing id which broke this initial script. 93 | # Therefore, we now change the ids to "change_counter" which will never look 94 | # like a canonical chainid. 95 | structure_after[0][ 96 | structure_before[0].get_list()[i].id 97 | ].id = chain_counter 98 | chain_counter += "KK" 99 | except KeyError: 100 | pass 101 | structure_after[0].get_list()[i].id = ( 102 | structure_before[0].get_list()[i].id 103 | ) 104 | if len(residues_before[i]) != len(residues_after[i]): 105 | raise PDBFixerResIdentifiabilityIssue() 106 | 107 | # When exceeding chainid Z, pdbfixer has discarded it, whereas biopython has not. 108 | # For simplicity, we just discard it as well and pretend it does not exist. 109 | # This is a very rare instance and will likely never be a problem unless you 110 | # are extremely unlucky to work with huge proteins where you care about the 111 | # truncation. 112 | except IndexError: 113 | continue 114 | 115 | counter = 99999 # A large residue number that will never exist in a pdb. 116 | for res1, res2 in zip(residues_before[i], residues_after[i]): 117 | assert ( 118 | res1.get_resname().strip() == res2.get_resname().strip() 119 | or pdbfixer.pdbfixer.substitutions[res1.get_resname()].strip() 120 | == res2.get_resname().strip() 121 | ) 122 | if res2.id != res1.id: 123 | try: 124 | # Similar issue as previous hack https://github.com/biopython/biopython/issues/1551 125 | structure_after[0][chain.get_id()][res1.id].id = ( 126 | " ", 127 | counter, 128 | " ", 129 | ) 130 | except KeyError: 131 | pass 132 | res2.id = res1.id 133 | counter += 1 134 | 135 | return structure_after 136 | 137 | 138 | def clean_pdb(pdb_input_filename: str, out_dir: str): 139 | """ 140 | Function to clean pdbs using pdbfixer. 141 | 142 | Parameters 143 | ---------- 144 | pdb_input_filename: str 145 | PDB filename 146 | out_dir: str 147 | Output directory. 148 | """ 149 | pdbid = pdb_input_filename.split("/")[-1].split(".pdb")[0] 150 | 151 | with tempfile.NamedTemporaryFile(mode="wt", delete=True) as temp1: 152 | first_model = PDB_PARSER.get_structure(pdbid, pdb_input_filename)[0] 153 | 154 | # Step 1: NonHetSelector filter 155 | with tempfile.NamedTemporaryFile(mode="wt", delete=True) as temp2: 156 | PDBIO.set_structure(first_model) 157 | PDBIO.save(temp2, select=NonHetSelector()) 158 | temp2.flush() 159 | first_model = PDB_PARSER.get_structure(temp2.name, temp2.name)[0] 160 | 161 | # Step 2: Replace altloc chars to " " and use pdbfixer 162 | with tempfile.NamedTemporaryFile(mode="wt", delete=True) as temp3: 163 | temp_3, fixer = _step_3_pdbfixer(first_model, temp3) 164 | 165 | # Step 3: Correct for pdbfixer not preserving insertion codes 166 | with tempfile.NamedTemporaryFile(mode="wt", delete=True) as temp4: 167 | structure_after = _step_4_fix_numbering(fixer, temp3, temp4) 168 | with open( 169 | os.path.join(out_dir, pdbid + ".pdb"), "w" 170 | ) as outpdb: 171 | PDBIO.set_structure(structure_after[0]) 172 | PDBIO.save(outpdb) 173 | 174 | if __name__ == "__main__": 175 | # Argument Parser 176 | parser = argparse.ArgumentParser() 177 | parser.add_argument("--pdb_file_in", type=str) 178 | parser.add_argument("--out_dir", type=str) 179 | 180 | # Parse arguments 181 | args_dict = vars(parser.parse_args()) 182 | pdb_input_filename = args_dict["pdb_file_in"] 183 | out_dir = args_dict["out_dir"] 184 | 185 | # Clean 186 | clean_pdb(pdb_input_filename, out_dir) 187 | -------------------------------------------------------------------------------- /src/run_test_rocklin.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import glob 3 | import re 4 | import torch 5 | import models.gvp.data, models.gvp.models 6 | import json 7 | import torch_geometric 8 | import esm 9 | import pandas as pd 10 | import random 11 | import torch.multiprocessing 12 | 13 | torch.multiprocessing.set_sharing_strategy("file_system") 14 | from models.msa_transformer.model import MSATransformer 15 | from models.gvp.models import SSEmbGNN 16 | from helpers import ( 17 | read_msa, 18 | loop_pred, 19 | ) 20 | from visualization import plot_rocklin 21 | import torch.utils.data 22 | from collections import OrderedDict 23 | import shutil 24 | import pdb_parser_scripts.parse_pdbs as parse_pdbs 25 | from collections import OrderedDict 26 | from ast import literal_eval 27 | import subprocess 28 | 29 | 30 | def test(run_name, epoch, num_ensemble=1, device=None): 31 | # Download raw data 32 | subprocess.run( 33 | ["wget", "https://zenodo.org/record/7992926/files/AlphaFold_model_PDBs.zip"] 34 | ) 35 | subprocess.run( 36 | ["unzip", "AlphaFold_model_PDBs.zip", "-d", "../data/test/rocklin/raw/"] 37 | ) 38 | subprocess.run(["rm", "AlphaFold_model_PDBs.zip"]) 39 | 40 | subprocess.run( 41 | [ 42 | "wget", 43 | "https://zenodo.org/record/7992926/files/Processed_K50_dG_datasets.zip", 44 | ] 45 | ) 46 | subprocess.run( 47 | ["unzip", "Processed_K50_dG_datasets.zip", "-d", "../data/test/rocklin/raw/"] 48 | ) 49 | subprocess.run(["rm", "Processed_K50_dG_datasets.zip"]) 50 | 51 | # Load data 52 | df = pd.read_csv( 53 | "../data/test/rocklin/raw/Processed_K50_dG_datasets/Tsuboyama2023_Dataset2_Dataset3_20230416.csv" 54 | ) 55 | 56 | # Use only Dataset 3 with well-defined ddG's 57 | df = df[df["ddG_ML"] != "-"] 58 | 59 | # Switch sign of experimental ddG's 60 | df["ddG_ML"] = -pd.to_numeric(df["ddG_ML"]) 61 | 62 | # Use only non-synonomous substitutions 63 | df = df[~df["mut_type"].str.startswith("ins")] 64 | df = df[~df["mut_type"].str.startswith("del")] 65 | df = df[df["mut_type"] != "wt"] 66 | 67 | # Change pdb names to align with structure names 68 | df["WT_name"] = df["WT_name"].str.replace("|", ":") 69 | 70 | # Move structures 71 | structures = list(df["WT_name"].unique()) 72 | structures_not_available = [] 73 | 74 | for structure in structures: 75 | try: 76 | shutil.copy( 77 | f"../data/test/rocklin/raw/AlphaFold_model_PDBs/{structure}", 78 | f"../data/test/rocklin/structure/raw/{structure}", 79 | ) 80 | except: 81 | structures_not_available.append(structure) 82 | 83 | df = df[~df["WT_name"].isin(structures_not_available)] 84 | print( 85 | f"Number of Rocklin assays without available AF2 structures: {len(structures_not_available)}" 86 | ) 87 | 88 | # Save ddG data 89 | df_ddg = df[["WT_name", "mut_type", "ddG_ML"]].reset_index(drop=True) 90 | df_ddg = df_ddg.rename( 91 | columns={"WT_name": "pdb_id", "mut_type": "variant", "ddG_ML": "score_exp"} 92 | ) 93 | df_ddg["pdb_id"] = df_ddg["pdb_id"].str[:-4] 94 | df_ddg.to_csv("../data/test/rocklin/exp/ddg.csv", index=False) 95 | 96 | ## Pre-process PDBs 97 | pdb_dir = "../data/test/rocklin/structure/" 98 | subprocess.run( 99 | [ 100 | "pdb_parser_scripts/clean_pdbs.sh", 101 | str(pdb_dir), 102 | ] 103 | ) 104 | parse_pdbs.parse(pdb_dir) 105 | 106 | # Load structure data 107 | with open(f"../data/test/rocklin/structure/coords.json") as json_file: 108 | data = json.load(json_file) 109 | json_file.close() 110 | 111 | # Compute MSAs 112 | sys.path += [":/projects/prism/people/skr526/mmseqs/bin"] 113 | subprocess.run( 114 | [ 115 | "colabfold_search", 116 | f"{pdb_dir}/seqs.fasta", 117 | "/projects/prism/people/skr526/databases", 118 | "../data/test/rocklin/msa/", 119 | ] 120 | ) 121 | subprocess.run(["python", "merge_and_sort_msas.py", "../data/test/rocklin/msa"]) 122 | 123 | # Load MSA data 124 | msa_filenames = sorted(glob.glob(f"../data/test/rocklin/msa/*.a3m")) 125 | mave_msa_sub = {} 126 | for i, f in enumerate(msa_filenames): 127 | # name = f.split("/")[-1].split(".")[0] 128 | name = f.split("/")[-1].split(".")[0][:-2] 129 | mave_msa_sub[name] = [] 130 | for j in range(num_ensemble): 131 | msa = read_msa(f) 132 | msa_sub = [msa[0]] 133 | k = min(len(msa) - 1, 16 - 1) 134 | msa_sub += [msa[i] for i in sorted(random.sample(range(1, len(msa)), k))] 135 | mave_msa_sub[name].append(msa_sub) 136 | 137 | # Add MSAs to data 138 | for entry in data: 139 | entry["msa"] = mave_msa_sub[entry["name"]] 140 | 141 | # Convert to graph data sets 142 | testset = models.gvp.data.ProteinGraphData(data) 143 | letter_to_num = testset.letter_to_num 144 | 145 | # Make variant pos dict 146 | variant_pos_dict = {} 147 | for pdb_id in df_ddg["pdb_id"].unique(): 148 | df_ddg_pdb = df_ddg[df_ddg["pdb_id"] == pdb_id] 149 | variant_wtpos_list = [ 150 | [x[:-1] for x in x.split(":")] for x in df_ddg_pdb["variant"].tolist() 151 | ] 152 | variant_wtpos_list = list( 153 | OrderedDict.fromkeys( 154 | [item for sublist in variant_wtpos_list for item in sublist] 155 | ) 156 | ) # Remove duplicates 157 | variant_pos_dict[pdb_id] = variant_wtpos_list 158 | 159 | # Load MSA Transformer 160 | _, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S() 161 | msa_batch_converter = msa_alphabet.get_batch_converter() 162 | model_msa = MSATransformer(msa_alphabet) 163 | model_msa = model_msa.to(device) 164 | 165 | model_dict = OrderedDict() 166 | state_dict_msa = torch.load( 167 | f"../output/train/models/msa_transformer/{run_name}_msa_transformer_{epoch}.pt" 168 | ) 169 | pattern = re.compile("module.") 170 | for k, v in state_dict_msa.items(): 171 | if re.search("module", k): 172 | model_dict[re.sub(pattern, "", k)] = v 173 | else: 174 | model_dict = state_dict_msa 175 | model_msa.load_state_dict(model_dict) 176 | 177 | # Load GVP 178 | node_dim = (256, 64) 179 | edge_dim = (32, 1) 180 | model_gvp = SSEmbGNN((6, 3), node_dim, (32, 1), edge_dim) 181 | model_gvp = model_gvp.to(device) 182 | 183 | model_dict = OrderedDict() 184 | state_dict_gvp = torch.load(f"../output/train/models/gvp/{run_name}_gvp_{epoch}.pt") 185 | pattern = re.compile("module.") 186 | for k, v in state_dict_gvp.items(): 187 | if k.startswith("module"): 188 | model_dict[k[7:]] = v 189 | else: 190 | model_dict = state_dict_gvp 191 | model_gvp.load_state_dict(model_dict) 192 | 193 | # Initialize data loader 194 | test_loader = torch_geometric.loader.DataLoader( 195 | testset, batch_size=1, shuffle=False 196 | ) 197 | 198 | # Call test 199 | model_msa.eval() 200 | model_gvp.eval() 201 | 202 | with torch.no_grad(): 203 | pred_list, acc_mean = loop_pred( 204 | model_msa, 205 | model_gvp, 206 | msa_batch_converter, 207 | test_loader, 208 | variant_pos_dict, 209 | data, 210 | letter_to_num, 211 | device=device, 212 | ) 213 | 214 | # Transform results into df 215 | df_ml = pd.DataFrame(pred_list, columns=["pdb_id", "variant_pos", "score_ml_pos"]) 216 | 217 | # Save 218 | df_ml.to_csv(f"../output/rocklin/df_ml_{run_name}.csv", index=False) 219 | 220 | # Load 221 | df_ml = pd.read_csv( 222 | f"../output/rocklin/df_ml_{run_name}.csv", 223 | converters=dict(score_ml_pos=literal_eval), 224 | ) 225 | 226 | # Compute score_ml from nlls 227 | # OBS: We cannot vectorize this part since we have an unknown number of mutations per position :( 228 | pdb_variant_list = df_ddg.values.tolist() 229 | for i, row in enumerate(pdb_variant_list): 230 | pdb_id = row[0] 231 | print( 232 | f"Computing score for protein {pdb_id} variant: {i+1}/{len(pdb_variant_list)}" 233 | ) 234 | variant_set = row[1].split(":") 235 | 236 | score_ml = 0.0 237 | for variant in variant_set: 238 | wt = letter_to_num[variant[0]] 239 | pos = int(re.findall(r"\d+", variant)[0]) 240 | mt = letter_to_num[variant[-1]] 241 | score_ml_pos = df_ml[ 242 | (df_ml["pdb_id"].str.startswith(pdb_id)) & (df_ml["variant_pos"] == pos) 243 | ]["score_ml_pos"].values[0] 244 | score_ml += float(score_ml_pos[mt]) 245 | pdb_variant_list[i].append(score_ml) 246 | 247 | # Convert to df 248 | df_total = pd.DataFrame( 249 | pdb_variant_list, columns=["DMS_id", "variant_set", "score_exp", "score_ml"] 250 | ) 251 | 252 | # Save 253 | df_total.to_csv(f"../output/rocklin/df_total_{run_name}.csv", index=False) 254 | 255 | # Load 256 | df_total = pd.read_csv(f"../output/rocklin/df_total_{run_name}.csv") 257 | 258 | # Compute correlations 259 | plot_rocklin(df_total) 260 | -------------------------------------------------------------------------------- /src/models/msa_transformer/axial_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class RowSelfAttention(nn.Module): 12 | """Compute self-attention over rows of a 2D input.""" 13 | 14 | def __init__( 15 | self, 16 | embed_dim, 17 | num_heads, 18 | dropout=0.0, 19 | max_tokens_per_msa: int = 2 ** 16, 20 | ): 21 | super().__init__() 22 | self.num_heads = num_heads 23 | self.dropout = dropout 24 | self.head_dim = embed_dim // num_heads 25 | self.scaling = self.head_dim ** -0.5 26 | self.max_tokens_per_msa = max_tokens_per_msa 27 | self.attn_shape = "hnij" 28 | 29 | self.k_proj = nn.Linear(embed_dim, embed_dim) 30 | self.v_proj = nn.Linear(embed_dim, embed_dim) 31 | self.q_proj = nn.Linear(embed_dim, embed_dim) 32 | 33 | self.out_proj = nn.Linear(embed_dim, embed_dim) 34 | self.dropout_module = nn.Dropout(dropout) 35 | 36 | def align_scaling(self, q): 37 | num_rows = q.size(0) 38 | return self.scaling / math.sqrt(num_rows) 39 | 40 | def _batched_forward( 41 | self, 42 | x, 43 | self_attn_mask=None, 44 | self_attn_padding_mask=None, 45 | ): 46 | num_rows, num_cols, batch_size, embed_dim = x.size() 47 | max_rows = max(1, self.max_tokens_per_msa // num_cols) 48 | attns = 0 49 | scaling = self.align_scaling(x) 50 | for start in range(0, num_rows, max_rows): 51 | attn_weights = self.compute_attention_weights( 52 | x[start : start + max_rows], 53 | scaling, 54 | self_attn_mask=self_attn_mask, 55 | self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows] 56 | if self_attn_padding_mask is not None 57 | else None, 58 | ) 59 | attns += attn_weights 60 | attn_probs = attns.softmax(-1) 61 | attn_probs = self.dropout_module(attn_probs) 62 | 63 | outputs = [] 64 | for start in range(0, num_rows, max_rows): 65 | output = self.compute_attention_update(x[start : start + max_rows], attn_probs) 66 | outputs.append(output) 67 | 68 | output = torch.cat(outputs, 0) 69 | return output, attn_probs 70 | 71 | def compute_attention_weights( 72 | self, 73 | x, 74 | scaling: float, 75 | self_attn_mask=None, 76 | self_attn_padding_mask=None, 77 | ): 78 | num_rows, num_cols, batch_size, embed_dim = x.size() 79 | q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) 80 | k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) 81 | q *= scaling 82 | if self_attn_padding_mask is not None: 83 | # Zero out any padded aligned positions - this is important since 84 | # we take a sum across the alignment axis. 85 | q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q) 86 | 87 | attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k) 88 | 89 | if self_attn_mask is not None: 90 | attn_weights = attn_weights.masked_fill( 91 | self_attn_mask.unsqueeze(0).unsqueeze(0), 92 | -10000, 93 | ) 94 | else: 95 | pass 96 | 97 | if self_attn_padding_mask is not None: 98 | attn_weights = attn_weights.masked_fill( 99 | self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2), 100 | -10000, 101 | ) 102 | 103 | return attn_weights 104 | 105 | def compute_attention_update( 106 | self, 107 | x, 108 | attn_probs, 109 | ): 110 | num_rows, num_cols, batch_size, embed_dim = x.size() 111 | v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) 112 | context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v) 113 | context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim) 114 | output = self.out_proj(context) 115 | return output 116 | 117 | def forward( 118 | self, 119 | x, 120 | self_attn_mask=None, 121 | self_attn_padding_mask=None, 122 | ): 123 | num_rows, num_cols, batch_size, embed_dim = x.size() 124 | if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled(): 125 | return self._batched_forward(x, self_attn_mask, self_attn_padding_mask) 126 | else: 127 | scaling = self.align_scaling(x) 128 | attn_weights = self.compute_attention_weights( 129 | x, scaling, self_attn_mask, self_attn_padding_mask 130 | ) 131 | attn_probs = attn_weights.softmax(-1) 132 | attn_probs = self.dropout_module(attn_probs) 133 | output = self.compute_attention_update(x, attn_probs) 134 | return output, attn_probs 135 | 136 | 137 | class ColumnSelfAttention(nn.Module): 138 | """Compute self-attention over columns of a 2D input.""" 139 | 140 | def __init__( 141 | self, 142 | embed_dim, 143 | num_heads, 144 | dropout=0.0, 145 | max_tokens_per_msa: int = 2 ** 16, 146 | ): 147 | super().__init__() 148 | 149 | self.num_heads = num_heads 150 | self.dropout = dropout 151 | self.head_dim = embed_dim // num_heads 152 | self.scaling = self.head_dim ** -0.5 153 | self.max_tokens_per_msa = max_tokens_per_msa 154 | 155 | self.k_proj = nn.Linear(embed_dim, embed_dim) 156 | self.v_proj = nn.Linear(embed_dim, embed_dim) 157 | self.q_proj = nn.Linear(embed_dim, embed_dim) 158 | 159 | self.out_proj = nn.Linear(embed_dim, embed_dim) 160 | self.dropout_module = nn.Dropout(dropout) 161 | 162 | def _batched_forward( 163 | self, 164 | x, 165 | self_attn_mask=None, 166 | self_attn_padding_mask=None, 167 | ): 168 | num_rows, num_cols, batch_size, embed_dim = x.size() 169 | max_cols = max(1, self.max_tokens_per_msa // num_rows) 170 | outputs = [] 171 | attns = [] 172 | for start in range(0, num_cols, max_cols): 173 | output, attn = self( 174 | x[:, start : start + max_cols], 175 | self_attn_mask=self_attn_mask, 176 | self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols] 177 | if self_attn_padding_mask is not None 178 | else None, 179 | ) 180 | outputs.append(output) 181 | attns.append(attn) 182 | output = torch.cat(outputs, 1) 183 | attns = torch.cat(attns, 1) 184 | return output, attns 185 | 186 | def compute_attention_update( 187 | self, 188 | x, 189 | self_attn_mask=None, 190 | self_attn_padding_mask=None, 191 | ): 192 | num_rows, num_cols, batch_size, embed_dim = x.size() 193 | if num_rows == 1: 194 | # if there is only 1 position, this is equivalent and doesn't break with padding 195 | attn_probs = torch.ones( 196 | self.num_heads, 197 | num_cols, 198 | batch_size, 199 | num_rows, 200 | num_rows, 201 | device=x.device, 202 | dtype=x.dtype, 203 | ) 204 | output = self.out_proj(self.v_proj(x)) 205 | else: 206 | q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) 207 | k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) 208 | v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim) 209 | q *= self.scaling 210 | 211 | attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k) 212 | 213 | if self_attn_mask is not None: 214 | raise NotImplementedError 215 | if self_attn_padding_mask is not None: 216 | attn_weights = attn_weights.masked_fill( 217 | self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3), 218 | -10000, 219 | ) 220 | 221 | attn_probs = attn_weights.softmax(-1) 222 | attn_probs = self.dropout_module(attn_probs) 223 | context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v) 224 | context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim) 225 | output = self.out_proj(context) 226 | return output, attn_probs 227 | 228 | def forward( 229 | self, 230 | x, 231 | self_attn_mask=None, 232 | self_attn_padding_mask=None, 233 | ): 234 | num_rows, num_cols, batch_size, embed_dim = x.size() 235 | # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled(): 236 | if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled(): 237 | return self._batched_forward( 238 | x, 239 | self_attn_mask, 240 | self_attn_padding_mask, 241 | ) 242 | else: 243 | return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask) 244 | -------------------------------------------------------------------------------- /src/run_test_clinvar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import glob 3 | import re 4 | import torch 5 | import models.gvp.data, models.gvp.models 6 | import json 7 | import torch_geometric 8 | import esm 9 | import pandas as pd 10 | import random 11 | import torch.multiprocessing 12 | import numpy as np 13 | 14 | torch.multiprocessing.set_sharing_strategy("file_system") 15 | from models.msa_transformer.model import MSATransformer 16 | from models.gvp.models import SSEmbGNN 17 | from helpers import ( 18 | read_msa, 19 | loop_pred, 20 | compute_auc_group, 21 | ) 22 | from visualization import plot_aucroc 23 | import torch.utils.data 24 | from collections import OrderedDict 25 | import shutil 26 | import pdb_parser_scripts.parse_pdbs as parse_pdbs 27 | from collections import OrderedDict 28 | from ast import literal_eval 29 | import subprocess 30 | from sklearn.metrics import auc, roc_curve, roc_auc_score 31 | 32 | def test(run_name, epoch, num_ensemble=5, msa_row_attn_mask=True, device=None): 33 | # Download raw data 34 | subprocess.run(["wget","https://marks.hms.harvard.edu/proteingym/clinical_ProteinGym_substitutions.zip"]) 35 | subprocess.run(["unzip","clinical_ProteinGym_substitutions.zip","-d","../data/test/clinvar/raw"]) 36 | subprocess.run(["rm","clinical_ProteinGym_substitutions.zip"]) 37 | 38 | # Load ClinVar data 39 | print("Processing ClinVar data") 40 | clinvar_files = glob.glob("../data/test/clinvar/raw/*.csv") 41 | dfs = [] 42 | for f in clinvar_files: 43 | df = pd.read_csv(f) 44 | dfs.append(df) 45 | df_clinvar = pd.concat(dfs, ignore_index=True) 46 | df_clinvar = df_clinvar[["protein","protein_sequence","mutant","DMS_bin_score"]] 47 | df_clinvar = df_clinvar.rename(columns={"protein":"prot_name", 48 | "protein_sequence":"seq", 49 | "mutant":"variant", 50 | "DMS_bin_score":"label", 51 | } 52 | ) 53 | 54 | # Save seqs to fasta 55 | df_uniqueseqs = df[["prot_name","seq"]].drop_duplicates() 56 | fh = open(f"../data/test/clinvar/seqs.fasta","w") 57 | for index, row in df_uniqueseqs.iterrows(): 58 | fh.write(f">{row['prot_name']}\n") 59 | fh.write(f"{row['seq']}") 60 | fh.write("\n") 61 | fh.close() 62 | 63 | ## Pre-process PDBs 64 | pdb_dir = "../data/test/clinvar/structure/" 65 | subprocess.run( 66 | [ 67 | "pdb_parser_scripts/clean_pdbs.sh", 68 | str(pdb_dir), 69 | ] 70 | ) 71 | parse_pdbs.parse(pdb_dir) 72 | 73 | # Load structure data 74 | with open(f"../data/test/clinvar/structure/coords.json") as json_file: 75 | data = json.load(json_file) 76 | json_file.close() 77 | 78 | # Compute MSAs 79 | sys.path += ["/projects/prism/people/skr526/mmseqs/bin"] 80 | subprocess.run("source activate mmseqs2 && colabfold_search ../data/test/clinvar/seqs.fasta /projects/prism/people/skr526/databases ../data/test/clinvar/msa/ && source activate struc-seq", shell=True) 81 | 82 | subprocess.run( 83 | [ 84 | "colabfold_search", 85 | "../data/test/clinvar/seqs.fasta", 86 | "/projects/prism/people/skr526/databases", 87 | "../data/test/clinvar/msa/", 88 | ] 89 | ) 90 | subprocess.run(["python", "merge_and_sort_msas.py", "../data/test/clinvar/msa"]) 91 | 92 | # Load MSA data 93 | msa_filenames = sorted(glob.glob(f"../data/test/clinvar/msa/*.a3m")) 94 | mave_msa_sub = {} 95 | for i, f in enumerate(msa_filenames): 96 | name = ".".join(f.split("/")[-1].split(".")[:-1]) 97 | mave_msa_sub[name] = [] 98 | for j in range(num_ensemble): 99 | msa = read_msa(f) 100 | msa_sub = [msa[0]] 101 | k = min(len(msa) - 1, 16 - 1) 102 | msa_sub += [msa[i] for i in sorted(random.sample(range(1, len(msa)), k))] 103 | mave_msa_sub[name].append(msa_sub) 104 | 105 | # Add MSAs to data 106 | for entry in data: 107 | entry["msa"] = mave_msa_sub[entry["name"]] 108 | 109 | # Make variant pos dict 110 | variant_pos_dict = {} 111 | for prot_name in df_clinvar["prot_name"].unique(): 112 | df_clinvar_prot = df_clinvar[df_clinvar["prot_name"] == prot_name] 113 | variant_wtpos_list = [ 114 | [x[:-1] for x in x.split(":")] for x in df_clinvar_prot["variant"].tolist() 115 | ] 116 | variant_wtpos_list = list( 117 | OrderedDict.fromkeys( 118 | [item for sublist in variant_wtpos_list for item in sublist] 119 | ) 120 | ) # Remove duplicates 121 | variant_pos_dict[prot_name.split(".")[0]] = variant_wtpos_list 122 | 123 | # Convert to graph data sets 124 | testset = models.gvp.data.ProteinGraphData(data) 125 | letter_to_num = testset.letter_to_num 126 | 127 | # Make variant pos dict 128 | variant_pos_dict = {} 129 | for prot_name in df_clinvar["prot_name"].unique(): 130 | df_clinvar_prot = df_clinvar[df_clinvar["prot_name"] == prot_name] 131 | variant_wtpos_list = [ 132 | [x[:-1] for x in x.split(":")] for x in df_clinvar_prot["variant"].tolist() 133 | ] 134 | variant_wtpos_list = list( 135 | OrderedDict.fromkeys( 136 | [item for sublist in variant_wtpos_list for item in sublist] 137 | ) 138 | ) # Remove duplicates 139 | variant_pos_dict[prot_name.split(".")[0]] = variant_wtpos_list 140 | 141 | # Save data and dict of variant positions 142 | with open(f"../data/test/clinvar/data_with_msas.pkl","wb") as fp: 143 | pickle.dump(data, fp) 144 | 145 | with open(f"../data/test/clinvar/variant_pos_dict.pkl","wb") as fp: 146 | pickle.dump(variant_pos_dict, fp) 147 | 148 | # Load MSA Transformer 149 | _, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S() 150 | msa_batch_converter = msa_alphabet.get_batch_converter() 151 | model_msa = MSATransformer(msa_alphabet) 152 | model_msa = model_msa.to(device) 153 | 154 | model_dict = OrderedDict() 155 | state_dict_msa = torch.load( 156 | f"../output/train/models/msa_transformer/{run_name}_msa_transformer_{epoch}.pt" 157 | ) 158 | pattern = re.compile("module.") 159 | for k, v in state_dict_msa.items(): 160 | if re.search("module", k): 161 | model_dict[re.sub(pattern, "", k)] = v 162 | else: 163 | model_dict = state_dict_msa 164 | model_msa.load_state_dict(model_dict) 165 | 166 | # Load GVP 167 | node_dim = (256, 64) 168 | edge_dim = (32, 1) 169 | model_gvp = SSEmbGNN((6, 3), node_dim, (32, 1), edge_dim) 170 | model_gvp = model_gvp.to(device) 171 | 172 | model_dict = OrderedDict() 173 | state_dict_gvp = torch.load(f"../output/train/models/gvp/{run_name}_gvp_{epoch}.pt") 174 | pattern = re.compile("module.") 175 | for k, v in state_dict_gvp.items(): 176 | if k.startswith("module"): 177 | model_dict[k[7:]] = v 178 | else: 179 | model_dict = state_dict_gvp 180 | model_gvp.load_state_dict(model_dict) 181 | 182 | # Initialize data loader 183 | test_loader = torch_geometric.loader.DataLoader( 184 | testset, batch_size=1, shuffle=False 185 | ) 186 | 187 | # Call test 188 | model_msa.eval() 189 | model_gvp.eval() 190 | 191 | with torch.no_grad(): 192 | pred_list, acc_mean = loop_pred( 193 | model_msa, 194 | model_gvp, 195 | msa_batch_converter, 196 | test_loader, 197 | variant_pos_dict, 198 | data, 199 | letter_to_num, 200 | msa_row_attn_mask=msa_row_attn_mask, 201 | device=device, 202 | ) 203 | 204 | # Transform results into df 205 | df_ml = pd.DataFrame(pred_list, columns=["prot_name", "variant_pos", "score_ml_pos"]) 206 | 207 | # Save 208 | df_ml.to_csv(f"../output/clinvar/df_ml_{run_name}.csv", index=False) 209 | 210 | # Load 211 | df_ml = pd.read_csv( 212 | f"../output/clinvar/df_ml_{run_name}.csv", 213 | converters=dict(score_ml_pos=literal_eval), 214 | ) 215 | 216 | # Compute score_ml from nlls 217 | clinvar_variant_list = df_clinvar.values.tolist() 218 | for i, row in enumerate(clinvar_variant_list): 219 | prot_name = row[0].split(".")[0] 220 | print( 221 | f"Computing score for assay {prot_name} variant: {i+1}/{len(clinvar_variant_list)}" 222 | ) 223 | variant_set = row[2].split(":") 224 | score_ml = 0.0 225 | 226 | for variant in variant_set: 227 | wt = letter_to_num[variant[0]] 228 | pos = int(re.findall(r"\d+", variant)[0]) 229 | mt = letter_to_num[variant[-1]] 230 | score_ml_pos = df_ml[ 231 | (df_ml["prot_name"] == prot_name) & (df_ml["variant_pos"] == pos) 232 | ]["score_ml_pos"].values[0] 233 | score_ml += float(score_ml_pos[mt]) 234 | clinvar_variant_list[i].append(score_ml) 235 | df_total = pd.DataFrame( 236 | clinvar_variant_list, columns=["prot_name", "seq", "variant_set", "label", "score_ml"] 237 | ) 238 | df_total = df_total[["prot_name", "variant_set", "label", "score_ml"]] 239 | 240 | # Save 241 | df_total.to_csv(f"../output/clinvar/df_total_{run_name}.csv", index=False) 242 | 243 | # Load 244 | df_total = pd.read_csv(f"../output/clinvar/df_total_{run_name}.csv") 245 | 246 | # Compute AUC 247 | df_total["label_bin"] = [1 if "path" in x else 0 for x in df_total['label'].str.lower()] 248 | df_total["score"] = -df_total["score_ml"] 249 | prot_level_auc = df_total.groupby('prot_name').apply(compute_auc_group) 250 | auc_mean = prot_level_auc.mean(skipna=True) 251 | print(f"SSEmb avg. AUC is: {auc_mean}") 252 | -------------------------------------------------------------------------------- /src/models/gvp/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import tqdm, random 4 | import torch, math 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | import torch_geometric 8 | import torch_cluster 9 | import esm 10 | from scipy.spatial import distance_matrix 11 | 12 | def _normalize(tensor, dim=-1): 13 | ''' 14 | Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. 15 | ''' 16 | return torch.nan_to_num( 17 | torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) 18 | 19 | 20 | def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'): 21 | ''' 22 | From https://github.com/jingraham/neurips19-graph-protein-design 23 | 24 | Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1. 25 | That is, if `D` has shape [...dims], then the returned tensor will have 26 | shape [...dims, D_count]. 27 | ''' 28 | D_mu = torch.linspace(D_min, D_max, D_count, device=device) 29 | D_mu = D_mu.view([1, -1]) 30 | D_sigma = (D_max - D_min) / D_count 31 | D_expand = torch.unsqueeze(D, -1) 32 | 33 | RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2) 34 | return RBF 35 | 36 | 37 | class CATHDataset: 38 | ''' 39 | Loader and container class for the CATH 4.2 dataset downloaded 40 | from http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/. 41 | 42 | Has attributes `self.total, self.train`, `self.val`, `self.test`, each of which are 43 | JSON/dictionary-type datasets as described in README.md. 44 | 45 | :param path: path to chain_set.jsonl 46 | :param splits_path: path to chain_set_splits.json or equivalent. 47 | ''' 48 | def __init__(self, path, splits_path, minlen=0, maxlen=10000): 49 | 50 | self.splits_path = splits_path 51 | self.total = [] 52 | self.maxlen = maxlen 53 | self.minlen = minlen 54 | 55 | with open(path) as f: 56 | lines = f.readlines() 57 | 58 | for line in tqdm.tqdm(lines): 59 | entry = json.loads(line) 60 | name = entry['name'] 61 | coords = entry['coords'] 62 | 63 | entry['coords'] = list(zip( 64 | coords['N'], coords['CA'], coords['C'], coords['O'] 65 | )) 66 | 67 | self.total.append(entry) 68 | 69 | def split(self): 70 | with open(self.splits_path) as f: 71 | dataset_splits = json.load(f) 72 | train_list, val_list, test_list = dataset_splits['train'], \ 73 | dataset_splits['validation'], dataset_splits['test'] 74 | 75 | self.train, self.val, self.test = [], [], [] 76 | 77 | for entry in self.total: 78 | if entry["name"] in train_list: 79 | self.train.append(entry) 80 | elif entry["name"] in val_list: 81 | self.val.append(entry) 82 | elif entry["name"] in test_list: 83 | self.test.append(entry) 84 | 85 | class BatchSampler(data.Sampler): 86 | ''' 87 | Batch sampler that samples a single protein at a time 88 | ''' 89 | def __init__(self, node_counts, shuffle=True): 90 | self.indices = [i for i in range(len(node_counts))] 91 | self.shuffle = shuffle 92 | self._form_batches() 93 | 94 | def _form_batches(self): 95 | self.batches = [] 96 | if self.shuffle: random.shuffle(self.indices) 97 | indices = self.indices 98 | for idx in indices: 99 | self.batches.append([idx]) 100 | 101 | def __len__(self): 102 | if not self.batches: self._form_batches() 103 | return len(self.batches) 104 | 105 | def __iter__(self): 106 | if not self.batches: self._form_batches() 107 | for batch in self.batches: yield batch 108 | 109 | class ProteinGraphData(data.Dataset): 110 | ''' 111 | A map-syle `torch.utils.data.Dataset` which transforms JSON/dictionary-style 112 | protein structures into featurized protein graphs as described in the 113 | manuscript. 114 | ''' 115 | def __init__(self, data_list, num_positional_embeddings=16, 116 | top_k=20, dist_cutoff=12, num_rbf=16, device="cpu"): 117 | super(ProteinGraphData, self).__init__() 118 | 119 | self.data_list = data_list 120 | self.dist_cutoff = dist_cutoff 121 | self.num_rbf = num_rbf 122 | self.num_positional_embeddings = num_positional_embeddings 123 | self.top_k = top_k 124 | self.device = device 125 | self.node_counts = [len(e['seq']) for e in data_list] 126 | 127 | self.letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9, 128 | 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8, 129 | 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, 130 | 'N': 2, 'Y': 18, 'M': 12, "": 20} 131 | self.num_to_letter = {v:k for k, v in self.letter_to_num.items()} 132 | 133 | def __len__(self): return len(self.data_list) 134 | 135 | def __getitem__(self, i): return self._featurize_as_graph(self.data_list[i]) 136 | 137 | def _featurize_as_graph(self, protein): 138 | name = protein['name'] 139 | with torch.no_grad(): 140 | coords = torch.tensor(np.array(protein["coords"]).astype(np.float32), device=self.device) 141 | seq = torch.as_tensor([self.letter_to_num[a] for a in protein['seq']], 142 | device=self.device, dtype=torch.long) 143 | 144 | # Mask positions without coords 145 | mask = torch.isfinite(coords.sum(dim=(1,2))) 146 | coords[~mask] = np.inf 147 | 148 | # Compute edges in graph 149 | X_ca = coords[:, 1] 150 | edge_index = torch_cluster.knn_graph(X_ca, k=self.top_k) 151 | 152 | # Compute distance mask for row attention 153 | dist_mask = torch.zeros((seq.size()[0], seq.size()[0])).bool() 154 | dist_mask[edge_index[1,:], edge_index[0,:]] = True 155 | dist_mask[np.diag_indices(seq.size()[0]), np.diag_indices(seq.size()[0])] = True 156 | dist_mask = ~dist_mask # Reverse so that non-contacts are set to zero in MSA row attention map 157 | 158 | # Compute graph node and edge features 159 | pos_embeddings = self._positional_embeddings(edge_index) 160 | E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]] 161 | rbf = _rbf(E_vectors.norm(dim=-1), D_count=self.num_rbf, device=self.device) 162 | 163 | dihedrals = self._dihedrals(coords) 164 | orientations = self._orientations(X_ca) 165 | sidechains = self._sidechains(coords) 166 | 167 | node_s = dihedrals 168 | node_v = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2) 169 | edge_s = torch.cat([rbf, pos_embeddings], dim=-1) 170 | edge_v = _normalize(E_vectors).unsqueeze(-2) 171 | 172 | node_s, node_v, edge_s, edge_v = map(torch.nan_to_num, 173 | (node_s, node_v, edge_s, edge_v)) 174 | 175 | if "msa" in protein: 176 | msa = protein["msa"] 177 | else: 178 | msa = None 179 | 180 | if "label" in protein: 181 | label = protein["label"] 182 | else: 183 | label = None 184 | 185 | if "emb" in protein: 186 | emb = protein["emb"] 187 | else: 188 | emb = None 189 | 190 | data = torch_geometric.data.Data(x=X_ca, seq=seq, name=name, 191 | node_s=node_s, node_v=node_v, 192 | edge_s=edge_s, edge_v=edge_v, 193 | edge_index=edge_index, 194 | mask=mask, 195 | msa=msa, 196 | dist_mask=dist_mask, 197 | label=label, 198 | emb=emb, 199 | ) 200 | 201 | return data 202 | 203 | def _dihedrals(self, X, eps=1e-7): 204 | # From https://github.com/jingraham/neurips19-graph-protein-design 205 | 206 | X = torch.reshape(X[:, :3], [3*X.shape[0], 3]) 207 | dX = X[1:] - X[:-1] 208 | U = _normalize(dX, dim=-1) 209 | u_2 = U[:-2] 210 | u_1 = U[1:-1] 211 | u_0 = U[2:] 212 | 213 | # Backbone normals 214 | n_2 = _normalize(torch.cross(u_2, u_1), dim=-1) 215 | n_1 = _normalize(torch.cross(u_1, u_0), dim=-1) 216 | 217 | # Angle between normals 218 | cosD = torch.sum(n_2 * n_1, -1) 219 | cosD = torch.clamp(cosD, -1 + eps, 1 - eps) 220 | D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD) 221 | 222 | # This scheme will remove phi[0], psi[-1], omega[-1] 223 | D = F.pad(D, [1, 2]) 224 | D = torch.reshape(D, [-1, 3]) 225 | # Lift angle representations to the circle 226 | D_features = torch.cat([torch.cos(D), torch.sin(D)], 1) 227 | return D_features 228 | 229 | 230 | def _positional_embeddings(self, edge_index, 231 | num_embeddings=None, 232 | period_range=[2, 1000]): 233 | # From https://github.com/jingraham/neurips19-graph-protein-design 234 | num_embeddings = num_embeddings or self.num_positional_embeddings 235 | d = edge_index[0] - edge_index[1] 236 | 237 | frequency = torch.exp( 238 | torch.arange(0, num_embeddings, 2, dtype=torch.float32, device=self.device) 239 | * -(np.log(10000.0) / num_embeddings) 240 | ) 241 | angles = d.unsqueeze(-1) * frequency 242 | E = torch.cat((torch.cos(angles), torch.sin(angles)), -1) 243 | return E 244 | 245 | def _orientations(self, X): 246 | forward = _normalize(X[1:] - X[:-1]) 247 | backward = _normalize(X[:-1] - X[1:]) 248 | forward = F.pad(forward, [0, 0, 0, 1]) 249 | backward = F.pad(backward, [0, 0, 1, 0]) 250 | return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2) 251 | 252 | def _sidechains(self, X): 253 | n, origin, c = X[:, 0], X[:, 1], X[:, 2] 254 | c, n = _normalize(c - origin), _normalize(n - origin) 255 | bisector = _normalize(c + n) 256 | perp = _normalize(torch.cross(c, n)) 257 | vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3) 258 | return vec 259 | -------------------------------------------------------------------------------- /src/models/msa_transformer/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import itertools 7 | import os 8 | from typing import Sequence, Tuple, List, Union 9 | import pickle 10 | import re 11 | import shutil 12 | import torch 13 | from pathlib import Path 14 | from constants import proteinseq_toks 15 | 16 | RawMSA = Sequence[Tuple[str, str]] 17 | 18 | 19 | class FastaBatchedDataset(object): 20 | def __init__(self, sequence_labels, sequence_strs): 21 | self.sequence_labels = list(sequence_labels) 22 | self.sequence_strs = list(sequence_strs) 23 | 24 | @classmethod 25 | def from_file(cls, fasta_file): 26 | sequence_labels, sequence_strs = [], [] 27 | cur_seq_label = None 28 | buf = [] 29 | 30 | def _flush_current_seq(): 31 | nonlocal cur_seq_label, buf 32 | if cur_seq_label is None: 33 | return 34 | sequence_labels.append(cur_seq_label) 35 | sequence_strs.append("".join(buf)) 36 | cur_seq_label = None 37 | buf = [] 38 | 39 | with open(fasta_file, "r") as infile: 40 | for line_idx, line in enumerate(infile): 41 | if line.startswith(">"): # label line 42 | _flush_current_seq() 43 | line = line[1:].strip() 44 | if len(line) > 0: 45 | cur_seq_label = line 46 | else: 47 | cur_seq_label = f"seqnum{line_idx:09d}" 48 | else: # sequence line 49 | buf.append(line.strip()) 50 | 51 | _flush_current_seq() 52 | 53 | assert len(set(sequence_labels)) == len( 54 | sequence_labels 55 | ), "Found duplicate sequence labels" 56 | 57 | return cls(sequence_labels, sequence_strs) 58 | 59 | def __len__(self): 60 | return len(self.sequence_labels) 61 | 62 | def __getitem__(self, idx): 63 | return self.sequence_labels[idx], self.sequence_strs[idx] 64 | 65 | def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0): 66 | sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)] 67 | sizes.sort() 68 | batches = [] 69 | buf = [] 70 | max_len = 0 71 | 72 | def _flush_current_buf(): 73 | nonlocal max_len, buf 74 | if len(buf) == 0: 75 | return 76 | batches.append(buf) 77 | buf = [] 78 | max_len = 0 79 | 80 | for sz, i in sizes: 81 | sz += extra_toks_per_seq 82 | if max(sz, max_len) * (len(buf) + 1) > toks_per_batch: 83 | _flush_current_buf() 84 | max_len = max(max_len, sz) 85 | buf.append(i) 86 | 87 | _flush_current_buf() 88 | return batches 89 | 90 | 91 | class Alphabet(object): 92 | def __init__( 93 | self, 94 | standard_toks: Sequence[str], 95 | prepend_toks: Sequence[str] = ("", "", "", ""), 96 | append_toks: Sequence[str] = ("", "", ""), 97 | prepend_bos: bool = True, 98 | append_eos: bool = False, 99 | use_msa: bool = False, 100 | ): 101 | self.standard_toks = list(standard_toks) 102 | self.prepend_toks = list(prepend_toks) 103 | self.append_toks = list(append_toks) 104 | self.prepend_bos = prepend_bos 105 | self.append_eos = append_eos 106 | self.use_msa = use_msa 107 | 108 | self.all_toks = list(self.prepend_toks) 109 | self.all_toks.extend(self.standard_toks) 110 | for i in range((8 - (len(self.all_toks) % 8)) % 8): 111 | self.all_toks.append(f"") 112 | self.all_toks.extend(self.append_toks) 113 | 114 | self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)} 115 | 116 | self.unk_idx = self.tok_to_idx[""] 117 | self.padding_idx = self.get_idx("") 118 | self.cls_idx = self.get_idx("") 119 | self.mask_idx = self.get_idx("") 120 | self.eos_idx = self.get_idx("") 121 | self.all_special_tokens = ['', '', '', '', ''] 122 | self.unique_no_split_tokens = self.all_toks 123 | 124 | def __len__(self): 125 | return len(self.all_toks) 126 | 127 | def get_idx(self, tok): 128 | return self.tok_to_idx.get(tok, self.unk_idx) 129 | 130 | def get_tok(self, ind): 131 | return self.all_toks[ind] 132 | 133 | def to_dict(self): 134 | return self.tok_to_idx.copy() 135 | 136 | def get_batch_converter(self, truncation_seq_length: int = None): 137 | if self.use_msa: 138 | return MSABatchConverter(self, truncation_seq_length) 139 | else: 140 | return BatchConverter(self, truncation_seq_length) 141 | 142 | @classmethod 143 | def from_architecture(cls, name: str) -> "Alphabet": 144 | if name in ("ESM-1", "protein_bert_base"): 145 | standard_toks = proteinseq_toks["toks"] 146 | prepend_toks: Tuple[str, ...] = ("", "", "", "") 147 | append_toks: Tuple[str, ...] = ("", "", "") 148 | prepend_bos = True 149 | append_eos = False 150 | use_msa = False 151 | elif name in ("ESM-1b", "roberta_large"): 152 | standard_toks = proteinseq_toks["toks"] 153 | prepend_toks = ("", "", "", "") 154 | append_toks = ("",) 155 | prepend_bos = True 156 | append_eos = True 157 | use_msa = False 158 | elif name in ("MSA Transformer", "msa_transformer"): 159 | standard_toks = proteinseq_toks["toks"] 160 | prepend_toks = ("", "", "", "") 161 | append_toks = ("",) 162 | prepend_bos = True 163 | append_eos = False 164 | use_msa = True 165 | elif "invariant_gvp" in name.lower(): 166 | standard_toks = proteinseq_toks["toks"] 167 | prepend_toks = ("", "", "", "") 168 | append_toks = ("", "", "") 169 | prepend_bos = True 170 | append_eos = False 171 | use_msa = False 172 | else: 173 | raise ValueError("Unknown architecture selected") 174 | return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa) 175 | 176 | def _tokenize(self, text) -> str: 177 | return text.split() 178 | 179 | def tokenize(self, text, **kwargs) -> List[str]: 180 | """ 181 | Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py 182 | Converts a string in a sequence of tokens, using the tokenizer. 183 | 184 | Args: 185 | text (:obj:`str`): 186 | The sequence to be encoded. 187 | 188 | Returns: 189 | :obj:`List[str]`: The list of tokens. 190 | """ 191 | 192 | def split_on_token(tok, text): 193 | result = [] 194 | split_text = text.split(tok) 195 | for i, sub_text in enumerate(split_text): 196 | # AddedToken can control whitespace stripping around them. 197 | # We use them for GPT2 and Roberta to have different behavior depending on the special token 198 | # Cf. https://github.com/huggingface/transformers/pull/2778 199 | # and https://github.com/huggingface/transformers/issues/3788 200 | # We strip left and right by default 201 | if i < len(split_text) - 1: 202 | sub_text = sub_text.rstrip() 203 | if i > 0: 204 | sub_text = sub_text.lstrip() 205 | 206 | if i == 0 and not sub_text: 207 | result.append(tok) 208 | elif i == len(split_text) - 1: 209 | if sub_text: 210 | result.append(sub_text) 211 | else: 212 | pass 213 | else: 214 | if sub_text: 215 | result.append(sub_text) 216 | result.append(tok) 217 | return result 218 | 219 | def split_on_tokens(tok_list, text): 220 | if not text.strip(): 221 | return [] 222 | 223 | tokenized_text = [] 224 | text_list = [text] 225 | for tok in tok_list: 226 | tokenized_text = [] 227 | for sub_text in text_list: 228 | if sub_text not in self.unique_no_split_tokens: 229 | tokenized_text.extend(split_on_token(tok, sub_text)) 230 | else: 231 | tokenized_text.append(sub_text) 232 | text_list = tokenized_text 233 | 234 | return list( 235 | itertools.chain.from_iterable( 236 | ( 237 | self._tokenize(token) 238 | if token not in self.unique_no_split_tokens 239 | else [token] 240 | for token in tokenized_text 241 | ) 242 | ) 243 | ) 244 | 245 | no_split_token = self.unique_no_split_tokens 246 | tokenized_text = split_on_tokens(no_split_token, text) 247 | return tokenized_text 248 | 249 | def encode(self, text): 250 | return [self.tok_to_idx[tok] for tok in self.tokenize(text)] 251 | 252 | 253 | class BatchConverter(object): 254 | """Callable to convert an unprocessed (labels + strings) batch to a 255 | processed (labels + tensor) batch. 256 | """ 257 | 258 | def __init__(self, alphabet, truncation_seq_length: int = None): 259 | self.alphabet = alphabet 260 | self.truncation_seq_length = truncation_seq_length 261 | 262 | def __call__(self, raw_batch: Sequence[Tuple[str, str]]): 263 | # RoBERTa uses an eos token, while ESM-1 does not. 264 | batch_size = len(raw_batch) 265 | batch_labels, seq_str_list = zip(*raw_batch) 266 | seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list] 267 | if self.truncation_seq_length: 268 | seq_encoded_list = [seq_str[:self.truncation_seq_length] for seq_str in seq_encoded_list] 269 | max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list) 270 | tokens = torch.empty( 271 | ( 272 | batch_size, 273 | max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos), 274 | ), 275 | dtype=torch.int64, 276 | ) 277 | tokens.fill_(self.alphabet.padding_idx) 278 | labels = [] 279 | strs = [] 280 | 281 | for i, (label, seq_str, seq_encoded) in enumerate( 282 | zip(batch_labels, seq_str_list, seq_encoded_list) 283 | ): 284 | labels.append(label) 285 | strs.append(seq_str) 286 | if self.alphabet.prepend_bos: 287 | tokens[i, 0] = self.alphabet.cls_idx 288 | seq = torch.tensor(seq_encoded, dtype=torch.int64) 289 | tokens[ 290 | i, 291 | int(self.alphabet.prepend_bos) : len(seq_encoded) 292 | + int(self.alphabet.prepend_bos), 293 | ] = seq 294 | if self.alphabet.append_eos: 295 | tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx 296 | 297 | return labels, strs, tokens 298 | 299 | 300 | class MSABatchConverter(BatchConverter): 301 | def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]): 302 | if isinstance(inputs[0][0], str): 303 | # Input is a single MSA 304 | raw_batch: Sequence[RawMSA] = [inputs] # type: ignore 305 | else: 306 | raw_batch = inputs # type: ignore 307 | 308 | batch_size = len(raw_batch) 309 | max_alignments = max(len(msa) for msa in raw_batch) 310 | max_seqlen = max(len(msa[0][1]) for msa in raw_batch) 311 | 312 | tokens = torch.empty( 313 | ( 314 | batch_size, 315 | max_alignments, 316 | max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos), 317 | ), 318 | dtype=torch.int64, 319 | ) 320 | tokens.fill_(self.alphabet.padding_idx) 321 | labels = [] 322 | strs = [] 323 | 324 | for i, msa in enumerate(raw_batch): 325 | msa_seqlens = set(len(seq) for _, seq in msa) 326 | if not len(msa_seqlens) == 1: 327 | raise RuntimeError( 328 | "Received unaligned sequences for input to MSA, all sequence " 329 | "lengths must be equal." 330 | ) 331 | msa_labels, msa_strs, msa_tokens = super().__call__(msa) 332 | labels.append(msa_labels) 333 | strs.append(msa_strs) 334 | tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens 335 | 336 | return labels, strs, tokens 337 | 338 | 339 | def read_fasta( 340 | path, 341 | keep_gaps=True, 342 | keep_insertions=True, 343 | to_upper=False, 344 | ): 345 | with open(path, "r") as f: 346 | for result in read_alignment_lines( 347 | f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper 348 | ): 349 | yield result 350 | 351 | 352 | def read_alignment_lines( 353 | lines, 354 | keep_gaps=True, 355 | keep_insertions=True, 356 | to_upper=False, 357 | ): 358 | seq = desc = None 359 | 360 | def parse(s): 361 | if not keep_gaps: 362 | s = re.sub("-", "", s) 363 | if not keep_insertions: 364 | s = re.sub("[a-z]", "", s) 365 | return s.upper() if to_upper else s 366 | 367 | for line in lines: 368 | # Line may be empty if seq % file_line_width == 0 369 | if len(line) > 0 and line[0] == ">": 370 | if seq is not None: 371 | yield desc, parse(seq) 372 | desc = line.strip().lstrip(">") 373 | seq = "" 374 | else: 375 | assert isinstance(seq, str) 376 | seq += line.strip() 377 | assert isinstance(seq, str) and isinstance(desc, str) 378 | yield desc, parse(seq) 379 | -------------------------------------------------------------------------------- /src/models/msa_transformer/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | from typing import Tuple, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .multihead_attention import MultiheadAttention # noqa 14 | from .axial_attention import ColumnSelfAttention, RowSelfAttention 15 | 16 | def gelu(x): 17 | """Implementation of the gelu activation function. 18 | 19 | For information: OpenAI GPT's gelu is slightly different 20 | (and gives slightly different results): 21 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 22 | """ 23 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 24 | 25 | 26 | def symmetrize(x): 27 | "Make layer symmetric in final two dimensions, used for contact prediction." 28 | return x + x.transpose(-1, -2) 29 | 30 | 31 | def apc(x): 32 | "Perform average product correct, used for contact prediction." 33 | a1 = x.sum(-1, keepdims=True) 34 | a2 = x.sum(-2, keepdims=True) 35 | a12 = x.sum((-1, -2), keepdims=True) 36 | 37 | avg = a1 * a2 38 | avg.div_(a12) # in-place to reduce memory 39 | normalized = x - avg 40 | return normalized 41 | 42 | 43 | class ESM1LayerNorm(nn.Module): 44 | def __init__(self, hidden_size, eps=1e-12, affine=True): 45 | """Construct a layernorm layer in the TF style (eps inside the sqrt).""" 46 | super().__init__() 47 | self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size) 48 | self.eps = eps 49 | self.affine = bool(affine) 50 | if self.affine: 51 | self.weight = nn.Parameter(torch.ones(hidden_size)) 52 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 53 | else: 54 | self.weight, self.bias = None, None 55 | 56 | def forward(self, x): 57 | dims = tuple(-(i + 1) for i in range(len(self.hidden_size))) 58 | means = x.mean(dims, keepdim=True) 59 | x_zeromean = x - means 60 | variances = x_zeromean.pow(2).mean(dims, keepdim=True) 61 | x = x_zeromean / torch.sqrt(variances + self.eps) 62 | if self.affine: 63 | x = (self.weight * x) + self.bias 64 | return x 65 | 66 | 67 | try: 68 | from apex.normalization import FusedLayerNorm as _FusedLayerNorm 69 | 70 | class ESM1bLayerNorm(_FusedLayerNorm): 71 | @torch.jit.unused 72 | def forward(self, x): 73 | if not x.is_cuda: 74 | return super().forward(x) 75 | else: 76 | with torch.cuda.device(x.device): 77 | return super().forward(x) 78 | 79 | except ImportError: 80 | from torch.nn import LayerNorm as ESM1bLayerNorm 81 | 82 | 83 | class TransformerLayer(nn.Module): 84 | """Transformer layer block.""" 85 | 86 | def __init__( 87 | self, 88 | embed_dim, 89 | ffn_embed_dim, 90 | attention_heads, 91 | add_bias_kv=True, 92 | use_esm1b_layer_norm=False, 93 | use_rotary_embeddings: bool = False, 94 | ): 95 | super().__init__() 96 | self.embed_dim = embed_dim 97 | self.ffn_embed_dim = ffn_embed_dim 98 | self.attention_heads = attention_heads 99 | self.use_rotary_embeddings = use_rotary_embeddings 100 | self._init_submodules(add_bias_kv, use_esm1b_layer_norm) 101 | 102 | def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm): 103 | BertLayerNorm = ESM1bLayerNorm if use_esm1b_layer_norm else ESM1LayerNorm 104 | 105 | self.self_attn = MultiheadAttention( 106 | self.embed_dim, 107 | self.attention_heads, 108 | add_bias_kv=add_bias_kv, 109 | add_zero_attn=False, 110 | use_rotary_embeddings=self.use_rotary_embeddings, 111 | ) 112 | self.self_attn_layer_norm = BertLayerNorm(self.embed_dim) 113 | 114 | self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim) 115 | self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim) 116 | 117 | self.final_layer_norm = BertLayerNorm(self.embed_dim) 118 | 119 | def forward( 120 | self, x, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False 121 | ): 122 | residual = x 123 | x = self.self_attn_layer_norm(x) 124 | x, attn = self.self_attn( 125 | query=x, 126 | key=x, 127 | value=x, 128 | key_padding_mask=self_attn_padding_mask, 129 | need_weights=True, 130 | need_head_weights=need_head_weights, 131 | attn_mask=self_attn_mask, 132 | ) 133 | x = residual + x 134 | 135 | residual = x 136 | x = self.final_layer_norm(x) 137 | x = gelu(self.fc1(x)) 138 | x = self.fc2(x) 139 | x = residual + x 140 | 141 | return x, attn 142 | 143 | 144 | class AxialTransformerLayer(nn.Module): 145 | """Implements an Axial MSA Transformer block.""" 146 | 147 | def __init__( 148 | self, 149 | embedding_dim: int = 768, 150 | ffn_embedding_dim: int = 3072, 151 | num_attention_heads: int = 8, 152 | dropout: float = 0.1, 153 | attention_dropout: float = 0.1, 154 | activation_dropout: float = 0.1, 155 | max_tokens_per_msa: int = 2**14, 156 | ) -> None: 157 | super().__init__() 158 | 159 | # Initialize parameters 160 | self.embedding_dim = embedding_dim 161 | self.dropout_prob = dropout 162 | 163 | row_self_attention = RowSelfAttention( 164 | embedding_dim, 165 | num_attention_heads, 166 | dropout=dropout, 167 | max_tokens_per_msa=max_tokens_per_msa, 168 | ) 169 | 170 | column_self_attention = ColumnSelfAttention( 171 | embedding_dim, 172 | num_attention_heads, 173 | dropout=dropout, 174 | max_tokens_per_msa=max_tokens_per_msa, 175 | ) 176 | 177 | feed_forward_layer = FeedForwardNetwork( 178 | embedding_dim, 179 | ffn_embedding_dim, 180 | activation_dropout=activation_dropout, 181 | max_tokens_per_msa=max_tokens_per_msa, 182 | ) 183 | 184 | self.row_self_attention = self.build_residual(row_self_attention) 185 | self.column_self_attention = self.build_residual(column_self_attention) 186 | self.feed_forward_layer = self.build_residual(feed_forward_layer) 187 | 188 | def build_residual(self, layer: nn.Module): 189 | return NormalizedResidualBlock( 190 | layer, 191 | self.embedding_dim, 192 | self.dropout_prob, 193 | ) 194 | 195 | def forward( 196 | self, 197 | x: torch.Tensor, 198 | self_attn_mask: Optional[torch.Tensor] = None, 199 | self_row_attn_mask: Optional[torch.Tensor] = None, 200 | self_attn_padding_mask: Optional[torch.Tensor] = None, 201 | need_head_weights: bool = False, 202 | ): 203 | """ 204 | LayerNorm is applied either before or after the self-attention/ffn 205 | modules similar to the original Transformer implementation. 206 | """ 207 | x, row_attn = self.row_self_attention( 208 | x, 209 | self_attn_mask=self_row_attn_mask, 210 | self_attn_padding_mask=self_attn_padding_mask, 211 | ) 212 | x, column_attn = self.column_self_attention( 213 | x, 214 | self_attn_mask=self_attn_mask, 215 | self_attn_padding_mask=self_attn_padding_mask, 216 | ) 217 | x = self.feed_forward_layer(x) 218 | if need_head_weights: 219 | return x, column_attn, row_attn 220 | else: 221 | return x 222 | 223 | class LearnedPositionalEmbedding(nn.Embedding): 224 | """ 225 | This module learns positional embeddings up to a fixed maximum size. 226 | Padding ids are ignored by either offsetting based on padding_idx 227 | or by setting padding_idx to None and ensuring that the appropriate 228 | position ids are passed to the forward function. 229 | """ 230 | 231 | def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): 232 | if padding_idx is not None: 233 | num_embeddings_ = num_embeddings + padding_idx + 1 234 | else: 235 | num_embeddings_ = num_embeddings 236 | super().__init__(num_embeddings_, embedding_dim, padding_idx) 237 | self.max_positions = num_embeddings 238 | 239 | def forward(self, input: torch.Tensor): 240 | """Input is expected to be of size [bsz x seqlen].""" 241 | if input.size(1) > self.max_positions: 242 | raise ValueError( 243 | f"Sequence length {input.size(1)} above maximum " 244 | f" sequence length of {self.max_positions}" 245 | ) 246 | mask = input.ne(self.padding_idx).int() 247 | positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx 248 | return F.embedding( 249 | positions, 250 | self.weight, 251 | self.padding_idx, 252 | self.max_norm, 253 | self.norm_type, 254 | self.scale_grad_by_freq, 255 | self.sparse, 256 | ) 257 | 258 | 259 | class SinusoidalPositionalEmbedding(nn.Module): 260 | def __init__(self, embed_dim, padding_idx, learned=False): 261 | super().__init__() 262 | self.embed_dim = embed_dim 263 | self.padding_idx = padding_idx 264 | self.register_buffer("_float_tensor", torch.FloatTensor(1)) 265 | self.weights = None 266 | 267 | def forward(self, x): 268 | bsz, seq_len = x.shape 269 | max_pos = self.padding_idx + 1 + seq_len 270 | if self.weights is None or max_pos > self.weights.size(0): 271 | self.weights = self.get_embedding(max_pos) 272 | self.weights = self.weights.type_as(self._float_tensor) 273 | 274 | positions = self.make_positions(x) 275 | return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() 276 | 277 | def make_positions(self, x): 278 | mask = x.ne(self.padding_idx) 279 | range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1 280 | positions = range_buf.expand_as(x) 281 | return positions * mask.long() + self.padding_idx * (1 - mask.long()) 282 | 283 | def get_embedding(self, num_embeddings): 284 | half_dim = self.embed_dim // 2 285 | emb = math.log(10000) / (half_dim - 1) 286 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 287 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 288 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) 289 | if self.embed_dim % 2 == 1: 290 | # zero pad 291 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) 292 | if self.padding_idx is not None: 293 | emb[self.padding_idx, :] = 0 294 | return emb 295 | 296 | 297 | class RobertaLMHead(nn.Module): 298 | """Head for masked language modeling.""" 299 | 300 | def __init__(self, embed_dim, output_dim, weight): 301 | super().__init__() 302 | self.dense = nn.Linear(embed_dim, embed_dim) 303 | self.layer_norm = ESM1bLayerNorm(embed_dim) 304 | self.weight = weight 305 | self.bias = nn.Parameter(torch.zeros(output_dim)) 306 | 307 | def forward(self, features): 308 | x = self.dense(features) 309 | x = gelu(x) 310 | x = self.layer_norm(x) 311 | # project back to size of vocabulary with bias 312 | x = F.linear(x, self.weight) + self.bias 313 | return x 314 | 315 | 316 | class ContactPredictionHead(nn.Module): 317 | """Performs symmetrization, apc, and computes a logistic regression on the output features""" 318 | 319 | def __init__( 320 | self, 321 | in_features: int, 322 | prepend_bos: bool, 323 | append_eos: bool, 324 | bias=True, 325 | eos_idx: Optional[int] = None, 326 | ): 327 | super().__init__() 328 | self.in_features = in_features 329 | self.prepend_bos = prepend_bos 330 | self.append_eos = append_eos 331 | if append_eos and eos_idx is None: 332 | raise ValueError("Using an alphabet with eos token, but no eos token was passed in.") 333 | self.eos_idx = eos_idx 334 | self.regression = nn.Linear(in_features, 1, bias) 335 | self.activation = nn.Sigmoid() 336 | 337 | def forward(self, tokens, attentions): 338 | # remove eos token attentions 339 | if self.append_eos: 340 | eos_mask = tokens.ne(self.eos_idx).to(attentions) 341 | eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) 342 | attentions = attentions * eos_mask[:, None, None, :, :] 343 | attentions = attentions[..., :-1, :-1] 344 | # remove cls token attentions 345 | if self.prepend_bos: 346 | attentions = attentions[..., 1:, 1:] 347 | batch_size, layers, heads, seqlen, _ = attentions.size() 348 | attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen) 349 | 350 | # features: B x C x T x T 351 | attentions = attentions.to( 352 | self.regression.weight.device 353 | ) # attentions always float32, may need to convert to float16 354 | attentions = apc(symmetrize(attentions)) 355 | attentions = attentions.permute(0, 2, 3, 1) 356 | return self.activation(self.regression(attentions).squeeze(3)) 357 | 358 | 359 | class NormalizedResidualBlock(nn.Module): 360 | def __init__( 361 | self, 362 | layer: nn.Module, 363 | embedding_dim: int, 364 | dropout: float = 0.1, 365 | ): 366 | super().__init__() 367 | self.embedding_dim = embedding_dim 368 | 369 | self.layer = layer 370 | self.dropout_module = nn.Dropout( 371 | dropout, 372 | ) 373 | self.layer_norm = ESM1bLayerNorm(self.embedding_dim) 374 | 375 | def forward(self, x, *args, **kwargs): 376 | residual = x 377 | x = self.layer_norm(x) 378 | outputs = self.layer(x, *args, **kwargs) 379 | if isinstance(outputs, tuple): 380 | x, *out = outputs 381 | else: 382 | x = outputs 383 | out = None 384 | 385 | x = self.dropout_module(x) 386 | x = residual + x 387 | 388 | if out is not None: 389 | return (x,) + tuple(out) 390 | else: 391 | return x 392 | 393 | 394 | class FeedForwardNetwork(nn.Module): 395 | def __init__( 396 | self, 397 | embedding_dim: int, 398 | ffn_embedding_dim: int, 399 | activation_dropout: float = 0.1, 400 | max_tokens_per_msa: int = 2**14, 401 | ): 402 | super().__init__() 403 | self.embedding_dim = embedding_dim 404 | self.ffn_embedding_dim = ffn_embedding_dim 405 | self.max_tokens_per_msa = max_tokens_per_msa 406 | self.activation_fn = nn.GELU() 407 | self.activation_dropout_module = nn.Dropout( 408 | activation_dropout, 409 | ) 410 | self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim) 411 | self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim) 412 | 413 | def forward(self, x): 414 | x = self.activation_fn(self.fc1(x)) 415 | x = self.activation_dropout_module(x) 416 | x = self.fc2(x) 417 | return x 418 | -------------------------------------------------------------------------------- /src/models/gvp/__init__.py: -------------------------------------------------------------------------------- 1 | import torch, functools 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import MessagePassing 5 | from torch_scatter import scatter_add 6 | import torch.utils.checkpoint as checkpoint 7 | 8 | def tuple_sum(*args): 9 | ''' 10 | Sums any number of tuples (s, V) elementwise. 11 | ''' 12 | return tuple(map(sum, zip(*args))) 13 | 14 | def tuple_cat(*args, dim=-1): 15 | ''' 16 | Concatenates any number of tuples (s, V) elementwise. 17 | 18 | :param dim: dimension along which to concatenate when viewed 19 | as the `dim` index for the scalar-channel tensors. 20 | This means that `dim=-1` will be applied as 21 | `dim=-2` for the vector-channel tensors. 22 | ''' 23 | dim %= len(args[0][0].shape) 24 | s_args, v_args = list(zip(*args)) 25 | return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim) 26 | 27 | def tuple_index(x, idx): 28 | ''' 29 | Indexes into a tuple (s, V) along the first dimension. 30 | 31 | :param idx: any object which can be used to index into a `torch.Tensor` 32 | ''' 33 | return x[0][idx], x[1][idx] 34 | 35 | def randn(n, dims, device="cpu"): 36 | ''' 37 | Returns random tuples (s, V) drawn elementwise from a normal distribution. 38 | 39 | :param n: number of data points 40 | :param dims: tuple of dimensions (n_scalar, n_vector) 41 | 42 | :return: (s, V) with s.shape = (n, n_scalar) and 43 | V.shape = (n, n_vector, 3) 44 | ''' 45 | return torch.randn(n, dims[0], device=device), \ 46 | torch.randn(n, dims[1], 3, device=device) 47 | 48 | def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-4, sqrt=True): 49 | ''' 50 | L2 norm of tensor clamped above a minimum value `eps`. 51 | 52 | :param sqrt: if `False`, returns the square of the L2 norm 53 | ''' 54 | out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps) 55 | return torch.sqrt(out) if sqrt else out 56 | 57 | def _split(x, nv): 58 | ''' 59 | Splits a merged representation of (s, V) back into a tuple. 60 | Should be used only with `_merge(s, V)` and only if the tuple 61 | representation cannot be used. 62 | 63 | :param x: the `torch.Tensor` returned from `_merge` 64 | :param nv: the number of vector channels in the input to `_merge` 65 | ''' 66 | v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3)) 67 | s = x[..., :-3*nv] 68 | return s, v 69 | 70 | def _merge(s, v): 71 | ''' 72 | Merges a tuple (s, V) into a single `torch.Tensor`, where the 73 | vector channels are flattened and appended to the scalar channels. 74 | Should be used only if the tuple representation cannot be used. 75 | Use `_split(x, nv)` to reverse. 76 | ''' 77 | v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],)) 78 | return torch.cat([s, v], -1) 79 | 80 | class GVP(nn.Module): 81 | ''' 82 | Geometric Vector Perceptron. See manuscript and README.md 83 | for more details. 84 | 85 | :param in_dims: tuple (n_scalar, n_vector) 86 | :param out_dims: tuple (n_scalar, n_vector) 87 | :param h_dim: intermediate number of vector channels, optional 88 | :param activations: tuple of functions (scalar_act, vector_act) 89 | :param vector_gate: whether to use vector gating. 90 | (vector_act will be used as sigma^+ in vector gating if `True`) 91 | ''' 92 | def __init__(self, in_dims, out_dims, h_dim=None, 93 | activations=(F.relu, torch.sigmoid), vector_gate=False): 94 | super(GVP, self).__init__() 95 | 96 | #self.dummy = DummyLayer() 97 | 98 | self.si, self.vi = in_dims 99 | self.so, self.vo = out_dims 100 | self.vector_gate = vector_gate 101 | if self.vi: 102 | self.h_dim = h_dim or max(self.vi, self.vo) 103 | self.wh = nn.Linear(self.vi, self.h_dim, bias=False) 104 | self.ws = nn.Linear(self.h_dim + self.si, self.so) 105 | if self.vo: 106 | self.wv = nn.Linear(self.h_dim, self.vo, bias=False) 107 | if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo) 108 | else: 109 | self.ws = nn.Linear(self.si, self.so) 110 | 111 | self.scalar_act, self.vector_act = activations 112 | self.dummy_param = nn.Parameter(torch.empty(0)) 113 | 114 | def forward(self, x): 115 | ''' 116 | :param x: tuple (s, V) of `torch.Tensor`, 117 | or (if vectors_in is 0), a single `torch.Tensor` 118 | :return: tuple (s, V) of `torch.Tensor`, 119 | or (if vectors_out is 0), a single `torch.Tensor` 120 | ''' 121 | if self.vi: 122 | s, v = x 123 | v = torch.transpose(v, -1, -2) 124 | vh = self.wh(v) 125 | vn = _norm_no_nan(vh, axis=-2) 126 | s = self.ws(torch.cat([s, vn], -1)) 127 | if self.vo: 128 | v = self.wv(vh) 129 | v = torch.transpose(v, -1, -2) 130 | if self.vector_gate: 131 | if self.vector_act: 132 | gate = self.wsv(self.vector_act(s)) 133 | else: 134 | gate = self.wsv(s) 135 | v = v * torch.sigmoid(gate).unsqueeze(-1) 136 | elif self.vector_act: 137 | v = v * self.vector_act( 138 | _norm_no_nan(v, axis=-1, keepdims=True)) 139 | else: 140 | s = self.ws(x) 141 | if self.vo: 142 | v = torch.zeros(s.shape[0], self.vo, 3, 143 | device=self.dummy_param.device) 144 | if self.scalar_act: 145 | s = self.scalar_act(s) 146 | 147 | return (s, v) if self.vo else s 148 | 149 | class _VDropout(nn.Module): 150 | ''' 151 | Vector channel dropout where the elements of each 152 | vector channel are dropped together. 153 | ''' 154 | def __init__(self, drop_rate): 155 | super(_VDropout, self).__init__() 156 | self.drop_rate = drop_rate 157 | self.dummy_param = nn.Parameter(torch.empty(0)) 158 | 159 | def forward(self, x): 160 | ''' 161 | :param x: `torch.Tensor` corresponding to vector channels 162 | ''' 163 | device = self.dummy_param.device 164 | if not self.training: 165 | return x 166 | mask = torch.bernoulli( 167 | (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device) 168 | ).unsqueeze(-1) 169 | x = mask * x / (1 - self.drop_rate) 170 | return x 171 | 172 | 173 | class Dropout(nn.Module): 174 | ''' 175 | Combined dropout for tuples (s, V). 176 | Takes tuples (s, V) as input and as output. 177 | ''' 178 | def __init__(self, drop_rate): 179 | super(Dropout, self).__init__() 180 | self.sdropout = nn.Dropout(drop_rate) 181 | self.vdropout = _VDropout(drop_rate) 182 | 183 | def forward(self, x): 184 | ''' 185 | :param x: tuple (s, V) of `torch.Tensor`, 186 | or single `torch.Tensor` 187 | (will be assumed to be scalar channels) 188 | ''' 189 | if type(x) is torch.Tensor: 190 | return self.sdropout(x) 191 | s, v = x 192 | return self.sdropout(s), self.vdropout(v) 193 | 194 | 195 | class LayerNorm(nn.Module): 196 | ''' 197 | Combined LayerNorm for tuples (s, V). 198 | Takes tuples (s, V) as input and as output. 199 | ''' 200 | def __init__(self, dims): 201 | super(LayerNorm, self).__init__() 202 | self.s, self.v = dims 203 | self.scalar_norm = nn.LayerNorm(self.s) 204 | 205 | def forward(self, x): 206 | ''' 207 | :param x: tuple (s, V) of `torch.Tensor`, 208 | or single `torch.Tensor` 209 | (will be assumed to be scalar channels) 210 | ''' 211 | if not self.v: 212 | return self.scalar_norm(x) 213 | s, v = x 214 | vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False) 215 | vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True)) 216 | return self.scalar_norm(s), v / vn 217 | 218 | class GVPConv(MessagePassing): 219 | ''' 220 | Graph convolution / message passing with Geometric Vector Perceptrons. 221 | Takes in a graph with node and edge embeddings, 222 | and returns new node embeddings. 223 | 224 | This does NOT do residual updates and pointwise feedforward layers 225 | ---see `GVPConvLayer`. 226 | 227 | :param in_dims: input node embedding dimensions (n_scalar, n_vector) 228 | :param out_dims: output node embedding dimensions (n_scalar, n_vector) 229 | :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) 230 | :param n_layers: number of GVPs in the message function 231 | :param module_list: preconstructed message function, overrides n_layers 232 | :param aggr: should be "add" if some incoming edges are masked, as in 233 | a masked autoregressive decoder architecture, otherwise "mean" 234 | :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs 235 | :param vector_gate: whether to use vector gating. 236 | (vector_act will be used as sigma^+ in vector gating if `True`) 237 | ''' 238 | def __init__(self, in_dims, out_dims, edge_dims, 239 | n_layers=3, module_list=None, aggr="mean", 240 | activations=(F.leaky_relu, torch.sigmoid), vector_gate=False): 241 | super(GVPConv, self).__init__(aggr=aggr) 242 | self.si, self.vi = in_dims 243 | self.so, self.vo = out_dims 244 | self.se, self.ve = edge_dims 245 | 246 | GVP_ = functools.partial(GVP, 247 | activations=activations, vector_gate=vector_gate) 248 | 249 | module_list = module_list or [] 250 | if not module_list: 251 | if n_layers == 1: 252 | module_list.append( 253 | GVP_((2*self.si + self.se, 2*self.vi + self.ve), 254 | (self.so, self.vo), activations=(None, None))) 255 | else: 256 | module_list.append( 257 | GVP_((2*self.si + self.se, 2*self.vi + self.ve), out_dims) 258 | ) 259 | for i in range(n_layers - 2): 260 | module_list.append(GVP_(out_dims, out_dims)) 261 | module_list.append(GVP_(out_dims, out_dims, 262 | activations=(None, None))) 263 | self.message_func = nn.Sequential(*module_list) 264 | 265 | def forward(self, x, edge_index, edge_attr): 266 | ''' 267 | :param x: tuple (s, V) of `torch.Tensor` 268 | :param edge_index: array of shape [2, n_edges] 269 | :param edge_attr: tuple (s, V) of `torch.Tensor` 270 | ''' 271 | x_s, x_v = x 272 | message = self.propagate(edge_index, 273 | s=x_s, v=x_v.reshape(x_v.shape[0], 3*x_v.shape[1]), 274 | edge_attr=edge_attr) 275 | return _split(message, self.vo) 276 | 277 | def message(self, s_i, v_i, s_j, v_j, edge_attr): 278 | v_j = v_j.view(v_j.shape[0], v_j.shape[1]//3, 3) 279 | v_i = v_i.view(v_i.shape[0], v_i.shape[1]//3, 3) 280 | message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i)) 281 | message = self.message_func(message) 282 | return _merge(*message) 283 | 284 | 285 | class GVPConvLayer(nn.Module): 286 | ''' 287 | Full graph convolution / message passing layer with 288 | Geometric Vector Perceptrons. Residually updates node embeddings with 289 | aggregated incoming messages, applies a pointwise feedforward 290 | network to node embeddings, and returns updated node embeddings. 291 | 292 | To only compute the aggregated messages, see `GVPConv`. 293 | 294 | :param node_dims: node embedding dimensions (n_scalar, n_vector) 295 | :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) 296 | :param n_message: number of GVPs to use in message function 297 | :param n_feedforward: number of GVPs to use in feedforward function 298 | :param drop_rate: drop probability in all dropout layers 299 | :param autoregressive: if `True`, this `GVPConvLayer` will be used 300 | with a different set of input node embeddings for messages 301 | where src >= dst 302 | :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs 303 | :param vector_gate: whether to use vector gating. 304 | (vector_act will be used as sigma^+ in vector gating if `True`) 305 | ''' 306 | def __init__(self, node_dims, edge_dims, 307 | n_message=3, n_feedforward=2, drop_rate=.1, 308 | autoregressive=False, 309 | activations=(F.leaky_relu, torch.sigmoid), vector_gate=True, aggr="mean"): 310 | 311 | super(GVPConvLayer, self).__init__() 312 | self.conv = GVPConv(node_dims, node_dims, edge_dims, n_message, 313 | aggr="add" if autoregressive else aggr, 314 | activations=activations, vector_gate=vector_gate) 315 | GVP_ = functools.partial(GVP, 316 | activations=activations, vector_gate=vector_gate) 317 | self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)]) 318 | self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)]) 319 | 320 | ff_func = [] 321 | if n_feedforward == 1: 322 | ff_func.append(GVP_(node_dims, node_dims, activations=(None, None))) 323 | else: 324 | hid_dims = 4*node_dims[0], 2*node_dims[1] 325 | ff_func.append(GVP_(node_dims, hid_dims)) 326 | for i in range(n_feedforward-2): 327 | ff_func.append(GVP_(hid_dims, hid_dims)) 328 | ff_func.append(GVP_(hid_dims, node_dims, activations=(None, None))) 329 | self.ff_func = nn.Sequential(*ff_func) 330 | 331 | def forward(self, x, edge_index, edge_attr, 332 | autoregressive_x=None, node_mask=None): 333 | ''' 334 | :param x: tuple (s, V) of `torch.Tensor` 335 | :param edge_index: array of shape [2, n_edges] 336 | :param edge_attr: tuple (s, V) of `torch.Tensor` 337 | :param autoregressive_x: tuple (s, V) of `torch.Tensor`. 338 | If not `None`, will be used as src node embeddings 339 | for forming messages where src >= dst. The corrent node 340 | embeddings `x` will still be the base of the update and the 341 | pointwise feedforward. 342 | :param node_mask: array of type `bool` to index into the first 343 | dim of node embeddings (s, V). If not `None`, only 344 | these nodes will be updated. 345 | ''' 346 | 347 | if autoregressive_x is not None: 348 | src, dst = edge_index 349 | mask = src < dst 350 | edge_index_forward = edge_index[:, mask] 351 | edge_index_backward = edge_index[:, ~mask] 352 | edge_attr_forward = tuple_index(edge_attr, mask) 353 | edge_attr_backward = tuple_index(edge_attr, ~mask) 354 | 355 | dh = tuple_sum( 356 | self.conv(x, edge_index_forward, edge_attr_forward), 357 | self.conv(autoregressive_x, edge_index_backward, edge_attr_backward) 358 | ) 359 | 360 | count = scatter_add(torch.ones_like(dst), dst, 361 | dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1) 362 | 363 | dh = dh[0] / count, dh[1] / count.unsqueeze(-1) 364 | 365 | else: 366 | dh = self.conv(x, edge_index, edge_attr) 367 | 368 | if node_mask is not None: 369 | x_ = x 370 | x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask) 371 | 372 | x = self.norm[0](tuple_sum(x, self.dropout[0](dh))) 373 | 374 | dh = self.ff_func(x) 375 | x = self.norm[1](tuple_sum(x, self.dropout[1](dh))) 376 | 377 | if node_mask is not None: 378 | x_[0][node_mask], x_[1][node_mask] = x[0], x[1] 379 | x = x_ 380 | return x 381 | -------------------------------------------------------------------------------- /data/test/proteingym/all_models_substitutions_Spearman_Uniprot_level.csv: -------------------------------------------------------------------------------- 1 | UniProt_ID,Tranception_L_no_retrieval,Tranception_S_retrieval,Tranception_M_retrieval,Tranception_L_retrieval,EVE_single,EVE_ensemble,MSA_Transformer_single,MSA_Transformer_ensemble,ESM1v_single,ESM1v_ensemble,Wavenet,DeepSequence_single,DeepSequence_ensemble,Site_Independent,EVmutation,RITA_s,RITA_m,RITA_l,RITA_xl,RITA_ensemble,Progen2_small,Progen2_medium,Progen2_base,Progen2_large,Progen2_xlarge,Progen2_ensemble,Unirep,Unirep_evotune,GEMME,VESPA,VESPAl,ESM1b,ProtGPT2,TranceptEVE_L,Neff_L_category,Taxon 2 | A0A140D2T1_ZIKV,0.272,0.362,0.366,0.351,0.394,0.405,0.475,0.454,-0.048,0.015,0.216,0.131,0.103,0.383,0.354,0.361,0.309,0.317,0.305,0.346,0.329,0.342,0.328,0.312,0.293,0.346,-0.133,0.062,0.444,0.319,0.296,-0.001,0.005,0.379,medium,Virus 3 | A0A192B1T2_9HIV1,0.514,0.509,0.503,0.513,0.509,0.516,0.514,0.514,0.492,0.516,0.465,0.413,0.432,0.481,0.407,0.496,0.507,0.509,0.505,0.519,0.497,0.501,0.463,0.49,0.484,0.508,0.0,0.513,0.504,0.541,0.507,0.456,0.327,0.528,medium,Virus 4 | A0A1I9GEU1_NEIME,0.099,0.031,0.041,0.057,0.053,0.054,0.094,0.102,0.068,0.068,0.067,0.107,0.098,-0.011,0.044,-0.01,0.047,0.071,0.088,0.066,0.05,0.088,0.08,0.089,0.095,0.087,-0.024,0.084,0.05,0.046,0.036,0.04,0.03,0.075,medium,Prokaryote 5 | A0A2Z5U3Z0_9INFA,0.521,0.512,0.541,0.539,0.526,0.534,0.33,0.522,0.489,0.53,0.436,0.469,0.498,0.479,0.49,0.458,0.518,0.502,0.526,0.528,0.389,0.49,0.498,0.471,0.49,0.518,0.015,0.464,0.528,0.466,0.388,0.14,0.08,0.566,medium,Virus 6 | A4D664_9INFA,0.404,0.366,0.38,0.393,0.409,0.408,0.333,0.336,0.026,0.033,0.26,0.406,0.403,0.411,0.34,0.329,0.386,0.404,0.398,0.398,0.064,0.26,0.285,0.244,0.33,0.292,0.029,0.364,0.435,0.318,0.296,0.044,0.055,0.461,medium,Virus 7 | A4GRB6_PSEAI,0.629,0.557,0.588,0.652,0.623,0.641,0.695,0.67,0.647,0.668,0.567,0.657,0.664,0.317,0.492,0.409,0.533,0.564,0.624,0.576,0.526,0.622,0.639,0.638,0.705,0.679,0.344,0.529,0.707,0.742,0.673,0.68,0.242,0.679,high,Prokaryote 8 | A4_HUMAN,0.363,0.436,0.366,0.452,0.305,0.309,0.393,0.392,0.309,0.398,0.245,0.446,0.421,0.388,0.379,0.322,0.276,0.311,0.299,0.308,0.4,0.292,0.314,0.307,0.307,0.33,0.346,0.157,0.468,0.242,0.173,0.294,0.527,0.41,low,Human 9 | AACC1_PSEAI,0.417,0.427,0.418,0.466,0.487,0.491,0.496,0.484,0.483,0.489,0.391,0.331,0.414,0.282,0.474,0.206,0.232,0.276,0.306,0.26,0.273,0.428,0.407,0.413,0.438,0.424,0.196,0.191,0.473,0.507,0.455,0.413,0.021,0.495,high,Prokaryote 10 | ADRB2_HUMAN,0.506,0.528,0.541,0.533,0.497,0.517,0.435,0.5,0.526,0.539,0.521,0.506,0.514,0.331,0.423,0.509,0.524,0.517,0.493,0.528,0.524,0.531,0.538,0.533,0.502,0.54,0.463,0.5,0.546,0.5,0.414,0.534,0.298,0.537,medium,Human 11 | AMIE_PSEAE,0.405,0.585,0.479,0.438,0.477,0.464,0.592,0.569,0.613,0.662,0.506,0.496,0.506,0.246,0.443,0.555,0.525,0.532,0.542,0.57,0.566,0.561,0.534,0.559,0.508,0.584,0.093,0.436,0.632,0.583,0.528,0.57,0.265,0.481,high,Prokaryote 12 | B3VI55_LIPST,0.555,0.37,0.404,0.527,0.481,0.498,0.579,0.56,0.537,0.564,0.499,0.485,0.494,0.28,0.426,0.276,0.396,0.462,0.495,0.443,0.365,0.492,0.525,0.479,0.552,0.531,0.104,0.465,0.544,0.537,0.46,0.444,0.139,0.54,medium,Eukaryote 13 | BLAT_ECOLX,0.459,0.594,0.595,0.596,0.658,0.672,0.645,0.646,0.622,0.654,0.602,0.668,0.683,0.429,0.652,0.518,0.512,0.502,0.492,0.564,0.539,0.594,0.598,0.534,0.429,0.612,0.109,0.418,0.609,0.691,0.673,0.646,0.148,0.681,high,Prokaryote 14 | BRCA1_HUMAN,0.409,0.456,0.456,0.514,0.424,0.466,0.472,0.501,0.404,0.447,0.296,0.441,0.432,0.501,0.385,0.145,0.379,0.405,0.361,0.417,0.305,0.473,0.457,0.442,0.423,0.471,0.125,0.189,0.413,0.539,0.482,0.494,0.22,0.533,low,Human 15 | C6KNH7_9INFA,0.392,0.412,0.419,0.433,0.433,0.436,0.4,0.403,0.428,0.492,0.343,0.351,0.357,0.393,0.371,0.396,0.363,0.378,0.349,0.391,0.254,0.454,0.428,0.465,0.37,0.473,-0.025,0.451,0.48,0.457,0.36,0.06,0.11,0.452,medium,Virus 16 | CALM1_HUMAN,0.291,0.218,0.242,0.266,0.238,0.236,0.237,0.251,0.235,0.264,0.229,0.238,0.233,0.175,0.233,0.173,0.227,0.242,0.254,0.233,0.232,0.276,0.279,0.3,0.313,0.301,0.174,0.187,0.228,0.21,0.146,0.25,0.087,0.246,high,Human 17 | CAPSD_AAV2S,0.492,0.38,0.386,0.473,0.346,0.33,0.35,0.368,0.196,0.199,0.258,0.317,0.372,0.407,0.345,0.191,0.25,0.269,0.279,0.278,0.203,0.198,0.266,0.204,0.399,0.261,0.374,0.424,0.511,0.183,0.177,0.183,0.132,0.426,low,Virus 18 | CCDB_ECOLI,0.367,0.461,0.453,0.495,0.522,0.53,0.406,0.461,0.393,0.458,0.488,0.522,0.544,0.395,0.496,0.03,-0.0,-0.094,0.156,0.046,-0.123,0.074,-0.029,0.054,0.45,0.107,-0.02,0.24,0.485,0.548,0.532,0.457,0.118,0.535,high,Prokaryote 19 | CP2C9_HUMAN,0.57,0.642,0.638,0.631,0.606,0.622,0.58,0.594,0.615,0.644,0.632,0.61,0.628,0.516,0.581,0.57,0.57,0.603,0.566,0.621,0.594,0.586,0.591,0.582,0.577,0.618,0.556,0.585,0.619,0.588,0.528,0.516,0.216,0.641,high,Human 20 | DLG4_HUMAN,0.58,0.653,0.7,0.667,0.609,0.616,0.602,0.614,0.55,0.608,0.65,0.607,0.577,0.679,0.584,0.574,0.583,0.569,0.538,0.581,0.605,0.607,0.564,0.562,0.495,0.594,0.723,0.629,0.616,0.634,0.621,0.52,0.511,0.636,low,Human 21 | DLG4_RAT,0.307,0.47,0.478,0.443,0.526,0.539,0.488,0.496,0.557,0.588,0.444,0.487,0.491,0.486,0.442,0.378,0.389,0.373,0.371,0.4,0.407,0.42,0.407,0.367,0.395,0.425,0.49,0.437,0.509,0.556,0.537,0.469,0.246,0.54,low,Eukaryote 22 | DYR_ECOLI,0.347,0.407,0.415,0.423,0.476,0.474,0.487,0.49,0.416,0.436,0.485,0.469,0.472,0.384,0.484,0.199,0.36,0.266,0.31,0.316,0.342,0.436,0.476,0.441,0.418,0.462,-0.018,0.321,0.442,0.451,0.408,0.469,0.1,0.481,medium,Prokaryote 23 | ENV_HV1B9,0.404,0.392,0.396,0.407,0.388,0.377,0.38,0.375,0.415,0.389,0.388,0.238,0.367,0.369,0.397,0.38,0.358,0.408,0.419,0.4,0.263,0.394,0.401,0.391,0.374,0.375,0.056,0.353,0.37,0.355,0.322,0.359,0.253,0.392,medium,Virus 24 | ENV_HV1BR,0.358,0.359,0.366,0.362,0.339,0.345,0.343,0.345,0.32,0.336,0.337,0.322,0.323,0.338,0.303,0.35,0.358,0.371,0.364,0.374,0.345,0.358,0.358,0.354,0.362,0.378,-0.001,0.322,0.348,0.325,0.29,0.298,0.189,0.365,medium,Virus 25 | ESTA_BACSU,0.261,0.31,0.329,0.326,0.387,0.387,0.435,0.431,0.305,0.335,0.32,0.389,0.415,0.258,0.399,0.121,0.196,0.268,0.286,0.261,0.27,0.254,0.311,0.272,0.379,0.328,0.187,0.315,0.403,0.434,0.406,0.336,0.075,0.397,high,Prokaryote 26 | F7YBW8_MESOW,0.433,0.062,0.011,0.425,0.428,0.43,0.375,0.413,0.382,0.393,0.391,0.395,0.44,0.06,0.395,-0.075,-0.122,-0.104,-0.005,-0.081,0.098,0.125,-0.021,0.298,0.403,0.22,0.016,0.32,0.374,0.461,0.441,0.437,0.037,0.444,high,Prokaryote 27 | GAL4_YEAST,0.325,0.56,0.555,0.557,0.486,0.532,0.602,0.58,0.458,0.462,0.507,0.507,0.567,0.237,0.403,0.319,0.355,0.383,0.372,0.37,0.364,0.424,0.493,0.451,0.579,0.504,0.337,-0.004,0.629,0.651,0.563,0.623,0.303,0.503,medium,Eukaryote 28 | GCN4_YEAST,0.265,0.258,0.257,0.278,0.242,0.241,0.249,0.25,0.272,0.284,0.186,0.245,0.248,0.253,0.25,0.175,0.189,0.168,0.167,0.179,0.158,0.118,0.081,0.134,0.164,0.144,0.182,0.221,0.222,0.255,0.247,0.238,0.025,0.274,low,Eukaryote 29 | GFP_AEQVI,0.629,0.649,0.651,0.672,0.679,0.679,0.648,0.657,0.098,0.103,0.598,0.672,0.673,0.649,0.644,0.078,0.105,0.181,0.098,0.123,0.046,0.188,0.298,0.642,0.647,0.446,0.049,0.635,0.679,0.607,0.615,0.524,0.074,0.706,low,Eukaryote 30 | GRB2_HUMAN,0.404,0.517,0.506,0.436,0.54,0.546,0.471,0.486,0.455,0.515,0.534,0.507,0.538,0.405,0.521,0.544,0.516,0.514,0.453,0.525,0.532,0.509,0.446,0.524,0.448,0.515,0.486,0.45,0.51,0.428,0.422,0.533,0.464,0.532,medium,Human 31 | HIS7_YEAST,0.585,0.488,0.496,0.616,0.533,0.531,0.508,0.506,0.411,0.477,0.454,0.558,0.557,0.472,0.542,0.325,0.402,0.433,0.478,0.443,0.404,0.485,0.464,0.516,0.539,0.518,0.146,0.273,0.522,0.4,0.374,0.472,0.143,0.582,medium,Eukaryote 32 | HSP82_YEAST,0.414,0.433,0.444,0.437,0.448,0.447,0.409,0.41,0.449,0.466,0.412,0.452,0.46,0.396,0.39,0.404,0.413,0.396,0.42,0.427,0.428,0.422,0.442,0.385,0.432,0.441,0.277,0.326,0.459,0.474,0.461,0.414,0.298,0.458,medium,Eukaryote 33 | I6TAH8_I68A0,0.337,0.329,0.354,0.348,0.364,0.361,0.303,0.324,0.018,0.014,0.212,0.268,0.264,0.347,0.317,0.308,0.328,0.374,0.377,0.365,0.004,0.011,0.097,0.002,0.302,0.127,-0.005,0.297,0.353,0.25,0.253,0.013,0.171,0.401,medium,Virus 34 | IF1_ECOLI,0.527,0.463,0.47,0.495,0.527,0.537,0.234,0.255,0.54,0.565,0.49,0.541,0.539,0.328,0.499,0.361,0.438,0.366,0.411,0.412,0.451,0.458,0.482,0.443,0.459,0.484,0.177,0.387,0.408,0.52,0.463,0.534,0.238,0.539,high,Prokaryote 35 | KCNH2_HUMAN,0.511,0.499,0.554,0.538,0.211,0.21,0.356,0.33,0.216,0.234,0.382,0.298,0.292,0.449,0.421,0.468,0.513,0.495,0.463,0.503,0.502,0.491,0.455,0.477,0.481,0.5,0.454,0.151,0.438,0.433,0.318,0.307,0.344,0.509,medium,Human 36 | KKA2_KLEPN,0.586,0.445,0.524,0.588,0.599,0.603,0.569,0.594,0.597,0.621,0.498,0.437,0.622,0.25,0.53,0.285,0.42,0.518,0.54,0.503,0.353,0.578,0.577,0.568,0.638,0.607,0.199,0.423,0.647,0.637,0.586,0.566,0.148,0.63,high,Prokaryote 37 | MK01_HUMAN,0.004,0.193,0.119,0.093,0.223,0.227,0.144,0.139,0.167,0.182,0.184,0.237,0.241,0.176,0.198,0.209,0.117,0.069,0.016,0.11,0.183,0.089,0.057,0.076,-0.067,0.077,0.209,0.163,0.234,0.176,0.222,0.033,0.106,0.202,medium,Human 38 | MSH2_HUMAN,0.281,0.35,0.388,0.346,0.383,0.39,0.39,0.395,0.38,0.4,0.364,0.367,0.376,0.352,0.399,0.296,0.31,0.264,0.257,0.313,0.326,0.317,0.32,0.326,0.301,0.343,0.204,0.339,0.375,0.351,0.324,0.349,0.208,0.383,medium,Human 39 | MTH3_HAEAE,0.665,0.422,0.496,0.638,0.696,0.704,0.671,0.68,0.692,0.708,0.658,0.709,0.718,0.371,0.612,0.311,0.454,0.6,0.662,0.593,0.488,0.663,0.711,0.684,0.727,0.707,0.314,0.644,0.672,0.697,0.652,0.581,0.291,0.714,medium,Prokaryote 40 | NCAP_I34A1,0.415,0.39,0.402,0.424,0.363,0.364,0.338,0.358,0.019,0.02,0.269,0.334,0.333,0.364,0.328,0.352,0.382,0.408,0.413,0.409,0.018,0.042,0.108,0.03,0.352,0.176,0.002,0.335,0.351,0.27,0.279,0.015,0.126,0.441,medium,Virus 41 | NRAM_I33A0,0.551,0.592,0.615,0.621,0.584,0.584,0.519,0.628,0.162,0.448,0.343,0.501,0.49,0.569,0.565,0.583,0.633,0.584,0.571,0.621,0.047,0.53,0.627,0.462,0.654,0.477,0.035,0.39,0.643,0.404,0.441,-0.076,-0.169,0.632,low,Virus 42 | NUD15_HUMAN,0.575,0.433,0.456,0.604,0.591,0.594,0.643,0.671,0.6,0.645,0.501,0.564,0.596,0.271,0.453,0.301,0.454,0.546,0.518,0.545,0.412,0.579,0.577,0.549,0.542,0.604,-0.005,0.389,0.605,0.623,0.553,0.603,0.147,0.635,high,Human 43 | P53_HUMAN,0.374,0.402,0.487,0.404,0.431,0.429,0.303,0.303,0.467,0.52,0.09,0.322,0.338,0.403,0.427,0.339,0.483,0.479,0.44,0.48,0.416,0.498,0.511,0.544,0.327,0.508,-0.106,0.296,0.472,0.469,0.411,0.5,0.215,0.411,low,Human 44 | P53_HUMAN_Kotler,0.42,0.611,0.584,0.564,0.555,0.576,0.582,0.579,0.538,0.622,0.432,0.496,0.556,0.629,0.588,0.429,0.461,0.478,0.463,0.518,0.47,0.48,0.465,0.47,0.493,0.518,0.109,0.498,0.603,0.631,0.585,0.6,0.039,0.581,low,Human 45 | P84126_THETH,0.536,0.515,0.526,0.547,0.564,0.578,0.637,0.636,0.552,0.588,0.618,0.603,0.617,0.508,0.58,0.419,0.507,0.478,0.558,0.507,0.521,0.584,0.589,0.571,0.653,0.614,0.364,0.526,0.521,0.553,0.469,0.569,0.359,0.581,medium,Prokaryote 46 | PABP_YEAST,0.64,0.689,0.692,0.688,0.654,0.648,0.662,0.655,0.662,0.678,0.541,0.541,0.55,0.663,0.617,0.638,0.665,0.666,0.692,0.688,0.638,0.698,0.7,0.676,0.666,0.704,0.474,0.569,0.675,0.703,0.635,0.688,0.261,0.684,medium,Eukaryote 47 | PA_I34A1,0.541,0.546,0.561,0.572,0.539,0.543,0.383,0.304,0.054,0.101,0.325,0.499,0.508,0.518,0.519,0.456,0.493,0.533,0.538,0.528,0.219,0.408,0.444,0.429,0.438,0.451,0.041,0.358,0.586,0.384,0.374,0.037,0.107,0.584,medium,Virus 48 | POLG_CXB3N,0.355,0.342,0.385,0.413,0.46,0.473,0.486,0.5,-0.059,0.042,0.356,0.377,0.411,0.423,0.39,0.339,0.39,0.381,0.377,0.395,0.138,0.388,0.383,0.369,0.393,0.392,-0.036,0.336,0.495,0.386,0.319,0.292,0.007,0.458,medium,Virus 49 | POLG_HCVJF,0.522,0.515,0.547,0.578,0.605,0.614,0.608,0.591,0.637,0.635,0.26,0.41,0.413,0.605,0.547,0.4,0.443,0.452,0.492,0.487,0.422,0.475,0.32,0.41,0.517,0.476,-0.039,0.196,0.637,0.587,0.485,0.178,0.182,0.56,medium,Virus 50 | PTEN_HUMAN,0.321,0.441,0.466,0.415,0.466,0.474,0.455,0.444,0.439,0.477,0.496,0.44,0.453,0.426,0.448,0.263,0.436,0.368,0.335,0.428,0.34,0.314,0.3,0.347,0.261,0.355,0.165,0.352,0.512,0.438,0.432,0.462,0.092,0.489,medium,Human 51 | Q2N0S5_9HIV1,0.406,0.52,0.502,0.501,0.495,0.502,0.504,0.514,0.509,0.537,0.403,0.352,0.393,0.493,0.379,0.518,0.403,0.39,0.337,0.444,0.517,0.401,0.394,0.436,0.354,0.459,0.003,0.437,0.516,0.478,0.483,0.47,0.291,0.513,medium,Virus 52 | Q59976_STRSQ,0.634,0.616,0.653,0.657,0.654,0.662,0.685,0.655,0.519,0.543,0.649,0.634,0.643,0.475,0.593,0.598,0.651,0.652,0.663,0.671,0.622,0.662,0.677,0.673,0.68,0.691,0.363,0.535,0.685,0.633,0.575,0.588,0.304,0.675,medium,Prokaryote 53 | R1AB_SARS2,0.216,0.35,0.399,0.401,0.6,0.605,-0.037,-0.037,-0.03,-0.04,0.266,0.212,0.227,0.577,0.561,0.214,0.259,0.274,0.289,0.272,0.242,0.236,0.203,0.21,0.224,0.241,-0.049,0.292,0.586,0.507,0.408,0.103,-0.056,0.565,medium,Virus 54 | RASH_HUMAN,0.377,0.454,0.478,0.45,0.466,0.48,0.408,0.413,0.36,0.405,0.514,0.444,0.476,0.447,0.436,0.437,0.414,0.42,0.396,0.432,0.433,0.403,0.401,0.374,0.305,0.412,0.313,0.353,0.436,0.338,0.319,0.318,0.18,0.487,high,Human 55 | REV_HV1H2,0.24,0.245,0.269,0.236,0.216,0.216,0.246,0.251,0.245,0.267,0.17,0.221,0.227,0.206,0.159,0.216,0.259,0.238,0.24,0.252,0.29,0.294,0.16,0.253,0.255,0.305,0.038,0.316,0.293,0.35,0.353,0.128,0.06,0.235,medium,Virus 56 | RL401_YEAST,0.382,0.43,0.479,0.418,0.38,0.407,0.396,0.394,0.277,0.316,0.437,0.38,0.425,0.322,0.352,0.392,0.51,0.487,0.431,0.485,0.49,0.434,0.436,0.418,0.395,0.467,0.123,0.373,0.366,0.426,0.436,0.191,0.14,0.42,medium,Eukaryote 57 | SC6A4_HUMAN,0.491,0.53,0.542,0.53,0.504,0.522,0.552,0.568,0.531,0.542,0.555,0.423,0.433,0.387,0.456,0.498,0.511,0.5,0.496,0.519,0.512,0.52,0.518,0.525,0.511,0.534,0.352,0.507,0.55,0.463,0.36,0.545,0.342,0.545,medium,Human 58 | SCN5A_HUMAN,0.069,0.095,0.086,0.086,0.153,0.158,0.152,0.135,0.217,0.135,0.126,0.162,0.162,0.13,0.131,0.106,0.143,0.151,0.127,0.129,0.124,0.106,0.093,0.107,0.167,0.124,0.13,-0.014,0.144,0.186,0.181,0.141,0.095,0.152,medium,Human 59 | SPG1_STRSG,0.279,0.243,0.204,0.289,0.247,0.272,0.142,0.245,0.237,0.192,0.278,-0.004,0.006,0.239,0.282,0.252,0.216,0.214,0.208,0.227,0.232,0.222,0.218,0.239,0.354,0.259,-0.041,0.017,0.287,0.478,0.494,0.305,0.111,0.29,low,Prokaryote 60 | SPIKE_SARS2,0.369,0.348,0.343,0.342,0.347,0.351,0.472,0.486,-0.044,-0.018,0.346,0.114,0.215,0.179,0.26,0.311,0.375,0.366,0.376,0.377,0.382,0.357,0.318,0.388,0.324,0.382,-0.031,0.36,0.236,0.318,0.352,0.024,0.174,0.408,medium,Virus 61 | SRC_HUMAN,0.348,0.502,0.502,0.492,0.496,0.507,0.258,0.37,0.561,0.585,0.484,0.465,0.465,0.526,0.508,0.439,0.413,0.421,0.373,0.439,0.437,0.46,0.438,0.422,0.33,0.442,0.532,0.44,0.549,0.578,0.574,0.469,0.431,0.516,medium,Human 62 | SUMO1_HUMAN,0.329,0.405,0.511,0.41,0.48,0.478,0.462,0.494,0.467,0.51,0.517,0.419,0.438,0.369,0.373,0.217,0.414,0.443,0.43,0.432,0.468,0.386,0.46,0.431,0.333,0.453,0.13,0.425,0.442,0.46,0.445,0.433,0.212,0.463,high,Human 63 | SYUA_HUMAN,0.181,0.124,0.176,0.159,0.131,0.139,0.1,0.152,0.242,0.233,0.22,0.119,0.132,0.103,0.111,0.138,0.203,0.155,0.141,0.167,0.086,0.151,0.167,0.105,0.136,0.141,0.131,0.212,0.222,0.205,0.189,0.234,0.005,0.156,medium,Human 64 | TADBP_HUMAN,0.121,0.142,0.185,0.121,0.08,0.08,0.058,0.041,0.051,0.048,-0.075,0.1,0.096,0.097,0.061,0.158,0.063,-0.011,-0.006,0.038,0.206,0.082,0.006,0.023,-0.014,0.047,0.291,-0.018,0.001,0.088,0.112,0.013,-0.051,0.109,low,Human 65 | TAT_HV1BR,0.203,0.379,0.263,0.237,0.319,0.3,0.303,0.328,0.343,0.342,0.203,0.255,0.268,0.293,0.201,0.379,0.364,0.396,0.398,0.397,0.394,0.269,0.136,0.246,0.223,0.279,-0.09,0.397,0.393,0.405,0.387,0.185,0.283,0.268,high,Virus 66 | TPK1_HUMAN,0.301,0.232,0.252,0.305,0.23,0.229,0.245,0.276,0.27,0.318,0.249,0.219,0.228,0.217,0.236,0.088,0.142,0.228,0.272,0.212,0.115,0.267,0.273,0.253,0.306,0.289,0.067,0.253,0.235,0.253,0.188,0.285,0.117,0.277,medium,Human 67 | TPMT_HUMAN,0.411,0.478,0.513,0.478,0.499,0.513,0.471,0.476,0.517,0.547,0.511,0.489,0.509,0.372,0.456,0.333,0.42,0.481,0.496,0.478,0.447,0.506,0.481,0.466,0.424,0.508,0.241,0.445,0.554,0.542,0.489,0.546,0.334,0.515,medium,Human 68 | TPOR_HUMAN,0.419,0.429,0.408,0.451,0.288,0.296,0.504,0.482,0.368,0.37,0.362,0.283,0.263,0.376,0.365,0.25,0.338,0.313,0.377,0.327,0.367,0.2,0.471,0.451,0.34,0.385,0.37,0.455,0.412,0.436,0.274,0.407,0.395,0.437,low,Human 69 | TRPC_SACS2,0.558,0.575,0.598,0.591,0.575,0.585,0.635,0.66,0.611,0.643,0.592,0.558,0.574,0.575,0.606,0.421,0.564,0.508,0.575,0.575,0.498,0.499,0.558,0.521,0.524,0.563,0.173,0.499,0.531,0.59,0.52,0.621,0.221,0.601,medium,Prokaryote 70 | TRPC_THEMA,0.448,0.39,0.442,0.438,0.416,0.417,0.479,0.477,0.478,0.506,0.43,0.394,0.411,0.405,0.45,0.336,0.399,0.395,0.455,0.416,0.437,0.421,0.369,0.387,0.478,0.433,0.395,0.434,0.369,0.427,0.365,0.439,0.14,0.442,medium,Prokaryote 71 | UBC9_HUMAN,0.428,0.327,0.452,0.475,0.508,0.522,0.496,0.519,0.477,0.509,0.557,0.521,0.531,0.371,0.496,0.22,0.405,0.445,0.419,0.43,0.42,0.479,0.471,0.449,0.442,0.484,-0.042,0.403,0.478,0.454,0.432,0.42,0.071,0.54,medium,Human 72 | UBE4B_MOUSE,0.262,0.419,0.351,0.39,0.463,0.47,0.344,0.364,0.447,0.471,0.383,0.454,0.468,0.412,0.417,0.122,0.324,0.337,0.3,0.314,0.444,0.376,0.395,0.376,0.289,0.408,0.078,0.089,0.435,0.437,0.396,0.361,0.091,0.472,low,Eukaryote 73 | VKOR1_HUMAN,0.438,0.4,0.416,0.475,0.425,0.432,0.446,0.462,0.437,0.461,0.44,0.4,0.416,0.369,0.392,0.168,0.186,0.334,0.37,0.314,0.25,0.428,0.424,0.398,0.414,0.43,0.122,0.391,0.434,0.407,0.338,0.428,0.134,0.487,medium,Human 74 | YAP1_HUMAN,0.217,0.402,0.326,0.359,0.449,0.455,0.07,0.053,0.28,0.285,0.189,0.464,0.46,0.428,0.321,0.18,0.176,0.158,0.167,0.176,0.299,0.226,0.257,0.207,0.15,0.239,0.329,0.289,0.334,0.425,0.451,0.333,0.139,0.439,low,Human 75 | ,0.401,0.419,0.43,0.446,0.443,0.449,0.421,0.432,0.372,0.401,0.391,0.404,0.421,0.375,0.413,0.321,0.366,0.375,0.38,0.388,0.341,0.383,0.383,0.387,0.402,0.413,0.166,0.348,0.459,0.444,0.408,0.358,0.175,0.472,, 76 | -------------------------------------------------------------------------------- /src/run_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | import torch 5 | import models.gvp.data, models.gvp.models 6 | import json 7 | import os 8 | import numpy as np 9 | import torch_geometric 10 | from functools import partial 11 | import esm 12 | import random 13 | 14 | print = partial(print, flush=True) 15 | import torch.multiprocessing 16 | 17 | torch.multiprocessing.set_sharing_strategy("file_system") 18 | from models.msa_transformer.model import MSATransformer 19 | from models.gvp.models import SSEmbGNN 20 | import pickle 21 | from visualization import ( 22 | plot_mave_corr_vs_depth, 23 | ) 24 | from helpers import ( 25 | read_msa, 26 | loop_trainval, 27 | prepare_mave_val, 28 | prepare_proteingym, 29 | prepare_proteingym_bad, 30 | prepare_proteingym_default, 31 | ) 32 | import run_test_mave, run_test_proteingym, run_test_proteingym_default, run_test_proteingym_bad, run_test_proteingym_good, run_test_rocklin, run_pipeline_scannet, run_test_clinvar 33 | import torch.utils.data 34 | import torch.multiprocessing as mp 35 | import torch.distributed as dist 36 | from torch.utils.data.distributed import DistributedSampler 37 | from torch.nn.parallel import DistributedDataParallel as DDP 38 | 39 | #DEVICES = [3, 4, 5, 6, 7, 8, 9] 40 | DEVICES = [1] 41 | os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join([str(x) for x in DEVICES]) 42 | 43 | def setup(rank, world_size): 44 | os.environ["MASTER_ADDR"] = "localhost" 45 | os.environ["MASTER_PORT"] = "1111" 46 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 47 | 48 | def prepare( 49 | dataset, 50 | rank, 51 | world_size, 52 | batch_size=1, 53 | pin_memory=False, 54 | num_workers=0, 55 | train=False, 56 | ): 57 | sampler = DistributedSampler( 58 | dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True 59 | ) 60 | if train == True: 61 | dataloader = torch_geometric.loader.DataLoader( 62 | dataset, 63 | batch_size=batch_size, 64 | pin_memory=pin_memory, 65 | num_workers=num_workers, 66 | drop_last=False, 67 | shuffle=False, 68 | sampler=sampler, 69 | ) 70 | else: 71 | dataloader = torch_geometric.loader.DataLoader( 72 | dataset, 73 | batch_size=batch_size, 74 | pin_memory=pin_memory, 75 | num_workers=num_workers, 76 | drop_last=False, 77 | shuffle=False, 78 | sampler=None, 79 | ) 80 | return dataloader 81 | 82 | 83 | def cleanup(): 84 | dist.destroy_process_group() 85 | 86 | 87 | def main(rank, world_size): 88 | # Setup the process groups 89 | torch.cuda.set_device(rank) 90 | torch.cuda.empty_cache() 91 | setup(rank, world_size) 92 | 93 | # Set fixed seed 94 | seed = 1 95 | torch.manual_seed(seed) 96 | np.random.seed(seed) 97 | random.seed(seed) 98 | torch.cuda.manual_seed_all(seed) 99 | torch.backends.cudnn.deterministic = True 100 | torch.backends.cudnn.benchmark = False 101 | torch.backends.cudnn.enabled = False 102 | 103 | # Print name of run 104 | run_name = "final_cath" 105 | 106 | # Set initial parameters 107 | EPOCHS = 200 108 | EPOCH_FINETUNE_MSA = 100 109 | VAL_INTERVAL = 10 110 | BATCH_PROTS = 128 // len(DEVICES) 111 | LR_LIST = [1e-3, 1e-6] 112 | PATIENCE = 1 113 | 114 | ## Load CATH data 115 | print("Preparing CATH data") 116 | pdb_dir_cath = "../data/train/cath" 117 | subprocess.run([f"bash {pdb_dir_cath}/getCATH.sh"], shell=True) 118 | subprocess.run([f"mv chain_set.jsonl {pdb_dir_cath}/chain_set.json"], shell=True) 119 | subprocess.run([f"mv chain_set_splits.json {pdb_dir_cath}/chain_set_splits.json"], shell=True) 120 | cath = models.gvp.data.CATHDataset( 121 | path=f"{pdb_dir_cath}/chain_set.json", 122 | splits_path=f"{pdb_dir_cath}/chain_set_splits.json", 123 | ) 124 | 125 | # Compute MSAs 126 | # TO DO: Add code example to extract sequences from CATH data set 127 | # to file: f"{pdb_dir_cath}/seqs.fasta" 128 | sys.path += [":/projects/prism/people/skr526/mmseqs/bin"] 129 | subprocess.run( 130 | [ 131 | "colabfold_search", 132 | f"{pdb_dir_cath}/seqs.fasta", 133 | "/projects/prism/people/skr526/databases", 134 | "../data/train/cath/msa/", 135 | ] 136 | ) 137 | subprocess.run(["python", "merge_and_sort_msas.py", "../data/train/cath/msa"]) 138 | 139 | # Add MSAs 140 | for i, entry in enumerate(cath.total): 141 | print(f"Adding CATH MSAs: {i+1}/{len(cath.total)}") 142 | entry["msa"] = read_msa(f"{pdb_dir_cath}/msa/{entry['name']}.a3m") 143 | 144 | # Checkpoint - save and load 145 | with open(f"{pdb_dir_cath}/data_with_msas.pkl", "wb") as fp: # Pickling 146 | pickle.dump(cath.total, fp) 147 | 148 | with open(f"{pdb_dir_cath}/data_with_msas.pkl", "rb") as fp: # Unpickling 149 | cath.total = pickle.load(fp) 150 | 151 | ## Filter data 152 | # Only keep entries where MSA and structucture sequence lengths match 153 | data = [ 154 | entry for entry in cath.total if len(entry["seq"]) == len(entry["msa"][0][1]) 155 | ] 156 | 157 | # Filter: Only keep entries without X in sequence 158 | data = [entry for entry in cath.total if "X" not in entry["seq"]] 159 | 160 | # Save all training and validation sequences in a fasta file to check homology 161 | cath.split() 162 | with open(f"../data/test/mave_val/structure/coords.json") as json_file: 163 | data_mave_val = json.load(json_file) 164 | 165 | with open(f"../data/test/proteingym/structure/coords.json") as json_file: 166 | data_proteingym = json.load(json_file) 167 | 168 | fh = open(f"../data/train/cath/seqs_cath.fasta", "w") 169 | for entry in cath.train: 170 | fh.write(f">{entry['name']}\n") 171 | fh.write(f"{entry['seq']}\n") 172 | 173 | for entry in cath.val: 174 | fh.write(f">{entry['name']}\n") 175 | fh.write(f"{entry['seq']}\n") 176 | 177 | for entry in data_mave_val: 178 | fh.write(f">{entry['name']}\n") 179 | fh.write(f"{entry['seq']}\n") 180 | 181 | for entry in data_proteingym: 182 | fh.write(f">{entry['name']}\n") 183 | fh.write(f"{entry['seq']}\n") 184 | fh.close() 185 | 186 | # Compute clusters of 95% sequence similarities between all training, validation and test proteins 187 | subprocess.run( 188 | [ 189 | "cd-hit", 190 | "-i", 191 | "../data/train/cath/seqs_cath.fasta", 192 | "-o", 193 | "../data/train/cath/seqs_cath_homology.fasta", 194 | "-c", 195 | "0.95", 196 | "-n", 197 | "5", 198 | "-d", 199 | "999", 200 | ] 201 | ) 202 | 203 | # Remove proteins from training data that has high sequence similarity with validation or test proteins 204 | val_prot_names = [entry["name"] for entry in cath.val] 205 | val_mave_prot_names = [entry["name"] for entry in data_mave_val] 206 | test_prot_names = [entry["name"] for entry in data_proteingym] 207 | valtest_prot_names = val_prot_names + val_mave_prot_names + test_prot_names 208 | 209 | fh = open("../data/train/cath/seqs_cath_homology.fasta.clstr", "r") 210 | cluster_dict = {} 211 | remove_list = [] 212 | for line in fh.readlines(): 213 | if line.startswith(">Cluster"): 214 | cluster_name = line 215 | cluster_dict[cluster_name] = [] 216 | else: 217 | cluster_dict[cluster_name].append(line.split(">")[1].split("...")[0]) 218 | 219 | for cluster_name, prot_names in cluster_dict.items(): 220 | if len(prot_names) > 1 and any( 221 | valtest_prot_name in prot_names for valtest_prot_name in valtest_prot_names 222 | ): 223 | remove_list += prot_names 224 | remove_list = [ 225 | prot_name for prot_name in remove_list if prot_name not in valtest_prot_names 226 | ] 227 | cath.train = [entry for entry in cath.train if entry["name"] not in remove_list] 228 | 229 | # Checkpoint - save and load 230 | with open( 231 | f"{pdb_dir_cath}/data_with_msas_filtered_train.pkl", "wb" 232 | ) as fp: # Pickling 233 | pickle.dump(cath.train, fp) 234 | with open( 235 | f"{pdb_dir_cath}/data_with_msas_filtered_val.pkl", "wb" 236 | ) as fp: # Pickling 237 | pickle.dump(cath.val, fp) 238 | 239 | # Prepare MAVE validation and ProteinGym test data 240 | prepare_mave_val() 241 | prepare_proteingym() 242 | prepare_proteingym_bad() 243 | prepare_proteingym_default() 244 | 245 | with open( 246 | f"{pdb_dir_cath}/data_with_msas_filtered_train.pkl", "rb" 247 | ) as fp: # Unpickling 248 | cath.train = pickle.load(fp) 249 | with open( 250 | f"{pdb_dir_cath}/data_with_msas_filtered_val.pkl", "rb" 251 | ) as fp: # Unpickling 252 | cath.val = pickle.load(fp) 253 | 254 | # Convert to graph data sets 255 | trainset = models.gvp.data.ProteinGraphData(cath.train) 256 | valset = models.gvp.data.ProteinGraphData(cath.val) 257 | 258 | ## Load and initialize MSA Transformer 259 | model_msa_pre, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S() 260 | torch.save( 261 | model_msa_pre.state_dict(), 262 | f"../output/train/models/msa_transformer/pretrained.pt", 263 | ) 264 | msa_batch_converter = msa_alphabet.get_batch_converter() 265 | model_msa = MSATransformer() 266 | model_msa.load_state_dict( 267 | torch.load(f"../output/train/models/msa_transformer/pretrained.pt") 268 | ) 269 | model_msa.to(rank) 270 | 271 | # Freeze MSA Transformer 272 | for param in model_msa.parameters(): 273 | param.requires_grad = False 274 | 275 | # Load and initialize GVP 276 | node_dim = (256, 64) 277 | edge_dim = (32, 1) 278 | model_gvp = SSEmbGNN((6, 3), node_dim, (32, 1), edge_dim) 279 | model_gvp.to(rank) 280 | model_gvp = DDP( 281 | model_gvp, 282 | device_ids=[rank], 283 | output_device=rank, 284 | find_unused_parameters=True, 285 | static_graph=False, 286 | ) 287 | 288 | # Initialize training modules 289 | train_loader = prepare(trainset, rank, world_size, train=True) 290 | val_loader = prepare(valset, rank, world_size) 291 | optimizer = torch.optim.Adam(model_gvp.parameters(), lr=LR_LIST[0]) 292 | scaler = torch.cuda.amp.GradScaler() 293 | best_epoch, best_corr_mave = None, 0 294 | patience_counter = 0 295 | 296 | # Initialize lists for monitoring loss 297 | epoch_list = [] 298 | loss_train_list, loss_val_list = [], [] 299 | acc_train_list, acc_val_list = [], [] 300 | corr_mave_list, acc_mave_list = [], [] 301 | 302 | for epoch in range(EPOCHS): 303 | # Check if we should fine-tune MSA Transformer row attention 304 | if epoch == EPOCH_FINETUNE_MSA: 305 | for param in model_msa.named_parameters(): 306 | if "row_self_attention" in param[0]: 307 | param[1].requires_grad = True 308 | model_msa = DDP( 309 | model_msa, 310 | device_ids=[rank], 311 | output_device=rank, 312 | find_unused_parameters=True, 313 | static_graph=False, 314 | ) 315 | optimizer.add_param_group( 316 | {"params": model_msa.parameters(), "lr": LR_LIST[1]} 317 | ) 318 | BATCH_PROTS = 2048 // len(DEVICES) 319 | 320 | # If we are using DistributedSampler, we need to tell it which epoch this is 321 | train_loader.sampler.set_epoch(epoch) 322 | 323 | # Train loop 324 | model_msa.train() 325 | model_gvp.train() 326 | 327 | loss_train, acc_train = loop_trainval( 328 | model_msa, 329 | model_gvp, 330 | msa_batch_converter, 331 | train_loader, 332 | BATCH_PROTS, 333 | epoch, 334 | rank, 335 | EPOCH_FINETUNE_MSA, 336 | optimizer=optimizer, 337 | scaler=scaler, 338 | ) 339 | ## Gather and save training metrics for epoch 340 | # OBS: This cannot be placed within validation loop or we get hangs 341 | loss_train = loss_train.type(torch.float32) 342 | loss_train_all_gather = [torch.zeros(1, device=rank) for _ in range(world_size)] 343 | dist.all_gather(loss_train_all_gather, loss_train) 344 | 345 | # Validation loop 346 | if rank == 0: 347 | if epoch % VAL_INTERVAL == 0: 348 | # Save model 349 | path_msa = f"../output/train/models/msa_transformer/{run_name}_msa_transformer_{epoch}.pt" 350 | path_gvp = f"../output/train/models/gvp/{run_name}_gvp_{epoch}.pt" 351 | path_optimizer = ( 352 | f"../output/train/models/optimizer/{run_name}_adam_{epoch}.pt" 353 | ) 354 | torch.save(model_msa.state_dict(), path_msa) 355 | torch.save(model_gvp.state_dict(), path_gvp) 356 | torch.save(optimizer.state_dict(), path_optimizer) 357 | 358 | with torch.no_grad(): 359 | # Do training validation 360 | model_msa.eval() 361 | model_gvp.eval() 362 | 363 | loss_val, acc_val = loop_trainval( 364 | model_msa, 365 | model_gvp, 366 | msa_batch_converter, 367 | val_loader, 368 | BATCH_PROTS, 369 | epoch, 370 | rank, 371 | EPOCH_FINETUNE_MSA, 372 | ) 373 | 374 | if epoch >= EPOCH_FINETUNE_MSA: 375 | # Do validation on MAVE set 376 | corr_mave, acc_mave = run_test_mave.test( 377 | run_name, 378 | epoch, 379 | device=rank, 380 | ) 381 | 382 | # Update patience 383 | if corr_mave > best_corr_mave: 384 | best_epoch = epoch 385 | patience_counter = 0 386 | else: 387 | patience_counter += 1 388 | else: 389 | corr_mave, acc_mave = 0.0, 0.0 390 | 391 | # Save validation results 392 | epoch_list.append(epoch) 393 | loss_train_list.append( 394 | torch.mean(torch.stack(loss_train_all_gather)).to("cpu").item() 395 | ) 396 | loss_val_list.append(loss_val.to("cpu").item()) 397 | acc_val_list.append(acc_val) 398 | corr_mave_list.append(corr_mave) 399 | acc_mave_list.append(corr_mave) 400 | 401 | metrics = { 402 | "epoch": epoch_list, 403 | "loss_train": loss_train_list, 404 | "loss_val": loss_val_list, 405 | "acc_val": acc_val_list, 406 | "corr_mave": corr_mave_list, 407 | "acc_mave": acc_mave_list, 408 | } 409 | with open(f"../output/train/metrics/{run_name}_metrics", "wb") as f: 410 | pickle.dump(metrics, f) 411 | 412 | if patience_counter == PATIENCE: 413 | break 414 | 415 | # Create barrier after each epoch 416 | dist.barrier() 417 | 418 | # Clean up 419 | cleanup() 420 | 421 | # MAVE val set 422 | print("Starting MAVE val predictions") 423 | run_test_mave.test(run_name, best_epoch, get_only_ssemb_metrics=False, device=rank) 424 | plot_mave_corr_vs_depth() 425 | print("Finished MAVE val predictions") 426 | 427 | # ProteinGym test set 428 | print("Starting ProteinGym test") 429 | run_test_proteingym.test(run_name, best_epoch, device=rank) 430 | print("Finished ProteinGym test") 431 | 432 | # Rocklin test set 433 | print("Starting Rocklin test") 434 | run_test_rocklin.test(run_name, best_epoch, num_ensemble=5, device=rank) 435 | print("Finished Rocklin test") 436 | 437 | # ScanNet test set 438 | print("Starting ScanNet test") 439 | run_pipeline_scannet.run(run_name, best_epoch, device=rank) 440 | print("Finished ScanNet test") 441 | 442 | # ClinVar test set 443 | print("Starting ClinVar test") 444 | run_test_clinvar.test(run_name, best_epoch, num_ensemble=5, device=rank) 445 | print("Finished ClinVar test") 446 | 447 | if __name__ == "__main__": 448 | world_size = len(DEVICES) 449 | mp.spawn( 450 | main, 451 | args=(world_size,), 452 | nprocs=world_size, 453 | ) 454 | -------------------------------------------------------------------------------- /src/run_pipeline_scannet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | import re 5 | import torch 6 | import models.gvp.data, models.gvp.models 7 | import json 8 | import numpy as np 9 | import torch_geometric 10 | import esm 11 | import pandas as pd 12 | import random 13 | import torch.multiprocessing 14 | 15 | torch.multiprocessing.set_sharing_strategy("file_system") 16 | from models.msa_transformer.model import MSATransformer 17 | from models.gvp.models import SSEmbGNN 18 | from models.scannet.model import TransformerModel 19 | import pickle 20 | from helpers import ( 21 | read_msa, 22 | loop_getemb, 23 | scannet_collate_fn, 24 | loop_scannet_trainval, 25 | loop_scannet_test, 26 | ) 27 | from visualization import plot_precision_recall 28 | import pdb_parser_scripts.parse_pdbs as parse_pdbs 29 | import torch.utils.data 30 | from sklearn.metrics import auc 31 | from collections import OrderedDict 32 | from Bio.PDB import PDBList 33 | from Bio.PDB import PDBParser 34 | from Bio.PDB.PDBIO import PDBIO 35 | import io 36 | from contextlib import redirect_stdout 37 | import pymol2 38 | import time 39 | 40 | 41 | def run(run_name, epoch, device=None): 42 | # Download raw data 43 | subprocess.run( 44 | [ 45 | "wget", 46 | "-c", 47 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/table.csv", 48 | "-P", 49 | "../data/test/scannet/", 50 | ] 51 | ) 52 | subprocess.run( 53 | [ 54 | "wget", 55 | "-c", 56 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_train.txt", 57 | "-P", 58 | "../data/test/scannet/labels", 59 | ] 60 | ) 61 | subprocess.run( 62 | [ 63 | "wget", 64 | "-c", 65 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_validation_70.txt", 66 | "-P", 67 | "../data/test/scannet/labels", 68 | ] 69 | ) 70 | subprocess.run( 71 | [ 72 | "wget", 73 | "-c", 74 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_validation_homology.txt", 75 | "-P", 76 | "../data/test/scannet/labels", 77 | ] 78 | ) 79 | subprocess.run( 80 | [ 81 | "wget", 82 | "-c", 83 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_validation_topology.txt", 84 | "-P", 85 | "../data/test/scannet/labels", 86 | ] 87 | ) 88 | subprocess.run( 89 | [ 90 | "wget", 91 | "-c", 92 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_validation_none.txt", 93 | "-P", 94 | "../data/test/scannet/labels", 95 | ] 96 | ) 97 | subprocess.run( 98 | [ 99 | "wget", 100 | "-c", 101 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_test_70.txt", 102 | "-P", 103 | "../data/test/scannet/labels", 104 | ] 105 | ) 106 | subprocess.run( 107 | [ 108 | "wget", 109 | "-c", 110 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_test_homology.txt", 111 | "-P", 112 | "../data/test/scannet/labels", 113 | ] 114 | ) 115 | subprocess.run( 116 | [ 117 | "wget", 118 | "-c", 119 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_test_topology.txt", 120 | "-P", 121 | "../data/test/scannet/labels", 122 | ] 123 | ) 124 | subprocess.run( 125 | [ 126 | "wget", 127 | "-c", 128 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_test_none.txt", 129 | "-P", 130 | "../data/test/scannet/labels", 131 | ] 132 | ) 133 | 134 | # Download PDBs 135 | parser = PDBParser() 136 | pdb_io = PDBIO() 137 | df = pd.read_csv("../data/test/scannet/table.csv") 138 | pdb_list_all = df["PDB ID"].unique() 139 | pdb_list = [x[:4] for x in pdb_list_all] 140 | chain_list = [x[-1] for x in pdb_list_all] 141 | 142 | pdbl = PDBList() 143 | f = io.StringIO() 144 | for i, pdbid in enumerate(pdb_list): 145 | print(f"{i+1}/{len(pdb_list)}") 146 | print(pdbid) 147 | out = "" 148 | if os.path.exists(f"../data/test/scannet/raw/{pdbid}.pdb"): 149 | print("PDB file already downloaded") 150 | else: 151 | with redirect_stdout(f): 152 | pdbl.retrieve_pdb_file( 153 | pdbid, pdir="../data/test/scannet/raw", file_format="pdb" 154 | ) 155 | out = f.getvalue() 156 | if "Desired structure doesn't exists" in out: 157 | try: 158 | pdbl.retrieve_pdb_file( 159 | pdbid, pdir="../data/test/scannet/raw", file_format="mmCif" 160 | ) 161 | with pymol2.PyMOL() as pymol: 162 | pymol.cmd.load( 163 | f"../data/test/scannet/raw/{pdbid}.cif", "my_protein" 164 | ) 165 | pymol.cmd.save( 166 | f"../data/test/scannet/raw/{pdbid}.cif".replace( 167 | ".cif", ".pdb" 168 | ), 169 | selection="my_protein", 170 | ) 171 | except: 172 | print("Protein does not exist as either PDB or mmCIF file") 173 | else: 174 | subprocess.run( 175 | [ 176 | "mv", 177 | f"../data/test/scannet/raw/pdb{pdbid}.ent", 178 | f"../data/test/scannet/raw/{pdbid}.pdb", 179 | ] 180 | ) 181 | time.sleep(1) 182 | subprocess.run(["rm", "-r", "obsolete"]) 183 | 184 | # Split into chains 185 | f = open("../data/test/scannet/scannet_download.log", "w") 186 | for i, pdbid in enumerate(pdb_list): 187 | try: 188 | structure = parser.get_structure( 189 | pdbid, f"../data/test/scannet/raw/{pdbid}.pdb" 190 | ) 191 | pdb_chains = structure.get_chains() 192 | 193 | for chain in pdb_chains: 194 | if chain.get_id() == chain_list[i]: 195 | pdb_io.set_structure(chain) 196 | pdb_io.save( 197 | f"../data/test/scannet/structure/raw/{pdbid}_{chain_list[i]}.pdb" 198 | ) 199 | except: 200 | f.write(f"{pdbid}_{chain_list[i]} not available\n") 201 | f.close() 202 | 203 | # Pre-process PDBs 204 | pdb_dir = "../data/test/scannet/structure" 205 | subprocess.call( 206 | [ 207 | f"pdb_parser_scripts/clean_pdbs.sh", 208 | str(pdb_dir), 209 | ] 210 | ) 211 | parse_pdbs.parse(pdb_dir) 212 | 213 | # Compute MSAs 214 | sys.path += [":/projects/prism/people/skr526/mmseqs/bin"] 215 | subprocess.run( 216 | [ 217 | "colabfold_search", 218 | f"{pdb_dir}/seqs.fasta", 219 | "/projects/prism/people/skr526/databases", 220 | "../data/test/scannet/msa/", 221 | ] 222 | ) 223 | subprocess.run(["python", "merge_and_sort_msas.py", "../data/test/scannet/msa"]) 224 | 225 | # Load structure data 226 | with open(f"../data/test/scannet/structure/coords.json") as json_file: 227 | data_raw = json.load(json_file) 228 | json_file.close() 229 | 230 | # Only keep entries with sequence lengths <= 1024 231 | data = [] 232 | for entry in data_raw: 233 | if len(entry["seq"]) <= 1024 - 1: # Consider added token 234 | data.append(entry) 235 | 236 | # Add MSAs to data 237 | for i, entry in enumerate(data): 238 | print(f"{i+1}/{len(data)}") 239 | msa = read_msa(f'../data/test/scannet/msa/{entry["name"]}.a3m') 240 | msa_sub = [msa[0]] 241 | k = min(len(msa) - 1, 256 - 1) 242 | msa_sub += [msa[i] for i in sorted(random.sample(range(1, len(msa)), k))] 243 | entry["msa"] = msa_sub 244 | 245 | with open("../output/scannet/data_scannet.pickle", "wb") as handle: 246 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 247 | 248 | with open("../output/scannet/data_scannet.pickle", "rb") as handle: 249 | data = pickle.load(handle) 250 | 251 | # Load MSA Transformer 252 | _, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S() 253 | model_msa = MSATransformer(msa_alphabet) 254 | model_msa = model_msa.to(device) 255 | msa_batch_converter = msa_alphabet.get_batch_converter() 256 | 257 | model_dict = OrderedDict() 258 | state_dict_msa = torch.load( 259 | f"../output/train/models/msa_transformer/{run_name}_msa_transformer_{epoch}.pt" 260 | ) 261 | pattern = re.compile("module.") 262 | for k, v in state_dict_msa.items(): 263 | if re.search("module", k): 264 | model_dict[re.sub(pattern, "", k)] = v 265 | else: 266 | model_dict = state_dict_msa 267 | model_msa.load_state_dict(model_dict) 268 | 269 | # Load GVP 270 | node_dim = (256, 64) 271 | edge_dim = (32, 1) 272 | model_gvp = SSEmbGNN((6, 3), node_dim, (32, 1), edge_dim) 273 | model_gvp = model_gvp.to(device) 274 | 275 | model_dict = OrderedDict() 276 | state_dict_gvp = torch.load(f"../output/train/models/gvp/{run_name}_gvp_{epoch}.pt") 277 | pattern = re.compile("module.") 278 | for k, v in state_dict_gvp.items(): 279 | if k.startswith("module"): 280 | model_dict[k[7:]] = v 281 | else: 282 | model_dict = state_dict_gvp 283 | model_gvp.load_state_dict(model_dict) 284 | 285 | # Convert to graph data sets 286 | allset = models.gvp.data.ProteinGraphData(data) 287 | 288 | # Init data loader 289 | data_loader = torch_geometric.loader.DataLoader(allset, batch_size=1, shuffle=False) 290 | 291 | # Add frozen embeddings to data 292 | with torch.no_grad(): 293 | emb_dict = loop_getemb( 294 | model_msa, 295 | model_gvp, 296 | msa_batch_converter, 297 | data_loader, 298 | device=device, 299 | ) 300 | for entry in data: 301 | entry["emb"] = emb_dict[entry["name"]] 302 | 303 | with open("../output/scannet/data_scannet_emb.pickle", "wb") as handle: 304 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) 305 | 306 | with open("../output/scannet/data_scannet_emb.pickle", "rb") as handle: 307 | data = pickle.load(handle) 308 | 309 | # Add sample weights to data 310 | df = pd.read_csv(f"../data/test/scannet/table.csv") 311 | for entry in data: 312 | if entry["name"] in df["PDB ID"].unique(): 313 | entry["prot_weight"] = df[df["PDB ID"] == entry["name"]][ 314 | "Sample weight" 315 | ].item() 316 | 317 | # Concat label files 318 | filenames = [ 319 | f"../data/test/scannet/labels/labels_train.txt", 320 | f"../data/test/scannet/labels/labels_validation_70.txt", 321 | f"../data/test/scannet/labels/labels_validation_homology.txt", 322 | f"../data/test/scannet/labels/labels_validation_topology.txt", 323 | f"../data/test/scannet/labels/labels_validation_none.txt", 324 | f"../data/test/scannet/labels/labels_test_70.txt", 325 | f"../data/test/scannet/labels/labels_test_homology.txt", 326 | f"../data/test/scannet/labels/labels_test_topology.txt", 327 | f"../data/test/scannet/labels/labels_test_none.txt", 328 | ] 329 | with open(f"../data/test/scannet/labels/labels_all.txt", "w") as outfile: 330 | for fname in filenames: 331 | with open(fname) as infile: 332 | for line in infile: 333 | outfile.write(line) 334 | 335 | # Add labels to data 336 | f = open(f"../data/test/scannet/labels/labels_all.txt", "r") 337 | label_dict = {} 338 | prot = None 339 | for line in f.readlines(): 340 | if line.startswith(">"): 341 | prot = line[1:].strip()[:4] + "_" + line[1:].strip()[-1] 342 | label_dict[prot] = [] 343 | else: 344 | label_dict[prot].append( 345 | line.strip().split(" ") 346 | ) # [chain_id, pos, aa, label] 347 | f.close() 348 | 349 | # Add labels to data 350 | i = 0 351 | j = 0 352 | data_clean = [] 353 | for entry in data: 354 | if entry["name"] in label_dict.keys(): 355 | label_seq = "".join([x[2] for x in label_dict[entry["name"]]]) 356 | 357 | if label_seq == entry["seq"]: 358 | entry["label"] = torch.tensor( 359 | [int(x[3]) for x in label_dict[entry["name"]]], device=device 360 | ) 361 | else: 362 | entry["label"] = None 363 | j += 1 364 | else: 365 | entry["label"] = None 366 | i += 1 367 | 368 | # If no errors; add label data to entry 369 | if entry["label"] is not None: 370 | data_clean.append(entry) 371 | 372 | print( 373 | f"Number of PDBs where we have structure data but no label data: {i}/{len(data)}" 374 | ) 375 | print( 376 | f"Number of PDBs where the label seq and the structure seq doesn't match: {j}/{len(data)}" 377 | ) 378 | print(f"Number of PDBs in cleaned data set: {len(data_clean)}") 379 | 380 | # Set parameters for downstream model learning 381 | EPOCHS = 40 382 | VAL_INTERVAL = 1 383 | BATCH_PROTS = 10 384 | LR = 1e-4 385 | 386 | # Split intro train/val/test 387 | df = pd.read_csv(f"../data/test/scannet/table.csv") 388 | data_train = [ 389 | x for x in data_clean if x["name"] in list(df[df["Set"] == "Train"]["PDB ID"]) 390 | ] 391 | data_val = [ 392 | x 393 | for x in data_clean 394 | if x["name"] in list(df[df["Set"] == "Validation (70\%)"]["PDB ID"]) 395 | ] 396 | 397 | # Load Transformer model 398 | model_transformer = TransformerModel( 399 | ntoken=1, nhead=4, d_hid=256, nlayers=4, dropout=0.1 400 | ) 401 | model_transformer = model_transformer.to(device) 402 | 403 | # Init optimizer 404 | optimizer = torch.optim.Adam(model_transformer.parameters(), lr=LR) 405 | 406 | # Initialize data loader 407 | train_loader = torch.utils.data.DataLoader( 408 | data_train, batch_size=1, collate_fn=scannet_collate_fn, shuffle=True 409 | ) # OBS: Use grad accumulation if bs > 1 410 | val_loader = torch.utils.data.DataLoader( 411 | data_val, batch_size=1, collate_fn=scannet_collate_fn, shuffle=False 412 | ) 413 | 414 | # Initialize lists for monitoring loss 415 | epoch_list = [] 416 | loss_train_list, loss_val_list = [], [] 417 | acc_train_list, acc_val_list = [], [] 418 | corr_mave_list, acc_mave_list = [], [] 419 | best_epoch, best_loss_val = None, np.inf 420 | 421 | # Begin training and validation loop 422 | for epoch in range(EPOCHS): 423 | # Train loop 424 | model_transformer.train() 425 | loss_train, acc_train = loop_scannet_trainval( 426 | model_transformer, 427 | train_loader, 428 | device=device, 429 | optimizer=optimizer, 430 | batch_prots=BATCH_PROTS, 431 | ) 432 | 433 | # Save model 434 | path_transformer = ( 435 | f"../output/scannet/models/transformer/{run_name}_transformer_{epoch}.pt" 436 | ) 437 | path_optimizer = ( 438 | f"../output/scannet/models/optimizer/{run_name}_adam_{epoch}.pt" 439 | ) 440 | torch.save(model_transformer.state_dict(), path_transformer) 441 | torch.save(optimizer.state_dict(), path_optimizer) 442 | 443 | # Compute validation 444 | if epoch % VAL_INTERVAL == 0: 445 | # Validation 446 | with torch.no_grad(): 447 | model_transformer.eval() 448 | loss_val, acc_val = loop_scannet_trainval( 449 | model_transformer, 450 | val_loader, 451 | device=device, 452 | ) 453 | 454 | if loss_val < best_loss_val: 455 | best_epoch, best_loss_val = epoch, loss_val 456 | 457 | # Save validation results 458 | epoch_list.append(epoch) 459 | loss_train_list.append(loss_train.to("cpu").item()) 460 | loss_val_list.append(loss_val.to("cpu").item()) 461 | acc_val_list.append(acc_val) 462 | 463 | metrics = { 464 | "epoch": epoch_list, 465 | "loss_train": loss_train_list, 466 | "loss_val": loss_val_list, 467 | "acc_val": acc_val_list, 468 | } 469 | with open(f"../output/scannet/metrics/{run_name}_metrics", "wb") as f: 470 | pickle.dump(metrics, f) 471 | 472 | # Test 473 | #best_epoch = 26 # OBS: Uncomment this line to use weights from paper 474 | print(f"Testing model! Using best model from epoch: {best_epoch}") 475 | model_transformer.load_state_dict( 476 | torch.load( 477 | f"../output/scannet/models/transformer/{run_name}_transformer_{best_epoch}.pt" 478 | ) 479 | ) 480 | 481 | test_sets = [ 482 | ["Test (70\%)"], 483 | ["Test (Homology)"], 484 | ["Test (Topology)"], 485 | ["Test (None)"], 486 | ["Test (70\%)", "Test (Homology)", "Test (Topology)", "Test (None)"], 487 | ] 488 | 489 | precision_list = [] 490 | recall_list = [] 491 | outfile = open(f"../output/scannet/{run_name}_test_results.txt", "w") 492 | 493 | for test_set in test_sets: 494 | print(f"Computing predictions for: {' & '.join(test_set)}") 495 | data_test = [ 496 | x 497 | for x in data_clean 498 | if x["name"] in list(df[df["Set"].isin(test_set)]["PDB ID"]) 499 | ] 500 | test_loader = torch.utils.data.DataLoader( 501 | data_test, batch_size=1, collate_fn=scannet_collate_fn, shuffle=False 502 | ) 503 | 504 | with torch.no_grad(): 505 | model_transformer.eval() 506 | precision, recall = loop_scannet_test( 507 | model_transformer, 508 | test_loader, 509 | device=device, 510 | ) 511 | precision_list.append(precision) 512 | recall_list.append(recall) 513 | auc_precision_recall = auc(recall, precision) 514 | outfile.write( 515 | f"Test AUCPR for {' & '.join(test_set)} is: {auc_precision_recall}\n" 516 | ) 517 | 518 | # Plot test results 519 | plot_precision_recall(recall_list, precision_list) 520 | outfile.close() 521 | -------------------------------------------------------------------------------- /src/models/msa_transformer/multihead_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | from typing import Dict, Optional, Tuple 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import Tensor, nn 12 | from torch.nn import Parameter 13 | import uuid 14 | 15 | 16 | def utils_softmax(x, dim: int, onnx_trace: bool = False): 17 | if onnx_trace: 18 | return F.softmax(x.float(), dim=dim) 19 | else: 20 | return F.softmax(x, dim=dim, dtype=torch.float32) 21 | 22 | 23 | class FairseqIncrementalState(object): 24 | def __init__(self, *args, **kwargs): 25 | super().__init__(*args, **kwargs) 26 | self.init_incremental_state() 27 | 28 | def init_incremental_state(self): 29 | self._incremental_state_id = str(uuid.uuid4()) 30 | 31 | def _get_full_incremental_state_key(self, key: str) -> str: 32 | return "{}.{}".format(self._incremental_state_id, key) 33 | 34 | def get_incremental_state( 35 | self, 36 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], 37 | key: str, 38 | ) -> Optional[Dict[str, Optional[Tensor]]]: 39 | """Helper for getting incremental state for an nn.Module.""" 40 | full_key = self._get_full_incremental_state_key(key) 41 | if incremental_state is None or full_key not in incremental_state: 42 | return None 43 | return incremental_state[full_key] 44 | 45 | def set_incremental_state( 46 | self, 47 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], 48 | key: str, 49 | value: Dict[str, Optional[Tensor]], 50 | ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: 51 | """Helper for setting incremental state for an nn.Module.""" 52 | if incremental_state is not None: 53 | full_key = self._get_full_incremental_state_key(key) 54 | incremental_state[full_key] = value 55 | return incremental_state 56 | 57 | 58 | def with_incremental_state(cls): 59 | cls.__bases__ = (FairseqIncrementalState,) + tuple( 60 | b for b in cls.__bases__ if b != FairseqIncrementalState 61 | ) 62 | return cls 63 | 64 | 65 | @with_incremental_state 66 | class MultiheadAttention(nn.Module): 67 | """Multi-headed attention. 68 | 69 | See "Attention Is All You Need" for more details. 70 | """ 71 | 72 | def __init__( 73 | self, 74 | embed_dim, 75 | num_heads, 76 | kdim=None, 77 | vdim=None, 78 | dropout=0.0, 79 | bias=True, 80 | add_bias_kv: bool = False, 81 | add_zero_attn: bool = False, 82 | self_attention: bool = False, 83 | encoder_decoder_attention: bool = False, 84 | use_rotary_embeddings: bool = False, 85 | ): 86 | super().__init__() 87 | self.embed_dim = embed_dim 88 | self.kdim = kdim if kdim is not None else embed_dim 89 | self.vdim = vdim if vdim is not None else embed_dim 90 | self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim 91 | 92 | self.num_heads = num_heads 93 | self.dropout = dropout 94 | self.head_dim = embed_dim // num_heads 95 | assert ( 96 | self.head_dim * num_heads == self.embed_dim 97 | ), "embed_dim must be divisible by num_heads" 98 | self.scaling = self.head_dim**-0.5 99 | 100 | self.self_attention = self_attention 101 | self.encoder_decoder_attention = encoder_decoder_attention 102 | 103 | assert not self.self_attention or self.qkv_same_dim, ( 104 | "Self-attention requires query, key and " "value to be of the same size" 105 | ) 106 | 107 | self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) 108 | self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) 109 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 110 | 111 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 112 | 113 | if add_bias_kv: 114 | self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) 115 | self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) 116 | else: 117 | self.bias_k = self.bias_v = None 118 | 119 | self.add_zero_attn = add_zero_attn 120 | 121 | self.reset_parameters() 122 | 123 | self.onnx_trace = False 124 | self.rot_emb = None 125 | if use_rotary_embeddings: 126 | self.rot_emb = RotaryEmbedding(dim=self.head_dim) 127 | 128 | self.enable_torch_version = False 129 | if hasattr(F, "multi_head_attention_forward"): 130 | self.enable_torch_version = True 131 | else: 132 | self.enable_torch_version = False 133 | 134 | def prepare_for_onnx_export_(self): 135 | self.onnx_trace = True 136 | 137 | def reset_parameters(self): 138 | if self.qkv_same_dim: 139 | # Empirically observed the convergence to be much better with 140 | # the scaled initialization 141 | nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) 142 | nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) 143 | nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) 144 | else: 145 | nn.init.xavier_uniform_(self.k_proj.weight) 146 | nn.init.xavier_uniform_(self.v_proj.weight) 147 | nn.init.xavier_uniform_(self.q_proj.weight) 148 | 149 | nn.init.xavier_uniform_(self.out_proj.weight) 150 | if self.out_proj.bias is not None: 151 | nn.init.constant_(self.out_proj.bias, 0.0) 152 | if self.bias_k is not None: 153 | nn.init.xavier_normal_(self.bias_k) 154 | if self.bias_v is not None: 155 | nn.init.xavier_normal_(self.bias_v) 156 | 157 | def forward( 158 | self, 159 | query, 160 | key: Optional[Tensor], 161 | value: Optional[Tensor], 162 | key_padding_mask: Optional[Tensor] = None, 163 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 164 | need_weights: bool = True, 165 | static_kv: bool = False, 166 | attn_mask: Optional[Tensor] = None, 167 | before_softmax: bool = False, 168 | need_head_weights: bool = False, 169 | ) -> Tuple[Tensor, Optional[Tensor]]: 170 | """Input shape: Time x Batch x Channel 171 | 172 | Args: 173 | key_padding_mask (ByteTensor, optional): mask to exclude 174 | keys that are pads, of shape `(batch, src_len)`, where 175 | padding elements are indicated by 1s. 176 | need_weights (bool, optional): return the attention weights, 177 | averaged over heads (default: False). 178 | attn_mask (ByteTensor, optional): typically used to 179 | implement causal attention, where the mask prevents the 180 | attention from looking forward in time (default: None). 181 | before_softmax (bool, optional): return the raw attention 182 | weights and values before the attention softmax. 183 | need_head_weights (bool, optional): return the attention 184 | weights for each head. Implies *need_weights*. Default: 185 | return the average attention weights over all heads. 186 | """ 187 | if need_head_weights: 188 | need_weights = True 189 | 190 | tgt_len, bsz, embed_dim = query.size() 191 | assert embed_dim == self.embed_dim 192 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 193 | 194 | if ( 195 | not self.rot_emb 196 | and self.enable_torch_version 197 | and not self.onnx_trace 198 | and incremental_state is None 199 | and not static_kv 200 | # A workaround for quantization to work. Otherwise JIT compilation 201 | # treats bias in linear module as method. 202 | and not torch.jit.is_scripting() 203 | and not need_head_weights 204 | ): 205 | assert key is not None and value is not None 206 | return F.multi_head_attention_forward( 207 | query, 208 | key, 209 | value, 210 | self.embed_dim, 211 | self.num_heads, 212 | torch.empty([0]), 213 | torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), 214 | self.bias_k, 215 | self.bias_v, 216 | self.add_zero_attn, 217 | self.dropout, 218 | self.out_proj.weight, 219 | self.out_proj.bias, 220 | self.training, 221 | key_padding_mask, 222 | need_weights, 223 | attn_mask, 224 | use_separate_proj_weight=True, 225 | q_proj_weight=self.q_proj.weight, 226 | k_proj_weight=self.k_proj.weight, 227 | v_proj_weight=self.v_proj.weight, 228 | ) 229 | if incremental_state is not None: 230 | saved_state = self._get_input_buffer(incremental_state) 231 | if saved_state is not None and "prev_key" in saved_state: 232 | # previous time steps are cached - no need to recompute 233 | # key and value if they are static 234 | if static_kv: 235 | assert self.encoder_decoder_attention and not self.self_attention 236 | key = value = None 237 | else: 238 | saved_state = None 239 | 240 | if self.self_attention: 241 | q = self.q_proj(query) 242 | k = self.k_proj(query) 243 | v = self.v_proj(query) 244 | elif self.encoder_decoder_attention: 245 | # encoder-decoder attention 246 | q = self.q_proj(query) 247 | if key is None: 248 | assert value is None 249 | k = v = None 250 | else: 251 | k = self.k_proj(key) 252 | v = self.v_proj(key) 253 | 254 | else: 255 | assert key is not None and value is not None 256 | q = self.q_proj(query) 257 | k = self.k_proj(key) 258 | v = self.v_proj(value) 259 | q *= self.scaling 260 | 261 | if self.bias_k is not None: 262 | assert self.bias_v is not None 263 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) 264 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) 265 | if attn_mask is not None: 266 | attn_mask = torch.cat( 267 | [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 268 | ) 269 | if key_padding_mask is not None: 270 | key_padding_mask = torch.cat( 271 | [ 272 | key_padding_mask, 273 | key_padding_mask.new_zeros(key_padding_mask.size(0), 1), 274 | ], 275 | dim=1, 276 | ) 277 | 278 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 279 | if k is not None: 280 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 281 | if v is not None: 282 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 283 | 284 | if saved_state is not None: 285 | # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) 286 | if "prev_key" in saved_state: 287 | _prev_key = saved_state["prev_key"] 288 | assert _prev_key is not None 289 | prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) 290 | if static_kv: 291 | k = prev_key 292 | else: 293 | assert k is not None 294 | k = torch.cat([prev_key, k], dim=1) 295 | if "prev_value" in saved_state: 296 | _prev_value = saved_state["prev_value"] 297 | assert _prev_value is not None 298 | prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) 299 | if static_kv: 300 | v = prev_value 301 | else: 302 | assert v is not None 303 | v = torch.cat([prev_value, v], dim=1) 304 | prev_key_padding_mask: Optional[Tensor] = None 305 | if "prev_key_padding_mask" in saved_state: 306 | prev_key_padding_mask = saved_state["prev_key_padding_mask"] 307 | assert k is not None and v is not None 308 | key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( 309 | key_padding_mask=key_padding_mask, 310 | prev_key_padding_mask=prev_key_padding_mask, 311 | batch_size=bsz, 312 | src_len=k.size(1), 313 | static_kv=static_kv, 314 | ) 315 | 316 | saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) 317 | saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) 318 | saved_state["prev_key_padding_mask"] = key_padding_mask 319 | # In this branch incremental_state is never None 320 | assert incremental_state is not None 321 | incremental_state = self._set_input_buffer(incremental_state, saved_state) 322 | assert k is not None 323 | src_len = k.size(1) 324 | 325 | # This is part of a workaround to get around fork/join parallelism 326 | # not supporting Optional types. 327 | if key_padding_mask is not None and key_padding_mask.dim() == 0: 328 | key_padding_mask = None 329 | 330 | if key_padding_mask is not None: 331 | assert key_padding_mask.size(0) == bsz 332 | assert key_padding_mask.size(1) == src_len 333 | 334 | if self.add_zero_attn: 335 | assert v is not None 336 | src_len += 1 337 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) 338 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) 339 | if attn_mask is not None: 340 | attn_mask = torch.cat( 341 | [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 342 | ) 343 | if key_padding_mask is not None: 344 | key_padding_mask = torch.cat( 345 | [ 346 | key_padding_mask, 347 | torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask), 348 | ], 349 | dim=1, 350 | ) 351 | 352 | if self.rot_emb: 353 | q, k = self.rot_emb(q, k) 354 | 355 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 356 | attn_weights = MultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) 357 | 358 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 359 | 360 | if attn_mask is not None: 361 | attn_mask = attn_mask.unsqueeze(0) 362 | if self.onnx_trace: 363 | attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) 364 | attn_weights += attn_mask 365 | 366 | if key_padding_mask is not None: 367 | # don't attend to padding symbols 368 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 369 | attn_weights = attn_weights.masked_fill( 370 | key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") 371 | ) 372 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 373 | 374 | if before_softmax: 375 | return attn_weights, v 376 | 377 | attn_weights_float = utils_softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace) 378 | attn_weights = attn_weights_float.type_as(attn_weights) 379 | attn_probs = F.dropout( 380 | attn_weights_float.type_as(attn_weights), 381 | p=self.dropout, 382 | training=self.training, 383 | ) 384 | assert v is not None 385 | attn = torch.bmm(attn_probs, v) 386 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 387 | if self.onnx_trace and attn.size(1) == 1: 388 | # when ONNX tracing a single decoder step (sequence length == 1) 389 | # the transpose is a no-op copy before view, thus unnecessary 390 | attn = attn.contiguous().view(tgt_len, bsz, embed_dim) 391 | else: 392 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 393 | attn = self.out_proj(attn) 394 | attn_weights: Optional[Tensor] = None 395 | if need_weights: 396 | attn_weights = attn_weights_float.view( 397 | bsz, self.num_heads, tgt_len, src_len 398 | ).type_as(attn).transpose(1, 0) 399 | if not need_head_weights: 400 | # average attention weights over heads 401 | attn_weights = attn_weights.mean(dim=0) 402 | 403 | return attn, attn_weights 404 | 405 | @staticmethod 406 | def _append_prev_key_padding_mask( 407 | key_padding_mask: Optional[Tensor], 408 | prev_key_padding_mask: Optional[Tensor], 409 | batch_size: int, 410 | src_len: int, 411 | static_kv: bool, 412 | ) -> Optional[Tensor]: 413 | # saved key padding masks have shape (bsz, seq_len) 414 | if prev_key_padding_mask is not None and static_kv: 415 | new_key_padding_mask = prev_key_padding_mask 416 | elif prev_key_padding_mask is not None and key_padding_mask is not None: 417 | new_key_padding_mask = torch.cat( 418 | [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 419 | ) 420 | # During incremental decoding, as the padding token enters and 421 | # leaves the frame, there will be a time when prev or current 422 | # is None 423 | elif prev_key_padding_mask is not None: 424 | filler = torch.zeros( 425 | (batch_size, src_len - prev_key_padding_mask.size(1)), 426 | device=prev_key_padding_mask.device, 427 | ) 428 | new_key_padding_mask = torch.cat( 429 | [prev_key_padding_mask.float(), filler.float()], dim=1 430 | ) 431 | elif key_padding_mask is not None: 432 | filler = torch.zeros( 433 | (batch_size, src_len - key_padding_mask.size(1)), 434 | device=key_padding_mask.device, 435 | ) 436 | new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1) 437 | else: 438 | new_key_padding_mask = prev_key_padding_mask 439 | return new_key_padding_mask 440 | 441 | @torch.jit.export 442 | def reorder_incremental_state( 443 | self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor 444 | ): 445 | """Reorder buffered internal state (for incremental generation).""" 446 | input_buffer = self._get_input_buffer(incremental_state) 447 | if input_buffer is not None: 448 | for k in input_buffer.keys(): 449 | input_buffer_k = input_buffer[k] 450 | if input_buffer_k is not None: 451 | if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size( 452 | 0 453 | ): 454 | break 455 | input_buffer[k] = input_buffer_k.index_select(0, new_order) 456 | incremental_state = self._set_input_buffer(incremental_state, input_buffer) 457 | return incremental_state 458 | 459 | def _get_input_buffer( 460 | self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] 461 | ) -> Dict[str, Optional[Tensor]]: 462 | result = self.get_incremental_state(incremental_state, "attn_state") 463 | if result is not None: 464 | return result 465 | else: 466 | empty_result: Dict[str, Optional[Tensor]] = {} 467 | return empty_result 468 | 469 | def _set_input_buffer( 470 | self, 471 | incremental_state: Dict[str, Dict[str, Optional[Tensor]]], 472 | buffer: Dict[str, Optional[Tensor]], 473 | ): 474 | return self.set_incremental_state(incremental_state, "attn_state", buffer) 475 | 476 | def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int): 477 | return attn_weights 478 | 479 | def upgrade_state_dict_named(self, state_dict, name): 480 | prefix = name + "." if name != "" else "" 481 | items_to_add = {} 482 | keys_to_remove = [] 483 | for k in state_dict.keys(): 484 | if k.endswith(prefix + "in_proj_weight"): 485 | # in_proj_weight used to be q + k + v with same dimensions 486 | dim = int(state_dict[k].shape[0] / 3) 487 | items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] 488 | items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] 489 | items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] 490 | 491 | keys_to_remove.append(k) 492 | 493 | k_bias = prefix + "in_proj_bias" 494 | if k_bias in state_dict.keys(): 495 | dim = int(state_dict[k].shape[0] / 3) 496 | items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] 497 | items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim] 498 | items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] 499 | 500 | keys_to_remove.append(prefix + "in_proj_bias") 501 | 502 | for k in keys_to_remove: 503 | del state_dict[k] 504 | 505 | for key, value in items_to_add.items(): 506 | state_dict[key] = value 507 | --------------------------------------------------------------------------------