├── output
├── clinvar
│ └── .gitkeep
├── rocklin
│ └── .gitkeep
├── mave_val
│ └── .gitkeep
├── proteingym
│ └── .gitkeep
├── scannet
│ ├── metrics
│ │ └── .gitkeep
│ └── models
│ │ ├── optimizer
│ │ └── .gitkeep
│ │ └── transformer
│ │ └── .gitkeep
└── train
│ ├── metrics
│ └── .gitkeep
│ └── models
│ ├── gvp
│ └── .gitkeep
│ ├── optimizer
│ └── .gitkeep
│ └── msa_transformer
│ └── .gitkeep
├── data
├── train
│ └── cath
│ │ ├── msa
│ │ └── .gitkeep
│ │ ├── getCATH.sh
│ │ └── CATH_get_pdbids.py
└── test
│ ├── clinvar
│ ├── msa
│ │ └── .gitkeep
│ ├── raw
│ │ └── .gitkeep
│ └── structure
│ │ ├── raw
│ │ └── .gitkeep
│ │ └── cleaned
│ │ └── .gitkeep
│ ├── mave_val
│ ├── msa
│ │ └── .gitkeep
│ └── structure
│ │ └── cleaned
│ │ └── .gitkeep
│ ├── proteingym
│ ├── exp
│ │ └── .gitkeep
│ ├── msa
│ │ └── .gitkeep
│ ├── raw
│ │ └── .gitkeep
│ ├── structure
│ │ ├── raw
│ │ │ └── .gitkeep
│ │ └── cleaned
│ │ │ └── .gitkeep
│ ├── acknowledgements.txt
│ └── all_models_substitutions_Spearman_Uniprot_level.csv
│ ├── rocklin
│ ├── exp
│ │ └── .gitkeep
│ ├── msa
│ │ └── .gitkeep
│ ├── raw
│ │ └── .gitkeep
│ └── structure
│ │ ├── raw
│ │ └── .gitkeep
│ │ └── cleaned
│ │ └── .gitkeep
│ └── scannet
│ ├── labels
│ └── .gitkeep
│ ├── msa
│ └── .gitkeep
│ ├── raw
│ └── .gitkeep
│ └── structure
│ ├── raw
│ └── .gitkeep
│ └── cleaned
│ └── .gitkeep
├── src
├── models
│ ├── msa_transformer
│ │ ├── version.py
│ │ ├── constants.py
│ │ ├── acknowledgements.txt
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── axial_attention.py
│ │ ├── data.py
│ │ ├── modules.py
│ │ └── multihead_attention.py
│ ├── gvp
│ │ ├── acknowledgements.txt
│ │ ├── models.py
│ │ ├── data.py
│ │ └── __init__.py
│ └── scannet
│ │ └── model.py
├── pdb_parser_scripts
│ ├── clean_pdbs.sh
│ ├── parse_pdbs.py
│ └── clean_pdb.py
├── merge_and_sort_msas.py
├── run_test_proteingym.py
├── run_test_mave.py
├── run_test_rocklin.py
├── run_test_clinvar.py
├── run_pipeline.py
└── run_pipeline_scannet.py
├── LICENSE
└── README.md
/output/clinvar/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/output/rocklin/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/train/cath/msa/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/output/mave_val/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/output/proteingym/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/clinvar/msa/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/clinvar/raw/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/mave_val/msa/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/proteingym/exp/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/proteingym/msa/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/proteingym/raw/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/rocklin/exp/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/rocklin/msa/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/rocklin/raw/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/scannet/labels/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/scannet/msa/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/scannet/raw/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/output/scannet/metrics/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/output/train/metrics/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/output/train/models/gvp/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/output/train/models/optimizer/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/clinvar/structure/raw/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/proteingym/structure/raw/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/rocklin/structure/raw/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/scannet/structure/raw/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/output/scannet/models/optimizer/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/output/scannet/models/transformer/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/clinvar/structure/cleaned/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/mave_val/structure/cleaned/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/proteingym/structure/cleaned/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/rocklin/structure/cleaned/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/test/scannet/structure/cleaned/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/output/train/models/msa_transformer/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/train/cath/getCATH.sh:
--------------------------------------------------------------------------------
1 | wget http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/chain_set.jsonl
2 | wget http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/chain_set_splits.json
3 |
--------------------------------------------------------------------------------
/src/models/msa_transformer/version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | version = "2.0.1"
7 |
--------------------------------------------------------------------------------
/src/models/msa_transformer/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | # fmt: off
7 | proteinseq_toks = {
8 | 'toks': ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-']
9 | }
10 | # fmt: on
11 |
--------------------------------------------------------------------------------
/src/models/msa_transformer/acknowledgements.txt:
--------------------------------------------------------------------------------
1 | Large parts of the code presented in this directory has been taken from:
2 | https://github.com/facebookresearch/esm/tree/main/esm
3 |
4 | @article{rao2021msa,
5 | author = {Rao, Roshan and Liu, Jason and Verkuil, Robert and Meier, Joshua and Canny, John F. and Abbeel, Pieter and Sercu, Tom and Rives, Alexander},
6 | title={MSA Transformer},
7 | year={2021},
8 | doi={10.1101/2021.02.12.430858},
9 | url={https://www.biorxiv.org/content/10.1101/2021.02.12.430858v1},
10 | journal={bioRxiv}
11 | }
12 |
--------------------------------------------------------------------------------
/data/test/proteingym/acknowledgements.txt:
--------------------------------------------------------------------------------
1 | The ProteinGym data files have been taken from:
2 | https://github.com/OATML-Markslab/ProteinGym
3 |
4 | @article{notin2023a,
5 | title={Proteingym: Large-scale benchmarks for protein design and fitness prediction},
6 | author={Notin, Pascal and Kollasch, Aaron W and Ritter, Daniel and van Niekerk, Lood and Paul, Steffanie and Spinner, Hansen and Rollins, Nathan and Shaw, Ada and Weitzman, Ruben and Frazer, Jonathan and others},
7 | journal={bioRxiv},
8 | pages={2023--12},
9 | year={2023},
10 | publisher={Cold Spring Harbor Laboratory}
11 | }
12 |
--------------------------------------------------------------------------------
/src/pdb_parser_scripts/clean_pdbs.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Settings
4 | counter=1
5 | dir=$(pwd)/pdb_parser_scripts/
6 | pdb_dir=$1
7 | pdbs=$pdb_dir/raw/*.pdb
8 | n_pdbs=$(echo $pdbs | wc -w)
9 |
10 | # Create data directories
11 | mkdir -p $pdb_dir/cleaned
12 |
13 | # Clean pdbs
14 | for pdb in $pdbs;
15 | do
16 | python $dir/clean_pdb.py --pdb_file_in $pdb \
17 | --out_dir $pdb_dir/cleaned/ \
18 | #&> /dev/null
19 |
20 | # Check for exit code 0 and skip file if not 0.
21 | if [ $? -eq 0 ]
22 | then
23 | echo "Successfully cleaned $pdb. $counter/$n_pdbs."
24 | else
25 | echo "Error when cleaning $pdb. Skipping.." >&2
26 | fi
27 | counter=$((counter+1))
28 | done
29 |
--------------------------------------------------------------------------------
/data/train/cath/CATH_get_pdbids.py:
--------------------------------------------------------------------------------
1 | import json
# Opening JSON file
f = open('chain_set_splits.json',)
data = json.load(f)
# Split and process
#pdbids_train = [x[:-2].upper() for x in data["train"]]
pdbids_train = [x[:-2].upper() for x in data["train"]]
pdbids_val = [x[:-2].upper() for x in data["validation"]]
pdbids_test = [x[:-2].upper() for x in data["test"]]
# Write
# tf = open("CATH_pdbids_train.txt", "w")
# for element in pdbids_train:
# tf. write(element + "\n")
# tf.close()
tf = open("CATH_pdbids_train.txt", "w")
for element in pdbids_train:
tf. write(element + "\n")
tf.close()
tf = open("CATH_pdbids_val.txt", "w")
for element in pdbids_val:
tf. write(element + "\n")
tf.close()
tf = open("CATH_pdbids_test.txt", "w")
for element in pdbids_test:
tf. write(element + "\n")
tf.close()
--------------------------------------------------------------------------------
/src/models/gvp/acknowledgements.txt:
--------------------------------------------------------------------------------
1 | Large parts of the code presented in this directory has been taken from:
2 | https://github.com/drorlab/gvp-pytorch/tree/main
3 |
4 | Citations:
5 | @inproceedings{
6 | jing2021learning,
7 | title={Learning from Protein Structure with Geometric Vector Perceptrons},
8 | author={Bowen Jing and Stephan Eismann and Patricia Suriana and Raphael John Lamarre Townshend and Ron Dror},
9 | booktitle={International Conference on Learning Representations},
10 | year={2021},
11 | url={https://openreview.net/forum?id=1YLJDvSx6J4}
12 | }
13 |
14 | @article{jing2021equivariant,
15 | title={Equivariant Graph Neural Networks for 3D Macromolecular Structure},
16 | author={Jing, Bowen and Eismann, Stephan and Soni, Pratham N and Dror, Ron O},
17 | journal={arXiv preprint arXiv:2106.03843},
18 | year={2021}
19 | }
20 |
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Linderstrøm-Lang Centre for Protein Science, University of Copenhagen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/models/scannet/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from tempfile import TemporaryDirectory
4 | from typing import Tuple
5 |
6 | import torch
7 | from torch import nn, Tensor
8 | from torch.nn import TransformerEncoder, TransformerEncoderLayer
9 | from torch.utils.data import dataset
10 |
11 | class TransformerModel(nn.Module):
12 | def __init__(self, ntoken: int, nhead: int, d_hid: int,
13 | nlayers: int, dropout: float = 0.5):
14 | super().__init__()
15 | self.model_type = 'Transformer'
16 | self.pos_encoder = PositionalEncoding(d_hid, dropout)
17 | encoder_layers = TransformerEncoderLayer(d_hid, nhead, d_hid, dropout)
18 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
19 | self.linear = nn.Linear(d_hid, 1)
20 |
21 | self.init_weights()
22 |
23 | def init_weights(self) -> None:
24 | initrange = 0.1
25 | self.linear.bias.data.zero_()
26 | self.linear.weight.data.uniform_(-initrange, initrange)
27 |
28 | def forward(self, src: Tensor, src_mask: Tensor = None) -> Tensor:
29 | src = self.pos_encoder(src)
30 | output = self.transformer_encoder(src, src_mask)
31 | output = self.linear(output)
32 | return output
33 |
34 | class PositionalEncoding(nn.Module):
35 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1024):
36 | super().__init__()
37 | self.dropout = nn.Dropout(p=dropout)
38 |
39 | position = torch.arange(max_len).unsqueeze(1)
40 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
41 | pe = torch.zeros(max_len, 1, d_model)
42 | pe[:, 0, 0::2] = torch.sin(position * div_term)
43 | pe[:, 0, 1::2] = torch.cos(position * div_term)
44 | self.register_buffer('pe', pe)
45 |
46 | def forward(self, x: Tensor) -> Tensor:
47 | x = x + self.pe[:x.size(0)]
48 | return self.dropout(x)
49 |
50 |
--------------------------------------------------------------------------------
/src/merge_and_sort_msas.py:
--------------------------------------------------------------------------------
1 | import string
2 | from Bio import SeqIO
3 | from typing import List, Tuple
4 | import numpy as np
5 | import glob
6 | import sys
7 | import subprocess
8 |
9 | # Initialize
10 | deletekeys = dict.fromkeys(string.ascii_lowercase)
11 | deletekeys["."] = None
12 | deletekeys["*"] = None
13 | translation = str.maketrans(deletekeys)
14 |
15 |
16 | def remove_insertions(sequence: str):
17 | """Removes any insertions into the sequence. Needed to load aligned sequences in an MSA."""
18 | return sequence.translate(translation)
19 |
20 |
21 | def read_msa(filename: str) -> List[Tuple[str, str]]:
22 | """Reads the first nseq sequences from an MSA file, automatically removes insertions."""
23 | return [
24 | (record.description, remove_insertions(str(record.seq)))
25 | for record in SeqIO.parse(filename, "fasta")
26 | ]
27 |
28 |
29 | def hamming_distance(string1, string2):
30 | return sum(c1 != c2 for c1, c2 in zip(string1, string2))
31 |
32 |
33 | # Initialize
34 | msa_dir = sys.argv[1]
35 | subprocess.run(["mkdir", "-p", f"{msa_dir}_tmp"])
36 |
37 | # Load MSA files
38 | msa_files = sorted(glob.glob(f"{msa_dir}/*.a3m"))
39 |
40 | # Loop
41 | for i, _file in enumerate(msa_files):
42 | print(f"Processing MSA: {i+1}/{len(msa_files)}")
43 |
44 | msa = read_msa(_file)
45 |
46 | seqs = [x for x in msa]
47 |
48 | query = seqs[0]
49 | seqs = seqs[1:]
50 | ham_dists = np.zeros(len(seqs))
51 |
52 | for j, seq in enumerate(seqs):
53 | assert len(query) == len(seq)
54 | ham_dists[j] = hamming_distance(query[1], seq[1])
55 |
56 | # Rank indices
57 | rank_indices = np.argsort(ham_dists)
58 |
59 | # Remove query duplicates
60 | if 0 in ham_dists:
61 | query_idx = np.argwhere(ham_dists == 0)[0]
62 | rank_indices = np.delete(
63 | rank_indices, np.argwhere(np.isin(rank_indices, query_idx))
64 | )
65 |
66 | # Construct new sorted MSA
67 | seqs_new = []
68 | for idx in rank_indices:
69 | seqs_new.append(seqs[idx])
70 |
71 | # Write to new file
72 | outfile = open(f"{msa_dir}_tmp/{query[0]}.a3m", "w")
73 | outfile.write(f">{query[0]}\n")
74 | outfile.write(f"{query[1]}\n")
75 |
76 | for seq in seqs_new:
77 | outfile.write(f">{seq[0]}\n")
78 | outfile.write(f"{seq[1]}\n")
79 | outfile.close()
80 |
81 | # Delete tmp directory
82 | subprocess.run(["rm", "-r", f"{msa_dir}"])
83 | subprocess.run(["mv", f"{msa_dir}_tmp", msa_dir])
84 |
--------------------------------------------------------------------------------
/src/pdb_parser_scripts/parse_pdbs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import enum
3 | import os
4 | import sys
5 | import json
6 | import glob
7 | import Bio
8 | import Bio.PDB
9 | from Bio import SeqIO
10 | from Bio.SeqRecord import SeqRecord
11 | import pickle
12 |
13 | def parse(pdb_dir):
14 | # Load PDBS
15 | pdb_filenames = sorted(glob.glob(f"{pdb_dir}/cleaned/*.pdb"))
16 |
17 | # Create fasta file
18 | fh = open(f"{pdb_dir}/seqs.fasta","w")
19 |
20 | # Initialize list of pdb dicts
21 | pdb_dict_list = []
22 |
23 | # Loop over proteins
24 | for pdb_filename in pdb_filenames:
25 |
26 | # Parse structure with Biopython
27 | pdb_parser = Bio.PDB.PDBParser()
28 | pdb_id = os.path.basename(pdb_filename).split("/")[-1][:-4]
29 | structure = pdb_parser.get_structure(pdb_id, pdb_filename)
30 | first_model = structure.get_list()[0]
31 | first_model.child_list = sorted(first_model.child_list) # Sort chains alphabetically
32 |
33 | # Iterate over chain,residue,atoms and extract features
34 | for chain in first_model: # Loop over chains even though there is only 1
35 |
36 | # Initialize
37 | chain_id = chain.id
38 | seq = []
39 | coords = []
40 | pdb_dict = {}
41 |
42 | for j, residue in enumerate(chain):
43 | atom_names = []
44 | backbone_coords = []
45 |
46 | for atom in residue:
47 | # Extract atom features
48 | if atom.name in ["N","CA","C","O"]:
49 | atom_names.append(atom.name)
50 | #backbone_coords.append(list(atom.coord))
51 | backbone_coords.append([str(x) for x in atom.coord])
52 |
53 | # Check that all backbone atoms are present
54 | if atom_names == ["N","CA","C","O"] and len(backbone_coords)==4 and residue._id[0].startswith("H_") == False: # HETATM check
55 |
56 | # Add coordinates
57 | coords.append(backbone_coords)
58 |
59 | # Add residue to sequence
60 | seq.append(Bio.PDB.Polypeptide.three_to_one(residue.resname))
61 |
62 | # Save coords+seq to dict
63 | pdb_dict["name"] = pdb_id
64 | pdb_dict["coords"] = coords
65 | pdb_dict["seq"] = "".join(seq)
66 | pdb_dict_list.append(pdb_dict)
67 |
68 | # Output seq to fasta
69 | fh.write(f">{pdb_dict['name']}\n")
70 | fh.write("".join(seq))
71 | fh.write("\n")
72 | fh.close()
73 |
74 | # Save total coord dict
75 | with open(f'{pdb_dir}/coords.json', 'w') as fp:
76 | json.dump(pdb_dict_list, fp)
77 | fp.close()
78 |
79 | if __name__ == '__main__':
80 | parse_pdbs()
81 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A joint embedding of protein sequence and structure enables robust variant effect predictions
2 |
3 | ## Introduction
4 | This repository contains scripts and data to repeat the analyses in Blaabjerg et al.:
5 | [*"A joint embedding of protein sequence and structure enables robust variant effect predictions"*](https://www.biorxiv.org/content/10.1101/2023.12.14.571755v1).
6 |
7 | ## Execution
8 | Execute the pipeline using `src/run_pipeline.py`.
9 | This main script will call other scripts in the `src` directory to train, validate and test the SSEmb model as described in the paper.
10 |
11 | ## Requirements
12 | The code has been developed and tested in a Unix environment using the following packages:
13 | * `python==3.7.16`
14 | * `pytorch==1.13.1`
15 | * `pyg==2.2.0`
16 | * `pytorch-scatter==2.1.0`
17 | * `pytorch-cluster==1.6.0`
18 | * `fair-esm==2.0.0`
19 | * `numpy==1.21.6`
20 | * `pandas==1.3.5`
21 | * `biopython==1.79`
22 | * `openmm==7.6.0`
23 | * `pdbfixer==1.8.1`
24 | * `scipy==1.7.3`
25 | * `scikit-learn==1.0.2`
26 | * `tqdm==4.64.1`
27 | * `pytz==2022.7`
28 | * `matplotlib==3.2.2`
29 | * `mpl-scatter-density==0.7`
30 |
31 | ## Downloads
32 | Data related to the paper can be download here: [https://zenodo.org/records/12798019](https://zenodo.org/records/12798019).
33 | The `data` directory contains the folding subdirectories:
34 | * `train`
35 | * `model_weights`: Final weights for the SSEmb-MSATransformer and SSEmb-GVPGNN modules.
36 | * `optimizer_weights`: Parameters for the optimizer at time of early-stopping.
37 | * `msa`: MSAs for the proteins in the training set.
38 | * `mave_val`:
39 | * `msa`: MSAs for the proteins in the MAVE validation set.
40 | * `rocklin`:
41 | * `msa`: MSAs for the proteins in the mega-scale stability change test set.
42 | * `proteingym`:
43 | * `structure`: AlphaFold-2 generated structures used for the ProteinGym test set.
44 | * `msa`: MSAs for the proteins in the ProteinGym test set.
45 | * `scannet`:
46 | * `model_weights`: Final weights for the SSEmb downstream model trained on the ScanNet data set.
47 | * `optimizer_weights`: Parameters for the optimizer at time of early-stopping.
48 | * `msa`: MSAs for the proteins in the ScanNet data set.
49 | * `clinvar`:
50 | * `structure`: AlphaFold-2 generated structures used for the ClinVar test set.
51 | * `msa`: MSAs for the proteins in the ClinVar test set.
52 |
53 | A copy of this repository can be found on Zenodo here: [https://zenodo.org/doi/10.5281/zenodo.13765792](https://zenodo.org/doi/10.5281/zenodo.13765792).
54 |
55 | ## SSEmbLab webserver
56 | We have created an online Colab-based webserver for making SSEmb predictions called SSEmbLab. The webserver can be accessed [here](https://colab.research.google.com/github/KULL-Centre/_2023_Blaabjerg_SSEmb/blob/main/SSEmbLab.ipynb).
57 |
58 | ## License
59 | Source code and model weights are licensed under the MIT License.
60 |
61 | ## Acknowledgements
62 | We thank Milot Mirdita and the rest of the ColabFold Search team for help in setting up the Colab SSEmb webserver.
63 |
64 | Code for the original MSA Transformer was developed by the ESM team at Meta Research:
65 | [https://github.com/facebookresearch/esm](https://github.com/facebookresearch/esm).
66 |
67 | Code for the original GVP-GNN was developed by Jing et al:
68 | [https://github.com/drorlab/gvp-pytorch](https://github.com/drorlab/gvp-pytorch).
69 |
70 | ## Citation
71 | Please cite:
72 |
73 | *Lasse M. Blaabjerg, Nicolas Jonsson, Wouter Boomsma, Amelie Stein, Kresten Lindorff-Larsen (2023). A joint embedding of protein sequence and structure enables robust variant effect predictions. bioRxiv, 2023.12.*
74 |
75 | ```
76 | @article {Blaabjerg2023.12.14.571755,
77 | author = {Lasse M. Blaabjerg and Nicolas Jonsson and Wouter Boomsma and Amelie Stein and Kresten Lindorff-Larsen},
78 | title = {A joint embedding of protein sequence and structure enables robust variant effect predictions},
79 | elocation-id = {2023.12.14.571755},
80 | year = {2023},
81 | doi = {10.1101/2023.12.14.571755},
82 | URL = {https://www.biorxiv.org/content/early/2023/12/16/2023.12.14.571755},
83 | eprint = {https://www.biorxiv.org/content/early/2023/12/16/2023.12.14.571755.full.pdf},
84 | journal = {bioRxiv}
85 | }
86 | ```
87 |
--------------------------------------------------------------------------------
/src/models/msa_transformer/__init__.py:
--------------------------------------------------------------------------------
1 | import torch, functools
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch_geometric.nn import MessagePassing
5 | from torch_scatter import scatter_add
6 |
7 | def tuple_sum(*args):
8 | '''
9 | Sums any number of tuples (s, V) elementwise.
10 | '''
11 | return tuple(map(sum, zip(*args)))
12 |
13 | def tuple_cat(*args, dim=-1):
14 | '''
15 | Concatenates any number of tuples (s, V) elementwise.
16 |
17 | :param dim: dimension along which to concatenate when viewed
18 | as the `dim` index for the scalar-channel tensors.
19 | This means that `dim=-1` will be applied as
20 | `dim=-2` for the vector-channel tensors.
21 | '''
22 | dim %= len(args[0][0].shape)
23 | s_args, v_args = list(zip(*args))
24 | return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim)
25 |
26 | def tuple_index(x, idx):
27 | '''
28 | Indexes into a tuple (s, V) along the first dimension.
29 |
30 | :param idx: any object which can be used to index into a `torch.Tensor`
31 | '''
32 | return x[0][idx], x[1][idx]
33 |
34 | def randn(n, dims, device="cpu"):
35 | '''
36 | Returns random tuples (s, V) drawn elementwise from a normal distribution.
37 |
38 | :param n: number of data points
39 | :param dims: tuple of dimensions (n_scalar, n_vector)
40 |
41 | :return: (s, V) with s.shape = (n, n_scalar) and
42 | V.shape = (n, n_vector, 3)
43 | '''
44 | return torch.randn(n, dims[0], device=device), \
45 | torch.randn(n, dims[1], 3, device=device)
46 |
47 | def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True):
48 | '''
49 | L2 norm of tensor clamped above a minimum value `eps`.
50 |
51 | :param sqrt: if `False`, returns the square of the L2 norm
52 | '''
53 | out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
54 | return torch.sqrt(out) if sqrt else out
55 |
56 | def _split(x, nv):
57 | '''
58 | Splits a merged representation of (s, V) back into a tuple.
59 | Should be used only with `_merge(s, V)` and only if the tuple
60 | representation cannot be used.
61 |
62 | :param x: the `torch.Tensor` returned from `_merge`
63 | :param nv: the number of vector channels in the input to `_merge`
64 | '''
65 | v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3))
66 | s = x[..., :-3*nv]
67 | return s, v
68 |
69 | def _merge(s, v):
70 | '''
71 | Merges a tuple (s, V) into a single `torch.Tensor`, where the
72 | vector channels are flattened and appended to the scalar channels.
73 | Should be used only if the tuple representation cannot be used.
74 | Use `_split(x, nv)` to reverse.
75 | '''
76 | v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],))
77 | return torch.cat([s, v], -1)
78 |
79 | class Dropout(nn.Module):
80 | '''
81 | Combined dropout for tuples (s, V).
82 | Takes tuples (s, V) as input and as output.
83 | '''
84 | def __init__(self, drop_rate):
85 | super(Dropout, self).__init__()
86 | self.sdropout = nn.Dropout(drop_rate)
87 | self.vdropout = _VDropout(drop_rate)
88 |
89 | def forward(self, x):
90 | '''
91 | :param x: tuple (s, V) of `torch.Tensor`,
92 | or single `torch.Tensor`
93 | (will be assumed to be scalar channels)
94 | '''
95 | if type(x) is torch.Tensor:
96 | return self.sdropout(x)
97 | s, v = x
98 | return self.sdropout(s), self.vdropout(v)
99 |
100 | class LayerNorm(nn.Module):
101 | '''
102 | Combined LayerNorm for tuples (s, V).
103 | Takes tuples (s, V) as input and as output.
104 | '''
105 | def __init__(self, dims):
106 | super(LayerNorm, self).__init__()
107 | self.s, self.v = dims
108 | self.scalar_norm = nn.LayerNorm(self.s)
109 |
110 | def forward(self, x):
111 | '''
112 | :param x: tuple (s, V) of `torch.Tensor`,
113 | or single `torch.Tensor`
114 | (will be assumed to be scalar channels)
115 | '''
116 | if not self.v:
117 | return self.scalar_norm(x)
118 | s, v = x
119 | vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False)
120 | vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True))
121 | return self.scalar_norm(s), v / vn
122 |
--------------------------------------------------------------------------------
/src/models/gvp/models.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from . import GVP, GVPConvLayer, LayerNorm, tuple_index
6 | from torch.distributions import Categorical
7 | from torch_scatter import scatter_mean
8 | import torch.utils.checkpoint as checkpoint
9 | import sys
10 | from torch_geometric.nn import aggr
11 |
12 | class SSEmbGNN(torch.nn.Module):
13 | '''
14 | GVP-GNN for structure-conditioned autoregressive
15 | protein design as described in manuscript.
16 |
17 | Takes in protein structure graphs of type `torch_geometric.data.Data`
18 | or `torch_geometric.data.Batch` and returns a categorical distribution
19 | over 20 amino acids at each position in a `torch.Tensor` of
20 | shape [n_nodes, 20].
21 |
22 | Should be used with `gvp.data.ProteinGraphDataset`, or with generators
23 | of `torch_geometric.data.Batch` objects with the same attributes.
24 |
25 | The standard forward pass requires sequence information as input
26 | and should be used for training or evaluating likelihood.
27 | For sampling or design, use `self.sample`.
28 |
29 | :param node_in_dim: node dimensions in input graph, should be
30 | (6, 3) if using original features
31 | :param node_h_dim: node dimensions to use in GVP-GNN layers
32 | :param edge_in_dim: edge dimensions in input graph, should be
33 | (32, 1) if using original features
34 | :param edge_h_dim: edge dimensions to embed to before use
35 | in GVP-GNN layers
36 | :param num_layers: number of GVP-GNN layers in each of the encoder
37 | and decoder modules
38 | :param drop_rate: rate to use in all dropout layers
39 | '''
40 | def __init__(self, node_in_dim, node_h_dim,
41 | edge_in_dim, edge_h_dim,
42 | num_layers=4, drop_rate=0.0, vector_gate=True):
43 |
44 | super(SSEmbGNN, self).__init__()
45 |
46 | # Get correct dimensions
47 | self.W_v = nn.Sequential(
48 | GVP(node_in_dim, node_h_dim, activations=(None, None), vector_gate=vector_gate),
49 | LayerNorm(node_h_dim)
50 | )
51 | self.W_e = nn.Sequential(
52 | GVP(edge_in_dim, edge_h_dim, activations=(None, None), vector_gate=vector_gate),
53 | LayerNorm(edge_h_dim)
54 | )
55 |
56 | # Encode
57 | self.encoder_layers = nn.ModuleList(
58 | GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate, vector_gate=vector_gate)
59 | for _ in range(num_layers))
60 |
61 | self.W_S = nn.Embedding(21, 21)
62 |
63 | self.W_M = nn.Sequential(
64 | nn.Linear(768, node_h_dim[0]),
65 | )
66 |
67 | self.W_decoder_in = nn.Sequential(
68 | nn.Linear(node_h_dim[0]*2, node_h_dim[0]*2),
69 | nn.LayerNorm(node_h_dim[0]*2),
70 | nn.ReLU(),
71 | nn.Dropout(drop_rate),
72 | nn.Linear(node_h_dim[0]*2, node_h_dim[0]),
73 | nn.LayerNorm(node_h_dim[0]),
74 | )
75 |
76 | # Decode
77 | node_h_dim = (node_h_dim[0], node_h_dim[1])
78 | edge_h_dim = (edge_h_dim[0] + 21, edge_h_dim[1])
79 |
80 | self.decoder_layers = nn.ModuleList(
81 | GVPConvLayer(node_h_dim, edge_h_dim,
82 | drop_rate=drop_rate, vector_gate=vector_gate)
83 | for _ in range(num_layers))
84 |
85 | # Out
86 | self.W_out = GVP(node_h_dim, (20, 0), activations=(None, None), vector_gate=vector_gate)
87 |
88 | def forward(self, h_V, edge_index, h_E, msa_emb, seq, get_emb=False):
89 | '''
90 | Forward pass to be used at train-time, or evaluating likelihood.
91 |
92 | :param h_V: tuple (s, V) of node embeddings
93 | :param edge_index: `torch.Tensor` of shape [2, num_edges]
94 | :param h_E: tuple (s, V) of edge embeddings
95 | :param seq: int `torch.Tensor` of shape [num_nodes]
96 | '''
97 | # Run through GVP to get correct hidden dimensions
98 | h_V = self.W_v(h_V)
99 | h_E = self.W_e(h_E)
100 |
101 | # Message passing
102 | # Encoding
103 | for layer in self.encoder_layers:
104 | h_V = layer(h_V, edge_index, h_E)
105 |
106 | # Add sequence info
107 | h_S = self.W_S(seq)
108 | h_S = h_S[edge_index[0]]
109 | h_E = (torch.cat([h_E[0], h_S], dim=-1), h_E[1])
110 |
111 | h_M = self.W_M(msa_emb)
112 | h_V = (self.W_decoder_in(torch.cat([h_V[0], h_M], dim=-1)), h_V[1])
113 |
114 | # Decoding
115 | for layer in self.decoder_layers:
116 | h_V = layer(h_V, edge_index, h_E)
117 |
118 | # Out
119 | if get_emb == True:
120 | return h_V[0]
121 | else:
122 | logits = self.W_out(h_V)
123 | return logits
124 |
--------------------------------------------------------------------------------
/src/run_test_proteingym.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import subprocess
3 | import glob
4 | import re
5 | import torch
6 | import models.gvp.data, models.gvp.models
7 | import json
8 | import torch_geometric
9 | import esm
10 | import pandas as pd
11 | import random
12 | import torch.multiprocessing
13 |
14 | torch.multiprocessing.set_sharing_strategy("file_system")
15 | from models.msa_transformer.model import MSATransformer
16 | from models.gvp.models import SSEmbGNN
17 | from helpers import (
18 | read_msa,
19 | loop_pred,
20 | )
21 | from visualization import plot_proteingym
22 | import pdb_parser_scripts.parse_pdbs as parse_pdbs
23 | import torch.utils.data
24 | from collections import OrderedDict
25 | from ast import literal_eval
26 | import subprocess
27 | import pickle
28 |
29 | def test(run_name, epoch, msa_row_attn_mask=True, device=None):
30 | # Load data and dict of variant positions
31 | with open(f"../data/test/proteingym/data_with_msas.pkl", "rb") as fp:
32 | data = pickle.load(fp)
33 |
34 | with open(f"../data/test/proteingym/variant_pos_dict.pkl", "rb") as fp:
35 | variant_pos_dict = pickle.load(fp)
36 |
37 | # Load DMS data
38 | df_dms = pd.read_csv("../data/test/proteingym/exp/dms.csv")
39 |
40 | # Convert to graph data sets
41 | testset = models.gvp.data.ProteinGraphData(data)
42 | letter_to_num = testset.letter_to_num
43 |
44 | # Load MSA Transformer
45 | _, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
46 | model_msa = MSATransformer(msa_alphabet)
47 | model_msa = model_msa.to(device)
48 | msa_batch_converter = msa_alphabet.get_batch_converter()
49 |
50 | model_dict = OrderedDict()
51 | state_dict_msa = torch.load(
52 | f"../output/train/models/msa_transformer/{run_name}_msa_transformer_{epoch}.pt"
53 | )
54 | pattern = re.compile("module.")
55 | for k, v in state_dict_msa.items():
56 | if re.search("module", k):
57 | model_dict[re.sub(pattern, "", k)] = v
58 | else:
59 | model_dict = state_dict_msa
60 | model_msa.load_state_dict(model_dict)
61 |
62 | # Load GVP
63 | node_dim = (256, 64)
64 | edge_dim = (32, 1)
65 | model_gvp = SSEmbGNN((6, 3), node_dim, (32, 1), edge_dim)
66 | model_gvp = model_gvp.to(device)
67 |
68 | model_dict = OrderedDict()
69 | state_dict_gvp = torch.load(f"../output/train/models/gvp/{run_name}_gvp_{epoch}.pt")
70 | pattern = re.compile("module.")
71 | for k, v in state_dict_gvp.items():
72 | if k.startswith("module"):
73 | model_dict[k[7:]] = v
74 | else:
75 | model_dict = state_dict_gvp
76 | model_gvp.load_state_dict(model_dict)
77 |
78 | # Initialize data loader
79 | test_loader = torch_geometric.loader.DataLoader(
80 | testset, batch_size=1, shuffle=False
81 | )
82 |
83 | # Call test
84 | model_msa.eval()
85 | model_gvp.eval()
86 |
87 | with torch.no_grad():
88 | pred_list, acc_mean = loop_pred(
89 | model_msa,
90 | model_gvp,
91 | msa_batch_converter,
92 | test_loader,
93 | variant_pos_dict,
94 | data,
95 | letter_to_num,
96 | msa_row_attn_mask=msa_row_attn_mask,
97 | device=device,
98 | )
99 |
100 | # Transform results into df
101 | df_ml = pd.DataFrame(pred_list, columns=["dms_id", "variant_pos", "score_ml_pos"])
102 |
103 | # Save
104 | df_ml.to_csv(f"../output/proteingym/df_ml_{run_name}_{epoch}.csv", index=False)
105 |
106 | # Load
107 | df_ml = pd.read_csv(
108 | f"../output/proteingym/df_ml_{run_name}_{epoch}.csv",
109 | converters=dict(score_ml_pos=literal_eval),
110 | )
111 |
112 | # Compute score_ml from nlls
113 | dms_variant_list = df_dms.values.tolist()
114 | for i, row in enumerate(dms_variant_list):
115 | dms_id = row[0]
116 | print(
117 | f"Computing score for assay {dms_id} variant: {i+1}/{len(dms_variant_list)}"
118 | )
119 | variant_set = row[1].split(":")
120 | score_ml = 0.0
121 |
122 | for variant in variant_set:
123 | wt = letter_to_num[variant[0]]
124 | pos = int(re.findall(r"\d+", variant)[0])
125 | mt = letter_to_num[variant[-1]]
126 | score_ml_pos = df_ml[
127 | (df_ml["dms_id"] == dms_id) & (df_ml["variant_pos"] == pos)
128 | ]["score_ml_pos"].values[0]
129 | score_ml += float(score_ml_pos[mt])
130 | dms_variant_list[i].append(score_ml)
131 | df_total = pd.DataFrame(
132 | dms_variant_list, columns=["dms_id", "variant_set", "score_dms", "score_ml"]
133 | )
134 |
135 | # Save
136 | df_total.to_csv(f"../output/proteingym/df_total_{run_name}_{epoch}.csv", index=False)
137 |
138 | # Load
139 | df_total = pd.read_csv(f"../output/proteingym/df_total_{run_name}_{epoch}.csv")
140 |
141 | # Compute correlations
142 | plot_proteingym(df_total, run_name, epoch)
143 |
--------------------------------------------------------------------------------
/src/run_test_mave.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import subprocess
3 | import glob
4 | import re
5 | import torch
6 | import models.gvp.data, models.gvp.models
7 | import json
8 | import torch_geometric
9 | import esm
10 | import pandas as pd
11 | import random
12 | import torch.multiprocessing
13 | import pickle
14 |
15 | torch.multiprocessing.set_sharing_strategy("file_system")
16 | from models.msa_transformer.model import MSATransformer
17 | from models.gvp.models import SSEmbGNN
18 | from statistics import mean
19 | from helpers import (
20 | read_msa,
21 | mave_val_pdb_to_prot,
22 | loop_pred,
23 | save_df_to_prism,
24 | get_prism_corr,
25 | get_prism_corr_all,
26 | )
27 | import pdb_parser_scripts.parse_pdbs as parse_pdbs
28 | import torch.utils.data
29 | from collections import OrderedDict
30 | from ast import literal_eval
31 | import subprocess
32 | import shutil
33 |
34 | def test(run_name, epoch, msa_row_attn_mask=True, get_only_ssemb_metrics=True, device=None):
35 | # Load data and dict of variant positions
36 | with open(f"../data/test/mave_val/data_with_msas.pkl", "rb") as fp:
37 | data = pickle.load(fp)
38 |
39 | with open(f"../data/test/mave_val/variant_pos_dict.pkl", "rb") as fp:
40 | variant_pos_dict = pickle.load(fp)
41 |
42 | # Convert to graph data sets
43 | testset = models.gvp.data.ProteinGraphData(data)
44 | letter_to_num = testset.letter_to_num
45 |
46 | # Load MSA Transformer
47 | _, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
48 | model_msa = MSATransformer(msa_alphabet)
49 | model_msa = model_msa.to(device)
50 | msa_batch_converter = msa_alphabet.get_batch_converter()
51 |
52 | model_dict = OrderedDict()
53 | state_dict_msa = torch.load(
54 | f"../output/train/models/msa_transformer/{run_name}_msa_transformer_{epoch}.pt"
55 | )
56 | pattern = re.compile("module.")
57 | for k, v in state_dict_msa.items():
58 | if re.search("module", k):
59 | model_dict[re.sub(pattern, "", k)] = v
60 | else:
61 | model_dict = state_dict_msa
62 | model_msa.load_state_dict(model_dict)
63 |
64 | # Load GVP
65 | node_dim = (256, 64)
66 | edge_dim = (32, 1)
67 | model_gvp = SSEmbGNN((6, 3), node_dim, (32, 1), edge_dim)
68 | model_gvp = model_gvp.to(device)
69 |
70 | model_dict = OrderedDict()
71 | state_dict_gvp = torch.load(f"../output/train/models/gvp/{run_name}_gvp_{epoch}.pt")
72 | pattern = re.compile("module.")
73 | for k, v in state_dict_gvp.items():
74 | if k.startswith("module"):
75 | model_dict[k[7:]] = v
76 | else:
77 | model_dict = state_dict_gvp
78 | model_gvp.load_state_dict(model_dict)
79 |
80 | # Initialize data loader
81 | test_loader = torch_geometric.loader.DataLoader(
82 | testset, batch_size=1, shuffle=False
83 | )
84 |
85 | # Call test
86 | model_msa.eval()
87 | model_gvp.eval()
88 |
89 | with torch.no_grad():
90 | pred_list, acc_mean = loop_pred(
91 | model_msa,
92 | model_gvp,
93 | msa_batch_converter,
94 | test_loader,
95 | variant_pos_dict,
96 | data,
97 | letter_to_num,
98 | msa_row_attn_mask=msa_row_attn_mask,
99 | device=device,
100 | )
101 |
102 | # Transform results into df
103 | df_ml = pd.DataFrame(pred_list, columns=["dms_id", "variant_pos", "score_ml_pos"])
104 |
105 | # Save
106 | df_ml.to_csv(f"../output/mave_val/df_ml_{run_name}.csv", index=False)
107 |
108 | # Load
109 | df_ml = pd.read_csv(
110 | f"../output/mave_val/df_ml_{run_name}.csv",
111 | converters=dict(score_ml_pos=literal_eval),
112 | )
113 |
114 | # Compute score_ml from nlls
115 | pred_list_scores = []
116 | mt_list = [x for x in sorted(letter_to_num, key=letter_to_num.get)][:-1]
117 |
118 | for entry in data:
119 | dms_id = entry["name"]
120 | df_dms_id = df_ml[df_ml["dms_id"] == dms_id]
121 |
122 | wt = [[wt] * 20 for wt in entry["seq"]]
123 | pos = [[pos] * 20 for pos in list(df_dms_id["variant_pos"])]
124 | pos = [item for sublist in pos for item in sublist]
125 | mt = mt_list * len(wt)
126 | wt = [item for sublist in wt for item in sublist]
127 | score_ml = [
128 | item for sublist in list(df_dms_id["score_ml_pos"]) for item in sublist
129 | ]
130 |
131 | rows = [
132 | [dms_id, wt[i] + str(pos[i]) + mt[i], score_ml[i]] for i in range(len(mt))
133 | ]
134 | pred_list_scores += rows
135 |
136 | # Transform results into df
137 | df_ml_scores = pd.DataFrame(
138 | pred_list_scores, columns=["dms_id", "variant", "score_ml"]
139 | )
140 |
141 | # Save
142 | df_ml_scores.to_csv(f"../output/mave_val/df_ml_scores_{run_name}.csv", index=False)
143 |
144 | # Load
145 | df_ml_scores = pd.read_csv(f"../output/mave_val/df_ml_scores_{run_name}.csv")
146 |
147 | # Save results to PRISM format
148 | for dms_id in df_ml_scores["dms_id"].unique():
149 | df_dms = df_ml_scores[df_ml_scores["dms_id"] == dms_id]
150 | save_df_to_prism(df_dms, run_name, dms_id)
151 |
152 | # Compute metrics
153 | if get_only_ssemb_metrics == True:
154 | corrs = []
155 | for dms_id in df_ml_scores["dms_id"].unique():
156 | corr = get_prism_corr(dms_id, run_name)
157 | corrs.append(corr)
158 | return mean(corrs), acc_mean
159 | else:
160 | corrs_ssemb = []
161 | corrs_gemme = []
162 | corrs_rosetta = []
163 | for dms_id in df_ml_scores["dms_id"].unique():
164 | corrs = get_prism_corr_all(dms_id, run_name)
165 | corrs_ssemb.append(corrs[0])
166 | corrs_gemme.append(corrs[1])
167 | corrs_rosetta.append(corrs[2])
168 | print(f"SSEmb: Mean MAVE spearman correlation: {mean(corrs_ssemb):.3f}")
169 | print(f"GEMME: Mean MAVE spearman correlation: {mean(corrs_gemme):.3f}")
170 | print(f"Rosetta: Mean MAVE spearman correlation: {mean(corrs_rosetta):.3f}")
171 |
--------------------------------------------------------------------------------
/src/models/msa_transformer/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | from .modules import (
10 | AxialTransformerLayer,
11 | LearnedPositionalEmbedding,
12 | RobertaLMHead,
13 | ESM1bLayerNorm,
14 | ContactPredictionHead,
15 | )
16 |
17 | from .axial_attention import RowSelfAttention, ColumnSelfAttention
18 | from .__init__ import LayerNorm, tuple_index
19 |
20 | class MSATransformer(nn.Module):
21 | def __init__(self):
22 | super().__init__()
23 | self.num_layers = 12
24 | self.embed_dim = 768
25 | self.logit_bias = True
26 | self.ffn_embed_dim = 3072
27 | self.attention_heads = 12
28 | self.dropout = 0.1
29 | self.attention_dropout = 0.1
30 | self.activation_dropout = 0.1
31 | self.max_tokens_per_msa = 2 ** 22
32 | self.max_positions = 1024
33 | self.embed_positions_msa = False
34 | self.alphabet_size = 33
35 | self.padding_idx = 1
36 | self.mask_idx = 32
37 | self.cls_idx = 0
38 | self.eos_idx = 2
39 | self.prepend_bos = True
40 | self.append_eos = False
41 |
42 | self.embed_tokens = nn.Embedding(
43 | self.alphabet_size, self.embed_dim, padding_idx=self.padding_idx
44 | )
45 |
46 | self.msa_position_embedding = nn.Parameter(
47 | 0.01 * torch.randn(1, 1024, 1, self.embed_dim),
48 | requires_grad=True,
49 | )
50 |
51 | self.dropout_module = nn.Dropout(self.dropout)
52 | self.layers = nn.ModuleList(
53 | [
54 | AxialTransformerLayer(
55 | self.embed_dim,
56 | self.ffn_embed_dim,
57 | self.attention_heads,
58 | self.dropout,
59 | self.attention_dropout,
60 | self.activation_dropout,
61 | self.max_tokens_per_msa,
62 | )
63 | for _ in range(self.num_layers)
64 | ]
65 | )
66 |
67 | self.contact_head = ContactPredictionHead(
68 | self.num_layers * self.attention_heads,
69 | self.prepend_bos,
70 | self.append_eos,
71 | eos_idx=self.eos_idx,
72 | )
73 | self.embed_positions = LearnedPositionalEmbedding(
74 | self.max_positions,
75 | self.embed_dim,
76 | self.padding_idx,
77 | )
78 | self.emb_layer_norm_before = ESM1bLayerNorm(self.embed_dim)
79 | self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)
80 | self.lm_head = RobertaLMHead(
81 | embed_dim=self.embed_dim,
82 | output_dim=self.alphabet_size,
83 | weight=self.embed_tokens.weight,
84 | )
85 |
86 | def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False, self_row_attn_mask=None):
87 | if return_contacts:
88 | need_head_weights = True
89 |
90 | assert tokens.ndim == 3
91 | batch_size, num_alignments, seqlen = tokens.size()
92 | padding_mask = tokens.eq(self.padding_idx) # B, R, C
93 | if not padding_mask.any():
94 | padding_mask = None
95 |
96 | x = self.embed_tokens(tokens)
97 | x += self.embed_positions(tokens.view(batch_size * num_alignments, seqlen)).view(x.size())
98 | if self.msa_position_embedding is not None:
99 | if x.size(1) > 1024:
100 | raise RuntimeError(
101 | "Using model with MSA position embedding trained on maximum MSA "
102 | f"depth of 1024, but received {x.size(1)} alignments."
103 | )
104 | x += self.msa_position_embedding[:, :num_alignments]
105 |
106 | x = self.emb_layer_norm_before(x)
107 |
108 | x = self.dropout_module(x)
109 |
110 | if padding_mask is not None:
111 | x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
112 |
113 | repr_layers = set(repr_layers)
114 | hidden_representations = {}
115 | if 0 in repr_layers:
116 | hidden_representations[0] = x
117 |
118 | if need_head_weights:
119 | row_attn_weights = []
120 | col_attn_weights = []
121 |
122 | # Mask row attention from tokens
123 | if self_row_attn_mask is not None:
124 | self_row_attn_mask = nn.functional.pad(self_row_attn_mask, (1, 0, 1, 0),
125 | value=True)
126 |
127 | # B x R x C x D -> R x C x B x D
128 | x = x.permute(1, 2, 0, 3)
129 |
130 | for layer_idx, layer in enumerate(self.layers):
131 | x = layer(
132 | x,
133 | self_attn_padding_mask=padding_mask,
134 | self_row_attn_mask=self_row_attn_mask,
135 | need_head_weights=need_head_weights,
136 | )
137 | if need_head_weights:
138 | x, col_attn, row_attn = x
139 | # H x C x B x R x R -> B x H x C x R x R
140 | col_attn_weights.append(col_attn.permute(2, 0, 1, 3, 4))
141 | # H x B x C x C -> B x H x C x C
142 | row_attn_weights.append(row_attn.permute(1, 0, 2, 3))
143 | if (layer_idx + 1) in repr_layers:
144 | hidden_representations[layer_idx + 1] = x.permute(2, 0, 1, 3)
145 |
146 | x = self.emb_layer_norm_after(x)
147 | x = x.permute(2, 0, 1, 3) # R x C x B x D -> B x R x C x D
148 |
149 | # last hidden representation should have layer norm applied
150 | if (layer_idx + 1) in repr_layers:
151 | hidden_representations[layer_idx + 1] = x
152 | x = self.lm_head(x)
153 |
154 | result = {"logits": x, "representations": hidden_representations}
155 | if need_head_weights:
156 | # col_attentions: B x L x H x C x R x R
157 | col_attentions = torch.stack(col_attn_weights, 1)
158 | # row_attentions: B x L x H x C x C
159 | row_attentions = torch.stack(row_attn_weights, 1)
160 | result["col_attentions"] = col_attentions
161 | result["row_attentions"] = row_attentions
162 | if return_contacts:
163 | contacts = self.contact_head(tokens, row_attentions)
164 | result["contacts"] = contacts
165 |
166 | return result
167 |
168 | def predict_contacts(self, tokens):
169 | return self(tokens, return_contacts=True)["contacts"]
170 |
171 | def max_tokens_per_msa_(self, value: int) -> None:
172 | """The MSA Transformer automatically batches attention computations when
173 | gradients are disabled to allow you to pass in larger MSAs at test time than
174 | you can fit in GPU memory. By default this occurs when more than 2^14 tokens
175 | are passed in the input MSA. You can set this value to infinity to disable
176 | this behavior.
177 | """
178 | for module in self.modules():
179 | if isinstance(module, (RowSelfAttention, ColumnSelfAttention)):
180 | module.max_tokens_per_msa = value
181 |
--------------------------------------------------------------------------------
/src/pdb_parser_scripts/clean_pdb.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import subprocess
4 | import sys
5 | import tempfile
6 | from io import BytesIO, StringIO
7 |
8 | import Bio.PDB
9 | import Bio.PDB.Polypeptide
10 | import Bio.SeqIO
11 | import pdbfixer
12 | import simtk
13 | import openmm
14 | import openmm.app
15 |
16 |
17 | PDBIO = Bio.PDB.PDBIO()
18 | PDB_PARSER = Bio.PDB.PDBParser(PERMISSIVE=0)
19 |
20 |
21 | class NonHetSelector(Bio.PDB.Select):
22 | """ Remove HET atoms and choose first conformation of disordered atoms"""
23 |
24 | def accept_residue(self, residue):
25 | norm_res_bool = residue.get_resname() in [
26 | pdbfixer.pdbfixer.substitutions[key]
27 | for key in pdbfixer.pdbfixer.substitutions
28 | ]
29 | abnorm_res_bool = residue.get_resname() in [
30 | key for key in pdbfixer.pdbfixer.substitutions
31 | ]
32 | return (
33 | norm_res_bool or abnorm_res_bool
34 | ) # Accept abnorm since they are converted later
35 |
36 | def accept_atom(self, atom):
37 | return (
38 | not atom.is_disordered()
39 | or atom.get_altloc() == "A"
40 | or atom.get_altloc() == "1"
41 | ) and atom.id[0] in ["C", "H", "N", "O", "S", "P"]
42 |
43 |
44 | class PDBFixerResIdentifiabilityIssue(Exception):
45 | pass
46 |
47 |
48 | def _step_3_pdbfixer(first_model, temp3):
49 | for chain in first_model:
50 | for res in chain:
51 | for atom in res:
52 | atom.set_altloc(" ")
53 | PDBIO.set_structure(first_model)
54 | PDBIO.save(temp3)
55 | temp3.flush()
56 |
57 | # Use PDBFixer to fix common PDB errors
58 | fixer = pdbfixer.PDBFixer(temp3.name)
59 | fixer.findMissingResidues()
60 | fixer.findNonstandardResidues()
61 | fixer.replaceNonstandardResidues()
62 | fixer.findMissingAtoms()
63 | fixer.addMissingAtoms()
64 | fixer.addMissingHydrogens(7.0)
65 | return temp3, fixer
66 |
67 |
68 | def _step_4_fix_numbering(fixer, temp3, temp4):
69 | simtk.openmm.app.PDBFile.writeFile(
70 | fixer.topology, fixer.positions, temp4, keepIds=False
71 | )
72 | temp4.flush()
73 | # Fix IDs manually since pdbfixer does not preserve insertion codes
74 | structure_before = PDB_PARSER.get_structure(temp3.name, temp3.name)
75 | structure_after = PDB_PARSER.get_structure(temp4.name, temp4.name)
76 | residues_before = []
77 | for chain in structure_before[0]:
78 | residues_before.append(chain.get_list())
79 | residues_after = []
80 | for chain in structure_after[0]:
81 | residues_after.append(chain.get_list())
82 | chain_counter = ""
83 | for i, chain in enumerate(structure_before[0]):
84 | try:
85 | if (
86 | structure_after[0].get_list()[i].id
87 | != structure_before[0].get_list()[i].id
88 | ):
89 | try:
90 | # HACK BECAUSE OF https://github.com/biopython/biopython/issues/1551
91 | # Essentially, a new change in biopython prevents you from changing the
92 | # id to an already existing id which broke this initial script.
93 | # Therefore, we now change the ids to "change_counter" which will never look
94 | # like a canonical chainid.
95 | structure_after[0][
96 | structure_before[0].get_list()[i].id
97 | ].id = chain_counter
98 | chain_counter += "KK"
99 | except KeyError:
100 | pass
101 | structure_after[0].get_list()[i].id = (
102 | structure_before[0].get_list()[i].id
103 | )
104 | if len(residues_before[i]) != len(residues_after[i]):
105 | raise PDBFixerResIdentifiabilityIssue()
106 |
107 | # When exceeding chainid Z, pdbfixer has discarded it, whereas biopython has not.
108 | # For simplicity, we just discard it as well and pretend it does not exist.
109 | # This is a very rare instance and will likely never be a problem unless you
110 | # are extremely unlucky to work with huge proteins where you care about the
111 | # truncation.
112 | except IndexError:
113 | continue
114 |
115 | counter = 99999 # A large residue number that will never exist in a pdb.
116 | for res1, res2 in zip(residues_before[i], residues_after[i]):
117 | assert (
118 | res1.get_resname().strip() == res2.get_resname().strip()
119 | or pdbfixer.pdbfixer.substitutions[res1.get_resname()].strip()
120 | == res2.get_resname().strip()
121 | )
122 | if res2.id != res1.id:
123 | try:
124 | # Similar issue as previous hack https://github.com/biopython/biopython/issues/1551
125 | structure_after[0][chain.get_id()][res1.id].id = (
126 | " ",
127 | counter,
128 | " ",
129 | )
130 | except KeyError:
131 | pass
132 | res2.id = res1.id
133 | counter += 1
134 |
135 | return structure_after
136 |
137 |
138 | def clean_pdb(pdb_input_filename: str, out_dir: str):
139 | """
140 | Function to clean pdbs using pdbfixer.
141 |
142 | Parameters
143 | ----------
144 | pdb_input_filename: str
145 | PDB filename
146 | out_dir: str
147 | Output directory.
148 | """
149 | pdbid = pdb_input_filename.split("/")[-1].split(".pdb")[0]
150 |
151 | with tempfile.NamedTemporaryFile(mode="wt", delete=True) as temp1:
152 | first_model = PDB_PARSER.get_structure(pdbid, pdb_input_filename)[0]
153 |
154 | # Step 1: NonHetSelector filter
155 | with tempfile.NamedTemporaryFile(mode="wt", delete=True) as temp2:
156 | PDBIO.set_structure(first_model)
157 | PDBIO.save(temp2, select=NonHetSelector())
158 | temp2.flush()
159 | first_model = PDB_PARSER.get_structure(temp2.name, temp2.name)[0]
160 |
161 | # Step 2: Replace altloc chars to " " and use pdbfixer
162 | with tempfile.NamedTemporaryFile(mode="wt", delete=True) as temp3:
163 | temp_3, fixer = _step_3_pdbfixer(first_model, temp3)
164 |
165 | # Step 3: Correct for pdbfixer not preserving insertion codes
166 | with tempfile.NamedTemporaryFile(mode="wt", delete=True) as temp4:
167 | structure_after = _step_4_fix_numbering(fixer, temp3, temp4)
168 | with open(
169 | os.path.join(out_dir, pdbid + ".pdb"), "w"
170 | ) as outpdb:
171 | PDBIO.set_structure(structure_after[0])
172 | PDBIO.save(outpdb)
173 |
174 | if __name__ == "__main__":
175 | # Argument Parser
176 | parser = argparse.ArgumentParser()
177 | parser.add_argument("--pdb_file_in", type=str)
178 | parser.add_argument("--out_dir", type=str)
179 |
180 | # Parse arguments
181 | args_dict = vars(parser.parse_args())
182 | pdb_input_filename = args_dict["pdb_file_in"]
183 | out_dir = args_dict["out_dir"]
184 |
185 | # Clean
186 | clean_pdb(pdb_input_filename, out_dir)
187 |
--------------------------------------------------------------------------------
/src/run_test_rocklin.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import glob
3 | import re
4 | import torch
5 | import models.gvp.data, models.gvp.models
6 | import json
7 | import torch_geometric
8 | import esm
9 | import pandas as pd
10 | import random
11 | import torch.multiprocessing
12 |
13 | torch.multiprocessing.set_sharing_strategy("file_system")
14 | from models.msa_transformer.model import MSATransformer
15 | from models.gvp.models import SSEmbGNN
16 | from helpers import (
17 | read_msa,
18 | loop_pred,
19 | )
20 | from visualization import plot_rocklin
21 | import torch.utils.data
22 | from collections import OrderedDict
23 | import shutil
24 | import pdb_parser_scripts.parse_pdbs as parse_pdbs
25 | from collections import OrderedDict
26 | from ast import literal_eval
27 | import subprocess
28 |
29 |
30 | def test(run_name, epoch, num_ensemble=1, device=None):
31 | # Download raw data
32 | subprocess.run(
33 | ["wget", "https://zenodo.org/record/7992926/files/AlphaFold_model_PDBs.zip"]
34 | )
35 | subprocess.run(
36 | ["unzip", "AlphaFold_model_PDBs.zip", "-d", "../data/test/rocklin/raw/"]
37 | )
38 | subprocess.run(["rm", "AlphaFold_model_PDBs.zip"])
39 |
40 | subprocess.run(
41 | [
42 | "wget",
43 | "https://zenodo.org/record/7992926/files/Processed_K50_dG_datasets.zip",
44 | ]
45 | )
46 | subprocess.run(
47 | ["unzip", "Processed_K50_dG_datasets.zip", "-d", "../data/test/rocklin/raw/"]
48 | )
49 | subprocess.run(["rm", "Processed_K50_dG_datasets.zip"])
50 |
51 | # Load data
52 | df = pd.read_csv(
53 | "../data/test/rocklin/raw/Processed_K50_dG_datasets/Tsuboyama2023_Dataset2_Dataset3_20230416.csv"
54 | )
55 |
56 | # Use only Dataset 3 with well-defined ddG's
57 | df = df[df["ddG_ML"] != "-"]
58 |
59 | # Switch sign of experimental ddG's
60 | df["ddG_ML"] = -pd.to_numeric(df["ddG_ML"])
61 |
62 | # Use only non-synonomous substitutions
63 | df = df[~df["mut_type"].str.startswith("ins")]
64 | df = df[~df["mut_type"].str.startswith("del")]
65 | df = df[df["mut_type"] != "wt"]
66 |
67 | # Change pdb names to align with structure names
68 | df["WT_name"] = df["WT_name"].str.replace("|", ":")
69 |
70 | # Move structures
71 | structures = list(df["WT_name"].unique())
72 | structures_not_available = []
73 |
74 | for structure in structures:
75 | try:
76 | shutil.copy(
77 | f"../data/test/rocklin/raw/AlphaFold_model_PDBs/{structure}",
78 | f"../data/test/rocklin/structure/raw/{structure}",
79 | )
80 | except:
81 | structures_not_available.append(structure)
82 |
83 | df = df[~df["WT_name"].isin(structures_not_available)]
84 | print(
85 | f"Number of Rocklin assays without available AF2 structures: {len(structures_not_available)}"
86 | )
87 |
88 | # Save ddG data
89 | df_ddg = df[["WT_name", "mut_type", "ddG_ML"]].reset_index(drop=True)
90 | df_ddg = df_ddg.rename(
91 | columns={"WT_name": "pdb_id", "mut_type": "variant", "ddG_ML": "score_exp"}
92 | )
93 | df_ddg["pdb_id"] = df_ddg["pdb_id"].str[:-4]
94 | df_ddg.to_csv("../data/test/rocklin/exp/ddg.csv", index=False)
95 |
96 | ## Pre-process PDBs
97 | pdb_dir = "../data/test/rocklin/structure/"
98 | subprocess.run(
99 | [
100 | "pdb_parser_scripts/clean_pdbs.sh",
101 | str(pdb_dir),
102 | ]
103 | )
104 | parse_pdbs.parse(pdb_dir)
105 |
106 | # Load structure data
107 | with open(f"../data/test/rocklin/structure/coords.json") as json_file:
108 | data = json.load(json_file)
109 | json_file.close()
110 |
111 | # Compute MSAs
112 | sys.path += [":/projects/prism/people/skr526/mmseqs/bin"]
113 | subprocess.run(
114 | [
115 | "colabfold_search",
116 | f"{pdb_dir}/seqs.fasta",
117 | "/projects/prism/people/skr526/databases",
118 | "../data/test/rocklin/msa/",
119 | ]
120 | )
121 | subprocess.run(["python", "merge_and_sort_msas.py", "../data/test/rocklin/msa"])
122 |
123 | # Load MSA data
124 | msa_filenames = sorted(glob.glob(f"../data/test/rocklin/msa/*.a3m"))
125 | mave_msa_sub = {}
126 | for i, f in enumerate(msa_filenames):
127 | # name = f.split("/")[-1].split(".")[0]
128 | name = f.split("/")[-1].split(".")[0][:-2]
129 | mave_msa_sub[name] = []
130 | for j in range(num_ensemble):
131 | msa = read_msa(f)
132 | msa_sub = [msa[0]]
133 | k = min(len(msa) - 1, 16 - 1)
134 | msa_sub += [msa[i] for i in sorted(random.sample(range(1, len(msa)), k))]
135 | mave_msa_sub[name].append(msa_sub)
136 |
137 | # Add MSAs to data
138 | for entry in data:
139 | entry["msa"] = mave_msa_sub[entry["name"]]
140 |
141 | # Convert to graph data sets
142 | testset = models.gvp.data.ProteinGraphData(data)
143 | letter_to_num = testset.letter_to_num
144 |
145 | # Make variant pos dict
146 | variant_pos_dict = {}
147 | for pdb_id in df_ddg["pdb_id"].unique():
148 | df_ddg_pdb = df_ddg[df_ddg["pdb_id"] == pdb_id]
149 | variant_wtpos_list = [
150 | [x[:-1] for x in x.split(":")] for x in df_ddg_pdb["variant"].tolist()
151 | ]
152 | variant_wtpos_list = list(
153 | OrderedDict.fromkeys(
154 | [item for sublist in variant_wtpos_list for item in sublist]
155 | )
156 | ) # Remove duplicates
157 | variant_pos_dict[pdb_id] = variant_wtpos_list
158 |
159 | # Load MSA Transformer
160 | _, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
161 | msa_batch_converter = msa_alphabet.get_batch_converter()
162 | model_msa = MSATransformer(msa_alphabet)
163 | model_msa = model_msa.to(device)
164 |
165 | model_dict = OrderedDict()
166 | state_dict_msa = torch.load(
167 | f"../output/train/models/msa_transformer/{run_name}_msa_transformer_{epoch}.pt"
168 | )
169 | pattern = re.compile("module.")
170 | for k, v in state_dict_msa.items():
171 | if re.search("module", k):
172 | model_dict[re.sub(pattern, "", k)] = v
173 | else:
174 | model_dict = state_dict_msa
175 | model_msa.load_state_dict(model_dict)
176 |
177 | # Load GVP
178 | node_dim = (256, 64)
179 | edge_dim = (32, 1)
180 | model_gvp = SSEmbGNN((6, 3), node_dim, (32, 1), edge_dim)
181 | model_gvp = model_gvp.to(device)
182 |
183 | model_dict = OrderedDict()
184 | state_dict_gvp = torch.load(f"../output/train/models/gvp/{run_name}_gvp_{epoch}.pt")
185 | pattern = re.compile("module.")
186 | for k, v in state_dict_gvp.items():
187 | if k.startswith("module"):
188 | model_dict[k[7:]] = v
189 | else:
190 | model_dict = state_dict_gvp
191 | model_gvp.load_state_dict(model_dict)
192 |
193 | # Initialize data loader
194 | test_loader = torch_geometric.loader.DataLoader(
195 | testset, batch_size=1, shuffle=False
196 | )
197 |
198 | # Call test
199 | model_msa.eval()
200 | model_gvp.eval()
201 |
202 | with torch.no_grad():
203 | pred_list, acc_mean = loop_pred(
204 | model_msa,
205 | model_gvp,
206 | msa_batch_converter,
207 | test_loader,
208 | variant_pos_dict,
209 | data,
210 | letter_to_num,
211 | device=device,
212 | )
213 |
214 | # Transform results into df
215 | df_ml = pd.DataFrame(pred_list, columns=["pdb_id", "variant_pos", "score_ml_pos"])
216 |
217 | # Save
218 | df_ml.to_csv(f"../output/rocklin/df_ml_{run_name}.csv", index=False)
219 |
220 | # Load
221 | df_ml = pd.read_csv(
222 | f"../output/rocklin/df_ml_{run_name}.csv",
223 | converters=dict(score_ml_pos=literal_eval),
224 | )
225 |
226 | # Compute score_ml from nlls
227 | # OBS: We cannot vectorize this part since we have an unknown number of mutations per position :(
228 | pdb_variant_list = df_ddg.values.tolist()
229 | for i, row in enumerate(pdb_variant_list):
230 | pdb_id = row[0]
231 | print(
232 | f"Computing score for protein {pdb_id} variant: {i+1}/{len(pdb_variant_list)}"
233 | )
234 | variant_set = row[1].split(":")
235 |
236 | score_ml = 0.0
237 | for variant in variant_set:
238 | wt = letter_to_num[variant[0]]
239 | pos = int(re.findall(r"\d+", variant)[0])
240 | mt = letter_to_num[variant[-1]]
241 | score_ml_pos = df_ml[
242 | (df_ml["pdb_id"].str.startswith(pdb_id)) & (df_ml["variant_pos"] == pos)
243 | ]["score_ml_pos"].values[0]
244 | score_ml += float(score_ml_pos[mt])
245 | pdb_variant_list[i].append(score_ml)
246 |
247 | # Convert to df
248 | df_total = pd.DataFrame(
249 | pdb_variant_list, columns=["DMS_id", "variant_set", "score_exp", "score_ml"]
250 | )
251 |
252 | # Save
253 | df_total.to_csv(f"../output/rocklin/df_total_{run_name}.csv", index=False)
254 |
255 | # Load
256 | df_total = pd.read_csv(f"../output/rocklin/df_total_{run_name}.csv")
257 |
258 | # Compute correlations
259 | plot_rocklin(df_total)
260 |
--------------------------------------------------------------------------------
/src/models/msa_transformer/axial_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import math
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | class RowSelfAttention(nn.Module):
12 | """Compute self-attention over rows of a 2D input."""
13 |
14 | def __init__(
15 | self,
16 | embed_dim,
17 | num_heads,
18 | dropout=0.0,
19 | max_tokens_per_msa: int = 2 ** 16,
20 | ):
21 | super().__init__()
22 | self.num_heads = num_heads
23 | self.dropout = dropout
24 | self.head_dim = embed_dim // num_heads
25 | self.scaling = self.head_dim ** -0.5
26 | self.max_tokens_per_msa = max_tokens_per_msa
27 | self.attn_shape = "hnij"
28 |
29 | self.k_proj = nn.Linear(embed_dim, embed_dim)
30 | self.v_proj = nn.Linear(embed_dim, embed_dim)
31 | self.q_proj = nn.Linear(embed_dim, embed_dim)
32 |
33 | self.out_proj = nn.Linear(embed_dim, embed_dim)
34 | self.dropout_module = nn.Dropout(dropout)
35 |
36 | def align_scaling(self, q):
37 | num_rows = q.size(0)
38 | return self.scaling / math.sqrt(num_rows)
39 |
40 | def _batched_forward(
41 | self,
42 | x,
43 | self_attn_mask=None,
44 | self_attn_padding_mask=None,
45 | ):
46 | num_rows, num_cols, batch_size, embed_dim = x.size()
47 | max_rows = max(1, self.max_tokens_per_msa // num_cols)
48 | attns = 0
49 | scaling = self.align_scaling(x)
50 | for start in range(0, num_rows, max_rows):
51 | attn_weights = self.compute_attention_weights(
52 | x[start : start + max_rows],
53 | scaling,
54 | self_attn_mask=self_attn_mask,
55 | self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows]
56 | if self_attn_padding_mask is not None
57 | else None,
58 | )
59 | attns += attn_weights
60 | attn_probs = attns.softmax(-1)
61 | attn_probs = self.dropout_module(attn_probs)
62 |
63 | outputs = []
64 | for start in range(0, num_rows, max_rows):
65 | output = self.compute_attention_update(x[start : start + max_rows], attn_probs)
66 | outputs.append(output)
67 |
68 | output = torch.cat(outputs, 0)
69 | return output, attn_probs
70 |
71 | def compute_attention_weights(
72 | self,
73 | x,
74 | scaling: float,
75 | self_attn_mask=None,
76 | self_attn_padding_mask=None,
77 | ):
78 | num_rows, num_cols, batch_size, embed_dim = x.size()
79 | q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
80 | k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
81 | q *= scaling
82 | if self_attn_padding_mask is not None:
83 | # Zero out any padded aligned positions - this is important since
84 | # we take a sum across the alignment axis.
85 | q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q)
86 |
87 | attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)
88 |
89 | if self_attn_mask is not None:
90 | attn_weights = attn_weights.masked_fill(
91 | self_attn_mask.unsqueeze(0).unsqueeze(0),
92 | -10000,
93 | )
94 | else:
95 | pass
96 |
97 | if self_attn_padding_mask is not None:
98 | attn_weights = attn_weights.masked_fill(
99 | self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2),
100 | -10000,
101 | )
102 |
103 | return attn_weights
104 |
105 | def compute_attention_update(
106 | self,
107 | x,
108 | attn_probs,
109 | ):
110 | num_rows, num_cols, batch_size, embed_dim = x.size()
111 | v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
112 | context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
113 | context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
114 | output = self.out_proj(context)
115 | return output
116 |
117 | def forward(
118 | self,
119 | x,
120 | self_attn_mask=None,
121 | self_attn_padding_mask=None,
122 | ):
123 | num_rows, num_cols, batch_size, embed_dim = x.size()
124 | if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled():
125 | return self._batched_forward(x, self_attn_mask, self_attn_padding_mask)
126 | else:
127 | scaling = self.align_scaling(x)
128 | attn_weights = self.compute_attention_weights(
129 | x, scaling, self_attn_mask, self_attn_padding_mask
130 | )
131 | attn_probs = attn_weights.softmax(-1)
132 | attn_probs = self.dropout_module(attn_probs)
133 | output = self.compute_attention_update(x, attn_probs)
134 | return output, attn_probs
135 |
136 |
137 | class ColumnSelfAttention(nn.Module):
138 | """Compute self-attention over columns of a 2D input."""
139 |
140 | def __init__(
141 | self,
142 | embed_dim,
143 | num_heads,
144 | dropout=0.0,
145 | max_tokens_per_msa: int = 2 ** 16,
146 | ):
147 | super().__init__()
148 |
149 | self.num_heads = num_heads
150 | self.dropout = dropout
151 | self.head_dim = embed_dim // num_heads
152 | self.scaling = self.head_dim ** -0.5
153 | self.max_tokens_per_msa = max_tokens_per_msa
154 |
155 | self.k_proj = nn.Linear(embed_dim, embed_dim)
156 | self.v_proj = nn.Linear(embed_dim, embed_dim)
157 | self.q_proj = nn.Linear(embed_dim, embed_dim)
158 |
159 | self.out_proj = nn.Linear(embed_dim, embed_dim)
160 | self.dropout_module = nn.Dropout(dropout)
161 |
162 | def _batched_forward(
163 | self,
164 | x,
165 | self_attn_mask=None,
166 | self_attn_padding_mask=None,
167 | ):
168 | num_rows, num_cols, batch_size, embed_dim = x.size()
169 | max_cols = max(1, self.max_tokens_per_msa // num_rows)
170 | outputs = []
171 | attns = []
172 | for start in range(0, num_cols, max_cols):
173 | output, attn = self(
174 | x[:, start : start + max_cols],
175 | self_attn_mask=self_attn_mask,
176 | self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols]
177 | if self_attn_padding_mask is not None
178 | else None,
179 | )
180 | outputs.append(output)
181 | attns.append(attn)
182 | output = torch.cat(outputs, 1)
183 | attns = torch.cat(attns, 1)
184 | return output, attns
185 |
186 | def compute_attention_update(
187 | self,
188 | x,
189 | self_attn_mask=None,
190 | self_attn_padding_mask=None,
191 | ):
192 | num_rows, num_cols, batch_size, embed_dim = x.size()
193 | if num_rows == 1:
194 | # if there is only 1 position, this is equivalent and doesn't break with padding
195 | attn_probs = torch.ones(
196 | self.num_heads,
197 | num_cols,
198 | batch_size,
199 | num_rows,
200 | num_rows,
201 | device=x.device,
202 | dtype=x.dtype,
203 | )
204 | output = self.out_proj(self.v_proj(x))
205 | else:
206 | q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
207 | k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
208 | v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
209 | q *= self.scaling
210 |
211 | attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)
212 |
213 | if self_attn_mask is not None:
214 | raise NotImplementedError
215 | if self_attn_padding_mask is not None:
216 | attn_weights = attn_weights.masked_fill(
217 | self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3),
218 | -10000,
219 | )
220 |
221 | attn_probs = attn_weights.softmax(-1)
222 | attn_probs = self.dropout_module(attn_probs)
223 | context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
224 | context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
225 | output = self.out_proj(context)
226 | return output, attn_probs
227 |
228 | def forward(
229 | self,
230 | x,
231 | self_attn_mask=None,
232 | self_attn_padding_mask=None,
233 | ):
234 | num_rows, num_cols, batch_size, embed_dim = x.size()
235 | # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled():
236 | if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled():
237 | return self._batched_forward(
238 | x,
239 | self_attn_mask,
240 | self_attn_padding_mask,
241 | )
242 | else:
243 | return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask)
244 |
--------------------------------------------------------------------------------
/src/run_test_clinvar.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import glob
3 | import re
4 | import torch
5 | import models.gvp.data, models.gvp.models
6 | import json
7 | import torch_geometric
8 | import esm
9 | import pandas as pd
10 | import random
11 | import torch.multiprocessing
12 | import numpy as np
13 |
14 | torch.multiprocessing.set_sharing_strategy("file_system")
15 | from models.msa_transformer.model import MSATransformer
16 | from models.gvp.models import SSEmbGNN
17 | from helpers import (
18 | read_msa,
19 | loop_pred,
20 | compute_auc_group,
21 | )
22 | from visualization import plot_aucroc
23 | import torch.utils.data
24 | from collections import OrderedDict
25 | import shutil
26 | import pdb_parser_scripts.parse_pdbs as parse_pdbs
27 | from collections import OrderedDict
28 | from ast import literal_eval
29 | import subprocess
30 | from sklearn.metrics import auc, roc_curve, roc_auc_score
31 |
32 | def test(run_name, epoch, num_ensemble=5, msa_row_attn_mask=True, device=None):
33 | # Download raw data
34 | subprocess.run(["wget","https://marks.hms.harvard.edu/proteingym/clinical_ProteinGym_substitutions.zip"])
35 | subprocess.run(["unzip","clinical_ProteinGym_substitutions.zip","-d","../data/test/clinvar/raw"])
36 | subprocess.run(["rm","clinical_ProteinGym_substitutions.zip"])
37 |
38 | # Load ClinVar data
39 | print("Processing ClinVar data")
40 | clinvar_files = glob.glob("../data/test/clinvar/raw/*.csv")
41 | dfs = []
42 | for f in clinvar_files:
43 | df = pd.read_csv(f)
44 | dfs.append(df)
45 | df_clinvar = pd.concat(dfs, ignore_index=True)
46 | df_clinvar = df_clinvar[["protein","protein_sequence","mutant","DMS_bin_score"]]
47 | df_clinvar = df_clinvar.rename(columns={"protein":"prot_name",
48 | "protein_sequence":"seq",
49 | "mutant":"variant",
50 | "DMS_bin_score":"label",
51 | }
52 | )
53 |
54 | # Save seqs to fasta
55 | df_uniqueseqs = df[["prot_name","seq"]].drop_duplicates()
56 | fh = open(f"../data/test/clinvar/seqs.fasta","w")
57 | for index, row in df_uniqueseqs.iterrows():
58 | fh.write(f">{row['prot_name']}\n")
59 | fh.write(f"{row['seq']}")
60 | fh.write("\n")
61 | fh.close()
62 |
63 | ## Pre-process PDBs
64 | pdb_dir = "../data/test/clinvar/structure/"
65 | subprocess.run(
66 | [
67 | "pdb_parser_scripts/clean_pdbs.sh",
68 | str(pdb_dir),
69 | ]
70 | )
71 | parse_pdbs.parse(pdb_dir)
72 |
73 | # Load structure data
74 | with open(f"../data/test/clinvar/structure/coords.json") as json_file:
75 | data = json.load(json_file)
76 | json_file.close()
77 |
78 | # Compute MSAs
79 | sys.path += ["/projects/prism/people/skr526/mmseqs/bin"]
80 | subprocess.run("source activate mmseqs2 && colabfold_search ../data/test/clinvar/seqs.fasta /projects/prism/people/skr526/databases ../data/test/clinvar/msa/ && source activate struc-seq", shell=True)
81 |
82 | subprocess.run(
83 | [
84 | "colabfold_search",
85 | "../data/test/clinvar/seqs.fasta",
86 | "/projects/prism/people/skr526/databases",
87 | "../data/test/clinvar/msa/",
88 | ]
89 | )
90 | subprocess.run(["python", "merge_and_sort_msas.py", "../data/test/clinvar/msa"])
91 |
92 | # Load MSA data
93 | msa_filenames = sorted(glob.glob(f"../data/test/clinvar/msa/*.a3m"))
94 | mave_msa_sub = {}
95 | for i, f in enumerate(msa_filenames):
96 | name = ".".join(f.split("/")[-1].split(".")[:-1])
97 | mave_msa_sub[name] = []
98 | for j in range(num_ensemble):
99 | msa = read_msa(f)
100 | msa_sub = [msa[0]]
101 | k = min(len(msa) - 1, 16 - 1)
102 | msa_sub += [msa[i] for i in sorted(random.sample(range(1, len(msa)), k))]
103 | mave_msa_sub[name].append(msa_sub)
104 |
105 | # Add MSAs to data
106 | for entry in data:
107 | entry["msa"] = mave_msa_sub[entry["name"]]
108 |
109 | # Make variant pos dict
110 | variant_pos_dict = {}
111 | for prot_name in df_clinvar["prot_name"].unique():
112 | df_clinvar_prot = df_clinvar[df_clinvar["prot_name"] == prot_name]
113 | variant_wtpos_list = [
114 | [x[:-1] for x in x.split(":")] for x in df_clinvar_prot["variant"].tolist()
115 | ]
116 | variant_wtpos_list = list(
117 | OrderedDict.fromkeys(
118 | [item for sublist in variant_wtpos_list for item in sublist]
119 | )
120 | ) # Remove duplicates
121 | variant_pos_dict[prot_name.split(".")[0]] = variant_wtpos_list
122 |
123 | # Convert to graph data sets
124 | testset = models.gvp.data.ProteinGraphData(data)
125 | letter_to_num = testset.letter_to_num
126 |
127 | # Make variant pos dict
128 | variant_pos_dict = {}
129 | for prot_name in df_clinvar["prot_name"].unique():
130 | df_clinvar_prot = df_clinvar[df_clinvar["prot_name"] == prot_name]
131 | variant_wtpos_list = [
132 | [x[:-1] for x in x.split(":")] for x in df_clinvar_prot["variant"].tolist()
133 | ]
134 | variant_wtpos_list = list(
135 | OrderedDict.fromkeys(
136 | [item for sublist in variant_wtpos_list for item in sublist]
137 | )
138 | ) # Remove duplicates
139 | variant_pos_dict[prot_name.split(".")[0]] = variant_wtpos_list
140 |
141 | # Save data and dict of variant positions
142 | with open(f"../data/test/clinvar/data_with_msas.pkl","wb") as fp:
143 | pickle.dump(data, fp)
144 |
145 | with open(f"../data/test/clinvar/variant_pos_dict.pkl","wb") as fp:
146 | pickle.dump(variant_pos_dict, fp)
147 |
148 | # Load MSA Transformer
149 | _, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
150 | msa_batch_converter = msa_alphabet.get_batch_converter()
151 | model_msa = MSATransformer(msa_alphabet)
152 | model_msa = model_msa.to(device)
153 |
154 | model_dict = OrderedDict()
155 | state_dict_msa = torch.load(
156 | f"../output/train/models/msa_transformer/{run_name}_msa_transformer_{epoch}.pt"
157 | )
158 | pattern = re.compile("module.")
159 | for k, v in state_dict_msa.items():
160 | if re.search("module", k):
161 | model_dict[re.sub(pattern, "", k)] = v
162 | else:
163 | model_dict = state_dict_msa
164 | model_msa.load_state_dict(model_dict)
165 |
166 | # Load GVP
167 | node_dim = (256, 64)
168 | edge_dim = (32, 1)
169 | model_gvp = SSEmbGNN((6, 3), node_dim, (32, 1), edge_dim)
170 | model_gvp = model_gvp.to(device)
171 |
172 | model_dict = OrderedDict()
173 | state_dict_gvp = torch.load(f"../output/train/models/gvp/{run_name}_gvp_{epoch}.pt")
174 | pattern = re.compile("module.")
175 | for k, v in state_dict_gvp.items():
176 | if k.startswith("module"):
177 | model_dict[k[7:]] = v
178 | else:
179 | model_dict = state_dict_gvp
180 | model_gvp.load_state_dict(model_dict)
181 |
182 | # Initialize data loader
183 | test_loader = torch_geometric.loader.DataLoader(
184 | testset, batch_size=1, shuffle=False
185 | )
186 |
187 | # Call test
188 | model_msa.eval()
189 | model_gvp.eval()
190 |
191 | with torch.no_grad():
192 | pred_list, acc_mean = loop_pred(
193 | model_msa,
194 | model_gvp,
195 | msa_batch_converter,
196 | test_loader,
197 | variant_pos_dict,
198 | data,
199 | letter_to_num,
200 | msa_row_attn_mask=msa_row_attn_mask,
201 | device=device,
202 | )
203 |
204 | # Transform results into df
205 | df_ml = pd.DataFrame(pred_list, columns=["prot_name", "variant_pos", "score_ml_pos"])
206 |
207 | # Save
208 | df_ml.to_csv(f"../output/clinvar/df_ml_{run_name}.csv", index=False)
209 |
210 | # Load
211 | df_ml = pd.read_csv(
212 | f"../output/clinvar/df_ml_{run_name}.csv",
213 | converters=dict(score_ml_pos=literal_eval),
214 | )
215 |
216 | # Compute score_ml from nlls
217 | clinvar_variant_list = df_clinvar.values.tolist()
218 | for i, row in enumerate(clinvar_variant_list):
219 | prot_name = row[0].split(".")[0]
220 | print(
221 | f"Computing score for assay {prot_name} variant: {i+1}/{len(clinvar_variant_list)}"
222 | )
223 | variant_set = row[2].split(":")
224 | score_ml = 0.0
225 |
226 | for variant in variant_set:
227 | wt = letter_to_num[variant[0]]
228 | pos = int(re.findall(r"\d+", variant)[0])
229 | mt = letter_to_num[variant[-1]]
230 | score_ml_pos = df_ml[
231 | (df_ml["prot_name"] == prot_name) & (df_ml["variant_pos"] == pos)
232 | ]["score_ml_pos"].values[0]
233 | score_ml += float(score_ml_pos[mt])
234 | clinvar_variant_list[i].append(score_ml)
235 | df_total = pd.DataFrame(
236 | clinvar_variant_list, columns=["prot_name", "seq", "variant_set", "label", "score_ml"]
237 | )
238 | df_total = df_total[["prot_name", "variant_set", "label", "score_ml"]]
239 |
240 | # Save
241 | df_total.to_csv(f"../output/clinvar/df_total_{run_name}.csv", index=False)
242 |
243 | # Load
244 | df_total = pd.read_csv(f"../output/clinvar/df_total_{run_name}.csv")
245 |
246 | # Compute AUC
247 | df_total["label_bin"] = [1 if "path" in x else 0 for x in df_total['label'].str.lower()]
248 | df_total["score"] = -df_total["score_ml"]
249 | prot_level_auc = df_total.groupby('prot_name').apply(compute_auc_group)
250 | auc_mean = prot_level_auc.mean(skipna=True)
251 | print(f"SSEmb avg. AUC is: {auc_mean}")
252 |
--------------------------------------------------------------------------------
/src/models/gvp/data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import tqdm, random
4 | import torch, math
5 | import torch.utils.data as data
6 | import torch.nn.functional as F
7 | import torch_geometric
8 | import torch_cluster
9 | import esm
10 | from scipy.spatial import distance_matrix
11 |
12 | def _normalize(tensor, dim=-1):
13 | '''
14 | Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
15 | '''
16 | return torch.nan_to_num(
17 | torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)))
18 |
19 |
20 | def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'):
21 | '''
22 | From https://github.com/jingraham/neurips19-graph-protein-design
23 |
24 | Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1.
25 | That is, if `D` has shape [...dims], then the returned tensor will have
26 | shape [...dims, D_count].
27 | '''
28 | D_mu = torch.linspace(D_min, D_max, D_count, device=device)
29 | D_mu = D_mu.view([1, -1])
30 | D_sigma = (D_max - D_min) / D_count
31 | D_expand = torch.unsqueeze(D, -1)
32 |
33 | RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2)
34 | return RBF
35 |
36 |
37 | class CATHDataset:
38 | '''
39 | Loader and container class for the CATH 4.2 dataset downloaded
40 | from http://people.csail.mit.edu/ingraham/graph-protein-design/data/cath/.
41 |
42 | Has attributes `self.total, self.train`, `self.val`, `self.test`, each of which are
43 | JSON/dictionary-type datasets as described in README.md.
44 |
45 | :param path: path to chain_set.jsonl
46 | :param splits_path: path to chain_set_splits.json or equivalent.
47 | '''
48 | def __init__(self, path, splits_path, minlen=0, maxlen=10000):
49 |
50 | self.splits_path = splits_path
51 | self.total = []
52 | self.maxlen = maxlen
53 | self.minlen = minlen
54 |
55 | with open(path) as f:
56 | lines = f.readlines()
57 |
58 | for line in tqdm.tqdm(lines):
59 | entry = json.loads(line)
60 | name = entry['name']
61 | coords = entry['coords']
62 |
63 | entry['coords'] = list(zip(
64 | coords['N'], coords['CA'], coords['C'], coords['O']
65 | ))
66 |
67 | self.total.append(entry)
68 |
69 | def split(self):
70 | with open(self.splits_path) as f:
71 | dataset_splits = json.load(f)
72 | train_list, val_list, test_list = dataset_splits['train'], \
73 | dataset_splits['validation'], dataset_splits['test']
74 |
75 | self.train, self.val, self.test = [], [], []
76 |
77 | for entry in self.total:
78 | if entry["name"] in train_list:
79 | self.train.append(entry)
80 | elif entry["name"] in val_list:
81 | self.val.append(entry)
82 | elif entry["name"] in test_list:
83 | self.test.append(entry)
84 |
85 | class BatchSampler(data.Sampler):
86 | '''
87 | Batch sampler that samples a single protein at a time
88 | '''
89 | def __init__(self, node_counts, shuffle=True):
90 | self.indices = [i for i in range(len(node_counts))]
91 | self.shuffle = shuffle
92 | self._form_batches()
93 |
94 | def _form_batches(self):
95 | self.batches = []
96 | if self.shuffle: random.shuffle(self.indices)
97 | indices = self.indices
98 | for idx in indices:
99 | self.batches.append([idx])
100 |
101 | def __len__(self):
102 | if not self.batches: self._form_batches()
103 | return len(self.batches)
104 |
105 | def __iter__(self):
106 | if not self.batches: self._form_batches()
107 | for batch in self.batches: yield batch
108 |
109 | class ProteinGraphData(data.Dataset):
110 | '''
111 | A map-syle `torch.utils.data.Dataset` which transforms JSON/dictionary-style
112 | protein structures into featurized protein graphs as described in the
113 | manuscript.
114 | '''
115 | def __init__(self, data_list, num_positional_embeddings=16,
116 | top_k=20, dist_cutoff=12, num_rbf=16, device="cpu"):
117 | super(ProteinGraphData, self).__init__()
118 |
119 | self.data_list = data_list
120 | self.dist_cutoff = dist_cutoff
121 | self.num_rbf = num_rbf
122 | self.num_positional_embeddings = num_positional_embeddings
123 | self.top_k = top_k
124 | self.device = device
125 | self.node_counts = [len(e['seq']) for e in data_list]
126 |
127 | self.letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9,
128 | 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8,
129 | 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19,
130 | 'N': 2, 'Y': 18, 'M': 12, "": 20}
131 | self.num_to_letter = {v:k for k, v in self.letter_to_num.items()}
132 |
133 | def __len__(self): return len(self.data_list)
134 |
135 | def __getitem__(self, i): return self._featurize_as_graph(self.data_list[i])
136 |
137 | def _featurize_as_graph(self, protein):
138 | name = protein['name']
139 | with torch.no_grad():
140 | coords = torch.tensor(np.array(protein["coords"]).astype(np.float32), device=self.device)
141 | seq = torch.as_tensor([self.letter_to_num[a] for a in protein['seq']],
142 | device=self.device, dtype=torch.long)
143 |
144 | # Mask positions without coords
145 | mask = torch.isfinite(coords.sum(dim=(1,2)))
146 | coords[~mask] = np.inf
147 |
148 | # Compute edges in graph
149 | X_ca = coords[:, 1]
150 | edge_index = torch_cluster.knn_graph(X_ca, k=self.top_k)
151 |
152 | # Compute distance mask for row attention
153 | dist_mask = torch.zeros((seq.size()[0], seq.size()[0])).bool()
154 | dist_mask[edge_index[1,:], edge_index[0,:]] = True
155 | dist_mask[np.diag_indices(seq.size()[0]), np.diag_indices(seq.size()[0])] = True
156 | dist_mask = ~dist_mask # Reverse so that non-contacts are set to zero in MSA row attention map
157 |
158 | # Compute graph node and edge features
159 | pos_embeddings = self._positional_embeddings(edge_index)
160 | E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]]
161 | rbf = _rbf(E_vectors.norm(dim=-1), D_count=self.num_rbf, device=self.device)
162 |
163 | dihedrals = self._dihedrals(coords)
164 | orientations = self._orientations(X_ca)
165 | sidechains = self._sidechains(coords)
166 |
167 | node_s = dihedrals
168 | node_v = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2)
169 | edge_s = torch.cat([rbf, pos_embeddings], dim=-1)
170 | edge_v = _normalize(E_vectors).unsqueeze(-2)
171 |
172 | node_s, node_v, edge_s, edge_v = map(torch.nan_to_num,
173 | (node_s, node_v, edge_s, edge_v))
174 |
175 | if "msa" in protein:
176 | msa = protein["msa"]
177 | else:
178 | msa = None
179 |
180 | if "label" in protein:
181 | label = protein["label"]
182 | else:
183 | label = None
184 |
185 | if "emb" in protein:
186 | emb = protein["emb"]
187 | else:
188 | emb = None
189 |
190 | data = torch_geometric.data.Data(x=X_ca, seq=seq, name=name,
191 | node_s=node_s, node_v=node_v,
192 | edge_s=edge_s, edge_v=edge_v,
193 | edge_index=edge_index,
194 | mask=mask,
195 | msa=msa,
196 | dist_mask=dist_mask,
197 | label=label,
198 | emb=emb,
199 | )
200 |
201 | return data
202 |
203 | def _dihedrals(self, X, eps=1e-7):
204 | # From https://github.com/jingraham/neurips19-graph-protein-design
205 |
206 | X = torch.reshape(X[:, :3], [3*X.shape[0], 3])
207 | dX = X[1:] - X[:-1]
208 | U = _normalize(dX, dim=-1)
209 | u_2 = U[:-2]
210 | u_1 = U[1:-1]
211 | u_0 = U[2:]
212 |
213 | # Backbone normals
214 | n_2 = _normalize(torch.cross(u_2, u_1), dim=-1)
215 | n_1 = _normalize(torch.cross(u_1, u_0), dim=-1)
216 |
217 | # Angle between normals
218 | cosD = torch.sum(n_2 * n_1, -1)
219 | cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
220 | D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD)
221 |
222 | # This scheme will remove phi[0], psi[-1], omega[-1]
223 | D = F.pad(D, [1, 2])
224 | D = torch.reshape(D, [-1, 3])
225 | # Lift angle representations to the circle
226 | D_features = torch.cat([torch.cos(D), torch.sin(D)], 1)
227 | return D_features
228 |
229 |
230 | def _positional_embeddings(self, edge_index,
231 | num_embeddings=None,
232 | period_range=[2, 1000]):
233 | # From https://github.com/jingraham/neurips19-graph-protein-design
234 | num_embeddings = num_embeddings or self.num_positional_embeddings
235 | d = edge_index[0] - edge_index[1]
236 |
237 | frequency = torch.exp(
238 | torch.arange(0, num_embeddings, 2, dtype=torch.float32, device=self.device)
239 | * -(np.log(10000.0) / num_embeddings)
240 | )
241 | angles = d.unsqueeze(-1) * frequency
242 | E = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
243 | return E
244 |
245 | def _orientations(self, X):
246 | forward = _normalize(X[1:] - X[:-1])
247 | backward = _normalize(X[:-1] - X[1:])
248 | forward = F.pad(forward, [0, 0, 0, 1])
249 | backward = F.pad(backward, [0, 0, 1, 0])
250 | return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2)
251 |
252 | def _sidechains(self, X):
253 | n, origin, c = X[:, 0], X[:, 1], X[:, 2]
254 | c, n = _normalize(c - origin), _normalize(n - origin)
255 | bisector = _normalize(c + n)
256 | perp = _normalize(torch.cross(c, n))
257 | vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3)
258 | return vec
259 |
--------------------------------------------------------------------------------
/src/models/msa_transformer/data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import itertools
7 | import os
8 | from typing import Sequence, Tuple, List, Union
9 | import pickle
10 | import re
11 | import shutil
12 | import torch
13 | from pathlib import Path
14 | from constants import proteinseq_toks
15 |
16 | RawMSA = Sequence[Tuple[str, str]]
17 |
18 |
19 | class FastaBatchedDataset(object):
20 | def __init__(self, sequence_labels, sequence_strs):
21 | self.sequence_labels = list(sequence_labels)
22 | self.sequence_strs = list(sequence_strs)
23 |
24 | @classmethod
25 | def from_file(cls, fasta_file):
26 | sequence_labels, sequence_strs = [], []
27 | cur_seq_label = None
28 | buf = []
29 |
30 | def _flush_current_seq():
31 | nonlocal cur_seq_label, buf
32 | if cur_seq_label is None:
33 | return
34 | sequence_labels.append(cur_seq_label)
35 | sequence_strs.append("".join(buf))
36 | cur_seq_label = None
37 | buf = []
38 |
39 | with open(fasta_file, "r") as infile:
40 | for line_idx, line in enumerate(infile):
41 | if line.startswith(">"): # label line
42 | _flush_current_seq()
43 | line = line[1:].strip()
44 | if len(line) > 0:
45 | cur_seq_label = line
46 | else:
47 | cur_seq_label = f"seqnum{line_idx:09d}"
48 | else: # sequence line
49 | buf.append(line.strip())
50 |
51 | _flush_current_seq()
52 |
53 | assert len(set(sequence_labels)) == len(
54 | sequence_labels
55 | ), "Found duplicate sequence labels"
56 |
57 | return cls(sequence_labels, sequence_strs)
58 |
59 | def __len__(self):
60 | return len(self.sequence_labels)
61 |
62 | def __getitem__(self, idx):
63 | return self.sequence_labels[idx], self.sequence_strs[idx]
64 |
65 | def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0):
66 | sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)]
67 | sizes.sort()
68 | batches = []
69 | buf = []
70 | max_len = 0
71 |
72 | def _flush_current_buf():
73 | nonlocal max_len, buf
74 | if len(buf) == 0:
75 | return
76 | batches.append(buf)
77 | buf = []
78 | max_len = 0
79 |
80 | for sz, i in sizes:
81 | sz += extra_toks_per_seq
82 | if max(sz, max_len) * (len(buf) + 1) > toks_per_batch:
83 | _flush_current_buf()
84 | max_len = max(max_len, sz)
85 | buf.append(i)
86 |
87 | _flush_current_buf()
88 | return batches
89 |
90 |
91 | class Alphabet(object):
92 | def __init__(
93 | self,
94 | standard_toks: Sequence[str],
95 | prepend_toks: Sequence[str] = ("", "", "", ""),
96 | append_toks: Sequence[str] = ("", "", ""),
97 | prepend_bos: bool = True,
98 | append_eos: bool = False,
99 | use_msa: bool = False,
100 | ):
101 | self.standard_toks = list(standard_toks)
102 | self.prepend_toks = list(prepend_toks)
103 | self.append_toks = list(append_toks)
104 | self.prepend_bos = prepend_bos
105 | self.append_eos = append_eos
106 | self.use_msa = use_msa
107 |
108 | self.all_toks = list(self.prepend_toks)
109 | self.all_toks.extend(self.standard_toks)
110 | for i in range((8 - (len(self.all_toks) % 8)) % 8):
111 | self.all_toks.append(f"")
112 | self.all_toks.extend(self.append_toks)
113 |
114 | self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
115 |
116 | self.unk_idx = self.tok_to_idx[""]
117 | self.padding_idx = self.get_idx("")
118 | self.cls_idx = self.get_idx("")
119 | self.mask_idx = self.get_idx("")
120 | self.eos_idx = self.get_idx("")
121 | self.all_special_tokens = ['', '', '', '', '']
122 | self.unique_no_split_tokens = self.all_toks
123 |
124 | def __len__(self):
125 | return len(self.all_toks)
126 |
127 | def get_idx(self, tok):
128 | return self.tok_to_idx.get(tok, self.unk_idx)
129 |
130 | def get_tok(self, ind):
131 | return self.all_toks[ind]
132 |
133 | def to_dict(self):
134 | return self.tok_to_idx.copy()
135 |
136 | def get_batch_converter(self, truncation_seq_length: int = None):
137 | if self.use_msa:
138 | return MSABatchConverter(self, truncation_seq_length)
139 | else:
140 | return BatchConverter(self, truncation_seq_length)
141 |
142 | @classmethod
143 | def from_architecture(cls, name: str) -> "Alphabet":
144 | if name in ("ESM-1", "protein_bert_base"):
145 | standard_toks = proteinseq_toks["toks"]
146 | prepend_toks: Tuple[str, ...] = ("", "", "", "")
147 | append_toks: Tuple[str, ...] = ("", "", "")
148 | prepend_bos = True
149 | append_eos = False
150 | use_msa = False
151 | elif name in ("ESM-1b", "roberta_large"):
152 | standard_toks = proteinseq_toks["toks"]
153 | prepend_toks = ("", "", "", "")
154 | append_toks = ("",)
155 | prepend_bos = True
156 | append_eos = True
157 | use_msa = False
158 | elif name in ("MSA Transformer", "msa_transformer"):
159 | standard_toks = proteinseq_toks["toks"]
160 | prepend_toks = ("", "", "", "")
161 | append_toks = ("",)
162 | prepend_bos = True
163 | append_eos = False
164 | use_msa = True
165 | elif "invariant_gvp" in name.lower():
166 | standard_toks = proteinseq_toks["toks"]
167 | prepend_toks = ("", "", "", "")
168 | append_toks = ("", "", "")
169 | prepend_bos = True
170 | append_eos = False
171 | use_msa = False
172 | else:
173 | raise ValueError("Unknown architecture selected")
174 | return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa)
175 |
176 | def _tokenize(self, text) -> str:
177 | return text.split()
178 |
179 | def tokenize(self, text, **kwargs) -> List[str]:
180 | """
181 | Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py
182 | Converts a string in a sequence of tokens, using the tokenizer.
183 |
184 | Args:
185 | text (:obj:`str`):
186 | The sequence to be encoded.
187 |
188 | Returns:
189 | :obj:`List[str]`: The list of tokens.
190 | """
191 |
192 | def split_on_token(tok, text):
193 | result = []
194 | split_text = text.split(tok)
195 | for i, sub_text in enumerate(split_text):
196 | # AddedToken can control whitespace stripping around them.
197 | # We use them for GPT2 and Roberta to have different behavior depending on the special token
198 | # Cf. https://github.com/huggingface/transformers/pull/2778
199 | # and https://github.com/huggingface/transformers/issues/3788
200 | # We strip left and right by default
201 | if i < len(split_text) - 1:
202 | sub_text = sub_text.rstrip()
203 | if i > 0:
204 | sub_text = sub_text.lstrip()
205 |
206 | if i == 0 and not sub_text:
207 | result.append(tok)
208 | elif i == len(split_text) - 1:
209 | if sub_text:
210 | result.append(sub_text)
211 | else:
212 | pass
213 | else:
214 | if sub_text:
215 | result.append(sub_text)
216 | result.append(tok)
217 | return result
218 |
219 | def split_on_tokens(tok_list, text):
220 | if not text.strip():
221 | return []
222 |
223 | tokenized_text = []
224 | text_list = [text]
225 | for tok in tok_list:
226 | tokenized_text = []
227 | for sub_text in text_list:
228 | if sub_text not in self.unique_no_split_tokens:
229 | tokenized_text.extend(split_on_token(tok, sub_text))
230 | else:
231 | tokenized_text.append(sub_text)
232 | text_list = tokenized_text
233 |
234 | return list(
235 | itertools.chain.from_iterable(
236 | (
237 | self._tokenize(token)
238 | if token not in self.unique_no_split_tokens
239 | else [token]
240 | for token in tokenized_text
241 | )
242 | )
243 | )
244 |
245 | no_split_token = self.unique_no_split_tokens
246 | tokenized_text = split_on_tokens(no_split_token, text)
247 | return tokenized_text
248 |
249 | def encode(self, text):
250 | return [self.tok_to_idx[tok] for tok in self.tokenize(text)]
251 |
252 |
253 | class BatchConverter(object):
254 | """Callable to convert an unprocessed (labels + strings) batch to a
255 | processed (labels + tensor) batch.
256 | """
257 |
258 | def __init__(self, alphabet, truncation_seq_length: int = None):
259 | self.alphabet = alphabet
260 | self.truncation_seq_length = truncation_seq_length
261 |
262 | def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
263 | # RoBERTa uses an eos token, while ESM-1 does not.
264 | batch_size = len(raw_batch)
265 | batch_labels, seq_str_list = zip(*raw_batch)
266 | seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list]
267 | if self.truncation_seq_length:
268 | seq_encoded_list = [seq_str[:self.truncation_seq_length] for seq_str in seq_encoded_list]
269 | max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
270 | tokens = torch.empty(
271 | (
272 | batch_size,
273 | max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
274 | ),
275 | dtype=torch.int64,
276 | )
277 | tokens.fill_(self.alphabet.padding_idx)
278 | labels = []
279 | strs = []
280 |
281 | for i, (label, seq_str, seq_encoded) in enumerate(
282 | zip(batch_labels, seq_str_list, seq_encoded_list)
283 | ):
284 | labels.append(label)
285 | strs.append(seq_str)
286 | if self.alphabet.prepend_bos:
287 | tokens[i, 0] = self.alphabet.cls_idx
288 | seq = torch.tensor(seq_encoded, dtype=torch.int64)
289 | tokens[
290 | i,
291 | int(self.alphabet.prepend_bos) : len(seq_encoded)
292 | + int(self.alphabet.prepend_bos),
293 | ] = seq
294 | if self.alphabet.append_eos:
295 | tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
296 |
297 | return labels, strs, tokens
298 |
299 |
300 | class MSABatchConverter(BatchConverter):
301 | def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]):
302 | if isinstance(inputs[0][0], str):
303 | # Input is a single MSA
304 | raw_batch: Sequence[RawMSA] = [inputs] # type: ignore
305 | else:
306 | raw_batch = inputs # type: ignore
307 |
308 | batch_size = len(raw_batch)
309 | max_alignments = max(len(msa) for msa in raw_batch)
310 | max_seqlen = max(len(msa[0][1]) for msa in raw_batch)
311 |
312 | tokens = torch.empty(
313 | (
314 | batch_size,
315 | max_alignments,
316 | max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
317 | ),
318 | dtype=torch.int64,
319 | )
320 | tokens.fill_(self.alphabet.padding_idx)
321 | labels = []
322 | strs = []
323 |
324 | for i, msa in enumerate(raw_batch):
325 | msa_seqlens = set(len(seq) for _, seq in msa)
326 | if not len(msa_seqlens) == 1:
327 | raise RuntimeError(
328 | "Received unaligned sequences for input to MSA, all sequence "
329 | "lengths must be equal."
330 | )
331 | msa_labels, msa_strs, msa_tokens = super().__call__(msa)
332 | labels.append(msa_labels)
333 | strs.append(msa_strs)
334 | tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens
335 |
336 | return labels, strs, tokens
337 |
338 |
339 | def read_fasta(
340 | path,
341 | keep_gaps=True,
342 | keep_insertions=True,
343 | to_upper=False,
344 | ):
345 | with open(path, "r") as f:
346 | for result in read_alignment_lines(
347 | f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper
348 | ):
349 | yield result
350 |
351 |
352 | def read_alignment_lines(
353 | lines,
354 | keep_gaps=True,
355 | keep_insertions=True,
356 | to_upper=False,
357 | ):
358 | seq = desc = None
359 |
360 | def parse(s):
361 | if not keep_gaps:
362 | s = re.sub("-", "", s)
363 | if not keep_insertions:
364 | s = re.sub("[a-z]", "", s)
365 | return s.upper() if to_upper else s
366 |
367 | for line in lines:
368 | # Line may be empty if seq % file_line_width == 0
369 | if len(line) > 0 and line[0] == ">":
370 | if seq is not None:
371 | yield desc, parse(seq)
372 | desc = line.strip().lstrip(">")
373 | seq = ""
374 | else:
375 | assert isinstance(seq, str)
376 | seq += line.strip()
377 | assert isinstance(seq, str) and isinstance(desc, str)
378 | yield desc, parse(seq)
379 |
--------------------------------------------------------------------------------
/src/models/msa_transformer/modules.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import math
7 | from typing import Tuple, Optional
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from .multihead_attention import MultiheadAttention # noqa
14 | from .axial_attention import ColumnSelfAttention, RowSelfAttention
15 |
16 | def gelu(x):
17 | """Implementation of the gelu activation function.
18 |
19 | For information: OpenAI GPT's gelu is slightly different
20 | (and gives slightly different results):
21 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
22 | """
23 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
24 |
25 |
26 | def symmetrize(x):
27 | "Make layer symmetric in final two dimensions, used for contact prediction."
28 | return x + x.transpose(-1, -2)
29 |
30 |
31 | def apc(x):
32 | "Perform average product correct, used for contact prediction."
33 | a1 = x.sum(-1, keepdims=True)
34 | a2 = x.sum(-2, keepdims=True)
35 | a12 = x.sum((-1, -2), keepdims=True)
36 |
37 | avg = a1 * a2
38 | avg.div_(a12) # in-place to reduce memory
39 | normalized = x - avg
40 | return normalized
41 |
42 |
43 | class ESM1LayerNorm(nn.Module):
44 | def __init__(self, hidden_size, eps=1e-12, affine=True):
45 | """Construct a layernorm layer in the TF style (eps inside the sqrt)."""
46 | super().__init__()
47 | self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size)
48 | self.eps = eps
49 | self.affine = bool(affine)
50 | if self.affine:
51 | self.weight = nn.Parameter(torch.ones(hidden_size))
52 | self.bias = nn.Parameter(torch.zeros(hidden_size))
53 | else:
54 | self.weight, self.bias = None, None
55 |
56 | def forward(self, x):
57 | dims = tuple(-(i + 1) for i in range(len(self.hidden_size)))
58 | means = x.mean(dims, keepdim=True)
59 | x_zeromean = x - means
60 | variances = x_zeromean.pow(2).mean(dims, keepdim=True)
61 | x = x_zeromean / torch.sqrt(variances + self.eps)
62 | if self.affine:
63 | x = (self.weight * x) + self.bias
64 | return x
65 |
66 |
67 | try:
68 | from apex.normalization import FusedLayerNorm as _FusedLayerNorm
69 |
70 | class ESM1bLayerNorm(_FusedLayerNorm):
71 | @torch.jit.unused
72 | def forward(self, x):
73 | if not x.is_cuda:
74 | return super().forward(x)
75 | else:
76 | with torch.cuda.device(x.device):
77 | return super().forward(x)
78 |
79 | except ImportError:
80 | from torch.nn import LayerNorm as ESM1bLayerNorm
81 |
82 |
83 | class TransformerLayer(nn.Module):
84 | """Transformer layer block."""
85 |
86 | def __init__(
87 | self,
88 | embed_dim,
89 | ffn_embed_dim,
90 | attention_heads,
91 | add_bias_kv=True,
92 | use_esm1b_layer_norm=False,
93 | use_rotary_embeddings: bool = False,
94 | ):
95 | super().__init__()
96 | self.embed_dim = embed_dim
97 | self.ffn_embed_dim = ffn_embed_dim
98 | self.attention_heads = attention_heads
99 | self.use_rotary_embeddings = use_rotary_embeddings
100 | self._init_submodules(add_bias_kv, use_esm1b_layer_norm)
101 |
102 | def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm):
103 | BertLayerNorm = ESM1bLayerNorm if use_esm1b_layer_norm else ESM1LayerNorm
104 |
105 | self.self_attn = MultiheadAttention(
106 | self.embed_dim,
107 | self.attention_heads,
108 | add_bias_kv=add_bias_kv,
109 | add_zero_attn=False,
110 | use_rotary_embeddings=self.use_rotary_embeddings,
111 | )
112 | self.self_attn_layer_norm = BertLayerNorm(self.embed_dim)
113 |
114 | self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
115 | self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
116 |
117 | self.final_layer_norm = BertLayerNorm(self.embed_dim)
118 |
119 | def forward(
120 | self, x, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False
121 | ):
122 | residual = x
123 | x = self.self_attn_layer_norm(x)
124 | x, attn = self.self_attn(
125 | query=x,
126 | key=x,
127 | value=x,
128 | key_padding_mask=self_attn_padding_mask,
129 | need_weights=True,
130 | need_head_weights=need_head_weights,
131 | attn_mask=self_attn_mask,
132 | )
133 | x = residual + x
134 |
135 | residual = x
136 | x = self.final_layer_norm(x)
137 | x = gelu(self.fc1(x))
138 | x = self.fc2(x)
139 | x = residual + x
140 |
141 | return x, attn
142 |
143 |
144 | class AxialTransformerLayer(nn.Module):
145 | """Implements an Axial MSA Transformer block."""
146 |
147 | def __init__(
148 | self,
149 | embedding_dim: int = 768,
150 | ffn_embedding_dim: int = 3072,
151 | num_attention_heads: int = 8,
152 | dropout: float = 0.1,
153 | attention_dropout: float = 0.1,
154 | activation_dropout: float = 0.1,
155 | max_tokens_per_msa: int = 2**14,
156 | ) -> None:
157 | super().__init__()
158 |
159 | # Initialize parameters
160 | self.embedding_dim = embedding_dim
161 | self.dropout_prob = dropout
162 |
163 | row_self_attention = RowSelfAttention(
164 | embedding_dim,
165 | num_attention_heads,
166 | dropout=dropout,
167 | max_tokens_per_msa=max_tokens_per_msa,
168 | )
169 |
170 | column_self_attention = ColumnSelfAttention(
171 | embedding_dim,
172 | num_attention_heads,
173 | dropout=dropout,
174 | max_tokens_per_msa=max_tokens_per_msa,
175 | )
176 |
177 | feed_forward_layer = FeedForwardNetwork(
178 | embedding_dim,
179 | ffn_embedding_dim,
180 | activation_dropout=activation_dropout,
181 | max_tokens_per_msa=max_tokens_per_msa,
182 | )
183 |
184 | self.row_self_attention = self.build_residual(row_self_attention)
185 | self.column_self_attention = self.build_residual(column_self_attention)
186 | self.feed_forward_layer = self.build_residual(feed_forward_layer)
187 |
188 | def build_residual(self, layer: nn.Module):
189 | return NormalizedResidualBlock(
190 | layer,
191 | self.embedding_dim,
192 | self.dropout_prob,
193 | )
194 |
195 | def forward(
196 | self,
197 | x: torch.Tensor,
198 | self_attn_mask: Optional[torch.Tensor] = None,
199 | self_row_attn_mask: Optional[torch.Tensor] = None,
200 | self_attn_padding_mask: Optional[torch.Tensor] = None,
201 | need_head_weights: bool = False,
202 | ):
203 | """
204 | LayerNorm is applied either before or after the self-attention/ffn
205 | modules similar to the original Transformer implementation.
206 | """
207 | x, row_attn = self.row_self_attention(
208 | x,
209 | self_attn_mask=self_row_attn_mask,
210 | self_attn_padding_mask=self_attn_padding_mask,
211 | )
212 | x, column_attn = self.column_self_attention(
213 | x,
214 | self_attn_mask=self_attn_mask,
215 | self_attn_padding_mask=self_attn_padding_mask,
216 | )
217 | x = self.feed_forward_layer(x)
218 | if need_head_weights:
219 | return x, column_attn, row_attn
220 | else:
221 | return x
222 |
223 | class LearnedPositionalEmbedding(nn.Embedding):
224 | """
225 | This module learns positional embeddings up to a fixed maximum size.
226 | Padding ids are ignored by either offsetting based on padding_idx
227 | or by setting padding_idx to None and ensuring that the appropriate
228 | position ids are passed to the forward function.
229 | """
230 |
231 | def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
232 | if padding_idx is not None:
233 | num_embeddings_ = num_embeddings + padding_idx + 1
234 | else:
235 | num_embeddings_ = num_embeddings
236 | super().__init__(num_embeddings_, embedding_dim, padding_idx)
237 | self.max_positions = num_embeddings
238 |
239 | def forward(self, input: torch.Tensor):
240 | """Input is expected to be of size [bsz x seqlen]."""
241 | if input.size(1) > self.max_positions:
242 | raise ValueError(
243 | f"Sequence length {input.size(1)} above maximum "
244 | f" sequence length of {self.max_positions}"
245 | )
246 | mask = input.ne(self.padding_idx).int()
247 | positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx
248 | return F.embedding(
249 | positions,
250 | self.weight,
251 | self.padding_idx,
252 | self.max_norm,
253 | self.norm_type,
254 | self.scale_grad_by_freq,
255 | self.sparse,
256 | )
257 |
258 |
259 | class SinusoidalPositionalEmbedding(nn.Module):
260 | def __init__(self, embed_dim, padding_idx, learned=False):
261 | super().__init__()
262 | self.embed_dim = embed_dim
263 | self.padding_idx = padding_idx
264 | self.register_buffer("_float_tensor", torch.FloatTensor(1))
265 | self.weights = None
266 |
267 | def forward(self, x):
268 | bsz, seq_len = x.shape
269 | max_pos = self.padding_idx + 1 + seq_len
270 | if self.weights is None or max_pos > self.weights.size(0):
271 | self.weights = self.get_embedding(max_pos)
272 | self.weights = self.weights.type_as(self._float_tensor)
273 |
274 | positions = self.make_positions(x)
275 | return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
276 |
277 | def make_positions(self, x):
278 | mask = x.ne(self.padding_idx)
279 | range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1
280 | positions = range_buf.expand_as(x)
281 | return positions * mask.long() + self.padding_idx * (1 - mask.long())
282 |
283 | def get_embedding(self, num_embeddings):
284 | half_dim = self.embed_dim // 2
285 | emb = math.log(10000) / (half_dim - 1)
286 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
287 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
288 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
289 | if self.embed_dim % 2 == 1:
290 | # zero pad
291 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
292 | if self.padding_idx is not None:
293 | emb[self.padding_idx, :] = 0
294 | return emb
295 |
296 |
297 | class RobertaLMHead(nn.Module):
298 | """Head for masked language modeling."""
299 |
300 | def __init__(self, embed_dim, output_dim, weight):
301 | super().__init__()
302 | self.dense = nn.Linear(embed_dim, embed_dim)
303 | self.layer_norm = ESM1bLayerNorm(embed_dim)
304 | self.weight = weight
305 | self.bias = nn.Parameter(torch.zeros(output_dim))
306 |
307 | def forward(self, features):
308 | x = self.dense(features)
309 | x = gelu(x)
310 | x = self.layer_norm(x)
311 | # project back to size of vocabulary with bias
312 | x = F.linear(x, self.weight) + self.bias
313 | return x
314 |
315 |
316 | class ContactPredictionHead(nn.Module):
317 | """Performs symmetrization, apc, and computes a logistic regression on the output features"""
318 |
319 | def __init__(
320 | self,
321 | in_features: int,
322 | prepend_bos: bool,
323 | append_eos: bool,
324 | bias=True,
325 | eos_idx: Optional[int] = None,
326 | ):
327 | super().__init__()
328 | self.in_features = in_features
329 | self.prepend_bos = prepend_bos
330 | self.append_eos = append_eos
331 | if append_eos and eos_idx is None:
332 | raise ValueError("Using an alphabet with eos token, but no eos token was passed in.")
333 | self.eos_idx = eos_idx
334 | self.regression = nn.Linear(in_features, 1, bias)
335 | self.activation = nn.Sigmoid()
336 |
337 | def forward(self, tokens, attentions):
338 | # remove eos token attentions
339 | if self.append_eos:
340 | eos_mask = tokens.ne(self.eos_idx).to(attentions)
341 | eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
342 | attentions = attentions * eos_mask[:, None, None, :, :]
343 | attentions = attentions[..., :-1, :-1]
344 | # remove cls token attentions
345 | if self.prepend_bos:
346 | attentions = attentions[..., 1:, 1:]
347 | batch_size, layers, heads, seqlen, _ = attentions.size()
348 | attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
349 |
350 | # features: B x C x T x T
351 | attentions = attentions.to(
352 | self.regression.weight.device
353 | ) # attentions always float32, may need to convert to float16
354 | attentions = apc(symmetrize(attentions))
355 | attentions = attentions.permute(0, 2, 3, 1)
356 | return self.activation(self.regression(attentions).squeeze(3))
357 |
358 |
359 | class NormalizedResidualBlock(nn.Module):
360 | def __init__(
361 | self,
362 | layer: nn.Module,
363 | embedding_dim: int,
364 | dropout: float = 0.1,
365 | ):
366 | super().__init__()
367 | self.embedding_dim = embedding_dim
368 |
369 | self.layer = layer
370 | self.dropout_module = nn.Dropout(
371 | dropout,
372 | )
373 | self.layer_norm = ESM1bLayerNorm(self.embedding_dim)
374 |
375 | def forward(self, x, *args, **kwargs):
376 | residual = x
377 | x = self.layer_norm(x)
378 | outputs = self.layer(x, *args, **kwargs)
379 | if isinstance(outputs, tuple):
380 | x, *out = outputs
381 | else:
382 | x = outputs
383 | out = None
384 |
385 | x = self.dropout_module(x)
386 | x = residual + x
387 |
388 | if out is not None:
389 | return (x,) + tuple(out)
390 | else:
391 | return x
392 |
393 |
394 | class FeedForwardNetwork(nn.Module):
395 | def __init__(
396 | self,
397 | embedding_dim: int,
398 | ffn_embedding_dim: int,
399 | activation_dropout: float = 0.1,
400 | max_tokens_per_msa: int = 2**14,
401 | ):
402 | super().__init__()
403 | self.embedding_dim = embedding_dim
404 | self.ffn_embedding_dim = ffn_embedding_dim
405 | self.max_tokens_per_msa = max_tokens_per_msa
406 | self.activation_fn = nn.GELU()
407 | self.activation_dropout_module = nn.Dropout(
408 | activation_dropout,
409 | )
410 | self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
411 | self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)
412 |
413 | def forward(self, x):
414 | x = self.activation_fn(self.fc1(x))
415 | x = self.activation_dropout_module(x)
416 | x = self.fc2(x)
417 | return x
418 |
--------------------------------------------------------------------------------
/src/models/gvp/__init__.py:
--------------------------------------------------------------------------------
1 | import torch, functools
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch_geometric.nn import MessagePassing
5 | from torch_scatter import scatter_add
6 | import torch.utils.checkpoint as checkpoint
7 |
8 | def tuple_sum(*args):
9 | '''
10 | Sums any number of tuples (s, V) elementwise.
11 | '''
12 | return tuple(map(sum, zip(*args)))
13 |
14 | def tuple_cat(*args, dim=-1):
15 | '''
16 | Concatenates any number of tuples (s, V) elementwise.
17 |
18 | :param dim: dimension along which to concatenate when viewed
19 | as the `dim` index for the scalar-channel tensors.
20 | This means that `dim=-1` will be applied as
21 | `dim=-2` for the vector-channel tensors.
22 | '''
23 | dim %= len(args[0][0].shape)
24 | s_args, v_args = list(zip(*args))
25 | return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim)
26 |
27 | def tuple_index(x, idx):
28 | '''
29 | Indexes into a tuple (s, V) along the first dimension.
30 |
31 | :param idx: any object which can be used to index into a `torch.Tensor`
32 | '''
33 | return x[0][idx], x[1][idx]
34 |
35 | def randn(n, dims, device="cpu"):
36 | '''
37 | Returns random tuples (s, V) drawn elementwise from a normal distribution.
38 |
39 | :param n: number of data points
40 | :param dims: tuple of dimensions (n_scalar, n_vector)
41 |
42 | :return: (s, V) with s.shape = (n, n_scalar) and
43 | V.shape = (n, n_vector, 3)
44 | '''
45 | return torch.randn(n, dims[0], device=device), \
46 | torch.randn(n, dims[1], 3, device=device)
47 |
48 | def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-4, sqrt=True):
49 | '''
50 | L2 norm of tensor clamped above a minimum value `eps`.
51 |
52 | :param sqrt: if `False`, returns the square of the L2 norm
53 | '''
54 | out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
55 | return torch.sqrt(out) if sqrt else out
56 |
57 | def _split(x, nv):
58 | '''
59 | Splits a merged representation of (s, V) back into a tuple.
60 | Should be used only with `_merge(s, V)` and only if the tuple
61 | representation cannot be used.
62 |
63 | :param x: the `torch.Tensor` returned from `_merge`
64 | :param nv: the number of vector channels in the input to `_merge`
65 | '''
66 | v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3))
67 | s = x[..., :-3*nv]
68 | return s, v
69 |
70 | def _merge(s, v):
71 | '''
72 | Merges a tuple (s, V) into a single `torch.Tensor`, where the
73 | vector channels are flattened and appended to the scalar channels.
74 | Should be used only if the tuple representation cannot be used.
75 | Use `_split(x, nv)` to reverse.
76 | '''
77 | v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],))
78 | return torch.cat([s, v], -1)
79 |
80 | class GVP(nn.Module):
81 | '''
82 | Geometric Vector Perceptron. See manuscript and README.md
83 | for more details.
84 |
85 | :param in_dims: tuple (n_scalar, n_vector)
86 | :param out_dims: tuple (n_scalar, n_vector)
87 | :param h_dim: intermediate number of vector channels, optional
88 | :param activations: tuple of functions (scalar_act, vector_act)
89 | :param vector_gate: whether to use vector gating.
90 | (vector_act will be used as sigma^+ in vector gating if `True`)
91 | '''
92 | def __init__(self, in_dims, out_dims, h_dim=None,
93 | activations=(F.relu, torch.sigmoid), vector_gate=False):
94 | super(GVP, self).__init__()
95 |
96 | #self.dummy = DummyLayer()
97 |
98 | self.si, self.vi = in_dims
99 | self.so, self.vo = out_dims
100 | self.vector_gate = vector_gate
101 | if self.vi:
102 | self.h_dim = h_dim or max(self.vi, self.vo)
103 | self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
104 | self.ws = nn.Linear(self.h_dim + self.si, self.so)
105 | if self.vo:
106 | self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
107 | if self.vector_gate: self.wsv = nn.Linear(self.so, self.vo)
108 | else:
109 | self.ws = nn.Linear(self.si, self.so)
110 |
111 | self.scalar_act, self.vector_act = activations
112 | self.dummy_param = nn.Parameter(torch.empty(0))
113 |
114 | def forward(self, x):
115 | '''
116 | :param x: tuple (s, V) of `torch.Tensor`,
117 | or (if vectors_in is 0), a single `torch.Tensor`
118 | :return: tuple (s, V) of `torch.Tensor`,
119 | or (if vectors_out is 0), a single `torch.Tensor`
120 | '''
121 | if self.vi:
122 | s, v = x
123 | v = torch.transpose(v, -1, -2)
124 | vh = self.wh(v)
125 | vn = _norm_no_nan(vh, axis=-2)
126 | s = self.ws(torch.cat([s, vn], -1))
127 | if self.vo:
128 | v = self.wv(vh)
129 | v = torch.transpose(v, -1, -2)
130 | if self.vector_gate:
131 | if self.vector_act:
132 | gate = self.wsv(self.vector_act(s))
133 | else:
134 | gate = self.wsv(s)
135 | v = v * torch.sigmoid(gate).unsqueeze(-1)
136 | elif self.vector_act:
137 | v = v * self.vector_act(
138 | _norm_no_nan(v, axis=-1, keepdims=True))
139 | else:
140 | s = self.ws(x)
141 | if self.vo:
142 | v = torch.zeros(s.shape[0], self.vo, 3,
143 | device=self.dummy_param.device)
144 | if self.scalar_act:
145 | s = self.scalar_act(s)
146 |
147 | return (s, v) if self.vo else s
148 |
149 | class _VDropout(nn.Module):
150 | '''
151 | Vector channel dropout where the elements of each
152 | vector channel are dropped together.
153 | '''
154 | def __init__(self, drop_rate):
155 | super(_VDropout, self).__init__()
156 | self.drop_rate = drop_rate
157 | self.dummy_param = nn.Parameter(torch.empty(0))
158 |
159 | def forward(self, x):
160 | '''
161 | :param x: `torch.Tensor` corresponding to vector channels
162 | '''
163 | device = self.dummy_param.device
164 | if not self.training:
165 | return x
166 | mask = torch.bernoulli(
167 | (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device)
168 | ).unsqueeze(-1)
169 | x = mask * x / (1 - self.drop_rate)
170 | return x
171 |
172 |
173 | class Dropout(nn.Module):
174 | '''
175 | Combined dropout for tuples (s, V).
176 | Takes tuples (s, V) as input and as output.
177 | '''
178 | def __init__(self, drop_rate):
179 | super(Dropout, self).__init__()
180 | self.sdropout = nn.Dropout(drop_rate)
181 | self.vdropout = _VDropout(drop_rate)
182 |
183 | def forward(self, x):
184 | '''
185 | :param x: tuple (s, V) of `torch.Tensor`,
186 | or single `torch.Tensor`
187 | (will be assumed to be scalar channels)
188 | '''
189 | if type(x) is torch.Tensor:
190 | return self.sdropout(x)
191 | s, v = x
192 | return self.sdropout(s), self.vdropout(v)
193 |
194 |
195 | class LayerNorm(nn.Module):
196 | '''
197 | Combined LayerNorm for tuples (s, V).
198 | Takes tuples (s, V) as input and as output.
199 | '''
200 | def __init__(self, dims):
201 | super(LayerNorm, self).__init__()
202 | self.s, self.v = dims
203 | self.scalar_norm = nn.LayerNorm(self.s)
204 |
205 | def forward(self, x):
206 | '''
207 | :param x: tuple (s, V) of `torch.Tensor`,
208 | or single `torch.Tensor`
209 | (will be assumed to be scalar channels)
210 | '''
211 | if not self.v:
212 | return self.scalar_norm(x)
213 | s, v = x
214 | vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False)
215 | vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True))
216 | return self.scalar_norm(s), v / vn
217 |
218 | class GVPConv(MessagePassing):
219 | '''
220 | Graph convolution / message passing with Geometric Vector Perceptrons.
221 | Takes in a graph with node and edge embeddings,
222 | and returns new node embeddings.
223 |
224 | This does NOT do residual updates and pointwise feedforward layers
225 | ---see `GVPConvLayer`.
226 |
227 | :param in_dims: input node embedding dimensions (n_scalar, n_vector)
228 | :param out_dims: output node embedding dimensions (n_scalar, n_vector)
229 | :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
230 | :param n_layers: number of GVPs in the message function
231 | :param module_list: preconstructed message function, overrides n_layers
232 | :param aggr: should be "add" if some incoming edges are masked, as in
233 | a masked autoregressive decoder architecture, otherwise "mean"
234 | :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
235 | :param vector_gate: whether to use vector gating.
236 | (vector_act will be used as sigma^+ in vector gating if `True`)
237 | '''
238 | def __init__(self, in_dims, out_dims, edge_dims,
239 | n_layers=3, module_list=None, aggr="mean",
240 | activations=(F.leaky_relu, torch.sigmoid), vector_gate=False):
241 | super(GVPConv, self).__init__(aggr=aggr)
242 | self.si, self.vi = in_dims
243 | self.so, self.vo = out_dims
244 | self.se, self.ve = edge_dims
245 |
246 | GVP_ = functools.partial(GVP,
247 | activations=activations, vector_gate=vector_gate)
248 |
249 | module_list = module_list or []
250 | if not module_list:
251 | if n_layers == 1:
252 | module_list.append(
253 | GVP_((2*self.si + self.se, 2*self.vi + self.ve),
254 | (self.so, self.vo), activations=(None, None)))
255 | else:
256 | module_list.append(
257 | GVP_((2*self.si + self.se, 2*self.vi + self.ve), out_dims)
258 | )
259 | for i in range(n_layers - 2):
260 | module_list.append(GVP_(out_dims, out_dims))
261 | module_list.append(GVP_(out_dims, out_dims,
262 | activations=(None, None)))
263 | self.message_func = nn.Sequential(*module_list)
264 |
265 | def forward(self, x, edge_index, edge_attr):
266 | '''
267 | :param x: tuple (s, V) of `torch.Tensor`
268 | :param edge_index: array of shape [2, n_edges]
269 | :param edge_attr: tuple (s, V) of `torch.Tensor`
270 | '''
271 | x_s, x_v = x
272 | message = self.propagate(edge_index,
273 | s=x_s, v=x_v.reshape(x_v.shape[0], 3*x_v.shape[1]),
274 | edge_attr=edge_attr)
275 | return _split(message, self.vo)
276 |
277 | def message(self, s_i, v_i, s_j, v_j, edge_attr):
278 | v_j = v_j.view(v_j.shape[0], v_j.shape[1]//3, 3)
279 | v_i = v_i.view(v_i.shape[0], v_i.shape[1]//3, 3)
280 | message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
281 | message = self.message_func(message)
282 | return _merge(*message)
283 |
284 |
285 | class GVPConvLayer(nn.Module):
286 | '''
287 | Full graph convolution / message passing layer with
288 | Geometric Vector Perceptrons. Residually updates node embeddings with
289 | aggregated incoming messages, applies a pointwise feedforward
290 | network to node embeddings, and returns updated node embeddings.
291 |
292 | To only compute the aggregated messages, see `GVPConv`.
293 |
294 | :param node_dims: node embedding dimensions (n_scalar, n_vector)
295 | :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
296 | :param n_message: number of GVPs to use in message function
297 | :param n_feedforward: number of GVPs to use in feedforward function
298 | :param drop_rate: drop probability in all dropout layers
299 | :param autoregressive: if `True`, this `GVPConvLayer` will be used
300 | with a different set of input node embeddings for messages
301 | where src >= dst
302 | :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
303 | :param vector_gate: whether to use vector gating.
304 | (vector_act will be used as sigma^+ in vector gating if `True`)
305 | '''
306 | def __init__(self, node_dims, edge_dims,
307 | n_message=3, n_feedforward=2, drop_rate=.1,
308 | autoregressive=False,
309 | activations=(F.leaky_relu, torch.sigmoid), vector_gate=True, aggr="mean"):
310 |
311 | super(GVPConvLayer, self).__init__()
312 | self.conv = GVPConv(node_dims, node_dims, edge_dims, n_message,
313 | aggr="add" if autoregressive else aggr,
314 | activations=activations, vector_gate=vector_gate)
315 | GVP_ = functools.partial(GVP,
316 | activations=activations, vector_gate=vector_gate)
317 | self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)])
318 | self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])
319 |
320 | ff_func = []
321 | if n_feedforward == 1:
322 | ff_func.append(GVP_(node_dims, node_dims, activations=(None, None)))
323 | else:
324 | hid_dims = 4*node_dims[0], 2*node_dims[1]
325 | ff_func.append(GVP_(node_dims, hid_dims))
326 | for i in range(n_feedforward-2):
327 | ff_func.append(GVP_(hid_dims, hid_dims))
328 | ff_func.append(GVP_(hid_dims, node_dims, activations=(None, None)))
329 | self.ff_func = nn.Sequential(*ff_func)
330 |
331 | def forward(self, x, edge_index, edge_attr,
332 | autoregressive_x=None, node_mask=None):
333 | '''
334 | :param x: tuple (s, V) of `torch.Tensor`
335 | :param edge_index: array of shape [2, n_edges]
336 | :param edge_attr: tuple (s, V) of `torch.Tensor`
337 | :param autoregressive_x: tuple (s, V) of `torch.Tensor`.
338 | If not `None`, will be used as src node embeddings
339 | for forming messages where src >= dst. The corrent node
340 | embeddings `x` will still be the base of the update and the
341 | pointwise feedforward.
342 | :param node_mask: array of type `bool` to index into the first
343 | dim of node embeddings (s, V). If not `None`, only
344 | these nodes will be updated.
345 | '''
346 |
347 | if autoregressive_x is not None:
348 | src, dst = edge_index
349 | mask = src < dst
350 | edge_index_forward = edge_index[:, mask]
351 | edge_index_backward = edge_index[:, ~mask]
352 | edge_attr_forward = tuple_index(edge_attr, mask)
353 | edge_attr_backward = tuple_index(edge_attr, ~mask)
354 |
355 | dh = tuple_sum(
356 | self.conv(x, edge_index_forward, edge_attr_forward),
357 | self.conv(autoregressive_x, edge_index_backward, edge_attr_backward)
358 | )
359 |
360 | count = scatter_add(torch.ones_like(dst), dst,
361 | dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1)
362 |
363 | dh = dh[0] / count, dh[1] / count.unsqueeze(-1)
364 |
365 | else:
366 | dh = self.conv(x, edge_index, edge_attr)
367 |
368 | if node_mask is not None:
369 | x_ = x
370 | x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)
371 |
372 | x = self.norm[0](tuple_sum(x, self.dropout[0](dh)))
373 |
374 | dh = self.ff_func(x)
375 | x = self.norm[1](tuple_sum(x, self.dropout[1](dh)))
376 |
377 | if node_mask is not None:
378 | x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
379 | x = x_
380 | return x
381 |
--------------------------------------------------------------------------------
/data/test/proteingym/all_models_substitutions_Spearman_Uniprot_level.csv:
--------------------------------------------------------------------------------
1 | UniProt_ID,Tranception_L_no_retrieval,Tranception_S_retrieval,Tranception_M_retrieval,Tranception_L_retrieval,EVE_single,EVE_ensemble,MSA_Transformer_single,MSA_Transformer_ensemble,ESM1v_single,ESM1v_ensemble,Wavenet,DeepSequence_single,DeepSequence_ensemble,Site_Independent,EVmutation,RITA_s,RITA_m,RITA_l,RITA_xl,RITA_ensemble,Progen2_small,Progen2_medium,Progen2_base,Progen2_large,Progen2_xlarge,Progen2_ensemble,Unirep,Unirep_evotune,GEMME,VESPA,VESPAl,ESM1b,ProtGPT2,TranceptEVE_L,Neff_L_category,Taxon
2 | A0A140D2T1_ZIKV,0.272,0.362,0.366,0.351,0.394,0.405,0.475,0.454,-0.048,0.015,0.216,0.131,0.103,0.383,0.354,0.361,0.309,0.317,0.305,0.346,0.329,0.342,0.328,0.312,0.293,0.346,-0.133,0.062,0.444,0.319,0.296,-0.001,0.005,0.379,medium,Virus
3 | A0A192B1T2_9HIV1,0.514,0.509,0.503,0.513,0.509,0.516,0.514,0.514,0.492,0.516,0.465,0.413,0.432,0.481,0.407,0.496,0.507,0.509,0.505,0.519,0.497,0.501,0.463,0.49,0.484,0.508,0.0,0.513,0.504,0.541,0.507,0.456,0.327,0.528,medium,Virus
4 | A0A1I9GEU1_NEIME,0.099,0.031,0.041,0.057,0.053,0.054,0.094,0.102,0.068,0.068,0.067,0.107,0.098,-0.011,0.044,-0.01,0.047,0.071,0.088,0.066,0.05,0.088,0.08,0.089,0.095,0.087,-0.024,0.084,0.05,0.046,0.036,0.04,0.03,0.075,medium,Prokaryote
5 | A0A2Z5U3Z0_9INFA,0.521,0.512,0.541,0.539,0.526,0.534,0.33,0.522,0.489,0.53,0.436,0.469,0.498,0.479,0.49,0.458,0.518,0.502,0.526,0.528,0.389,0.49,0.498,0.471,0.49,0.518,0.015,0.464,0.528,0.466,0.388,0.14,0.08,0.566,medium,Virus
6 | A4D664_9INFA,0.404,0.366,0.38,0.393,0.409,0.408,0.333,0.336,0.026,0.033,0.26,0.406,0.403,0.411,0.34,0.329,0.386,0.404,0.398,0.398,0.064,0.26,0.285,0.244,0.33,0.292,0.029,0.364,0.435,0.318,0.296,0.044,0.055,0.461,medium,Virus
7 | A4GRB6_PSEAI,0.629,0.557,0.588,0.652,0.623,0.641,0.695,0.67,0.647,0.668,0.567,0.657,0.664,0.317,0.492,0.409,0.533,0.564,0.624,0.576,0.526,0.622,0.639,0.638,0.705,0.679,0.344,0.529,0.707,0.742,0.673,0.68,0.242,0.679,high,Prokaryote
8 | A4_HUMAN,0.363,0.436,0.366,0.452,0.305,0.309,0.393,0.392,0.309,0.398,0.245,0.446,0.421,0.388,0.379,0.322,0.276,0.311,0.299,0.308,0.4,0.292,0.314,0.307,0.307,0.33,0.346,0.157,0.468,0.242,0.173,0.294,0.527,0.41,low,Human
9 | AACC1_PSEAI,0.417,0.427,0.418,0.466,0.487,0.491,0.496,0.484,0.483,0.489,0.391,0.331,0.414,0.282,0.474,0.206,0.232,0.276,0.306,0.26,0.273,0.428,0.407,0.413,0.438,0.424,0.196,0.191,0.473,0.507,0.455,0.413,0.021,0.495,high,Prokaryote
10 | ADRB2_HUMAN,0.506,0.528,0.541,0.533,0.497,0.517,0.435,0.5,0.526,0.539,0.521,0.506,0.514,0.331,0.423,0.509,0.524,0.517,0.493,0.528,0.524,0.531,0.538,0.533,0.502,0.54,0.463,0.5,0.546,0.5,0.414,0.534,0.298,0.537,medium,Human
11 | AMIE_PSEAE,0.405,0.585,0.479,0.438,0.477,0.464,0.592,0.569,0.613,0.662,0.506,0.496,0.506,0.246,0.443,0.555,0.525,0.532,0.542,0.57,0.566,0.561,0.534,0.559,0.508,0.584,0.093,0.436,0.632,0.583,0.528,0.57,0.265,0.481,high,Prokaryote
12 | B3VI55_LIPST,0.555,0.37,0.404,0.527,0.481,0.498,0.579,0.56,0.537,0.564,0.499,0.485,0.494,0.28,0.426,0.276,0.396,0.462,0.495,0.443,0.365,0.492,0.525,0.479,0.552,0.531,0.104,0.465,0.544,0.537,0.46,0.444,0.139,0.54,medium,Eukaryote
13 | BLAT_ECOLX,0.459,0.594,0.595,0.596,0.658,0.672,0.645,0.646,0.622,0.654,0.602,0.668,0.683,0.429,0.652,0.518,0.512,0.502,0.492,0.564,0.539,0.594,0.598,0.534,0.429,0.612,0.109,0.418,0.609,0.691,0.673,0.646,0.148,0.681,high,Prokaryote
14 | BRCA1_HUMAN,0.409,0.456,0.456,0.514,0.424,0.466,0.472,0.501,0.404,0.447,0.296,0.441,0.432,0.501,0.385,0.145,0.379,0.405,0.361,0.417,0.305,0.473,0.457,0.442,0.423,0.471,0.125,0.189,0.413,0.539,0.482,0.494,0.22,0.533,low,Human
15 | C6KNH7_9INFA,0.392,0.412,0.419,0.433,0.433,0.436,0.4,0.403,0.428,0.492,0.343,0.351,0.357,0.393,0.371,0.396,0.363,0.378,0.349,0.391,0.254,0.454,0.428,0.465,0.37,0.473,-0.025,0.451,0.48,0.457,0.36,0.06,0.11,0.452,medium,Virus
16 | CALM1_HUMAN,0.291,0.218,0.242,0.266,0.238,0.236,0.237,0.251,0.235,0.264,0.229,0.238,0.233,0.175,0.233,0.173,0.227,0.242,0.254,0.233,0.232,0.276,0.279,0.3,0.313,0.301,0.174,0.187,0.228,0.21,0.146,0.25,0.087,0.246,high,Human
17 | CAPSD_AAV2S,0.492,0.38,0.386,0.473,0.346,0.33,0.35,0.368,0.196,0.199,0.258,0.317,0.372,0.407,0.345,0.191,0.25,0.269,0.279,0.278,0.203,0.198,0.266,0.204,0.399,0.261,0.374,0.424,0.511,0.183,0.177,0.183,0.132,0.426,low,Virus
18 | CCDB_ECOLI,0.367,0.461,0.453,0.495,0.522,0.53,0.406,0.461,0.393,0.458,0.488,0.522,0.544,0.395,0.496,0.03,-0.0,-0.094,0.156,0.046,-0.123,0.074,-0.029,0.054,0.45,0.107,-0.02,0.24,0.485,0.548,0.532,0.457,0.118,0.535,high,Prokaryote
19 | CP2C9_HUMAN,0.57,0.642,0.638,0.631,0.606,0.622,0.58,0.594,0.615,0.644,0.632,0.61,0.628,0.516,0.581,0.57,0.57,0.603,0.566,0.621,0.594,0.586,0.591,0.582,0.577,0.618,0.556,0.585,0.619,0.588,0.528,0.516,0.216,0.641,high,Human
20 | DLG4_HUMAN,0.58,0.653,0.7,0.667,0.609,0.616,0.602,0.614,0.55,0.608,0.65,0.607,0.577,0.679,0.584,0.574,0.583,0.569,0.538,0.581,0.605,0.607,0.564,0.562,0.495,0.594,0.723,0.629,0.616,0.634,0.621,0.52,0.511,0.636,low,Human
21 | DLG4_RAT,0.307,0.47,0.478,0.443,0.526,0.539,0.488,0.496,0.557,0.588,0.444,0.487,0.491,0.486,0.442,0.378,0.389,0.373,0.371,0.4,0.407,0.42,0.407,0.367,0.395,0.425,0.49,0.437,0.509,0.556,0.537,0.469,0.246,0.54,low,Eukaryote
22 | DYR_ECOLI,0.347,0.407,0.415,0.423,0.476,0.474,0.487,0.49,0.416,0.436,0.485,0.469,0.472,0.384,0.484,0.199,0.36,0.266,0.31,0.316,0.342,0.436,0.476,0.441,0.418,0.462,-0.018,0.321,0.442,0.451,0.408,0.469,0.1,0.481,medium,Prokaryote
23 | ENV_HV1B9,0.404,0.392,0.396,0.407,0.388,0.377,0.38,0.375,0.415,0.389,0.388,0.238,0.367,0.369,0.397,0.38,0.358,0.408,0.419,0.4,0.263,0.394,0.401,0.391,0.374,0.375,0.056,0.353,0.37,0.355,0.322,0.359,0.253,0.392,medium,Virus
24 | ENV_HV1BR,0.358,0.359,0.366,0.362,0.339,0.345,0.343,0.345,0.32,0.336,0.337,0.322,0.323,0.338,0.303,0.35,0.358,0.371,0.364,0.374,0.345,0.358,0.358,0.354,0.362,0.378,-0.001,0.322,0.348,0.325,0.29,0.298,0.189,0.365,medium,Virus
25 | ESTA_BACSU,0.261,0.31,0.329,0.326,0.387,0.387,0.435,0.431,0.305,0.335,0.32,0.389,0.415,0.258,0.399,0.121,0.196,0.268,0.286,0.261,0.27,0.254,0.311,0.272,0.379,0.328,0.187,0.315,0.403,0.434,0.406,0.336,0.075,0.397,high,Prokaryote
26 | F7YBW8_MESOW,0.433,0.062,0.011,0.425,0.428,0.43,0.375,0.413,0.382,0.393,0.391,0.395,0.44,0.06,0.395,-0.075,-0.122,-0.104,-0.005,-0.081,0.098,0.125,-0.021,0.298,0.403,0.22,0.016,0.32,0.374,0.461,0.441,0.437,0.037,0.444,high,Prokaryote
27 | GAL4_YEAST,0.325,0.56,0.555,0.557,0.486,0.532,0.602,0.58,0.458,0.462,0.507,0.507,0.567,0.237,0.403,0.319,0.355,0.383,0.372,0.37,0.364,0.424,0.493,0.451,0.579,0.504,0.337,-0.004,0.629,0.651,0.563,0.623,0.303,0.503,medium,Eukaryote
28 | GCN4_YEAST,0.265,0.258,0.257,0.278,0.242,0.241,0.249,0.25,0.272,0.284,0.186,0.245,0.248,0.253,0.25,0.175,0.189,0.168,0.167,0.179,0.158,0.118,0.081,0.134,0.164,0.144,0.182,0.221,0.222,0.255,0.247,0.238,0.025,0.274,low,Eukaryote
29 | GFP_AEQVI,0.629,0.649,0.651,0.672,0.679,0.679,0.648,0.657,0.098,0.103,0.598,0.672,0.673,0.649,0.644,0.078,0.105,0.181,0.098,0.123,0.046,0.188,0.298,0.642,0.647,0.446,0.049,0.635,0.679,0.607,0.615,0.524,0.074,0.706,low,Eukaryote
30 | GRB2_HUMAN,0.404,0.517,0.506,0.436,0.54,0.546,0.471,0.486,0.455,0.515,0.534,0.507,0.538,0.405,0.521,0.544,0.516,0.514,0.453,0.525,0.532,0.509,0.446,0.524,0.448,0.515,0.486,0.45,0.51,0.428,0.422,0.533,0.464,0.532,medium,Human
31 | HIS7_YEAST,0.585,0.488,0.496,0.616,0.533,0.531,0.508,0.506,0.411,0.477,0.454,0.558,0.557,0.472,0.542,0.325,0.402,0.433,0.478,0.443,0.404,0.485,0.464,0.516,0.539,0.518,0.146,0.273,0.522,0.4,0.374,0.472,0.143,0.582,medium,Eukaryote
32 | HSP82_YEAST,0.414,0.433,0.444,0.437,0.448,0.447,0.409,0.41,0.449,0.466,0.412,0.452,0.46,0.396,0.39,0.404,0.413,0.396,0.42,0.427,0.428,0.422,0.442,0.385,0.432,0.441,0.277,0.326,0.459,0.474,0.461,0.414,0.298,0.458,medium,Eukaryote
33 | I6TAH8_I68A0,0.337,0.329,0.354,0.348,0.364,0.361,0.303,0.324,0.018,0.014,0.212,0.268,0.264,0.347,0.317,0.308,0.328,0.374,0.377,0.365,0.004,0.011,0.097,0.002,0.302,0.127,-0.005,0.297,0.353,0.25,0.253,0.013,0.171,0.401,medium,Virus
34 | IF1_ECOLI,0.527,0.463,0.47,0.495,0.527,0.537,0.234,0.255,0.54,0.565,0.49,0.541,0.539,0.328,0.499,0.361,0.438,0.366,0.411,0.412,0.451,0.458,0.482,0.443,0.459,0.484,0.177,0.387,0.408,0.52,0.463,0.534,0.238,0.539,high,Prokaryote
35 | KCNH2_HUMAN,0.511,0.499,0.554,0.538,0.211,0.21,0.356,0.33,0.216,0.234,0.382,0.298,0.292,0.449,0.421,0.468,0.513,0.495,0.463,0.503,0.502,0.491,0.455,0.477,0.481,0.5,0.454,0.151,0.438,0.433,0.318,0.307,0.344,0.509,medium,Human
36 | KKA2_KLEPN,0.586,0.445,0.524,0.588,0.599,0.603,0.569,0.594,0.597,0.621,0.498,0.437,0.622,0.25,0.53,0.285,0.42,0.518,0.54,0.503,0.353,0.578,0.577,0.568,0.638,0.607,0.199,0.423,0.647,0.637,0.586,0.566,0.148,0.63,high,Prokaryote
37 | MK01_HUMAN,0.004,0.193,0.119,0.093,0.223,0.227,0.144,0.139,0.167,0.182,0.184,0.237,0.241,0.176,0.198,0.209,0.117,0.069,0.016,0.11,0.183,0.089,0.057,0.076,-0.067,0.077,0.209,0.163,0.234,0.176,0.222,0.033,0.106,0.202,medium,Human
38 | MSH2_HUMAN,0.281,0.35,0.388,0.346,0.383,0.39,0.39,0.395,0.38,0.4,0.364,0.367,0.376,0.352,0.399,0.296,0.31,0.264,0.257,0.313,0.326,0.317,0.32,0.326,0.301,0.343,0.204,0.339,0.375,0.351,0.324,0.349,0.208,0.383,medium,Human
39 | MTH3_HAEAE,0.665,0.422,0.496,0.638,0.696,0.704,0.671,0.68,0.692,0.708,0.658,0.709,0.718,0.371,0.612,0.311,0.454,0.6,0.662,0.593,0.488,0.663,0.711,0.684,0.727,0.707,0.314,0.644,0.672,0.697,0.652,0.581,0.291,0.714,medium,Prokaryote
40 | NCAP_I34A1,0.415,0.39,0.402,0.424,0.363,0.364,0.338,0.358,0.019,0.02,0.269,0.334,0.333,0.364,0.328,0.352,0.382,0.408,0.413,0.409,0.018,0.042,0.108,0.03,0.352,0.176,0.002,0.335,0.351,0.27,0.279,0.015,0.126,0.441,medium,Virus
41 | NRAM_I33A0,0.551,0.592,0.615,0.621,0.584,0.584,0.519,0.628,0.162,0.448,0.343,0.501,0.49,0.569,0.565,0.583,0.633,0.584,0.571,0.621,0.047,0.53,0.627,0.462,0.654,0.477,0.035,0.39,0.643,0.404,0.441,-0.076,-0.169,0.632,low,Virus
42 | NUD15_HUMAN,0.575,0.433,0.456,0.604,0.591,0.594,0.643,0.671,0.6,0.645,0.501,0.564,0.596,0.271,0.453,0.301,0.454,0.546,0.518,0.545,0.412,0.579,0.577,0.549,0.542,0.604,-0.005,0.389,0.605,0.623,0.553,0.603,0.147,0.635,high,Human
43 | P53_HUMAN,0.374,0.402,0.487,0.404,0.431,0.429,0.303,0.303,0.467,0.52,0.09,0.322,0.338,0.403,0.427,0.339,0.483,0.479,0.44,0.48,0.416,0.498,0.511,0.544,0.327,0.508,-0.106,0.296,0.472,0.469,0.411,0.5,0.215,0.411,low,Human
44 | P53_HUMAN_Kotler,0.42,0.611,0.584,0.564,0.555,0.576,0.582,0.579,0.538,0.622,0.432,0.496,0.556,0.629,0.588,0.429,0.461,0.478,0.463,0.518,0.47,0.48,0.465,0.47,0.493,0.518,0.109,0.498,0.603,0.631,0.585,0.6,0.039,0.581,low,Human
45 | P84126_THETH,0.536,0.515,0.526,0.547,0.564,0.578,0.637,0.636,0.552,0.588,0.618,0.603,0.617,0.508,0.58,0.419,0.507,0.478,0.558,0.507,0.521,0.584,0.589,0.571,0.653,0.614,0.364,0.526,0.521,0.553,0.469,0.569,0.359,0.581,medium,Prokaryote
46 | PABP_YEAST,0.64,0.689,0.692,0.688,0.654,0.648,0.662,0.655,0.662,0.678,0.541,0.541,0.55,0.663,0.617,0.638,0.665,0.666,0.692,0.688,0.638,0.698,0.7,0.676,0.666,0.704,0.474,0.569,0.675,0.703,0.635,0.688,0.261,0.684,medium,Eukaryote
47 | PA_I34A1,0.541,0.546,0.561,0.572,0.539,0.543,0.383,0.304,0.054,0.101,0.325,0.499,0.508,0.518,0.519,0.456,0.493,0.533,0.538,0.528,0.219,0.408,0.444,0.429,0.438,0.451,0.041,0.358,0.586,0.384,0.374,0.037,0.107,0.584,medium,Virus
48 | POLG_CXB3N,0.355,0.342,0.385,0.413,0.46,0.473,0.486,0.5,-0.059,0.042,0.356,0.377,0.411,0.423,0.39,0.339,0.39,0.381,0.377,0.395,0.138,0.388,0.383,0.369,0.393,0.392,-0.036,0.336,0.495,0.386,0.319,0.292,0.007,0.458,medium,Virus
49 | POLG_HCVJF,0.522,0.515,0.547,0.578,0.605,0.614,0.608,0.591,0.637,0.635,0.26,0.41,0.413,0.605,0.547,0.4,0.443,0.452,0.492,0.487,0.422,0.475,0.32,0.41,0.517,0.476,-0.039,0.196,0.637,0.587,0.485,0.178,0.182,0.56,medium,Virus
50 | PTEN_HUMAN,0.321,0.441,0.466,0.415,0.466,0.474,0.455,0.444,0.439,0.477,0.496,0.44,0.453,0.426,0.448,0.263,0.436,0.368,0.335,0.428,0.34,0.314,0.3,0.347,0.261,0.355,0.165,0.352,0.512,0.438,0.432,0.462,0.092,0.489,medium,Human
51 | Q2N0S5_9HIV1,0.406,0.52,0.502,0.501,0.495,0.502,0.504,0.514,0.509,0.537,0.403,0.352,0.393,0.493,0.379,0.518,0.403,0.39,0.337,0.444,0.517,0.401,0.394,0.436,0.354,0.459,0.003,0.437,0.516,0.478,0.483,0.47,0.291,0.513,medium,Virus
52 | Q59976_STRSQ,0.634,0.616,0.653,0.657,0.654,0.662,0.685,0.655,0.519,0.543,0.649,0.634,0.643,0.475,0.593,0.598,0.651,0.652,0.663,0.671,0.622,0.662,0.677,0.673,0.68,0.691,0.363,0.535,0.685,0.633,0.575,0.588,0.304,0.675,medium,Prokaryote
53 | R1AB_SARS2,0.216,0.35,0.399,0.401,0.6,0.605,-0.037,-0.037,-0.03,-0.04,0.266,0.212,0.227,0.577,0.561,0.214,0.259,0.274,0.289,0.272,0.242,0.236,0.203,0.21,0.224,0.241,-0.049,0.292,0.586,0.507,0.408,0.103,-0.056,0.565,medium,Virus
54 | RASH_HUMAN,0.377,0.454,0.478,0.45,0.466,0.48,0.408,0.413,0.36,0.405,0.514,0.444,0.476,0.447,0.436,0.437,0.414,0.42,0.396,0.432,0.433,0.403,0.401,0.374,0.305,0.412,0.313,0.353,0.436,0.338,0.319,0.318,0.18,0.487,high,Human
55 | REV_HV1H2,0.24,0.245,0.269,0.236,0.216,0.216,0.246,0.251,0.245,0.267,0.17,0.221,0.227,0.206,0.159,0.216,0.259,0.238,0.24,0.252,0.29,0.294,0.16,0.253,0.255,0.305,0.038,0.316,0.293,0.35,0.353,0.128,0.06,0.235,medium,Virus
56 | RL401_YEAST,0.382,0.43,0.479,0.418,0.38,0.407,0.396,0.394,0.277,0.316,0.437,0.38,0.425,0.322,0.352,0.392,0.51,0.487,0.431,0.485,0.49,0.434,0.436,0.418,0.395,0.467,0.123,0.373,0.366,0.426,0.436,0.191,0.14,0.42,medium,Eukaryote
57 | SC6A4_HUMAN,0.491,0.53,0.542,0.53,0.504,0.522,0.552,0.568,0.531,0.542,0.555,0.423,0.433,0.387,0.456,0.498,0.511,0.5,0.496,0.519,0.512,0.52,0.518,0.525,0.511,0.534,0.352,0.507,0.55,0.463,0.36,0.545,0.342,0.545,medium,Human
58 | SCN5A_HUMAN,0.069,0.095,0.086,0.086,0.153,0.158,0.152,0.135,0.217,0.135,0.126,0.162,0.162,0.13,0.131,0.106,0.143,0.151,0.127,0.129,0.124,0.106,0.093,0.107,0.167,0.124,0.13,-0.014,0.144,0.186,0.181,0.141,0.095,0.152,medium,Human
59 | SPG1_STRSG,0.279,0.243,0.204,0.289,0.247,0.272,0.142,0.245,0.237,0.192,0.278,-0.004,0.006,0.239,0.282,0.252,0.216,0.214,0.208,0.227,0.232,0.222,0.218,0.239,0.354,0.259,-0.041,0.017,0.287,0.478,0.494,0.305,0.111,0.29,low,Prokaryote
60 | SPIKE_SARS2,0.369,0.348,0.343,0.342,0.347,0.351,0.472,0.486,-0.044,-0.018,0.346,0.114,0.215,0.179,0.26,0.311,0.375,0.366,0.376,0.377,0.382,0.357,0.318,0.388,0.324,0.382,-0.031,0.36,0.236,0.318,0.352,0.024,0.174,0.408,medium,Virus
61 | SRC_HUMAN,0.348,0.502,0.502,0.492,0.496,0.507,0.258,0.37,0.561,0.585,0.484,0.465,0.465,0.526,0.508,0.439,0.413,0.421,0.373,0.439,0.437,0.46,0.438,0.422,0.33,0.442,0.532,0.44,0.549,0.578,0.574,0.469,0.431,0.516,medium,Human
62 | SUMO1_HUMAN,0.329,0.405,0.511,0.41,0.48,0.478,0.462,0.494,0.467,0.51,0.517,0.419,0.438,0.369,0.373,0.217,0.414,0.443,0.43,0.432,0.468,0.386,0.46,0.431,0.333,0.453,0.13,0.425,0.442,0.46,0.445,0.433,0.212,0.463,high,Human
63 | SYUA_HUMAN,0.181,0.124,0.176,0.159,0.131,0.139,0.1,0.152,0.242,0.233,0.22,0.119,0.132,0.103,0.111,0.138,0.203,0.155,0.141,0.167,0.086,0.151,0.167,0.105,0.136,0.141,0.131,0.212,0.222,0.205,0.189,0.234,0.005,0.156,medium,Human
64 | TADBP_HUMAN,0.121,0.142,0.185,0.121,0.08,0.08,0.058,0.041,0.051,0.048,-0.075,0.1,0.096,0.097,0.061,0.158,0.063,-0.011,-0.006,0.038,0.206,0.082,0.006,0.023,-0.014,0.047,0.291,-0.018,0.001,0.088,0.112,0.013,-0.051,0.109,low,Human
65 | TAT_HV1BR,0.203,0.379,0.263,0.237,0.319,0.3,0.303,0.328,0.343,0.342,0.203,0.255,0.268,0.293,0.201,0.379,0.364,0.396,0.398,0.397,0.394,0.269,0.136,0.246,0.223,0.279,-0.09,0.397,0.393,0.405,0.387,0.185,0.283,0.268,high,Virus
66 | TPK1_HUMAN,0.301,0.232,0.252,0.305,0.23,0.229,0.245,0.276,0.27,0.318,0.249,0.219,0.228,0.217,0.236,0.088,0.142,0.228,0.272,0.212,0.115,0.267,0.273,0.253,0.306,0.289,0.067,0.253,0.235,0.253,0.188,0.285,0.117,0.277,medium,Human
67 | TPMT_HUMAN,0.411,0.478,0.513,0.478,0.499,0.513,0.471,0.476,0.517,0.547,0.511,0.489,0.509,0.372,0.456,0.333,0.42,0.481,0.496,0.478,0.447,0.506,0.481,0.466,0.424,0.508,0.241,0.445,0.554,0.542,0.489,0.546,0.334,0.515,medium,Human
68 | TPOR_HUMAN,0.419,0.429,0.408,0.451,0.288,0.296,0.504,0.482,0.368,0.37,0.362,0.283,0.263,0.376,0.365,0.25,0.338,0.313,0.377,0.327,0.367,0.2,0.471,0.451,0.34,0.385,0.37,0.455,0.412,0.436,0.274,0.407,0.395,0.437,low,Human
69 | TRPC_SACS2,0.558,0.575,0.598,0.591,0.575,0.585,0.635,0.66,0.611,0.643,0.592,0.558,0.574,0.575,0.606,0.421,0.564,0.508,0.575,0.575,0.498,0.499,0.558,0.521,0.524,0.563,0.173,0.499,0.531,0.59,0.52,0.621,0.221,0.601,medium,Prokaryote
70 | TRPC_THEMA,0.448,0.39,0.442,0.438,0.416,0.417,0.479,0.477,0.478,0.506,0.43,0.394,0.411,0.405,0.45,0.336,0.399,0.395,0.455,0.416,0.437,0.421,0.369,0.387,0.478,0.433,0.395,0.434,0.369,0.427,0.365,0.439,0.14,0.442,medium,Prokaryote
71 | UBC9_HUMAN,0.428,0.327,0.452,0.475,0.508,0.522,0.496,0.519,0.477,0.509,0.557,0.521,0.531,0.371,0.496,0.22,0.405,0.445,0.419,0.43,0.42,0.479,0.471,0.449,0.442,0.484,-0.042,0.403,0.478,0.454,0.432,0.42,0.071,0.54,medium,Human
72 | UBE4B_MOUSE,0.262,0.419,0.351,0.39,0.463,0.47,0.344,0.364,0.447,0.471,0.383,0.454,0.468,0.412,0.417,0.122,0.324,0.337,0.3,0.314,0.444,0.376,0.395,0.376,0.289,0.408,0.078,0.089,0.435,0.437,0.396,0.361,0.091,0.472,low,Eukaryote
73 | VKOR1_HUMAN,0.438,0.4,0.416,0.475,0.425,0.432,0.446,0.462,0.437,0.461,0.44,0.4,0.416,0.369,0.392,0.168,0.186,0.334,0.37,0.314,0.25,0.428,0.424,0.398,0.414,0.43,0.122,0.391,0.434,0.407,0.338,0.428,0.134,0.487,medium,Human
74 | YAP1_HUMAN,0.217,0.402,0.326,0.359,0.449,0.455,0.07,0.053,0.28,0.285,0.189,0.464,0.46,0.428,0.321,0.18,0.176,0.158,0.167,0.176,0.299,0.226,0.257,0.207,0.15,0.239,0.329,0.289,0.334,0.425,0.451,0.333,0.139,0.439,low,Human
75 | ,0.401,0.419,0.43,0.446,0.443,0.449,0.421,0.432,0.372,0.401,0.391,0.404,0.421,0.375,0.413,0.321,0.366,0.375,0.38,0.388,0.341,0.383,0.383,0.387,0.402,0.413,0.166,0.348,0.459,0.444,0.408,0.358,0.175,0.472,,
76 |
--------------------------------------------------------------------------------
/src/run_pipeline.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import subprocess
4 | import torch
5 | import models.gvp.data, models.gvp.models
6 | import json
7 | import os
8 | import numpy as np
9 | import torch_geometric
10 | from functools import partial
11 | import esm
12 | import random
13 |
14 | print = partial(print, flush=True)
15 | import torch.multiprocessing
16 |
17 | torch.multiprocessing.set_sharing_strategy("file_system")
18 | from models.msa_transformer.model import MSATransformer
19 | from models.gvp.models import SSEmbGNN
20 | import pickle
21 | from visualization import (
22 | plot_mave_corr_vs_depth,
23 | )
24 | from helpers import (
25 | read_msa,
26 | loop_trainval,
27 | prepare_mave_val,
28 | prepare_proteingym,
29 | prepare_proteingym_bad,
30 | prepare_proteingym_default,
31 | )
32 | import run_test_mave, run_test_proteingym, run_test_proteingym_default, run_test_proteingym_bad, run_test_proteingym_good, run_test_rocklin, run_pipeline_scannet, run_test_clinvar
33 | import torch.utils.data
34 | import torch.multiprocessing as mp
35 | import torch.distributed as dist
36 | from torch.utils.data.distributed import DistributedSampler
37 | from torch.nn.parallel import DistributedDataParallel as DDP
38 |
39 | #DEVICES = [3, 4, 5, 6, 7, 8, 9]
40 | DEVICES = [1]
41 | os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join([str(x) for x in DEVICES])
42 |
43 | def setup(rank, world_size):
44 | os.environ["MASTER_ADDR"] = "localhost"
45 | os.environ["MASTER_PORT"] = "1111"
46 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
47 |
48 | def prepare(
49 | dataset,
50 | rank,
51 | world_size,
52 | batch_size=1,
53 | pin_memory=False,
54 | num_workers=0,
55 | train=False,
56 | ):
57 | sampler = DistributedSampler(
58 | dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=True
59 | )
60 | if train == True:
61 | dataloader = torch_geometric.loader.DataLoader(
62 | dataset,
63 | batch_size=batch_size,
64 | pin_memory=pin_memory,
65 | num_workers=num_workers,
66 | drop_last=False,
67 | shuffle=False,
68 | sampler=sampler,
69 | )
70 | else:
71 | dataloader = torch_geometric.loader.DataLoader(
72 | dataset,
73 | batch_size=batch_size,
74 | pin_memory=pin_memory,
75 | num_workers=num_workers,
76 | drop_last=False,
77 | shuffle=False,
78 | sampler=None,
79 | )
80 | return dataloader
81 |
82 |
83 | def cleanup():
84 | dist.destroy_process_group()
85 |
86 |
87 | def main(rank, world_size):
88 | # Setup the process groups
89 | torch.cuda.set_device(rank)
90 | torch.cuda.empty_cache()
91 | setup(rank, world_size)
92 |
93 | # Set fixed seed
94 | seed = 1
95 | torch.manual_seed(seed)
96 | np.random.seed(seed)
97 | random.seed(seed)
98 | torch.cuda.manual_seed_all(seed)
99 | torch.backends.cudnn.deterministic = True
100 | torch.backends.cudnn.benchmark = False
101 | torch.backends.cudnn.enabled = False
102 |
103 | # Print name of run
104 | run_name = "final_cath"
105 |
106 | # Set initial parameters
107 | EPOCHS = 200
108 | EPOCH_FINETUNE_MSA = 100
109 | VAL_INTERVAL = 10
110 | BATCH_PROTS = 128 // len(DEVICES)
111 | LR_LIST = [1e-3, 1e-6]
112 | PATIENCE = 1
113 |
114 | ## Load CATH data
115 | print("Preparing CATH data")
116 | pdb_dir_cath = "../data/train/cath"
117 | subprocess.run([f"bash {pdb_dir_cath}/getCATH.sh"], shell=True)
118 | subprocess.run([f"mv chain_set.jsonl {pdb_dir_cath}/chain_set.json"], shell=True)
119 | subprocess.run([f"mv chain_set_splits.json {pdb_dir_cath}/chain_set_splits.json"], shell=True)
120 | cath = models.gvp.data.CATHDataset(
121 | path=f"{pdb_dir_cath}/chain_set.json",
122 | splits_path=f"{pdb_dir_cath}/chain_set_splits.json",
123 | )
124 |
125 | # Compute MSAs
126 | # TO DO: Add code example to extract sequences from CATH data set
127 | # to file: f"{pdb_dir_cath}/seqs.fasta"
128 | sys.path += [":/projects/prism/people/skr526/mmseqs/bin"]
129 | subprocess.run(
130 | [
131 | "colabfold_search",
132 | f"{pdb_dir_cath}/seqs.fasta",
133 | "/projects/prism/people/skr526/databases",
134 | "../data/train/cath/msa/",
135 | ]
136 | )
137 | subprocess.run(["python", "merge_and_sort_msas.py", "../data/train/cath/msa"])
138 |
139 | # Add MSAs
140 | for i, entry in enumerate(cath.total):
141 | print(f"Adding CATH MSAs: {i+1}/{len(cath.total)}")
142 | entry["msa"] = read_msa(f"{pdb_dir_cath}/msa/{entry['name']}.a3m")
143 |
144 | # Checkpoint - save and load
145 | with open(f"{pdb_dir_cath}/data_with_msas.pkl", "wb") as fp: # Pickling
146 | pickle.dump(cath.total, fp)
147 |
148 | with open(f"{pdb_dir_cath}/data_with_msas.pkl", "rb") as fp: # Unpickling
149 | cath.total = pickle.load(fp)
150 |
151 | ## Filter data
152 | # Only keep entries where MSA and structucture sequence lengths match
153 | data = [
154 | entry for entry in cath.total if len(entry["seq"]) == len(entry["msa"][0][1])
155 | ]
156 |
157 | # Filter: Only keep entries without X in sequence
158 | data = [entry for entry in cath.total if "X" not in entry["seq"]]
159 |
160 | # Save all training and validation sequences in a fasta file to check homology
161 | cath.split()
162 | with open(f"../data/test/mave_val/structure/coords.json") as json_file:
163 | data_mave_val = json.load(json_file)
164 |
165 | with open(f"../data/test/proteingym/structure/coords.json") as json_file:
166 | data_proteingym = json.load(json_file)
167 |
168 | fh = open(f"../data/train/cath/seqs_cath.fasta", "w")
169 | for entry in cath.train:
170 | fh.write(f">{entry['name']}\n")
171 | fh.write(f"{entry['seq']}\n")
172 |
173 | for entry in cath.val:
174 | fh.write(f">{entry['name']}\n")
175 | fh.write(f"{entry['seq']}\n")
176 |
177 | for entry in data_mave_val:
178 | fh.write(f">{entry['name']}\n")
179 | fh.write(f"{entry['seq']}\n")
180 |
181 | for entry in data_proteingym:
182 | fh.write(f">{entry['name']}\n")
183 | fh.write(f"{entry['seq']}\n")
184 | fh.close()
185 |
186 | # Compute clusters of 95% sequence similarities between all training, validation and test proteins
187 | subprocess.run(
188 | [
189 | "cd-hit",
190 | "-i",
191 | "../data/train/cath/seqs_cath.fasta",
192 | "-o",
193 | "../data/train/cath/seqs_cath_homology.fasta",
194 | "-c",
195 | "0.95",
196 | "-n",
197 | "5",
198 | "-d",
199 | "999",
200 | ]
201 | )
202 |
203 | # Remove proteins from training data that has high sequence similarity with validation or test proteins
204 | val_prot_names = [entry["name"] for entry in cath.val]
205 | val_mave_prot_names = [entry["name"] for entry in data_mave_val]
206 | test_prot_names = [entry["name"] for entry in data_proteingym]
207 | valtest_prot_names = val_prot_names + val_mave_prot_names + test_prot_names
208 |
209 | fh = open("../data/train/cath/seqs_cath_homology.fasta.clstr", "r")
210 | cluster_dict = {}
211 | remove_list = []
212 | for line in fh.readlines():
213 | if line.startswith(">Cluster"):
214 | cluster_name = line
215 | cluster_dict[cluster_name] = []
216 | else:
217 | cluster_dict[cluster_name].append(line.split(">")[1].split("...")[0])
218 |
219 | for cluster_name, prot_names in cluster_dict.items():
220 | if len(prot_names) > 1 and any(
221 | valtest_prot_name in prot_names for valtest_prot_name in valtest_prot_names
222 | ):
223 | remove_list += prot_names
224 | remove_list = [
225 | prot_name for prot_name in remove_list if prot_name not in valtest_prot_names
226 | ]
227 | cath.train = [entry for entry in cath.train if entry["name"] not in remove_list]
228 |
229 | # Checkpoint - save and load
230 | with open(
231 | f"{pdb_dir_cath}/data_with_msas_filtered_train.pkl", "wb"
232 | ) as fp: # Pickling
233 | pickle.dump(cath.train, fp)
234 | with open(
235 | f"{pdb_dir_cath}/data_with_msas_filtered_val.pkl", "wb"
236 | ) as fp: # Pickling
237 | pickle.dump(cath.val, fp)
238 |
239 | # Prepare MAVE validation and ProteinGym test data
240 | prepare_mave_val()
241 | prepare_proteingym()
242 | prepare_proteingym_bad()
243 | prepare_proteingym_default()
244 |
245 | with open(
246 | f"{pdb_dir_cath}/data_with_msas_filtered_train.pkl", "rb"
247 | ) as fp: # Unpickling
248 | cath.train = pickle.load(fp)
249 | with open(
250 | f"{pdb_dir_cath}/data_with_msas_filtered_val.pkl", "rb"
251 | ) as fp: # Unpickling
252 | cath.val = pickle.load(fp)
253 |
254 | # Convert to graph data sets
255 | trainset = models.gvp.data.ProteinGraphData(cath.train)
256 | valset = models.gvp.data.ProteinGraphData(cath.val)
257 |
258 | ## Load and initialize MSA Transformer
259 | model_msa_pre, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
260 | torch.save(
261 | model_msa_pre.state_dict(),
262 | f"../output/train/models/msa_transformer/pretrained.pt",
263 | )
264 | msa_batch_converter = msa_alphabet.get_batch_converter()
265 | model_msa = MSATransformer()
266 | model_msa.load_state_dict(
267 | torch.load(f"../output/train/models/msa_transformer/pretrained.pt")
268 | )
269 | model_msa.to(rank)
270 |
271 | # Freeze MSA Transformer
272 | for param in model_msa.parameters():
273 | param.requires_grad = False
274 |
275 | # Load and initialize GVP
276 | node_dim = (256, 64)
277 | edge_dim = (32, 1)
278 | model_gvp = SSEmbGNN((6, 3), node_dim, (32, 1), edge_dim)
279 | model_gvp.to(rank)
280 | model_gvp = DDP(
281 | model_gvp,
282 | device_ids=[rank],
283 | output_device=rank,
284 | find_unused_parameters=True,
285 | static_graph=False,
286 | )
287 |
288 | # Initialize training modules
289 | train_loader = prepare(trainset, rank, world_size, train=True)
290 | val_loader = prepare(valset, rank, world_size)
291 | optimizer = torch.optim.Adam(model_gvp.parameters(), lr=LR_LIST[0])
292 | scaler = torch.cuda.amp.GradScaler()
293 | best_epoch, best_corr_mave = None, 0
294 | patience_counter = 0
295 |
296 | # Initialize lists for monitoring loss
297 | epoch_list = []
298 | loss_train_list, loss_val_list = [], []
299 | acc_train_list, acc_val_list = [], []
300 | corr_mave_list, acc_mave_list = [], []
301 |
302 | for epoch in range(EPOCHS):
303 | # Check if we should fine-tune MSA Transformer row attention
304 | if epoch == EPOCH_FINETUNE_MSA:
305 | for param in model_msa.named_parameters():
306 | if "row_self_attention" in param[0]:
307 | param[1].requires_grad = True
308 | model_msa = DDP(
309 | model_msa,
310 | device_ids=[rank],
311 | output_device=rank,
312 | find_unused_parameters=True,
313 | static_graph=False,
314 | )
315 | optimizer.add_param_group(
316 | {"params": model_msa.parameters(), "lr": LR_LIST[1]}
317 | )
318 | BATCH_PROTS = 2048 // len(DEVICES)
319 |
320 | # If we are using DistributedSampler, we need to tell it which epoch this is
321 | train_loader.sampler.set_epoch(epoch)
322 |
323 | # Train loop
324 | model_msa.train()
325 | model_gvp.train()
326 |
327 | loss_train, acc_train = loop_trainval(
328 | model_msa,
329 | model_gvp,
330 | msa_batch_converter,
331 | train_loader,
332 | BATCH_PROTS,
333 | epoch,
334 | rank,
335 | EPOCH_FINETUNE_MSA,
336 | optimizer=optimizer,
337 | scaler=scaler,
338 | )
339 | ## Gather and save training metrics for epoch
340 | # OBS: This cannot be placed within validation loop or we get hangs
341 | loss_train = loss_train.type(torch.float32)
342 | loss_train_all_gather = [torch.zeros(1, device=rank) for _ in range(world_size)]
343 | dist.all_gather(loss_train_all_gather, loss_train)
344 |
345 | # Validation loop
346 | if rank == 0:
347 | if epoch % VAL_INTERVAL == 0:
348 | # Save model
349 | path_msa = f"../output/train/models/msa_transformer/{run_name}_msa_transformer_{epoch}.pt"
350 | path_gvp = f"../output/train/models/gvp/{run_name}_gvp_{epoch}.pt"
351 | path_optimizer = (
352 | f"../output/train/models/optimizer/{run_name}_adam_{epoch}.pt"
353 | )
354 | torch.save(model_msa.state_dict(), path_msa)
355 | torch.save(model_gvp.state_dict(), path_gvp)
356 | torch.save(optimizer.state_dict(), path_optimizer)
357 |
358 | with torch.no_grad():
359 | # Do training validation
360 | model_msa.eval()
361 | model_gvp.eval()
362 |
363 | loss_val, acc_val = loop_trainval(
364 | model_msa,
365 | model_gvp,
366 | msa_batch_converter,
367 | val_loader,
368 | BATCH_PROTS,
369 | epoch,
370 | rank,
371 | EPOCH_FINETUNE_MSA,
372 | )
373 |
374 | if epoch >= EPOCH_FINETUNE_MSA:
375 | # Do validation on MAVE set
376 | corr_mave, acc_mave = run_test_mave.test(
377 | run_name,
378 | epoch,
379 | device=rank,
380 | )
381 |
382 | # Update patience
383 | if corr_mave > best_corr_mave:
384 | best_epoch = epoch
385 | patience_counter = 0
386 | else:
387 | patience_counter += 1
388 | else:
389 | corr_mave, acc_mave = 0.0, 0.0
390 |
391 | # Save validation results
392 | epoch_list.append(epoch)
393 | loss_train_list.append(
394 | torch.mean(torch.stack(loss_train_all_gather)).to("cpu").item()
395 | )
396 | loss_val_list.append(loss_val.to("cpu").item())
397 | acc_val_list.append(acc_val)
398 | corr_mave_list.append(corr_mave)
399 | acc_mave_list.append(corr_mave)
400 |
401 | metrics = {
402 | "epoch": epoch_list,
403 | "loss_train": loss_train_list,
404 | "loss_val": loss_val_list,
405 | "acc_val": acc_val_list,
406 | "corr_mave": corr_mave_list,
407 | "acc_mave": acc_mave_list,
408 | }
409 | with open(f"../output/train/metrics/{run_name}_metrics", "wb") as f:
410 | pickle.dump(metrics, f)
411 |
412 | if patience_counter == PATIENCE:
413 | break
414 |
415 | # Create barrier after each epoch
416 | dist.barrier()
417 |
418 | # Clean up
419 | cleanup()
420 |
421 | # MAVE val set
422 | print("Starting MAVE val predictions")
423 | run_test_mave.test(run_name, best_epoch, get_only_ssemb_metrics=False, device=rank)
424 | plot_mave_corr_vs_depth()
425 | print("Finished MAVE val predictions")
426 |
427 | # ProteinGym test set
428 | print("Starting ProteinGym test")
429 | run_test_proteingym.test(run_name, best_epoch, device=rank)
430 | print("Finished ProteinGym test")
431 |
432 | # Rocklin test set
433 | print("Starting Rocklin test")
434 | run_test_rocklin.test(run_name, best_epoch, num_ensemble=5, device=rank)
435 | print("Finished Rocklin test")
436 |
437 | # ScanNet test set
438 | print("Starting ScanNet test")
439 | run_pipeline_scannet.run(run_name, best_epoch, device=rank)
440 | print("Finished ScanNet test")
441 |
442 | # ClinVar test set
443 | print("Starting ClinVar test")
444 | run_test_clinvar.test(run_name, best_epoch, num_ensemble=5, device=rank)
445 | print("Finished ClinVar test")
446 |
447 | if __name__ == "__main__":
448 | world_size = len(DEVICES)
449 | mp.spawn(
450 | main,
451 | args=(world_size,),
452 | nprocs=world_size,
453 | )
454 |
--------------------------------------------------------------------------------
/src/run_pipeline_scannet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import subprocess
4 | import re
5 | import torch
6 | import models.gvp.data, models.gvp.models
7 | import json
8 | import numpy as np
9 | import torch_geometric
10 | import esm
11 | import pandas as pd
12 | import random
13 | import torch.multiprocessing
14 |
15 | torch.multiprocessing.set_sharing_strategy("file_system")
16 | from models.msa_transformer.model import MSATransformer
17 | from models.gvp.models import SSEmbGNN
18 | from models.scannet.model import TransformerModel
19 | import pickle
20 | from helpers import (
21 | read_msa,
22 | loop_getemb,
23 | scannet_collate_fn,
24 | loop_scannet_trainval,
25 | loop_scannet_test,
26 | )
27 | from visualization import plot_precision_recall
28 | import pdb_parser_scripts.parse_pdbs as parse_pdbs
29 | import torch.utils.data
30 | from sklearn.metrics import auc
31 | from collections import OrderedDict
32 | from Bio.PDB import PDBList
33 | from Bio.PDB import PDBParser
34 | from Bio.PDB.PDBIO import PDBIO
35 | import io
36 | from contextlib import redirect_stdout
37 | import pymol2
38 | import time
39 |
40 |
41 | def run(run_name, epoch, device=None):
42 | # Download raw data
43 | subprocess.run(
44 | [
45 | "wget",
46 | "-c",
47 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/table.csv",
48 | "-P",
49 | "../data/test/scannet/",
50 | ]
51 | )
52 | subprocess.run(
53 | [
54 | "wget",
55 | "-c",
56 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_train.txt",
57 | "-P",
58 | "../data/test/scannet/labels",
59 | ]
60 | )
61 | subprocess.run(
62 | [
63 | "wget",
64 | "-c",
65 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_validation_70.txt",
66 | "-P",
67 | "../data/test/scannet/labels",
68 | ]
69 | )
70 | subprocess.run(
71 | [
72 | "wget",
73 | "-c",
74 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_validation_homology.txt",
75 | "-P",
76 | "../data/test/scannet/labels",
77 | ]
78 | )
79 | subprocess.run(
80 | [
81 | "wget",
82 | "-c",
83 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_validation_topology.txt",
84 | "-P",
85 | "../data/test/scannet/labels",
86 | ]
87 | )
88 | subprocess.run(
89 | [
90 | "wget",
91 | "-c",
92 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_validation_none.txt",
93 | "-P",
94 | "../data/test/scannet/labels",
95 | ]
96 | )
97 | subprocess.run(
98 | [
99 | "wget",
100 | "-c",
101 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_test_70.txt",
102 | "-P",
103 | "../data/test/scannet/labels",
104 | ]
105 | )
106 | subprocess.run(
107 | [
108 | "wget",
109 | "-c",
110 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_test_homology.txt",
111 | "-P",
112 | "../data/test/scannet/labels",
113 | ]
114 | )
115 | subprocess.run(
116 | [
117 | "wget",
118 | "-c",
119 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_test_topology.txt",
120 | "-P",
121 | "../data/test/scannet/labels",
122 | ]
123 | )
124 | subprocess.run(
125 | [
126 | "wget",
127 | "-c",
128 | "https://raw.githubusercontent.com/jertubiana/ScanNet/main/datasets/PPBS/labels_test_none.txt",
129 | "-P",
130 | "../data/test/scannet/labels",
131 | ]
132 | )
133 |
134 | # Download PDBs
135 | parser = PDBParser()
136 | pdb_io = PDBIO()
137 | df = pd.read_csv("../data/test/scannet/table.csv")
138 | pdb_list_all = df["PDB ID"].unique()
139 | pdb_list = [x[:4] for x in pdb_list_all]
140 | chain_list = [x[-1] for x in pdb_list_all]
141 |
142 | pdbl = PDBList()
143 | f = io.StringIO()
144 | for i, pdbid in enumerate(pdb_list):
145 | print(f"{i+1}/{len(pdb_list)}")
146 | print(pdbid)
147 | out = ""
148 | if os.path.exists(f"../data/test/scannet/raw/{pdbid}.pdb"):
149 | print("PDB file already downloaded")
150 | else:
151 | with redirect_stdout(f):
152 | pdbl.retrieve_pdb_file(
153 | pdbid, pdir="../data/test/scannet/raw", file_format="pdb"
154 | )
155 | out = f.getvalue()
156 | if "Desired structure doesn't exists" in out:
157 | try:
158 | pdbl.retrieve_pdb_file(
159 | pdbid, pdir="../data/test/scannet/raw", file_format="mmCif"
160 | )
161 | with pymol2.PyMOL() as pymol:
162 | pymol.cmd.load(
163 | f"../data/test/scannet/raw/{pdbid}.cif", "my_protein"
164 | )
165 | pymol.cmd.save(
166 | f"../data/test/scannet/raw/{pdbid}.cif".replace(
167 | ".cif", ".pdb"
168 | ),
169 | selection="my_protein",
170 | )
171 | except:
172 | print("Protein does not exist as either PDB or mmCIF file")
173 | else:
174 | subprocess.run(
175 | [
176 | "mv",
177 | f"../data/test/scannet/raw/pdb{pdbid}.ent",
178 | f"../data/test/scannet/raw/{pdbid}.pdb",
179 | ]
180 | )
181 | time.sleep(1)
182 | subprocess.run(["rm", "-r", "obsolete"])
183 |
184 | # Split into chains
185 | f = open("../data/test/scannet/scannet_download.log", "w")
186 | for i, pdbid in enumerate(pdb_list):
187 | try:
188 | structure = parser.get_structure(
189 | pdbid, f"../data/test/scannet/raw/{pdbid}.pdb"
190 | )
191 | pdb_chains = structure.get_chains()
192 |
193 | for chain in pdb_chains:
194 | if chain.get_id() == chain_list[i]:
195 | pdb_io.set_structure(chain)
196 | pdb_io.save(
197 | f"../data/test/scannet/structure/raw/{pdbid}_{chain_list[i]}.pdb"
198 | )
199 | except:
200 | f.write(f"{pdbid}_{chain_list[i]} not available\n")
201 | f.close()
202 |
203 | # Pre-process PDBs
204 | pdb_dir = "../data/test/scannet/structure"
205 | subprocess.call(
206 | [
207 | f"pdb_parser_scripts/clean_pdbs.sh",
208 | str(pdb_dir),
209 | ]
210 | )
211 | parse_pdbs.parse(pdb_dir)
212 |
213 | # Compute MSAs
214 | sys.path += [":/projects/prism/people/skr526/mmseqs/bin"]
215 | subprocess.run(
216 | [
217 | "colabfold_search",
218 | f"{pdb_dir}/seqs.fasta",
219 | "/projects/prism/people/skr526/databases",
220 | "../data/test/scannet/msa/",
221 | ]
222 | )
223 | subprocess.run(["python", "merge_and_sort_msas.py", "../data/test/scannet/msa"])
224 |
225 | # Load structure data
226 | with open(f"../data/test/scannet/structure/coords.json") as json_file:
227 | data_raw = json.load(json_file)
228 | json_file.close()
229 |
230 | # Only keep entries with sequence lengths <= 1024
231 | data = []
232 | for entry in data_raw:
233 | if len(entry["seq"]) <= 1024 - 1: # Consider added token
234 | data.append(entry)
235 |
236 | # Add MSAs to data
237 | for i, entry in enumerate(data):
238 | print(f"{i+1}/{len(data)}")
239 | msa = read_msa(f'../data/test/scannet/msa/{entry["name"]}.a3m')
240 | msa_sub = [msa[0]]
241 | k = min(len(msa) - 1, 256 - 1)
242 | msa_sub += [msa[i] for i in sorted(random.sample(range(1, len(msa)), k))]
243 | entry["msa"] = msa_sub
244 |
245 | with open("../output/scannet/data_scannet.pickle", "wb") as handle:
246 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
247 |
248 | with open("../output/scannet/data_scannet.pickle", "rb") as handle:
249 | data = pickle.load(handle)
250 |
251 | # Load MSA Transformer
252 | _, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
253 | model_msa = MSATransformer(msa_alphabet)
254 | model_msa = model_msa.to(device)
255 | msa_batch_converter = msa_alphabet.get_batch_converter()
256 |
257 | model_dict = OrderedDict()
258 | state_dict_msa = torch.load(
259 | f"../output/train/models/msa_transformer/{run_name}_msa_transformer_{epoch}.pt"
260 | )
261 | pattern = re.compile("module.")
262 | for k, v in state_dict_msa.items():
263 | if re.search("module", k):
264 | model_dict[re.sub(pattern, "", k)] = v
265 | else:
266 | model_dict = state_dict_msa
267 | model_msa.load_state_dict(model_dict)
268 |
269 | # Load GVP
270 | node_dim = (256, 64)
271 | edge_dim = (32, 1)
272 | model_gvp = SSEmbGNN((6, 3), node_dim, (32, 1), edge_dim)
273 | model_gvp = model_gvp.to(device)
274 |
275 | model_dict = OrderedDict()
276 | state_dict_gvp = torch.load(f"../output/train/models/gvp/{run_name}_gvp_{epoch}.pt")
277 | pattern = re.compile("module.")
278 | for k, v in state_dict_gvp.items():
279 | if k.startswith("module"):
280 | model_dict[k[7:]] = v
281 | else:
282 | model_dict = state_dict_gvp
283 | model_gvp.load_state_dict(model_dict)
284 |
285 | # Convert to graph data sets
286 | allset = models.gvp.data.ProteinGraphData(data)
287 |
288 | # Init data loader
289 | data_loader = torch_geometric.loader.DataLoader(allset, batch_size=1, shuffle=False)
290 |
291 | # Add frozen embeddings to data
292 | with torch.no_grad():
293 | emb_dict = loop_getemb(
294 | model_msa,
295 | model_gvp,
296 | msa_batch_converter,
297 | data_loader,
298 | device=device,
299 | )
300 | for entry in data:
301 | entry["emb"] = emb_dict[entry["name"]]
302 |
303 | with open("../output/scannet/data_scannet_emb.pickle", "wb") as handle:
304 | pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
305 |
306 | with open("../output/scannet/data_scannet_emb.pickle", "rb") as handle:
307 | data = pickle.load(handle)
308 |
309 | # Add sample weights to data
310 | df = pd.read_csv(f"../data/test/scannet/table.csv")
311 | for entry in data:
312 | if entry["name"] in df["PDB ID"].unique():
313 | entry["prot_weight"] = df[df["PDB ID"] == entry["name"]][
314 | "Sample weight"
315 | ].item()
316 |
317 | # Concat label files
318 | filenames = [
319 | f"../data/test/scannet/labels/labels_train.txt",
320 | f"../data/test/scannet/labels/labels_validation_70.txt",
321 | f"../data/test/scannet/labels/labels_validation_homology.txt",
322 | f"../data/test/scannet/labels/labels_validation_topology.txt",
323 | f"../data/test/scannet/labels/labels_validation_none.txt",
324 | f"../data/test/scannet/labels/labels_test_70.txt",
325 | f"../data/test/scannet/labels/labels_test_homology.txt",
326 | f"../data/test/scannet/labels/labels_test_topology.txt",
327 | f"../data/test/scannet/labels/labels_test_none.txt",
328 | ]
329 | with open(f"../data/test/scannet/labels/labels_all.txt", "w") as outfile:
330 | for fname in filenames:
331 | with open(fname) as infile:
332 | for line in infile:
333 | outfile.write(line)
334 |
335 | # Add labels to data
336 | f = open(f"../data/test/scannet/labels/labels_all.txt", "r")
337 | label_dict = {}
338 | prot = None
339 | for line in f.readlines():
340 | if line.startswith(">"):
341 | prot = line[1:].strip()[:4] + "_" + line[1:].strip()[-1]
342 | label_dict[prot] = []
343 | else:
344 | label_dict[prot].append(
345 | line.strip().split(" ")
346 | ) # [chain_id, pos, aa, label]
347 | f.close()
348 |
349 | # Add labels to data
350 | i = 0
351 | j = 0
352 | data_clean = []
353 | for entry in data:
354 | if entry["name"] in label_dict.keys():
355 | label_seq = "".join([x[2] for x in label_dict[entry["name"]]])
356 |
357 | if label_seq == entry["seq"]:
358 | entry["label"] = torch.tensor(
359 | [int(x[3]) for x in label_dict[entry["name"]]], device=device
360 | )
361 | else:
362 | entry["label"] = None
363 | j += 1
364 | else:
365 | entry["label"] = None
366 | i += 1
367 |
368 | # If no errors; add label data to entry
369 | if entry["label"] is not None:
370 | data_clean.append(entry)
371 |
372 | print(
373 | f"Number of PDBs where we have structure data but no label data: {i}/{len(data)}"
374 | )
375 | print(
376 | f"Number of PDBs where the label seq and the structure seq doesn't match: {j}/{len(data)}"
377 | )
378 | print(f"Number of PDBs in cleaned data set: {len(data_clean)}")
379 |
380 | # Set parameters for downstream model learning
381 | EPOCHS = 40
382 | VAL_INTERVAL = 1
383 | BATCH_PROTS = 10
384 | LR = 1e-4
385 |
386 | # Split intro train/val/test
387 | df = pd.read_csv(f"../data/test/scannet/table.csv")
388 | data_train = [
389 | x for x in data_clean if x["name"] in list(df[df["Set"] == "Train"]["PDB ID"])
390 | ]
391 | data_val = [
392 | x
393 | for x in data_clean
394 | if x["name"] in list(df[df["Set"] == "Validation (70\%)"]["PDB ID"])
395 | ]
396 |
397 | # Load Transformer model
398 | model_transformer = TransformerModel(
399 | ntoken=1, nhead=4, d_hid=256, nlayers=4, dropout=0.1
400 | )
401 | model_transformer = model_transformer.to(device)
402 |
403 | # Init optimizer
404 | optimizer = torch.optim.Adam(model_transformer.parameters(), lr=LR)
405 |
406 | # Initialize data loader
407 | train_loader = torch.utils.data.DataLoader(
408 | data_train, batch_size=1, collate_fn=scannet_collate_fn, shuffle=True
409 | ) # OBS: Use grad accumulation if bs > 1
410 | val_loader = torch.utils.data.DataLoader(
411 | data_val, batch_size=1, collate_fn=scannet_collate_fn, shuffle=False
412 | )
413 |
414 | # Initialize lists for monitoring loss
415 | epoch_list = []
416 | loss_train_list, loss_val_list = [], []
417 | acc_train_list, acc_val_list = [], []
418 | corr_mave_list, acc_mave_list = [], []
419 | best_epoch, best_loss_val = None, np.inf
420 |
421 | # Begin training and validation loop
422 | for epoch in range(EPOCHS):
423 | # Train loop
424 | model_transformer.train()
425 | loss_train, acc_train = loop_scannet_trainval(
426 | model_transformer,
427 | train_loader,
428 | device=device,
429 | optimizer=optimizer,
430 | batch_prots=BATCH_PROTS,
431 | )
432 |
433 | # Save model
434 | path_transformer = (
435 | f"../output/scannet/models/transformer/{run_name}_transformer_{epoch}.pt"
436 | )
437 | path_optimizer = (
438 | f"../output/scannet/models/optimizer/{run_name}_adam_{epoch}.pt"
439 | )
440 | torch.save(model_transformer.state_dict(), path_transformer)
441 | torch.save(optimizer.state_dict(), path_optimizer)
442 |
443 | # Compute validation
444 | if epoch % VAL_INTERVAL == 0:
445 | # Validation
446 | with torch.no_grad():
447 | model_transformer.eval()
448 | loss_val, acc_val = loop_scannet_trainval(
449 | model_transformer,
450 | val_loader,
451 | device=device,
452 | )
453 |
454 | if loss_val < best_loss_val:
455 | best_epoch, best_loss_val = epoch, loss_val
456 |
457 | # Save validation results
458 | epoch_list.append(epoch)
459 | loss_train_list.append(loss_train.to("cpu").item())
460 | loss_val_list.append(loss_val.to("cpu").item())
461 | acc_val_list.append(acc_val)
462 |
463 | metrics = {
464 | "epoch": epoch_list,
465 | "loss_train": loss_train_list,
466 | "loss_val": loss_val_list,
467 | "acc_val": acc_val_list,
468 | }
469 | with open(f"../output/scannet/metrics/{run_name}_metrics", "wb") as f:
470 | pickle.dump(metrics, f)
471 |
472 | # Test
473 | #best_epoch = 26 # OBS: Uncomment this line to use weights from paper
474 | print(f"Testing model! Using best model from epoch: {best_epoch}")
475 | model_transformer.load_state_dict(
476 | torch.load(
477 | f"../output/scannet/models/transformer/{run_name}_transformer_{best_epoch}.pt"
478 | )
479 | )
480 |
481 | test_sets = [
482 | ["Test (70\%)"],
483 | ["Test (Homology)"],
484 | ["Test (Topology)"],
485 | ["Test (None)"],
486 | ["Test (70\%)", "Test (Homology)", "Test (Topology)", "Test (None)"],
487 | ]
488 |
489 | precision_list = []
490 | recall_list = []
491 | outfile = open(f"../output/scannet/{run_name}_test_results.txt", "w")
492 |
493 | for test_set in test_sets:
494 | print(f"Computing predictions for: {' & '.join(test_set)}")
495 | data_test = [
496 | x
497 | for x in data_clean
498 | if x["name"] in list(df[df["Set"].isin(test_set)]["PDB ID"])
499 | ]
500 | test_loader = torch.utils.data.DataLoader(
501 | data_test, batch_size=1, collate_fn=scannet_collate_fn, shuffle=False
502 | )
503 |
504 | with torch.no_grad():
505 | model_transformer.eval()
506 | precision, recall = loop_scannet_test(
507 | model_transformer,
508 | test_loader,
509 | device=device,
510 | )
511 | precision_list.append(precision)
512 | recall_list.append(recall)
513 | auc_precision_recall = auc(recall, precision)
514 | outfile.write(
515 | f"Test AUCPR for {' & '.join(test_set)} is: {auc_precision_recall}\n"
516 | )
517 |
518 | # Plot test results
519 | plot_precision_recall(recall_list, precision_list)
520 | outfile.close()
521 |
--------------------------------------------------------------------------------
/src/models/msa_transformer/multihead_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import math
7 | from typing import Dict, Optional, Tuple
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torch import Tensor, nn
12 | from torch.nn import Parameter
13 | import uuid
14 |
15 |
16 | def utils_softmax(x, dim: int, onnx_trace: bool = False):
17 | if onnx_trace:
18 | return F.softmax(x.float(), dim=dim)
19 | else:
20 | return F.softmax(x, dim=dim, dtype=torch.float32)
21 |
22 |
23 | class FairseqIncrementalState(object):
24 | def __init__(self, *args, **kwargs):
25 | super().__init__(*args, **kwargs)
26 | self.init_incremental_state()
27 |
28 | def init_incremental_state(self):
29 | self._incremental_state_id = str(uuid.uuid4())
30 |
31 | def _get_full_incremental_state_key(self, key: str) -> str:
32 | return "{}.{}".format(self._incremental_state_id, key)
33 |
34 | def get_incremental_state(
35 | self,
36 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
37 | key: str,
38 | ) -> Optional[Dict[str, Optional[Tensor]]]:
39 | """Helper for getting incremental state for an nn.Module."""
40 | full_key = self._get_full_incremental_state_key(key)
41 | if incremental_state is None or full_key not in incremental_state:
42 | return None
43 | return incremental_state[full_key]
44 |
45 | def set_incremental_state(
46 | self,
47 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
48 | key: str,
49 | value: Dict[str, Optional[Tensor]],
50 | ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
51 | """Helper for setting incremental state for an nn.Module."""
52 | if incremental_state is not None:
53 | full_key = self._get_full_incremental_state_key(key)
54 | incremental_state[full_key] = value
55 | return incremental_state
56 |
57 |
58 | def with_incremental_state(cls):
59 | cls.__bases__ = (FairseqIncrementalState,) + tuple(
60 | b for b in cls.__bases__ if b != FairseqIncrementalState
61 | )
62 | return cls
63 |
64 |
65 | @with_incremental_state
66 | class MultiheadAttention(nn.Module):
67 | """Multi-headed attention.
68 |
69 | See "Attention Is All You Need" for more details.
70 | """
71 |
72 | def __init__(
73 | self,
74 | embed_dim,
75 | num_heads,
76 | kdim=None,
77 | vdim=None,
78 | dropout=0.0,
79 | bias=True,
80 | add_bias_kv: bool = False,
81 | add_zero_attn: bool = False,
82 | self_attention: bool = False,
83 | encoder_decoder_attention: bool = False,
84 | use_rotary_embeddings: bool = False,
85 | ):
86 | super().__init__()
87 | self.embed_dim = embed_dim
88 | self.kdim = kdim if kdim is not None else embed_dim
89 | self.vdim = vdim if vdim is not None else embed_dim
90 | self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
91 |
92 | self.num_heads = num_heads
93 | self.dropout = dropout
94 | self.head_dim = embed_dim // num_heads
95 | assert (
96 | self.head_dim * num_heads == self.embed_dim
97 | ), "embed_dim must be divisible by num_heads"
98 | self.scaling = self.head_dim**-0.5
99 |
100 | self.self_attention = self_attention
101 | self.encoder_decoder_attention = encoder_decoder_attention
102 |
103 | assert not self.self_attention or self.qkv_same_dim, (
104 | "Self-attention requires query, key and " "value to be of the same size"
105 | )
106 |
107 | self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)
108 | self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
109 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
110 |
111 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
112 |
113 | if add_bias_kv:
114 | self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
115 | self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
116 | else:
117 | self.bias_k = self.bias_v = None
118 |
119 | self.add_zero_attn = add_zero_attn
120 |
121 | self.reset_parameters()
122 |
123 | self.onnx_trace = False
124 | self.rot_emb = None
125 | if use_rotary_embeddings:
126 | self.rot_emb = RotaryEmbedding(dim=self.head_dim)
127 |
128 | self.enable_torch_version = False
129 | if hasattr(F, "multi_head_attention_forward"):
130 | self.enable_torch_version = True
131 | else:
132 | self.enable_torch_version = False
133 |
134 | def prepare_for_onnx_export_(self):
135 | self.onnx_trace = True
136 |
137 | def reset_parameters(self):
138 | if self.qkv_same_dim:
139 | # Empirically observed the convergence to be much better with
140 | # the scaled initialization
141 | nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
142 | nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
143 | nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
144 | else:
145 | nn.init.xavier_uniform_(self.k_proj.weight)
146 | nn.init.xavier_uniform_(self.v_proj.weight)
147 | nn.init.xavier_uniform_(self.q_proj.weight)
148 |
149 | nn.init.xavier_uniform_(self.out_proj.weight)
150 | if self.out_proj.bias is not None:
151 | nn.init.constant_(self.out_proj.bias, 0.0)
152 | if self.bias_k is not None:
153 | nn.init.xavier_normal_(self.bias_k)
154 | if self.bias_v is not None:
155 | nn.init.xavier_normal_(self.bias_v)
156 |
157 | def forward(
158 | self,
159 | query,
160 | key: Optional[Tensor],
161 | value: Optional[Tensor],
162 | key_padding_mask: Optional[Tensor] = None,
163 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
164 | need_weights: bool = True,
165 | static_kv: bool = False,
166 | attn_mask: Optional[Tensor] = None,
167 | before_softmax: bool = False,
168 | need_head_weights: bool = False,
169 | ) -> Tuple[Tensor, Optional[Tensor]]:
170 | """Input shape: Time x Batch x Channel
171 |
172 | Args:
173 | key_padding_mask (ByteTensor, optional): mask to exclude
174 | keys that are pads, of shape `(batch, src_len)`, where
175 | padding elements are indicated by 1s.
176 | need_weights (bool, optional): return the attention weights,
177 | averaged over heads (default: False).
178 | attn_mask (ByteTensor, optional): typically used to
179 | implement causal attention, where the mask prevents the
180 | attention from looking forward in time (default: None).
181 | before_softmax (bool, optional): return the raw attention
182 | weights and values before the attention softmax.
183 | need_head_weights (bool, optional): return the attention
184 | weights for each head. Implies *need_weights*. Default:
185 | return the average attention weights over all heads.
186 | """
187 | if need_head_weights:
188 | need_weights = True
189 |
190 | tgt_len, bsz, embed_dim = query.size()
191 | assert embed_dim == self.embed_dim
192 | assert list(query.size()) == [tgt_len, bsz, embed_dim]
193 |
194 | if (
195 | not self.rot_emb
196 | and self.enable_torch_version
197 | and not self.onnx_trace
198 | and incremental_state is None
199 | and not static_kv
200 | # A workaround for quantization to work. Otherwise JIT compilation
201 | # treats bias in linear module as method.
202 | and not torch.jit.is_scripting()
203 | and not need_head_weights
204 | ):
205 | assert key is not None and value is not None
206 | return F.multi_head_attention_forward(
207 | query,
208 | key,
209 | value,
210 | self.embed_dim,
211 | self.num_heads,
212 | torch.empty([0]),
213 | torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
214 | self.bias_k,
215 | self.bias_v,
216 | self.add_zero_attn,
217 | self.dropout,
218 | self.out_proj.weight,
219 | self.out_proj.bias,
220 | self.training,
221 | key_padding_mask,
222 | need_weights,
223 | attn_mask,
224 | use_separate_proj_weight=True,
225 | q_proj_weight=self.q_proj.weight,
226 | k_proj_weight=self.k_proj.weight,
227 | v_proj_weight=self.v_proj.weight,
228 | )
229 | if incremental_state is not None:
230 | saved_state = self._get_input_buffer(incremental_state)
231 | if saved_state is not None and "prev_key" in saved_state:
232 | # previous time steps are cached - no need to recompute
233 | # key and value if they are static
234 | if static_kv:
235 | assert self.encoder_decoder_attention and not self.self_attention
236 | key = value = None
237 | else:
238 | saved_state = None
239 |
240 | if self.self_attention:
241 | q = self.q_proj(query)
242 | k = self.k_proj(query)
243 | v = self.v_proj(query)
244 | elif self.encoder_decoder_attention:
245 | # encoder-decoder attention
246 | q = self.q_proj(query)
247 | if key is None:
248 | assert value is None
249 | k = v = None
250 | else:
251 | k = self.k_proj(key)
252 | v = self.v_proj(key)
253 |
254 | else:
255 | assert key is not None and value is not None
256 | q = self.q_proj(query)
257 | k = self.k_proj(key)
258 | v = self.v_proj(value)
259 | q *= self.scaling
260 |
261 | if self.bias_k is not None:
262 | assert self.bias_v is not None
263 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
264 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
265 | if attn_mask is not None:
266 | attn_mask = torch.cat(
267 | [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
268 | )
269 | if key_padding_mask is not None:
270 | key_padding_mask = torch.cat(
271 | [
272 | key_padding_mask,
273 | key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
274 | ],
275 | dim=1,
276 | )
277 |
278 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
279 | if k is not None:
280 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
281 | if v is not None:
282 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
283 |
284 | if saved_state is not None:
285 | # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
286 | if "prev_key" in saved_state:
287 | _prev_key = saved_state["prev_key"]
288 | assert _prev_key is not None
289 | prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
290 | if static_kv:
291 | k = prev_key
292 | else:
293 | assert k is not None
294 | k = torch.cat([prev_key, k], dim=1)
295 | if "prev_value" in saved_state:
296 | _prev_value = saved_state["prev_value"]
297 | assert _prev_value is not None
298 | prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
299 | if static_kv:
300 | v = prev_value
301 | else:
302 | assert v is not None
303 | v = torch.cat([prev_value, v], dim=1)
304 | prev_key_padding_mask: Optional[Tensor] = None
305 | if "prev_key_padding_mask" in saved_state:
306 | prev_key_padding_mask = saved_state["prev_key_padding_mask"]
307 | assert k is not None and v is not None
308 | key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
309 | key_padding_mask=key_padding_mask,
310 | prev_key_padding_mask=prev_key_padding_mask,
311 | batch_size=bsz,
312 | src_len=k.size(1),
313 | static_kv=static_kv,
314 | )
315 |
316 | saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
317 | saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
318 | saved_state["prev_key_padding_mask"] = key_padding_mask
319 | # In this branch incremental_state is never None
320 | assert incremental_state is not None
321 | incremental_state = self._set_input_buffer(incremental_state, saved_state)
322 | assert k is not None
323 | src_len = k.size(1)
324 |
325 | # This is part of a workaround to get around fork/join parallelism
326 | # not supporting Optional types.
327 | if key_padding_mask is not None and key_padding_mask.dim() == 0:
328 | key_padding_mask = None
329 |
330 | if key_padding_mask is not None:
331 | assert key_padding_mask.size(0) == bsz
332 | assert key_padding_mask.size(1) == src_len
333 |
334 | if self.add_zero_attn:
335 | assert v is not None
336 | src_len += 1
337 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
338 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
339 | if attn_mask is not None:
340 | attn_mask = torch.cat(
341 | [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
342 | )
343 | if key_padding_mask is not None:
344 | key_padding_mask = torch.cat(
345 | [
346 | key_padding_mask,
347 | torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
348 | ],
349 | dim=1,
350 | )
351 |
352 | if self.rot_emb:
353 | q, k = self.rot_emb(q, k)
354 |
355 | attn_weights = torch.bmm(q, k.transpose(1, 2))
356 | attn_weights = MultiheadAttention.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
357 |
358 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
359 |
360 | if attn_mask is not None:
361 | attn_mask = attn_mask.unsqueeze(0)
362 | if self.onnx_trace:
363 | attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
364 | attn_weights += attn_mask
365 |
366 | if key_padding_mask is not None:
367 | # don't attend to padding symbols
368 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
369 | attn_weights = attn_weights.masked_fill(
370 | key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
371 | )
372 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
373 |
374 | if before_softmax:
375 | return attn_weights, v
376 |
377 | attn_weights_float = utils_softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
378 | attn_weights = attn_weights_float.type_as(attn_weights)
379 | attn_probs = F.dropout(
380 | attn_weights_float.type_as(attn_weights),
381 | p=self.dropout,
382 | training=self.training,
383 | )
384 | assert v is not None
385 | attn = torch.bmm(attn_probs, v)
386 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
387 | if self.onnx_trace and attn.size(1) == 1:
388 | # when ONNX tracing a single decoder step (sequence length == 1)
389 | # the transpose is a no-op copy before view, thus unnecessary
390 | attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
391 | else:
392 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
393 | attn = self.out_proj(attn)
394 | attn_weights: Optional[Tensor] = None
395 | if need_weights:
396 | attn_weights = attn_weights_float.view(
397 | bsz, self.num_heads, tgt_len, src_len
398 | ).type_as(attn).transpose(1, 0)
399 | if not need_head_weights:
400 | # average attention weights over heads
401 | attn_weights = attn_weights.mean(dim=0)
402 |
403 | return attn, attn_weights
404 |
405 | @staticmethod
406 | def _append_prev_key_padding_mask(
407 | key_padding_mask: Optional[Tensor],
408 | prev_key_padding_mask: Optional[Tensor],
409 | batch_size: int,
410 | src_len: int,
411 | static_kv: bool,
412 | ) -> Optional[Tensor]:
413 | # saved key padding masks have shape (bsz, seq_len)
414 | if prev_key_padding_mask is not None and static_kv:
415 | new_key_padding_mask = prev_key_padding_mask
416 | elif prev_key_padding_mask is not None and key_padding_mask is not None:
417 | new_key_padding_mask = torch.cat(
418 | [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
419 | )
420 | # During incremental decoding, as the padding token enters and
421 | # leaves the frame, there will be a time when prev or current
422 | # is None
423 | elif prev_key_padding_mask is not None:
424 | filler = torch.zeros(
425 | (batch_size, src_len - prev_key_padding_mask.size(1)),
426 | device=prev_key_padding_mask.device,
427 | )
428 | new_key_padding_mask = torch.cat(
429 | [prev_key_padding_mask.float(), filler.float()], dim=1
430 | )
431 | elif key_padding_mask is not None:
432 | filler = torch.zeros(
433 | (batch_size, src_len - key_padding_mask.size(1)),
434 | device=key_padding_mask.device,
435 | )
436 | new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
437 | else:
438 | new_key_padding_mask = prev_key_padding_mask
439 | return new_key_padding_mask
440 |
441 | @torch.jit.export
442 | def reorder_incremental_state(
443 | self, incremental_state: Dict[str, Dict[str, Optional[Tensor]]], new_order: Tensor
444 | ):
445 | """Reorder buffered internal state (for incremental generation)."""
446 | input_buffer = self._get_input_buffer(incremental_state)
447 | if input_buffer is not None:
448 | for k in input_buffer.keys():
449 | input_buffer_k = input_buffer[k]
450 | if input_buffer_k is not None:
451 | if self.encoder_decoder_attention and input_buffer_k.size(0) == new_order.size(
452 | 0
453 | ):
454 | break
455 | input_buffer[k] = input_buffer_k.index_select(0, new_order)
456 | incremental_state = self._set_input_buffer(incremental_state, input_buffer)
457 | return incremental_state
458 |
459 | def _get_input_buffer(
460 | self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
461 | ) -> Dict[str, Optional[Tensor]]:
462 | result = self.get_incremental_state(incremental_state, "attn_state")
463 | if result is not None:
464 | return result
465 | else:
466 | empty_result: Dict[str, Optional[Tensor]] = {}
467 | return empty_result
468 |
469 | def _set_input_buffer(
470 | self,
471 | incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
472 | buffer: Dict[str, Optional[Tensor]],
473 | ):
474 | return self.set_incremental_state(incremental_state, "attn_state", buffer)
475 |
476 | def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
477 | return attn_weights
478 |
479 | def upgrade_state_dict_named(self, state_dict, name):
480 | prefix = name + "." if name != "" else ""
481 | items_to_add = {}
482 | keys_to_remove = []
483 | for k in state_dict.keys():
484 | if k.endswith(prefix + "in_proj_weight"):
485 | # in_proj_weight used to be q + k + v with same dimensions
486 | dim = int(state_dict[k].shape[0] / 3)
487 | items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
488 | items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
489 | items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
490 |
491 | keys_to_remove.append(k)
492 |
493 | k_bias = prefix + "in_proj_bias"
494 | if k_bias in state_dict.keys():
495 | dim = int(state_dict[k].shape[0] / 3)
496 | items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
497 | items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
498 | items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
499 |
500 | keys_to_remove.append(prefix + "in_proj_bias")
501 |
502 | for k in keys_to_remove:
503 | del state_dict[k]
504 |
505 | for key, value in items_to_add.items():
506 | state_dict[key] = value
507 |
--------------------------------------------------------------------------------