├── scripts ├── __init__.py ├── __pycache__ │ ├── util.cpython-37.pyc │ ├── util.cpython-38.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── mmseqs2.cpython-37.pyc │ ├── mmseqs2.cpython-38.pyc │ ├── predict.cpython-37.pyc │ └── predict.cpython-38.pyc ├── util.py ├── predict.py ├── mmseqs2.py └── colabfold_alphafold.py ├── .gitignore ├── LICENSE └── README.md /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | alphafold/ 2 | -------------------------------------------------------------------------------- /scripts/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meilerlab/AF2_GPCR_Kinase/HEAD/scripts/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meilerlab/AF2_GPCR_Kinase/HEAD/scripts/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meilerlab/AF2_GPCR_Kinase/HEAD/scripts/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meilerlab/AF2_GPCR_Kinase/HEAD/scripts/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/mmseqs2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meilerlab/AF2_GPCR_Kinase/HEAD/scripts/__pycache__/mmseqs2.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/mmseqs2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meilerlab/AF2_GPCR_Kinase/HEAD/scripts/__pycache__/mmseqs2.cpython-38.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/predict.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meilerlab/AF2_GPCR_Kinase/HEAD/scripts/__pycache__/predict.cpython-37.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/predict.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meilerlab/AF2_GPCR_Kinase/HEAD/scripts/__pycache__/predict.cpython-38.pyc -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Diego del Alamo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from typing import Dict, List, NoReturn 5 | 6 | from alphafold.data import pipeline 7 | from alphafold.data import templates 8 | from alphafold.data.tools import hhsearch 9 | 10 | def mk_mock_template(seq: str) -> dict: 11 | 12 | r"""Generates mock templates that will not influence prediction 13 | Taken from ColabFold version 62d7558c91a9809712b022faf9d91d8b183c328c 14 | 15 | Parameters 16 | ---------- 17 | seq: Query sequence 18 | 19 | Returns 20 | ---------- 21 | Dictionary with blank/empty/meaningless features 22 | 23 | """ 24 | 25 | # Define constants 26 | lentype = templates.residue_constants.atom_type_num 27 | lseq = len(seq) 28 | 29 | # Since alphafold's model requires a template input 30 | # We create a blank example w/ zero input, confidence -1 31 | aatypes = np.array( 32 | templates.residue_constants.sequence_to_onehot( 33 | "-" * lseq, templates.residue_constants.HHBLITS_AA_TO_ID 34 | ) 35 | ) 36 | 37 | return { 38 | "template_all_atom_positions": np.zeros((lseq, lentype, 3))[None], 39 | "template_all_atom_masks": np.zeros((lseq, lentype))[None], 40 | "template_sequence": [f"none".encode()], 41 | "template_aatype": aatypes[None], 42 | "template_confidence_scores": np.full(lseq, -1)[None], 43 | "template_domain_names": [f"none".encode()], 44 | "template_release_date": [f"none".encode()], 45 | } 46 | 47 | 48 | ############################### 49 | 50 | 51 | def mk_template(seq: str, a3m_lines=str, path=str) -> dict: 52 | 53 | r"""Parses templates into features 54 | 55 | Parameters 56 | ---------- 57 | seq : Query sequence 58 | a3m_lines : Lines form MMSeqs2 alignment 59 | path : Path to templates fetched using MMSeqs2 60 | 61 | Returns 62 | ---------- 63 | Dictionary with features 64 | 65 | """ 66 | 67 | result = hhsearch.HHSearch( 68 | binary_path="hhsearch", databases=[f"{ path }/pdb70"] 69 | ).query(a3m_lines) 70 | 71 | return templates.HhsearchHitFeaturizer( 72 | mmcif_dir=path, 73 | max_template_date="2100-01-01", 74 | max_hits=20, 75 | kalign_binary_path="kalign", 76 | release_dates_path=None, 77 | obsolete_pdbs_path=None, 78 | ).get_templates(query_sequence=seq, hits=pipeline.parsers.parse_hhr(result)) 79 | 80 | 81 | ############################### 82 | 83 | 84 | def setup_features(seq: str, a3m_lines: list, tfeatures_in: dict) -> dict: 85 | 86 | r"""Set up features for alphafold 87 | 88 | Parameters 89 | ---------- 90 | seq : Sequence (string) 91 | a3m_lines : Sequence alignment lines 92 | tfeatures_in : Template features 93 | 94 | Returns 95 | ---------- 96 | Alphafold features object 97 | 98 | """ 99 | 100 | msa = pipeline.parsers.parse_a3m(a3m_lines) 101 | return { 102 | **pipeline.make_sequence_features( 103 | sequence=seq, description="none", num_res=len(seq) 104 | ), 105 | **pipeline.make_msa_features(msas=[msa]), 106 | **tfeatures_in, 107 | } 108 | 109 | 110 | def mutate_msa( 111 | a3m_lines: str, 112 | pos_res: Dict[int, str], 113 | ) -> str: 114 | r"""Mutates every position in an MSA to a residue of interest 115 | 116 | Example usage: mutate_msa( a3m_lines, { 15: "A", 155: "A" } ) 117 | This will mutate residues 15 and 155 to alanine throughout the MSA 118 | 119 | Parameters 120 | ---------- 121 | a3m_lines : Sequence alignment 122 | pos : Position to change 123 | target_res : Residue to mutate to 124 | 125 | Returns 126 | ---------- 127 | Sequence alignment (as string) 128 | 129 | """ 130 | 131 | for target_res in pos_res.values(): 132 | assert len(target_res) == 1 133 | 134 | output = [] 135 | 136 | # Iterate over alignment lines 137 | for line in a3m_lines.split("\n"): 138 | if line.startswith(">"): 139 | output.append(line) 140 | elif len(line) > 1: 141 | line = list(line) 142 | for pos, res in pos_res.items(): 143 | if line[pos] in "ACDEFGHIKLMNPQRSTVWY": 144 | line[pos] = res 145 | output.append("".join(line)) 146 | else: 147 | output.append(line) 148 | return "\n".join(output) 149 | 150 | 151 | def mutate(x, y): 152 | mutate_msa(x, y) # Alias for brevity 153 | 154 | 155 | def plddt_to_bfactor(filename: str, maxval: float = 100.0) -> NoReturn: 156 | r"""Converts a pLDDT vals to a B factor 157 | This equation is derived from the following publication: 158 | "Improved protein structure refinement guided by deep learning based 159 | accuracy estimation" by Hiranuma et al 2021 160 | https://doi.org/10.1038/s41467-021-21511-x 161 | 162 | Parameters 163 | ---------- 164 | filename : Name of PDB file 165 | maxval : Set to 100 if using AF2 (or 1 if RoseTTAFold) 166 | 167 | Returns 168 | ---------- 169 | None 170 | 171 | """ 172 | pdb = Bio.PDB.PDBParser().get_structure("TEMP", filename) 173 | for atom in pdb.get_atoms(): 174 | rmsf = 1.5 * np.exp(4 * (0.7 - (atom.bfactor / maxval))) 175 | atom.bfactor = (8.0 / 3.0) * (np.pi**2) * (rmsf**2) 176 | 177 | pdbio = Bio.PDB.PDBIO() 178 | pdbio.set_structure(pdb) 179 | pdbio.save(filename) 180 | 181 | def pdb2str( pdbfile: str ) -> str: 182 | 183 | r""" Converts PDB file to string 184 | 185 | Credit to Sergey Ovchinnikov for writing this 186 | 187 | Args: 188 | pdbfile: String with PDB file to convert 189 | 190 | Output: 191 | String 192 | 193 | """ 194 | lines = [] 195 | for line in open( pdbfile, "r" ): 196 | if line[ :4 ] == "ATOM": 197 | lines.append( line ) 198 | return "".join( lines ) 199 | 200 | def remove_msa_for_template_aligned_regions(feature_dict): 201 | mask = np.zeros(feature_dict['seq_length'][0], dtype=bool) 202 | for templ in feature_dict['template_sequence']: 203 | for i,aa in enumerate(templ.decode("utf-8")): 204 | if aa != '-': 205 | mask[i] = True 206 | feature_dict['deletion_matrix_int'][:,mask] = 0 207 | feature_dict['msa'][:,mask] = 21 208 | return feature_dict 209 | 210 | def remove_msa_for_custom_template_aligned_regions(feature_dict): 211 | mask = np.zeros(feature_dict['seq_length'][0], dtype=bool) 212 | tempseq = list(feature_dict['template_sequence']) 213 | for i,aa in enumerate(tempseq): 214 | if aa != '-': 215 | mask[i] = True 216 | feature_dict['deletion_matrix_int'][:,mask] = 0 217 | feature_dict['msa'][:,mask] = 21 218 | return feature_dict 219 | 220 | #read a pdb file and return the sequence 221 | def pdb2seq(pdbfile): 222 | from Bio import SeqIO 223 | for record in SeqIO.parse(pdbfile, "pdb-atom"): 224 | return record.seq 225 | 226 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Python interface to model user-defined funtional and structural features of Kinases and GPCRs with AlphaFold2. 2 | 3 | This repository is an expansion of our previous [colabfold](https://github.com/sokrypton/ColabFold)-based work ["Sampling alternative conformational states of transporters and receptors with AlphaFold2"](https://elifesciences.org/articles/75751) by Diego del Alamo, Davide Sala, Hassane S. Mchaourab, and Jens Meiler. All the functionalities have been kept. To have a general overview, please read the README.md at https://github.com/delalamo/af2_conformations. Here, our previous workflow is extended with the aim of predicting a user-defined conformational state of GPCR and Kinase with minimal effort. We also introduced known features like using a custom pdb template or print the predicted pTM score for each model. 4 | 5 | ### How to use the code in this repository 6 | 7 | Before importing the code contained in the `scripts/` folder, the user needs to install the AlphaFold source code and download the parameters to a directory named `params/`. Additional Python modules that must be installed include [Numpy](https://numpy.org/), [Requests](https://docs.python-requests.org/en/latest/), and [Logging](https://abseil.io/docs/python/guides/logging). 8 | 9 | The scripts can be imported and used out-of-the-box to fetch multiple sequence alignments and/or templates of interest. Note that the `max_msa_clusters` and `max_extra_msa` options can be provided to reduce the size of the multiple sequence alignment. If these are not provided, the networks default values will be used. Additional options allow the number of recycles, as well as the number of loops through the recurrent Structure Module, to be specified. In addition, ptm can be enabled to print pTM score as a suffix of model name. 10 | 11 | ## Predicting a user-defined GPCR functional state 12 | 13 | To predict a specific activation state of a GPCR target, the pdbs list must contain one of the following string in the first position ("Inactive", "Active", "Intermediate", "G protein", "Arrestin"). The script will retrieve templates in the annotated functional state from GPCRdb.org. Template PDBs can be excluded simply by adding PDB IDs without chain the ID specified to the list. Example: ["G protein", "7FII"] to predict the active state of your target by using G protein bound templates but excluding 7FII. Which PDB ids have been used to bias the prediction can be retrieved from the log file (example by typing 'grep PDBS example.log'). 14 | Templates can also be randomized for each model. 15 | Below, an example of outputting info into example.log and predict 50 models of LSHR by using the best 4 templates determined with a G protein bound but excluding all LSHR PDBs released. 16 | 17 | ```python 18 | from AF2_GPCR_Kinase.scripts import mmseqs2 19 | import multiprocessing 20 | import logging 21 | logging.basicConfig(filename='example.log', level=logging.DEBUG) # print log with debug level 22 | 23 | # Jobname for reference 24 | jobname = 'lshr_gprot_4t' 25 | 26 | # Amino acid sequence. Whitespace and inappropriate characters are automatically removed 27 | sequence = ("YDFLRVLIWLINILAIMGNMTVLFVLLTSRYKLTVPRFLMCNLSFADFCMGLYLLLIASVDSQTKGQYYNHAIDWQTGSGCSTAGFFTVFASELSVYTLTVITLERWHTITYAIHLDQKLRLRHAILIMLGGWLFSSLIAMLPLVGVSNYMKVSICFPMDVETTLSQVYILTILILNVVAFFIICACYIKIYFAVRNPELMATNKDTKIAKKMAILIFTDFTCMAPISFFAISAAFKVPLITVTNSKVLLVLFYPINSCANPFLYAIFTKTFQRDFFLLLSKFGCC") 28 | 29 | # State annotation followed by PDB IDs to be excluded. 30 | pdbs = ["G protein", "7FII", "7FIG", "7FIH", "7FIJ"] 31 | 32 | #parameters 33 | max_msa_clusters = 32 # Number of sequence clusters 34 | max_extra_msa = 64 # Number of extra sequences not clustered 35 | max_recycles = 1 # Number of neural network iterations 36 | n_struct_module_repeats = 8 # Number of the structure module iterations 37 | n_models = 50 # Number of models to be predicted 38 | model_id = -1 # Which AF neural network. -1 = Randomize 39 | model_params = -1 # Which AF neural network parameters. -1 = Randomize 40 | ptm = True # Print pTM value before .pdb of each model 41 | _rank = 1 # Number assigned to the first predicted model 42 | remove_msa_for_template_aligned = False # Remove the genetic information for regions already covered by templates. Copied from Heo L. et al., DOI: 10.1002/prot.26382. 43 | n_templates = 4 # Number of templates to be used 44 | 45 | # Initializes the Runner object that queries the MMSeqs2 server 46 | mmseqs2_runner = mmseqs2.MMSeqs2Runner( jobname, sequence, n_templates = n_templates ) 47 | 48 | # Fetches the data and saves to the appropriate directory 49 | a3m_lines, template_path = mmseqs2_runner.run_job(templates = pdbs ) 50 | 51 | from AF2_GPCR_Kinase.scripts import predict 52 | 53 | for i in range( n_models ): 54 | model_name = str(jobname + "_" + str(_rank) + ".pdb") 55 | 56 | # Optionally, templates can be shuffled within the list of PDBs passing filters. 57 | # Uncomment line below to enable templates randomization. 58 | # template_path = mmseqs2_runner.shuffle_templates() 59 | 60 | # Run a prediction with templates 61 | predict.predict_structure_from_templates( sequence, model_name, a3m_lines, template_path=template_path, model_id=model_id, max_msa_clusters=max_msa_clusters, max_extra_msa=max_extra_msa, max_recycles=max_recycles, n_struct_module_repeats=n_struct_module_repeats, ptm=ptm, remove_msa_for_template_aligned=remove_msa_for_template_aligned ) 62 | 63 | #Two alternatives to predict 1. without templates or 2. with local pdb as a template 64 | 65 | # 1. Run a prediction without templates 66 | predict.predict_structure_no_templates( sequence, model_name, a3m_lines, model_id=model_id, max_msa_clusters=max_msa_clusters, max_extra_msa=max_extra_msa, max_recycles=max_recycles, n_struct_module_repeats=n_struct_module_repeats, ptm=ptm) 67 | 68 | # 2. Run a prediction with a local pdb template. 69 | predict.predict_structure_from_custom_template( sequence, model_name, a3m_lines, template_pdb="pdb_file", model_id=model_id, max_msa_clusters=max_msa_clusters, max_extra_msa=max_extra_msa, max_recycles=max_recycles, n_struct_module_repeats=n_struct_module_repeats, ptm=ptm, remove_msa_for_template_aligned=remove_msa_for_template_aligned) 70 | 71 | _rank += 1 72 | ``` 73 | 74 | 75 | ## Predicting user-defined structural features of kinases 76 | 77 | Similar to predicting user-defined GPCRs functional states, users can force AF2 to retrieve kinase templates matching three structural feature values from KLIFS. Allowed structural properties are: 1. DFG 2. aC_helix 3. Salt-bridge (KIII.17 and EαC.24). Allowed values for the corresponding property 1. DFG: out, in, out-like, all 2. aC_helix: out, in, all 3. Salt-bridge: yes, no, all. Following the example above, the format must be a three members list in the first position of the pdbs list. Besides sequence and jobname, what to change in the script above to predict models with templates in DFG=out is reported below. 78 | 79 | ```python 80 | # structural properties of kinase templates to be used 81 | kinase_temps = ["out", "all", "all" ] # Format is [DFG, aC_helix, salt_bridge] 82 | 83 | # kinases annotation list followed by PDB IDs to be excluded 84 | pdbs = [kinase_temps, "6N3O", "6N3L", "6N3N", "7QQ6", "7QWK"] 85 | ``` 86 | ### Introducing mutations into MSA 87 | 88 | There is also functionality to introduce mutations (e.g. alanines) across the entire MSA to remove the evolutionary evidence for specific interactions (see [here](https://www.biorxiv.org/content/10.1101/2021.11.29.470469v1) and [here](https://twitter.com/sokrypton/status/1464748132852547591) on why you would want to do this). This can be achieved as follows: 89 | 90 | ```python 91 | # Define the mutations and introduce into the sequence and MSA 92 | residues = [ 41,42,45,46,56,59,60,63,281,282,285,286,403,407 ] 93 | muts = { r: "A" for r in residues } 94 | mutated_msa = util.mutate_msa( a3m_lines, muts ) 95 | ``` 96 | ### Known issues 97 | 98 | Here is a shortlist of known problems that we are currently working on: 99 | * The MMSeqs2 server queries the PDB70, rather than the full PDB. This can cause some structures to be missed if their sequences are nearly identical to those of other PDB files. 100 | * Multimer prediction is not currently supported. 101 | * Custom MSAs are not currently supported. 102 | * Additional annotations of both GPCRs and kinases are not currently supported. 103 | 104 | If you find any other issues please let us know in the "issues" tab above. 105 | 106 | ### Citations 107 | 108 | If the code in this repository has helped your scientific project, please consider citing our papers: 109 | 110 | ```bibtex 111 | @article {Sala2022.12.11.519936, 112 | author = {Sala, Davide and Meiler, Jens}, 113 | title = {Biasing AlphaFold2 to predict GPCRs and Kinases with user-defined functional or structural properties}, 114 | elocation-id = {2022.12.11.519936}, 115 | year = {2022}, 116 | doi = {10.1101/2022.12.11.519936}, 117 | publisher = {Cold Spring Harbor Laboratory}, 118 | URL = {https://www.biorxiv.org/content/early/2022/12/11/2022.12.11.519936}, 119 | eprint = {https://www.biorxiv.org/content/early/2022/12/11/2022.12.11.519936.full.pdf}, 120 | journal = {bioRxiv} 121 | } 122 | @article {10.7554/eLife.75751, 123 | article_type = {journal}, 124 | title = {Sampling alternative conformational states of transporters and receptors with AlphaFold2}, 125 | author = {del Alamo, Diego and Sala, Davide and Mchaourab, Hassane S and Meiler, Jens}, 126 | editor = {Robertson, Janice L and Swartz, Kenton J and Robertson, Janice L}, 127 | volume = 11, 128 | year = 2022, 129 | month = {mar}, 130 | pub_date = {2022-03-03}, 131 | pages = {e75751}, 132 | citation = {eLife 2022;11:e75751}, 133 | doi = {10.7554/eLife.75751}, 134 | url = {https://doi.org/10.7554/eLife.75751}, 135 | journal = {eLife}, 136 | issn = {2050-084X}, 137 | publisher = {eLife Sciences Publications, Ltd}, 138 | } 139 | ``` 140 | 141 | -------------------------------------------------------------------------------- /scripts/predict.py: -------------------------------------------------------------------------------- 1 | from . import util 2 | import os 3 | import numpy as np 4 | import random 5 | import sys 6 | 7 | import alphafold 8 | from alphafold.common import protein 9 | from alphafold.model import data 10 | from alphafold.model import config 11 | from alphafold.model import model 12 | 13 | from typing import Any, List, Mapping, NoReturn 14 | 15 | from absl import logging 16 | import jax.numpy as jnp 17 | import jax 18 | 19 | def set_config( 20 | use_templates: bool, 21 | max_msa_clusters: int, 22 | max_extra_msa: int, 23 | max_recycles: int, 24 | model_id: int, 25 | n_struct_module_repeats: int, 26 | n_features_in: int, 27 | monomer: bool = True, 28 | model_params: int = 0, 29 | ) -> model.RunModel: 30 | 31 | r"""Generated Runner object for AlphaFold 32 | 33 | Parameters 34 | ---------- 35 | use_templates : Whether templates are used 36 | max_msa_cluster : How many sequences to use in MSA 37 | max_extra_msa : How many extra sequences to include for summary stats 38 | max_recycles : Number of recycling iterations 39 | model_id : Which AF2 model to use 40 | n_struct_module_repeats : Number of passes through structure module 41 | n_features_in : Unclear 42 | monomer : Predicting as a monomer (set to False if using AlphaFold-multimer) 43 | model_params : Which AF2 model config to use 44 | 45 | Returns 46 | ---------- 47 | AlphaFold RunModel object 48 | 49 | """ 50 | 51 | if model_id not in range(1, 6): 52 | logging.warning("model_id must be between 1 and 5!") 53 | if use_templates: 54 | model_id = random.randint(1, 2) 55 | else: 56 | model_id = random.randint(1, 5) 57 | 58 | # Match model_params to model_id 59 | # Sometimes we don't want to do this, for example, 60 | # to reproduce output from ColabFold (which only uses models 1 and 3) 61 | 62 | name = f"model_{ model_params }_ptm" 63 | if not monomer: 64 | name = f"model_{ model_params }_multimer" 65 | 66 | cfg = config.model_config(name) 67 | 68 | #### Provide config settings 69 | 70 | #### MSAs 71 | 72 | cfg.data.eval.num_ensemble = 1 73 | if max_msa_clusters > 0: 74 | cfg.data.eval.max_msa_clusters = min(n_features_in, max_msa_clusters) 75 | if max_extra_msa > 0: 76 | cfg.data.common.max_extra_msa = max( 77 | 1, min(n_features_in - max_msa_clusters, max_extra_msa) 78 | ) 79 | 80 | #### Recycle and number of iterations 81 | 82 | if monomer: 83 | cfg.data.common.num_recycle = max_recycles 84 | cfg.model.num_recycle = max_recycles 85 | cfg.model.heads.structure_module.num_layer = n_struct_module_repeats 86 | 87 | #### Templates 88 | 89 | t = use_templates # for brevity 90 | 91 | cfg.data.common.use_templates = use_templates 92 | cfg.model.embeddings_and_evoformer.template.embed_torsion_angles = t 93 | cfg.model.embeddings_and_evoformer.template.enabled = t 94 | cfg.data.common.reduce_msa_clusters_by_max_templates = t 95 | cfg.data.eval.subsample_templates = t 96 | 97 | p = data.get_model_haiku_params(model_name=name, data_dir=".") 98 | 99 | logging.debug("Prediction parameters:") 100 | logging.debug("\tModel ID: {}".format(model_id)) 101 | logging.debug("\tUsing templates: {}".format(t)) 102 | logging.debug( 103 | "\tMaximum MSA clusters: {}".format(cfg.data.eval.max_msa_clusters) 104 | ) 105 | logging.debug( 106 | "\tMaximum extra MSA clusters: {}".format( 107 | cfg.data.common.max_extra_msa 108 | ) 109 | ) 110 | logging.debug( 111 | "\tNumber recycling iterations: {}".format(cfg.model.num_recycle) 112 | ) 113 | logging.debug( 114 | "\tNumber of structure module repeats: {}".format( 115 | cfg.model.heads.structure_module.num_layer 116 | ) 117 | ) 118 | 119 | return model.RunModel(cfg, p) 120 | 121 | 122 | def run_one_job( 123 | runner: model.RunModel, features_in: dict, random_seed: int, outname: str, ptm: bool 124 | ) -> Mapping[str, Any]: 125 | r"""Runs one AF2 job with input parameters 126 | 127 | Parameters 128 | ---------- 129 | runner : AlphaFold2 job runner 130 | features_in : Input features, including MSA and templates 131 | random_seed : Random seed 132 | outname : Name of PDB file to write 133 | 134 | Returns 135 | ---------- 136 | None 137 | 138 | """ 139 | 140 | # Do one last bit of processing 141 | features = runner.process_features(features_in, random_seed=random_seed) 142 | 143 | # Generate the model 144 | result = runner.predict(features, random_seed) 145 | pred = protein.from_prediction(features, result) 146 | 147 | # Write to file 148 | to_np = lambda a: np.asarray(a) 149 | if ptm: 150 | prefix = outname.rsplit(".",1)[0] 151 | suffix = outname.rsplit(".",1)[-1] 152 | outname = prefix + f"_{to_np(result['ptm']):.2f}." + suffix 153 | 154 | to_pdb(outname, pred, result["plddt"], features_in["residue_index"]) 155 | 156 | return result 157 | 158 | 159 | def predict_structure_from_templates( 160 | seq: str, 161 | outname: str, 162 | a3m_lines: str, 163 | template_path: str, 164 | model_id: int = -1, 165 | model_params: int = -1, 166 | random_seed: int = -1, 167 | max_msa_clusters: int = 8, 168 | max_extra_msa: int = 16, 169 | max_recycles: int = 3, 170 | n_struct_module_repeats: int = 8, 171 | ptm: bool = False, 172 | remove_msa_for_template_aligned: bool = False, 173 | ) -> NoReturn: 174 | 175 | r"""Predicts the structure. 176 | 177 | Parameters 178 | ---------- 179 | seq : Sequence 180 | outname : Name of output PDB 181 | a3m_lines : String of entire alignment 182 | template_paths : Where to locate templates 183 | model_id : Which AF2 model to run (must be 1 or 2 for templates) 184 | model_params : Which parameters to provide to AF2 model 185 | random_seed : Random seed 186 | max_msa_clusters : Number of sequences to use 187 | max_extra_msa : Number of extra seqs for summary stats 188 | max_recycles : Number of iterations through AF2 189 | n_struct_module_repeats : Number of passes through structural refinement 190 | ptm: whether adding ptm score within file name or not 191 | move_prefix : Prefix for temporary files (deleted after fxn completion) 192 | 193 | Returns 194 | ---------- 195 | None 196 | 197 | """ 198 | 199 | if random_seed == -1: 200 | random_seed = random.randrange(sys.maxsize) 201 | 202 | if model_id not in (1, 2): 203 | model_id = random.randint(1, 2) 204 | 205 | if model_params not in (1, 2): 206 | model_params = random.randint(1, 2) 207 | 208 | # Assemble the dictionary of input features 209 | features_in = util.setup_features( 210 | seq, a3m_lines, util.mk_template(seq, a3m_lines, template_path).features 211 | ) 212 | 213 | if remove_msa_for_template_aligned: 214 | features_in = util.remove_msa_for_template_aligned_regions(features_in) 215 | 216 | # Run the models 217 | model_runner = set_config( 218 | True, 219 | max_msa_clusters, 220 | max_extra_msa, 221 | max_recycles, 222 | model_id, 223 | n_struct_module_repeats, 224 | len(features_in["msa"]), 225 | model_params=model_params, 226 | ) 227 | 228 | result = run_one_job(model_runner, features_in, random_seed, outname, ptm) 229 | 230 | del model_runner 231 | 232 | return result 233 | 234 | def predict_structure_no_templates( 235 | seq: str, 236 | outname: str, 237 | a3m_lines: str, 238 | model_id: int = -1, 239 | model_params: int = -1, 240 | random_seed: int = -1, 241 | max_msa_clusters: int = -1, 242 | max_extra_msa: int = -1, 243 | max_recycles: int = 3, 244 | n_struct_module_repeats: int = 8, 245 | ptm: bool = False, 246 | ) -> NoReturn: 247 | 248 | r"""Predicts the structure. 249 | 250 | Parameters 251 | ---------- 252 | seq : Sequence 253 | outname : Name of output PDB 254 | a3m_lines : String of entire alignment 255 | model_id : Which AF2 model to run (must be 1 or 2 for templates) 256 | random_seed : Random seed 257 | max_msa_clusters : Number of sequences to use 258 | max_extra_msa : Number of extra seqs for summary stats 259 | max_recycles : Number of iterations through AF2 260 | n_struct_module_repeats : Number of passes through structural refinement 261 | ptm : ptm: whether adding ptm score within file name or not 262 | 263 | Returns 264 | ---------- 265 | None 266 | 267 | """ 268 | 269 | # Set AF2 model details 270 | if model_id not in range(1, 6): 271 | model_id = random.randint(1, 5) 272 | 273 | if model_params not in range(1, 6): 274 | model_params = model_id 275 | 276 | if random_seed == -1: 277 | random_seed = random.randrange(sys.maxsize) 278 | 279 | features_in = util.setup_features(seq, a3m_lines, util.mk_mock_template(seq)) 280 | 281 | model_runner = set_config( 282 | False, 283 | max_msa_clusters, 284 | max_extra_msa, 285 | max_recycles, 286 | model_id, 287 | n_struct_module_repeats, 288 | len(features_in["msa"]), 289 | model_params=model_params, 290 | ) 291 | 292 | result = run_one_job(model_runner, features_in, random_seed, outname, ptm) 293 | 294 | del model_runner 295 | 296 | return result 297 | 298 | def predict_structure_from_custom_template( 299 | seq: str, 300 | outname: str, 301 | a3m_lines: str, 302 | template_pdb: str, 303 | model_id: int = -1, 304 | model_params: int = -1, 305 | random_seed: int = -1, 306 | max_msa_clusters: int = -1, 307 | max_extra_msa: int = -1, 308 | max_recycles: int = 3, 309 | n_struct_module_repeats: int = 8, 310 | ptm: bool = False, 311 | remove_msa_for_template_aligned: bool = False 312 | ): 313 | 314 | f""" Predicts the structure. 315 | Parameters 316 | ---------- 317 | seq : Sequence 318 | outname : Name of output PDB 319 | a3m_lines : String of entire alignment 320 | template_pdb : name of the PDB file with path in case it's not in the local folder 321 | model_id : Which AF2 model to run (must be 1 or 2 for templates) 322 | model_params : Which parameters to provide to AF2 model 323 | random_seed : Random seed 324 | max_msa_clusters : Number of sequences to use 325 | max_extra_msa : Number of extra seqs for summary stats 326 | max_recycles : Number of iterations through AF2 327 | n_struct_module_repeats : Number of passes through structural refinement 328 | ptm: whether adding ptm score within file name or not 329 | 330 | 331 | Output: 332 | None 333 | """ 334 | 335 | if random_seed == -1: 336 | random_seed = random.randrange(sys.maxsize) 337 | 338 | if model_id not in (1, 2): 339 | model_id = random.randint(1, 2) 340 | 341 | if model_params not in (1, 2): 342 | model_params = random.randint(1, 2) 343 | 344 | print( "Prediction parameters:" ) 345 | print( f"\tTemplate: { template_pdb }" ) 346 | print( f"\tMaximum number of MSA clusters: { max_msa_clusters }" ) 347 | print( f"\tMaximum number of extra MSA clusters: { max_extra_msa }" ) 348 | print( f"\tMaximum number of recycling iterations: { max_recycles }" ) 349 | 350 | pdb = protein.from_pdb_string( util.pdb2str( template_pdb ) ) 351 | tempseq = util.pdb2seq(template_pdb) 352 | tfeatures_in = { 353 | "template_aatype" : jax.nn.one_hot( pdb.aatype, 22 )[ : ][ None ], 354 | "template_all_atom_masks" : pdb.atom_mask[ : ][ None ], 355 | "template_all_atom_positions" : pdb.atom_positions[ :][ None ], 356 | "template_domain_names" : np.asarray( [ "None" ] ), 357 | "template_sequence" : tempseq } 358 | 359 | 360 | # Assemble the dictionary of input features 361 | features_in = util.setup_features( 362 | seq, a3m_lines, tfeatures_in) 363 | 364 | if remove_msa_for_template_aligned: 365 | features_in = util.remove_msa_for_custom_template_aligned_regions(features_in) 366 | 367 | # Run the models 368 | model_runner = set_config( 369 | True, 370 | max_msa_clusters, 371 | max_extra_msa, 372 | max_recycles, 373 | model_id, 374 | n_struct_module_repeats, 375 | len(features_in["msa"]), 376 | model_params=model_params, 377 | ) 378 | 379 | result = run_one_job(model_runner, features_in, random_seed, outname, ptm) 380 | 381 | del model_runner 382 | 383 | return result 384 | 385 | def to_pdb( 386 | outname, pred, plddts, res_idx # type unknown but check? # type unknown but check? 387 | ) -> NoReturn: 388 | 389 | r"""Writes unrelaxed PDB to file 390 | 391 | Parameters 392 | ---------- 393 | outname : Name of output PDB 394 | pred : Prediction to write to PDB 395 | plddts : Predicted errors 396 | res_idx : Residues to print (default=all) 397 | 398 | Returns 399 | ---------- 400 | None 401 | 402 | """ 403 | 404 | with open(outname, "w") as outfile: 405 | outfile.write(protein.to_pdb(pred)) 406 | 407 | with open(f"b_{ outname }", "w") as outfile: 408 | for line in open(outname, "r").readlines(): 409 | if line[0:6] == "ATOM ": 410 | seq_id = int(line[22:26].strip()) - 1 411 | seq_id = np.where(res_idx == seq_id)[0][0] 412 | outfile.write( 413 | "{}A{}{:6.2f}{}".format( 414 | line[:21], line[22:60], plddts[seq_id], line[66:] 415 | ) 416 | ) 417 | 418 | os.rename(f"b_{ outname }", outname) 419 | 420 | -------------------------------------------------------------------------------- /scripts/mmseqs2.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import numpy as np 3 | import os 4 | import re 5 | import requests 6 | import tarfile 7 | import time 8 | import random 9 | 10 | from absl import logging 11 | from typing import List, NoReturn, Tuple 12 | 13 | 14 | class MMSeqs2Runner: 15 | 16 | r"""Runner object 17 | 18 | Fetches sequence alignment and templates from MMSeqs2 server 19 | Based on the function run_mmseqs2 from ColabFold (sokrypton/ColabFold) 20 | Version 62d7558c91a9809712b022faf9d91d8b183c328c 21 | 22 | Relevant publications 23 | ---------- 24 | * "Clustering huge protein sequence sets in linear time" 25 | https://doi.org/10.1038/s41467-018-04964-5 26 | * "MMseqs2 enables sensitive protein sequence searching for the analysis 27 | of massive data sets" 28 | https://doi.org/10.1038/nbt.3988 29 | 30 | Private variables 31 | ---------- 32 | self.job: Job ID (five-char string) 33 | self.seq: Sequence to search 34 | self.host_url: URL address to ping for data 35 | self.t_url: URL address to ping for templates from PDB 36 | self.n_templates = Number of templates to fetch (default=20) 37 | self.path: Path to use 38 | self.tarfile: Compressed file archive to download 39 | """ 40 | 41 | def __init__( 42 | self, 43 | job: str, 44 | seq: str, 45 | host_url: str = "https://a3m.mmseqs.com", 46 | t_url: str = "https://a3m-templates.mmseqs.com/template", 47 | path_suffix: str = "env", 48 | n_templates: int = 20, 49 | shuffling_templates: bool = False, 50 | ): 51 | 52 | r"""Initialize runner object 53 | 54 | Parameters 55 | ---------- 56 | job : Job name 57 | seq : Amino acid sequence 58 | host_url : Website to ping for sequence data 59 | t_url : Website to ping for template info 60 | path_suffix : Suffix for path info 61 | 62 | """ 63 | 64 | # Clean up sequence 65 | self.seq = self._cleanseq(seq.upper()) 66 | 67 | # Come up with unique job ID for MMSeqs 68 | self.job = self._define_jobname(job) 69 | 70 | # Save everything else 71 | self.host_url = host_url 72 | self.t_url = t_url 73 | self.n_templates = n_templates 74 | self.shuffling_templates = shuffling_templates 75 | 76 | self.path = "_".join((self.job, path_suffix)) 77 | 78 | if not os.path.isdir(self.path): 79 | os.system(f"mkdir { self.path }") 80 | 81 | self.tarfile = f"{ self.path }/out.tar.gz" 82 | 83 | def _cleanseq(self, seq) -> str: 84 | 85 | r"""Cleans the sequence to remove whitespace and noncanonical letters 86 | 87 | Parameters 88 | ---------- 89 | seq : Amino acid sequence (only all 20 here) 90 | 91 | Returns 92 | ---------- 93 | Cleaned up amin acid sequence 94 | 95 | """ 96 | 97 | if any([aa in seq for aa in "BJOUXZ"]): 98 | logging.warning("Sequence contains non-canonical amino acids!") 99 | logging.warning("Removing B, J, O, U, X, and Z from sequence") 100 | seq = re.sub(r"[BJOUXZ]", "", seq) 101 | 102 | return re.sub(r"[^A-Z]", "", "".join(seq.split())) 103 | 104 | def _define_jobname(self, job: str) -> str: 105 | 106 | r"""Provides a unique five-digit identifier for the job name 107 | 108 | Parameters 109 | ---------- 110 | job : Job name 111 | 112 | Returns 113 | ---------- 114 | Defined job name 115 | 116 | """ 117 | 118 | return "_".join( 119 | ( 120 | re.sub(r"\W+", "", "".join(job.split())), 121 | hashlib.sha1(self.seq.encode()).hexdigest()[:5], 122 | ) 123 | ) 124 | 125 | def _submit(self) -> dict: 126 | 127 | r"""Submit job to MMSeqs2 server 128 | 129 | Parameters 130 | ---------- 131 | None 132 | 133 | Returns 134 | ---------- 135 | None 136 | 137 | """ 138 | 139 | data = {"q": f">101\n{ self.seq }", "mode": "env"} 140 | 141 | res = requests.post(f"{ self.host_url }/ticket/msa", data=data) 142 | 143 | try: 144 | out = res.json() 145 | 146 | except ValueError: 147 | out = {"status": "UNKNOWN"} 148 | 149 | return out 150 | 151 | def _status(self, idx: str) -> dict: 152 | 153 | r"""Check status of job 154 | 155 | Parameters 156 | ---------- 157 | idx : Index assigned by MMSeqs2 server 158 | 159 | Returns 160 | ---------- 161 | None 162 | 163 | """ 164 | 165 | res = requests.get(f"{ self.host_url }/ticket/{ idx }") 166 | 167 | try: 168 | out = res.json() 169 | 170 | except ValueError: 171 | out = {"status": "UNKNOWN"} 172 | 173 | return out 174 | 175 | def _download(self, idx: str, path: str) -> NoReturn: 176 | 177 | r"""Download job outputs 178 | 179 | Parameters 180 | ---------- 181 | idx : Index assigned by MMSeqs2 server 182 | path : Path to download data 183 | 184 | Returns 185 | ---------- 186 | None 187 | 188 | """ 189 | 190 | res = requests.get(f"{ self.host_url }/result/download/{ idx }") 191 | 192 | with open(path, "wb") as out: 193 | out.write(res.content) 194 | 195 | def _search_mmseqs2(self) -> NoReturn: 196 | 197 | r"""Run the search and download results 198 | Heavily modified from ColabFold 199 | 200 | Parameters 201 | ---------- 202 | None 203 | 204 | Returns 205 | ---------- 206 | None 207 | 208 | """ 209 | 210 | if os.path.isfile(self.tarfile): 211 | return 212 | 213 | out = self._submit() 214 | 215 | time.sleep(5 + np.random.randint(0, 5)) 216 | while out["status"] in ["UNKNOWN", "RATELIMIT"]: 217 | # resubmit 218 | time.sleep(5 + np.random.randint(0, 5)) 219 | out = self._submit() 220 | 221 | logging.debug(f"ID: { out[ 'id' ] }") 222 | 223 | while out["status"] in ["UNKNOWN", "RUNNING", "PENDING"]: 224 | time.sleep(5 + np.random.randint(0, 5)) 225 | out = self._status(out["id"]) 226 | 227 | if out["status"] == "COMPLETE": 228 | self._download(out["id"], self.tarfile) 229 | 230 | elif out["status"] == "ERROR": 231 | raise RuntimeError( 232 | " ".join( 233 | ( 234 | "MMseqs2 API is giving errors.", 235 | "Please confirm your input is a valid protein sequence.", 236 | "If error persists, please try again in an hour.", 237 | ) 238 | ) 239 | ) 240 | 241 | def process_templates(self, templates: List[str] = [] ) -> list: 242 | 243 | r"""Process templates and fetch from MMSeqs2 server 244 | 245 | Parameters 246 | ---------- 247 | templates : list of pdb ids with chain 248 | exclusion_gpcrs : list of pdb ids without chain 249 | 250 | Returns 251 | ---------- 252 | Directory containing templates (empty if not using templates) 253 | 254 | """ 255 | 256 | path = f"{ self.job }_env/templates_101" 257 | if os.path.isdir(path): 258 | os.system(f"rm -r { path }") 259 | 260 | # templates = {} 261 | logging.info("\t".join(("seq", "pdb", "cid", "evalue"))) 262 | 263 | pdbs = [] 264 | check_duplicates = [] 265 | with open(f"{ self.path }/pdb70.m8", "r") as infile: 266 | 267 | for line in infile: 268 | 269 | sl = line.rstrip().split() 270 | pdb = sl[1] 271 | pdbid = pdb.split("_")[0] 272 | # GPCRdb only accepts pdb codes in uppercase (otherwise the returned request will be empty) 273 | pdbid = pdbid.upper() 274 | if templates: 275 | if templates[0] in ["Active", "Inactive", "Intermediate", "G protein", "Arrestin"] and pdbid not in check_duplicates and pdbid not in templates: 276 | activation_state = templates[0] 277 | url = "http://gpcrdb.org/services/structure/{}".format( pdbid ) 278 | r = requests.get( url ) 279 | rj = r.json() 280 | if type(rj) is dict and rj["state"] == activation_state: 281 | pdbs.append(pdb) 282 | check_duplicates.append(pdbid) 283 | elif type(rj) is dict and "signalling_protein" in rj: 284 | if rj["signalling_protein"]["type"] == activation_state: 285 | pdbs.append(pdb) 286 | check_duplicates.append(pdbid) 287 | 288 | if len(templates[0]) == 3 and pdbid not in check_duplicates and pdbid not in templates: 289 | if templates[0][0] in ["in", "out", "out-like"]: 290 | dfg = templates[0][0] 291 | elif templates[0][0] == "all": 292 | dfg = "all" 293 | else: 294 | raise RuntimeError("DFG value invalid") 295 | if templates[0][1] in ["in", "out",]: 296 | ac_helix = templates[0][1] 297 | elif templates[0][1] == "all": 298 | ac_helix = "all" 299 | else: 300 | raise RuntimeError("ac_helix value invalid") 301 | if templates[0][2] in ["yes", "no", "all"]: 302 | salt_bridge = templates[0][2] 303 | else: 304 | raise RuntimeError("salt_bridge value invalid") 305 | url = "https://klifs.net/api_v2/structures_pdb_list?pdb-codes={}".format( pdbid ) 306 | r = requests.get( url ) 307 | rj = r.json() 308 | #print(rj) 309 | if rj[0] != 400: 310 | #take kinase_ID value and search for structure_conformation 311 | structure_ID = rj[0]["structure_ID"] 312 | url = "https://klifs.net/api_v2/structure_conformation?structure_ID={}".format( structure_ID ) 313 | r = requests.get( url ) 314 | rj = r.json() 315 | #print(rj) 316 | if float(rj[0]["salt_bridge_17_24"]) > 0 and float(rj[0]["salt_bridge_17_24"]) <= 4.5: 317 | ref_sb = "yes" 318 | else: 319 | ref_sb = "no" 320 | if dfg != "all" and ac_helix != "all" and salt_bridge != "all": 321 | if rj[0]["DFG"] == dfg and rj[0]["ac_helix"] == ac_helix and salt_bridge == ref_sb: 322 | pdbs.append(pdb) 323 | check_duplicates.append(pdbid) 324 | elif dfg != "all" and ac_helix != "all" and salt_bridge == "all": 325 | if rj[0]["DFG"] == dfg and rj[0]["ac_helix"] == ac_helix: 326 | pdbs.append(pdb) 327 | check_duplicates.append(pdbid) 328 | elif dfg != "all" and ac_helix == "all" and salt_bridge != "all": 329 | if rj[0]["DFG"] == dfg and salt_bridge == ref_sb: 330 | pdbs.append(pdb) 331 | check_duplicates.append(pdbid) 332 | elif dfg != "all" and ac_helix == "all" and salt_bridge == "all": 333 | if rj[0]["DFG"] == dfg: 334 | pdbs.append(pdb) 335 | check_duplicates.append(pdbid) 336 | elif dfg == "all" and ac_helix != "all" and salt_bridge != "all": 337 | if rj[0]["ac_helix"] == ac_helix and salt_bridge == ref_sb: 338 | pdbs.append(pdb) 339 | check_duplicates.append(pdbid) 340 | elif dfg == "all" and ac_helix != "all" and salt_bridge == "all": 341 | if rj[0]["ac_helix"] == ac_helix: 342 | pdbs.append(pdb) 343 | check_duplicates.append(pdbid) 344 | elif dfg == "all" and ac_helix == "all" and salt_bridge != "all": 345 | if salt_bridge == ref_sb: 346 | pdbs.append(pdb) 347 | check_duplicates.append(pdbid) 348 | elif dfg == "all" and ac_helix == "all" and salt_bridge == "all": 349 | pdbs.append(pdb) 350 | check_duplicates.append(pdbid) 351 | 352 | elif pdb in templates: 353 | pdbs.append(sl[1]) 354 | logging.info(f"{ sl[0] }\t{ sl[1] }\t{ sl[2] }\t{ sl[10] }") 355 | 356 | #write comma-seprated pdbs to file 357 | with open(f"{ self.path }/template_pdbs.txt", "w") as outfile: 358 | for pdb in pdbs: 359 | outfile.write(f"{ pdb },") 360 | 361 | return self.download_templates(pdbs) 362 | 363 | def download_templates(self, pdbs) -> str: 364 | """Shuffle templates.""" 365 | 366 | path = f"{ self.job }_env/templates_101" 367 | if os.path.isdir(path): 368 | os.system(f"rm -r { path }") 369 | 370 | if len(pdbs) == 0: 371 | logging.warning("No templates found.") 372 | return "" 373 | else: 374 | if not os.path.isdir(path): 375 | os.mkdir(path) 376 | 377 | if len(pdbs) > 1 and self.shuffling_templates: 378 | random.shuffle(pdbs) 379 | 380 | pdbs = ",".join(pdbs[: self.n_templates]) 381 | 382 | logging.info("TEMPLATE PDBS USED: " + pdbs) 383 | 384 | os.system(f"wget -q -O - { self.t_url }/{ pdbs } |tar xzf - -C { path }/") 385 | 386 | os.system(f"cp { path }/pdb70_a3m.ffindex { path }/pdb70_cs219.ffindex") 387 | 388 | os.system(f"touch { path }/pdb70_cs219.ffdata") 389 | 390 | return path 391 | 392 | def _process_alignment( 393 | self, a3m_files: list, templates: List[str] = [] 394 | ) -> Tuple[str, str]: 395 | 396 | r"""Process sequence alignment 397 | (modified from ColabFold) 398 | 399 | Parameters 400 | ---------- 401 | a3m_files : List of files to parse 402 | token : Token to look for when parsing 403 | 404 | Returns 405 | ---------- 406 | Tuple with [0] string with alignment, and [1] path to template 407 | 408 | """ 409 | 410 | a3m_lines = "" 411 | 412 | for a3m_file in a3m_files: 413 | for line in open(os.path.join(self.path, a3m_file), "r"): 414 | if len(line) > 0: 415 | a3m_lines += line.replace("\x00", "") 416 | 417 | return a3m_lines, self.process_templates(templates) 418 | 419 | def run_job(self, templates: List[str] = []) -> Tuple[str, str]: 420 | 421 | r""" 422 | Run sequence alignments using MMseqs2 423 | 424 | Parameters 425 | ---------- 426 | use_templates: Whether to use templates 427 | 428 | Returns 429 | ---------- 430 | Tuple with [0] string with alignment, and [1] path to template 431 | 432 | """ 433 | 434 | self._search_mmseqs2() 435 | 436 | a3m_files = ["uniref.a3m", "bfd.mgnify30.metaeuk30.smag30.a3m"] 437 | 438 | # extract a3m files 439 | if not os.path.isfile(os.path.join(self.path, a3m_files[0])): 440 | with tarfile.open(self.tarfile) as tar_gz: 441 | tar_gz.extractall(self.path) 442 | 443 | return self._process_alignment(a3m_files, templates) 444 | 445 | def shuffle_templates(self) -> List: 446 | 447 | r""" 448 | Run sequence alignments using MMseqs2 449 | 450 | Parameters 451 | ---------- 452 | use_templates: Whether to use templates 453 | 454 | Returns 455 | ---------- 456 | Tuple with [0] string with alignment, and [1] path to template 457 | 458 | """ 459 | #read input file and extract the fir row in a list 460 | with open(f"{ self.path }/template_pdbs.txt", "r") as infile: 461 | pdbs = infile.read().split(",") 462 | 463 | #remove last element of a list if it is empty 464 | if pdbs[-1] == "": 465 | pdbs.pop() 466 | print("READ_LIST: ", pdbs) 467 | 468 | if len(pdbs) > 1: 469 | self.shuffling_templates=True 470 | else: 471 | logging.warning("Impossible to shuffle with 1 template only.") 472 | 473 | return self.download_templates(pdbs) 474 | -------------------------------------------------------------------------------- /scripts/colabfold_alphafold.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from urllib import request 4 | from concurrent import futures 5 | import pickle 6 | 7 | import jax 8 | from alphafold.data.tools import jackhmmer 9 | from alphafold.data import parsers 10 | from alphafold.data import pipeline 11 | from alphafold.common import protein 12 | from alphafold.model import config 13 | from alphafold.model import model 14 | from alphafold.model import data 15 | from alphafold.model.tf import shape_placeholders 16 | 17 | import tensorflow as tf 18 | 19 | from string import ascii_uppercase 20 | 21 | import numpy as np 22 | import matplotlib.pyplot as plt 23 | 24 | import re 25 | import colabfold as cf 26 | import pairmsa 27 | 28 | try: 29 | from google.colab import files 30 | IN_COLAB = True 31 | except: 32 | IN_COLAB = False 33 | 34 | if os.getenv('COLABFOLD_PATH'): 35 | print("COLABFOLD_PATH is set to " + os.getenv('COLABFOLD_PATH')) 36 | colabfold_path = os.getenv('COLABFOLD_PATH') 37 | else: 38 | print("COLABFOLD_PATH is not set.") 39 | colabfold_path = '.' 40 | 41 | import tqdm.notebook 42 | TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]' 43 | 44 | ####################################################################################################################################### 45 | # prep_inputs 46 | ####################################################################################################################################### 47 | 48 | def prep_inputs(sequence, jobname="test", homooligomer="1", output_dir=None, clean=False, verbose=True): 49 | # process inputs 50 | sequence = str(sequence) 51 | sequence = re.sub("[^A-Z:/]", "", sequence.upper()) 52 | sequence = re.sub(":+",":",sequence) 53 | sequence = re.sub("/+","/",sequence) 54 | sequence = re.sub("^[:/]+","",sequence) 55 | sequence = re.sub("[:/]+$","",sequence) 56 | jobname = re.sub(r'\W+', '', jobname) 57 | homooligomer = str(homooligomer) 58 | homooligomer = re.sub("[:/]+",":",homooligomer) 59 | homooligomer = re.sub("^[:/]+","",homooligomer) 60 | homooligomer = re.sub("[:/]+$","",homooligomer) 61 | 62 | if len(homooligomer) == 0: homooligomer = "1" 63 | homooligomer = re.sub("[^0-9:]", "", homooligomer) 64 | 65 | # define inputs 66 | I = {"ori_sequence":sequence, 67 | "sequence":sequence.replace("/","").replace(":",""), 68 | "seqs":sequence.replace("/","").split(":"), 69 | "homooligomer":homooligomer, 70 | "homooligomers":[int(h) for h in homooligomer.split(":")], 71 | "msas":[], "deletion_matrices":[]} 72 | 73 | # adjust homooligomer option 74 | if len(I["seqs"]) != len(I["homooligomers"]): 75 | if len(I["homooligomers"]) == 1: 76 | I["homooligomers"] = [I["homooligomers"][0]] * len(I["seqs"]) 77 | else: 78 | if verbose: 79 | print("WARNING: Mismatch between number of breaks ':' in 'sequence' and 'homooligomer' definition") 80 | while len(I["seqs"]) > len(I["homooligomers"]): 81 | I["homooligomers"].append(1) 82 | I["homooligomers"] = I["homooligomers"][:len(I["seqs"])] 83 | I["homooligomer"] = ":".join([str(h) for h in I["homooligomers"]]) 84 | 85 | # define full sequence being modelled 86 | I["full_sequence"] = ''.join([s*h for s,h in zip(I["seqs"],I["homooligomers"])]) 87 | I["lengths"] = [len(seq) for seq in I["seqs"]] 88 | 89 | # prediction directory 90 | if output_dir is None: 91 | I["output_dir"] = 'prediction_' + jobname + '_' + cf.get_hash(I["full_sequence"])[:5] 92 | else: 93 | I["output_dir"] = output_dir 94 | os.makedirs(I["output_dir"], exist_ok=True) 95 | 96 | # delete existing files in working directory 97 | if clean: 98 | for f in os.listdir(I["output_dir"]): 99 | os.remove(os.path.join(I["output_dir"], f)) 100 | 101 | if verbose and len(I["full_sequence"]) > 1400: 102 | print(f"WARNING: For a typical Google-Colab-GPU (16G) session, the max total length is ~1400 residues. You are at {len(I['full_sequence'])}!") 103 | print(f"Run Alphafold may crash, unless you trim to the protein(s) to a short length. (See trim options below).") 104 | 105 | if verbose: 106 | print(f"homooligomer: {I['homooligomer']}") 107 | print(f"total_length: {len(I['full_sequence'])}") 108 | print(f"output_dir: {I['output_dir']}") 109 | 110 | return I 111 | 112 | ####################################################################################################################################### 113 | # prep_msa 114 | ####################################################################################################################################### 115 | 116 | def run_jackhmmer(sequence, prefix, jackhmmer_binary_path='jackhmmer', verbose=True): 117 | 118 | fasta_path = f"{prefix}.fasta" 119 | with open(fasta_path, 'wt') as f: 120 | f.write(f'>query\n{sequence}') 121 | 122 | pickled_msa_path = f"{prefix}.jackhmmer.pickle" 123 | if os.path.isfile(pickled_msa_path): 124 | msas_dict = pickle.load(open(pickled_msa_path,"rb")) 125 | msas, deletion_matrices, names = (msas_dict[k] for k in ['msas', 'deletion_matrices', 'names']) 126 | full_msa = [] 127 | for msa in msas: 128 | full_msa += msa 129 | else: 130 | # --- Find the closest source --- 131 | test_url_pattern = 'https://storage.googleapis.com/alphafold-colab{:s}/latest/uniref90_2021_03.fasta.1' 132 | ex = futures.ThreadPoolExecutor(3) 133 | def fetch(source): 134 | request.urlretrieve(test_url_pattern.format(source)) 135 | return source 136 | fs = [ex.submit(fetch, source) for source in ['', '-europe', '-asia']] 137 | source = None 138 | for f in futures.as_completed(fs): 139 | source = f.result() 140 | ex.shutdown() 141 | break 142 | 143 | dbs = [] 144 | 145 | num_jackhmmer_chunks = {'uniref90': 59, 'smallbfd': 17, 'mgnify': 71} 146 | total_jackhmmer_chunks = sum(num_jackhmmer_chunks.values()) 147 | disable_tqdm = not verbose 148 | with tqdm.notebook.tqdm(total=total_jackhmmer_chunks, bar_format=TQDM_BAR_FORMAT, disable=disable_tqdm) as pbar: 149 | def jackhmmer_chunk_callback(i): 150 | pbar.update(n=1) 151 | 152 | pbar.set_description('Searching uniref90') 153 | jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( 154 | binary_path=jackhmmer_binary_path, 155 | database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/uniref90_2021_03.fasta', 156 | get_tblout=True, 157 | num_streamed_chunks=num_jackhmmer_chunks['uniref90'], 158 | streaming_callback=jackhmmer_chunk_callback, 159 | z_value=135301051) 160 | dbs.append(('uniref90', jackhmmer_uniref90_runner.query(fasta_path))) 161 | 162 | pbar.set_description('Searching smallbfd') 163 | jackhmmer_smallbfd_runner = jackhmmer.Jackhmmer( 164 | binary_path=jackhmmer_binary_path, 165 | database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/bfd-first_non_consensus_sequences.fasta', 166 | get_tblout=True, 167 | num_streamed_chunks=num_jackhmmer_chunks['smallbfd'], 168 | streaming_callback=jackhmmer_chunk_callback, 169 | z_value=65984053) 170 | dbs.append(('smallbfd', jackhmmer_smallbfd_runner.query(fasta_path))) 171 | 172 | pbar.set_description('Searching mgnify') 173 | jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( 174 | binary_path=jackhmmer_binary_path, 175 | database_path=f'https://storage.googleapis.com/alphafold-colab{source}/latest/mgy_clusters_2019_05.fasta', 176 | get_tblout=True, 177 | num_streamed_chunks=num_jackhmmer_chunks['mgnify'], 178 | streaming_callback=jackhmmer_chunk_callback, 179 | z_value=304820129) 180 | dbs.append(('mgnify', jackhmmer_mgnify_runner.query(fasta_path))) 181 | 182 | # --- Extract the MSAs and visualize --- 183 | # Extract the MSAs from the Stockholm files. 184 | # NB: deduplication happens later in pipeline.make_msa_features. 185 | 186 | mgnify_max_hits = 501 187 | msas = [] 188 | deletion_matrices = [] 189 | names = [] 190 | for db_name, db_results in dbs: 191 | unsorted_results = [] 192 | for i, result in enumerate(db_results): 193 | msa, deletion_matrix, target_names = parsers.parse_stockholm(result['sto']) 194 | e_values_dict = parsers.parse_e_values_from_tblout(result['tbl']) 195 | e_values = [e_values_dict[t.split('/')[0]] for t in target_names] 196 | zipped_results = zip(msa, deletion_matrix, target_names, e_values) 197 | if i != 0: 198 | # Only take query from the first chunk 199 | zipped_results = [x for x in zipped_results if x[2] != 'query'] 200 | unsorted_results.extend(zipped_results) 201 | sorted_by_evalue = sorted(unsorted_results, key=lambda x: x[3]) 202 | db_msas, db_deletion_matrices, db_names, _ = zip(*sorted_by_evalue) 203 | if db_msas: 204 | if db_name == 'mgnify': 205 | db_msas = db_msas[:mgnify_max_hits] 206 | db_deletion_matrices = db_deletion_matrices[:mgnify_max_hits] 207 | db_names = db_names[:mgnify_max_hits] 208 | msas.append(db_msas) 209 | deletion_matrices.append(db_deletion_matrices) 210 | names.append(db_names) 211 | msa_size = len(set(db_msas)) 212 | print(f'{msa_size} Sequences Found in {db_name}') 213 | 214 | pickle.dump({"msas":msas, 215 | "deletion_matrices":deletion_matrices, 216 | "names":names}, open(pickled_msa_path,"wb")) 217 | return msas, deletion_matrices, names 218 | 219 | def prep_msa(I, msa_method="mmseqs2", add_custom_msa=False, msa_format="fas", 220 | pair_mode="unpaired", pair_cov=50, pair_qid=20, 221 | hhfilter_loc="hhfilter", reformat_loc="reformat.pl", TMP_DIR="tmp", 222 | custom_msa=None, precomputed=None, 223 | mmseqs_host_url="https://a3m.mmseqs.com", 224 | verbose=True): 225 | 226 | # make temp directory 227 | os.makedirs(TMP_DIR, exist_ok=True) 228 | 229 | # clear previous inputs 230 | I["msas"] = [] 231 | I["deletion_matrices"] = [] 232 | 233 | if add_custom_msa: 234 | if IN_COLAB: 235 | print(f"upload custom msa in '{msa_format}' format") 236 | msa_dict = files.upload() 237 | lines = msa_dict[list(msa_dict.keys())[0]].decode() 238 | input_file = os.path.join(I["output_dir"],f"upload.{msa_format}") 239 | with open(input_file,"w") as tmp_upload: 240 | tmp_upload.write(lines) 241 | else: 242 | input_file = custom_msa 243 | if input_file is None or not os.path.isfile(input_file): 244 | raise ValueError("ERROR: `custom_msa` undefined") 245 | else: 246 | # convert to a3m 247 | output_file = os.path.join(I["output_dir"],f"upload.a3m") 248 | os.system(f"{reformat_loc} {msa_format} a3m {input_file} {output_file}") 249 | 250 | # parse 251 | msa, mtx = parsers.parse_a3m(open(output_file,"r").read()) 252 | I["msas"].append(msa) 253 | I["deletion_matrices"].append(mtx) 254 | if len(I["msas"][0][0]) != len(I["sequence"]): 255 | raise ValueError("ERROR: the length of msa does not match input sequence") 256 | 257 | if msa_method == "precomputed": 258 | if IN_COLAB: 259 | print("upload precomputed pickled msa from previous run") 260 | uploaded_dict = files.upload() 261 | uploaded_filename = list(uploaded_dict.keys())[0] 262 | I.update(pickle.loads(uploaded_dict[uploaded_filename])) 263 | elif precomputed is None: 264 | raise ValueError("ERROR: `precomputed` undefined") 265 | else: 266 | I.update(pickle.load(open(precomputed,"rb"))) 267 | 268 | elif msa_method == "single_sequence": 269 | if len(I["msas"]) == 0: 270 | I["msas"].append([I["sequence"]]) 271 | I["deletion_matrices"].append([[0]*len(I["sequence"])]) 272 | 273 | else: 274 | _blank_seq = ["-" * L for L in I["lengths"]] 275 | _blank_mtx = [[0] * L for L in I["lengths"]] 276 | def _pad(ns,vals,mode): 277 | if mode == "seq": _blank = _blank_seq.copy() 278 | if mode == "mtx": _blank = _blank_mtx.copy() 279 | if isinstance(ns, list): 280 | for n,val in zip(ns,vals): _blank[n] = val 281 | else: _blank[ns] = vals 282 | if mode == "seq": return "".join(_blank) 283 | if mode == "mtx": return sum(_blank,[]) 284 | 285 | if len(I["seqs"]) == 1 or "unpaired" in pair_mode: 286 | # gather msas 287 | if msa_method == "mmseqs2": 288 | prefix = cf.get_hash(I["sequence"]) 289 | prefix = os.path.join(TMP_DIR,prefix) 290 | print(f"running mmseqs2") 291 | A3M_LINES = cf.run_mmseqs2(I["seqs"], prefix, use_filter=True, host_url=mmseqs_host_url) 292 | 293 | for n, seq in enumerate(I["seqs"]): 294 | # tmp directory 295 | prefix = cf.get_hash(seq) 296 | prefix = os.path.join(TMP_DIR,prefix) 297 | 298 | if msa_method == "mmseqs2": 299 | # run mmseqs2 300 | a3m_lines = A3M_LINES[n] 301 | msa, mtx = parsers.parse_a3m(a3m_lines) 302 | msas_, mtxs_ = [msa],[mtx] 303 | 304 | elif msa_method == "jackhmmer": 305 | print(f"running jackhmmer on seq_{n}") 306 | # run jackhmmer 307 | msas_, mtxs_, names_ = ([sum(x,())] for x in run_jackhmmer(seq, prefix)) 308 | 309 | # pad sequences 310 | for msa_,mtx_ in zip(msas_,mtxs_): 311 | msa,mtx = [I["sequence"]],[[0]*len(I["sequence"])] 312 | for s,m in zip(msa_,mtx_): 313 | msa.append(_pad(n,s,"seq")) 314 | mtx.append(_pad(n,m,"mtx")) 315 | 316 | I["msas"].append(msa) 317 | I["deletion_matrices"].append(mtx) 318 | 319 | # PAIR_MSA 320 | if len(I["seqs"]) > 1 and (pair_mode == "paired" or pair_mode == "unpaired+paired"): 321 | print("attempting to pair some sequences...") 322 | 323 | if msa_method == "mmseqs2": 324 | prefix = cf.get_hash(I["sequence"]) 325 | prefix = os.path.join(TMP_DIR,prefix) 326 | print(f"running mmseqs2_noenv_nofilter on all seqs") 327 | A3M_LINES = cf.run_mmseqs2(I["seqs"], prefix, use_env=False, use_filter=False, host_url=mmseqs_host_url) 328 | 329 | _data = [] 330 | for a in range(len(I["seqs"])): 331 | print(f"prepping seq_{a}") 332 | _seq = I["seqs"][a] 333 | _prefix = os.path.join(TMP_DIR,cf.get_hash(_seq)) 334 | 335 | if msa_method == "mmseqs2": 336 | a3m_lines = A3M_LINES[a] 337 | _msa, _mtx, _lab = pairmsa.parse_a3m(a3m_lines, 338 | filter_qid=pair_qid/100, 339 | filter_cov=pair_cov/100) 340 | 341 | elif msa_method == "jackhmmer": 342 | _msas, _mtxs, _names = run_jackhmmer(_seq, _prefix) 343 | _msa, _mtx, _lab = pairmsa.get_uni_jackhmmer(_msas[0], _mtxs[0], _names[0], 344 | filter_qid=pair_qid/100, 345 | filter_cov=pair_cov/100) 346 | 347 | if len(_msa) > 1: 348 | _data.append(pairmsa.hash_it(_msa, _lab, _mtx, call_uniprot=False)) 349 | else: 350 | _data.append(None) 351 | 352 | Ln = len(I["seqs"]) 353 | O = [[None for _ in I["seqs"]] for _ in I["seqs"]] 354 | for a in range(Ln): 355 | if _data[a] is not None: 356 | for b in range(a+1,Ln): 357 | if _data[b] is not None: 358 | print(f"attempting pairwise stitch for {a} {b}") 359 | O[a][b] = pairmsa._stitch(_data[a],_data[b]) 360 | _seq_a, _seq_b, _mtx_a, _mtx_b = (*O[a][b]["seq"],*O[a][b]["mtx"]) 361 | 362 | # filter to remove redundant sequences 363 | ok = [] 364 | with open(f"{TMP_DIR}/tmp.fas","w") as fas_file: 365 | fas_file.writelines([f">{n}\n{a+b}\n" for n,(a,b) in enumerate(zip(_seq_a,_seq_b))]) 366 | os.system(f"{hhfilter_loc} -maxseq 1000000 -i {TMP_DIR}/tmp.fas -o {TMP_DIR}/tmp.id90.fas -id 90") 367 | for line in open(f"{TMP_DIR}/tmp.id90.fas","r"): 368 | if line.startswith(">"): ok.append(int(line[1:])) 369 | 370 | if verbose: 371 | print(f"found {len(_seq_a)} pairs ({len(ok)} after filtering)") 372 | 373 | if len(_seq_a) > 0: 374 | msa,mtx = [I["sequence"]],[[0]*len(I["sequence"])] 375 | for s_a,s_b,m_a,m_b in zip(_seq_a, _seq_b, _mtx_a, _mtx_b): 376 | msa.append(_pad([a,b],[s_a,s_b],"seq")) 377 | mtx.append(_pad([a,b],[m_a,m_b],"mtx")) 378 | I["msas"].append(msa) 379 | I["deletion_matrices"].append(mtx) 380 | 381 | # save MSA as pickle 382 | pickle.dump({"msas":I["msas"],"deletion_matrices":I["deletion_matrices"]}, 383 | open(os.path.join(I["output_dir"],"msa.pickle"),"wb")) 384 | return I 385 | 386 | ####################################################################################################################################### 387 | # prep_filter 388 | ####################################################################################################################################### 389 | 390 | def trim_inputs(trim, msas, deletion_matrices, ori_seq=None, inverse=False): 391 | ''' 392 | input: trim, msas, deletion_matrices, ori_seq 393 | output: msas, deletion_matrices, ori_seq 394 | ''' 395 | if ori_seq is None: ori_seq = msas[0][0] 396 | seqs = ori_seq.replace("/","").split(":") 397 | L_ini = 0 398 | chain_idx = {} 399 | idx_chain = [] 400 | for chain,seq in zip(ascii_uppercase,seqs): 401 | L = len(seq) 402 | chain_idx[chain] = dict(zip(range(L),range(L_ini,L_ini+L))) 403 | idx_chain += [f"{chain}{i+1}" for i in range(L)] 404 | L_ini += L 405 | global_idx = dict(zip(range(L_ini),range(L_ini))) 406 | 407 | mode = "keeping" if inverse else "trimming" 408 | trim_set = [] 409 | for idx in trim.split(","): 410 | 411 | i,j = idx.split("-") if "-" in idx else (idx,"") 412 | 413 | # set index reference frame 414 | trim_idx_i = trim_idx_j = global_idx 415 | if i != "" and i[0] in ascii_uppercase: 416 | trim_idx_i,i = chain_idx[i[0]], i[1:] 417 | if j != "" and j[0] in ascii_uppercase: 418 | trim_idx_j,j = chain_idx[j[0]], j[1:] 419 | 420 | # set which positions to trim 421 | if "-" in idx: 422 | i = trim_idx_i[int(i)-1] if i != "" else trim_idx_i[0] 423 | j = trim_idx_j[int(j)-1] if j != "" else trim_idx_j[len(trim_idx_j) - 1] 424 | trim_set += list(range(i,j+1)) 425 | print(f"{mode} positions: {idx_chain[i]}-{idx_chain[j]}") 426 | else: 427 | i = trim_idx_i[int(i)-1] 428 | trim_set.append(i) 429 | print(f"{mode} position: {idx_chain[i]}") 430 | 431 | # deduplicate list 432 | trim_set = set(trim_set) 433 | if inverse: 434 | trim_set = set(range(L_ini)) ^ trim_set 435 | 436 | trim_set = sorted(list(trim_set)) 437 | 438 | # trim MSA 439 | mod_msas, mod_mtxs = [],[] 440 | for msa, mtx in zip(msas, deletion_matrices): 441 | mod_msa = np.delete([list(s) for s in msa], trim_set, 1) 442 | ok = (mod_msa != "-").sum(-1) > 0 443 | mod_msas.append(["".join(s) for s in mod_msa[ok]]) 444 | mod_mtx = np.asarray(mtx)[ok] 445 | mod_mtxs.append(np.delete(mod_mtx, trim_set, 1).tolist()) 446 | 447 | # trim original sequence 448 | mod_idx = [] 449 | mod_chain = [] 450 | mod_ori_seq = [] 451 | for n,a in enumerate(ori_seq.replace("/","").replace(":","")): 452 | if n not in trim_set: 453 | mod_ori_seq.append(a) 454 | mod_idx.append(n) 455 | mod_chain.append(idx_chain[n][0]) 456 | if len(mod_idx) > 1: 457 | if mod_chain[-1] != mod_chain[-2]: 458 | mod_ori_seq[-1] = ":" 459 | mod_ori_seq.append(a) 460 | elif (mod_idx[-1] - mod_idx[-2]) > 1: 461 | mod_ori_seq[-1] = "/" 462 | mod_ori_seq.append(a) 463 | 464 | mod_ori_seq = "".join(mod_ori_seq) 465 | chains = sorted([ascii_uppercase.index(a) for a in set(mod_chain)]) 466 | return {"msas":mod_msas, "deletion_matrices":mod_mtxs, 467 | "ori_sequence":mod_ori_seq, "chains":chains} 468 | 469 | def cov_qid_filter(msas, deletion_matrices, ori_seq=None, cov=0, qid=0): 470 | if ori_seq is None: ori_seq = msas[0][0] 471 | seqs = ori_seq.replace("/","").split(":") 472 | ref_seq_ = np.array(list("".join(seqs))) 473 | 474 | new_msas,new_mtxs = [],[] 475 | L = np.asarray([len(seq) for seq in seqs]) 476 | Ln = np.cumsum(np.append(0,L)) 477 | for msa, mtx in zip(msas, deletion_matrices): 478 | msa_ = np.asarray([list(seq) for seq in msa]) 479 | 480 | # coverage (non-gap characters) 481 | cov_ = msa_ != "-" 482 | # sequence identity to query 483 | qid_ = msa_ == ref_seq_ 484 | 485 | # split by protein (for protein complexes) 486 | cov__ = np.stack([cov_[:,Ln[i]:Ln[i+1]].sum(-1) for i in range(len(seqs))],-1) 487 | qid__ = np.stack([qid_[:,Ln[i]:Ln[i+1]].sum(-1) for i in range(len(seqs))],-1) 488 | 489 | not_empty__ = cov__ > 0 490 | ok = [] 491 | for n in range(len(msa)): 492 | m = not_empty__[n] 493 | if m.sum() > 0: 494 | q = qid__[n][m].sum() / cov__[n][m].sum() 495 | c = cov__[n][m].sum() / L[m].sum() 496 | if q > qid and c > cov: 497 | ok.append(n) 498 | 499 | new_msas.append([msa[n] for n in ok]) 500 | new_mtxs.append([mtx[n] for n in ok]) 501 | return {"msas":new_msas, "deletion_matrices":new_mtxs} 502 | 503 | def prep_filter(I, trim="", trim_inverse=False, cov=0, qid=0, verbose=True): 504 | trim = re.sub("[^0-9A-Z,-]", "", trim.upper()) 505 | trim = re.sub(",+",",",trim) 506 | trim = re.sub("^[,]+","",trim) 507 | trim = re.sub("[,]+$","",trim) 508 | if trim != "" or cov > 0 or qid > 0: 509 | mod_I = dict(I) 510 | 511 | if trim != "": 512 | mod_I.update(trim_inputs(trim, mod_I["msas"], mod_I["deletion_matrices"], 513 | mod_I["ori_sequence"], inverse=trim_inverse)) 514 | 515 | mod_I["homooligomers"] = [mod_I["homooligomers"][c] for c in mod_I["chains"]] 516 | mod_I["sequence"] = mod_I["ori_sequence"].replace("/","").replace(":","") 517 | mod_I["seqs"] = mod_I["ori_sequence"].replace("/","").split(":") 518 | mod_I["full_sequence"] = "".join([s*h for s,h in zip(mod_I["seqs"], mod_I["homooligomers"])]) 519 | new_length = len(mod_I["full_sequence"]) 520 | if verbose: 521 | print(f"total_length: '{new_length}' after trimming") 522 | 523 | if cov > 0 or qid > 0: 524 | mod_I.update(cov_qid_filter(mod_I["msas"], mod_I["deletion_matrices"], 525 | mod_I["ori_sequence"], cov=cov/100, qid=qid/100)) 526 | return mod_I 527 | else: 528 | return I 529 | 530 | ####################################################################################################################################### 531 | # prep features 532 | ####################################################################################################################################### 533 | 534 | def prep_feats(I, clean=False): 535 | def _placeholder_template_feats(num_templates_, num_res_): 536 | return { 537 | 'template_aatype': np.zeros([num_templates_, num_res_, 22], np.float32), 538 | 'template_all_atom_masks': np.zeros([num_templates_, num_res_, 37], np.float32), 539 | 'template_all_atom_positions': np.zeros([num_templates_, num_res_, 37, 3], np.float32), 540 | 'template_domain_names': np.zeros([num_templates_], np.float32), 541 | 'template_sum_probs': np.zeros([num_templates_], np.float32), 542 | } 543 | # delete old files 544 | if clean: 545 | for f in os.listdir(I["output_dir"]): 546 | if "rank_" in f: os.remove(os.path.join(I["output_dir"], f)) 547 | 548 | if len(I["msas"]) == 0: 549 | print("WARNING: no MSA found, switching to 'single_sequence' mode") 550 | I["msas"].append([I["sequence"]]) 551 | I["deletion_matrices"].append([[0]*len(I["sequence"])]) 552 | 553 | # homooligomerize 554 | lengths = [len(seq) for seq in I["seqs"]] 555 | msas_mod, deletion_matrices_mod = cf.homooligomerize_heterooligomer(I["msas"], I["deletion_matrices"], 556 | lengths, I["homooligomers"]) 557 | # define input features 558 | num_res = len(I["full_sequence"]) 559 | feature_dict = {} 560 | feature_dict.update(pipeline.make_sequence_features(I["full_sequence"], 'test', num_res)) 561 | feature_dict.update(pipeline.make_msa_features(msas_mod, deletion_matrices=deletion_matrices_mod)) 562 | feature_dict.update(_placeholder_template_feats(0, num_res)) 563 | 564 | # set chainbreaks 565 | Ls = [] 566 | for seq,h in zip(I["ori_sequence"].split(":"), I["homooligomers"]): 567 | Ls += [len(s) for s in seq.split("/")] * h 568 | Ls_plot = [] 569 | for seq,h in zip(I["seqs"], I["homooligomers"]): 570 | Ls_plot += [len(seq)] * h 571 | 572 | feature_dict['residue_index'] = cf.chain_break(feature_dict['residue_index'], Ls) 573 | feature_dict['Ls'] = Ls_plot 574 | feature_dict['output_dir'] = I["output_dir"] 575 | return feature_dict 576 | 577 | def make_fixed_size(feat, runner): 578 | '''pad input features''' 579 | opt = runner["opt"] 580 | cfg = runner["model"].config 581 | shape_schema = {k:[None]+v for k,v in dict(cfg.data.eval.feat).items()} 582 | pad_size_map = { 583 | shape_placeholders.NUM_RES: opt["L"], 584 | shape_placeholders.NUM_MSA_SEQ: cfg.data.eval.max_msa_clusters, 585 | shape_placeholders.NUM_EXTRA_SEQ: cfg.data.common.max_extra_msa, 586 | shape_placeholders.NUM_TEMPLATES: 0, 587 | } 588 | for k, v in feat.items(): 589 | # Don't transfer this to the accelerator. 590 | if k == 'extra_cluster_assignment': 591 | continue 592 | shape = list(v.shape) 593 | schema = shape_schema[k] 594 | assert len(shape) == len(schema), ( 595 | f'Rank mismatch between shape and shape schema for {k}: ' 596 | f'{shape} vs {schema}') 597 | pad_size = [pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)] 598 | padding = [(0, p - tf.shape(v)[i]) for i, p in enumerate(pad_size)] 599 | if padding: 600 | feat[k] = tf.pad(v, padding, name=f'pad_to_fixed_{k}') 601 | feat[k].set_shape(pad_size) 602 | return {k:np.asarray(v) for k,v in feat.items()} 603 | 604 | ####################################################################################################################################### 605 | # run alphafold 606 | ####################################################################################################################################### 607 | 608 | def clear_mem(device=None): 609 | '''remove all data from device''' 610 | backend = jax.lib.xla_bridge.get_backend(device) 611 | if hasattr(backend,'live_buffers'): 612 | for buf in backend.live_buffers(): 613 | buf.delete() 614 | 615 | OPT_DEFAULT = {"N":None, "L":None, 616 | "use_ptm":True, "use_turbo":True, 617 | "max_recycles":3, "tol":0, "num_ensemble":1, 618 | "max_msa_clusters":512, "max_extra_msa":1024, 619 | "is_training":False} 620 | 621 | def prep_model_runner(opt=None, model_name="model_5", old_runner=None, params_loc='./alphafold/data'): 622 | 623 | # setup the [opt]ions 624 | if opt is None: 625 | opt = OPT_DEFAULT.copy() 626 | else: 627 | for k in OPT_DEFAULT: 628 | if k not in opt: opt[k] = OPT_DEFAULT[k] 629 | 630 | # if old_runner not defined or [opt]ions changed, start new runner 631 | if old_runner is None or old_runner["opt"] != opt: 632 | clear_mem() 633 | name = f"{model_name}_ptm" if opt["use_ptm"] else model_name 634 | cfg = config.model_config(name) 635 | 636 | if opt["use_turbo"]: 637 | if opt["N"] is None: 638 | cfg.data.eval.max_msa_clusters = opt["max_msa_clusters"] 639 | cfg.data.common.max_extra_msa = opt["max_extra_msa"] 640 | else: 641 | msa_clusters = min(opt["N"], opt["max_msa_clusters"]) 642 | cfg.data.eval.max_msa_clusters = msa_clusters 643 | cfg.data.common.max_extra_msa = max(min(opt["N"] - msa_clusters, opt["max_extra_msa"]),1) 644 | 645 | cfg.data.common.num_recycle = opt["max_recycles"] 646 | cfg.model.num_recycle = opt["max_recycles"] 647 | cfg.model.recycle_tol = opt["tol"] 648 | cfg.data.eval.num_ensemble = opt["num_ensemble"] 649 | 650 | params = data.get_model_haiku_params(name, colabfold_path + "/" + params_loc) 651 | return {"model":model.RunModel(cfg, params, is_training=opt["is_training"]), "opt":opt} 652 | else: 653 | return old_runner 654 | 655 | def run_alphafold(feature_dict, opt=None, runner=None, num_models=5, num_samples=1, subsample_msa=True, 656 | pad_feats=False, rank_by="pLDDT", show_images=True, params_loc='./alphafold/data', verbose=True): 657 | 658 | def do_subsample_msa(F, random_seed=0): 659 | '''subsample msa to avoid running out of memory''' 660 | N = len(F["msa"]) 661 | L = len(F["residue_index"]) 662 | N_ = int(3E7/L) 663 | if N > N_: 664 | if verbose: 665 | print(f"whhhaaa... too many sequences ({N}) subsampling to {N_}") 666 | np.random.seed(random_seed) 667 | idx = np.append(0,np.random.permutation(np.arange(1,N)))[:N_] 668 | F_ = {} 669 | F_["msa"] = F["msa"][idx] 670 | F_["deletion_matrix_int"] = F["deletion_matrix_int"][idx] 671 | F_["num_alignments"] = np.full_like(F["num_alignments"],N_) 672 | for k in F.keys(): 673 | if k not in F_: F_[k] = F[k] 674 | return F_ 675 | else: 676 | return F 677 | 678 | def parse_results(prediction_result, processed_feature_dict, r, t, num_res): 679 | '''parse results and convert to numpy arrays''' 680 | 681 | to_np = lambda a: np.asarray(a) 682 | def class_to_np(c): 683 | class dict2obj(): 684 | def __init__(self, d): 685 | for k,v in d.items(): setattr(self, k, to_np(v)) 686 | return dict2obj(c.__dict__) 687 | 688 | dist_bins = jax.numpy.append(0,prediction_result["distogram"]["bin_edges"]) 689 | dist_logits = prediction_result["distogram"]["logits"][:num_res,:][:,:num_res] 690 | dist_mtx = dist_bins[dist_logits.argmax(-1)] 691 | contact_mtx = jax.nn.softmax(dist_logits)[:,:,dist_bins < 8].sum(-1) 692 | 693 | b_factors = prediction_result['plddt'][:,None] * prediction_result['structure_module']['final_atom_mask'] 694 | p = protein.from_prediction(processed_feature_dict, prediction_result, b_factors=b_factors) 695 | plddt = prediction_result['plddt'][:num_res] 696 | out = {"unrelaxed_protein": class_to_np(p), 697 | "plddt": to_np(plddt), 698 | "pLDDT": to_np(plddt.mean()), 699 | "dists": to_np(dist_mtx), 700 | "adj": to_np(contact_mtx), 701 | "recycles":to_np(r), 702 | "tol":to_np(t)} 703 | if "ptm" in prediction_result: 704 | out["pae"] = to_np(prediction_result['predicted_aligned_error'][:num_res,:][:,:num_res]) 705 | out["pTMscore"] = to_np(prediction_result['ptm']) 706 | return out 707 | 708 | num_res = len(feature_dict["residue_index"]) 709 | 710 | # if [opt]ions not defined 711 | if opt is None: 712 | opt = OPT_DEFAULT.copy() 713 | opt["N"] = len(feature_dict["msa"]) 714 | opt["L"] = num_res 715 | else: 716 | for k in OPT_DEFAULT.keys(): 717 | if k not in opt: opt[k] = OPT_DEFAULT[k] 718 | 719 | model_names = ['model_1', 'model_2', 'model_3', 'model_4', 'model_5'][:num_models] 720 | total = len(model_names) * num_samples 721 | outs = {} 722 | 723 | def do_report(key): 724 | o = outs[key] 725 | if verbose: 726 | line = f"{key} recycles:{o['recycles']} tol:{o['tol']:.2f} pLDDT:{o['pLDDT']:.2f}" 727 | if 'pTMscore' in o: 728 | line += f" pTMscore:{o['pTMscore']:.2f}" 729 | print(line) 730 | if show_images: 731 | fig = cf.plot_protein(o['unrelaxed_protein'], Ls=feature_dict["Ls"], dpi=100) 732 | plt.show() 733 | tmp_pdb_path = os.path.join(feature_dict["output_dir"],f'unranked_{key}_unrelaxed.pdb') 734 | pdb_lines = protein.to_pdb(o['unrelaxed_protein']) 735 | with open(tmp_pdb_path, 'w') as f: f.write(pdb_lines) 736 | 737 | disable_tqdm = not verbose 738 | with tqdm.notebook.tqdm(total=total, bar_format=TQDM_BAR_FORMAT, disable=disable_tqdm) as pbar: 739 | if opt["use_turbo"]: 740 | if runner is None: 741 | runner = prep_model_runner(opt,params_loc=params_loc) 742 | 743 | # go through each random_seed 744 | for seed in range(num_samples): 745 | # prep input features 746 | feat = do_subsample_msa(feature_dict, random_seed=seed) if subsample_msa else feature_dict 747 | processed_feature_dict = runner["model"].process_features(feat, random_seed=seed) 748 | if pad_feats: 749 | processed_feature_dict = make_fixed_size(processed_feature_dict, runner) 750 | 751 | # go through each model 752 | for num, model_name in enumerate(model_names): 753 | name = model_name+"_ptm" if opt["use_ptm"] else model_name 754 | key = f"{name}_seed_{seed}" 755 | pbar.set_description(f'Running {key}') 756 | 757 | # replace model parameters 758 | params = data.get_model_haiku_params(name, colabfold_path + "/" + params_loc) 759 | for k in runner["model"].params.keys(): 760 | runner["model"].params[k] = params[k] 761 | 762 | # predict 763 | prediction_result, (r, t) = runner["model"].predict(processed_feature_dict, random_seed=seed) 764 | outs[key] = parse_results(prediction_result, processed_feature_dict, r=r, t=t, num_res=num_res) 765 | 766 | # cleanup 767 | del prediction_result, params, r, t 768 | 769 | # report 770 | do_report(key) 771 | pbar.update(n=1) 772 | 773 | # cleanup 774 | del processed_feature_dict 775 | if subsample_msa: del feat 776 | 777 | else: 778 | # go through each model 779 | for num, model_name in enumerate(model_names): 780 | name = model_name+"_ptm" if opt["use_ptm"] else model_name 781 | model_runner = prep_model_runner(opt, model_name=model_name, params_loc=params_loc)["model"] 782 | 783 | # go through each random_seed 784 | for seed in range(num_samples): 785 | key = f"{name}_seed_{seed}" 786 | pbar.set_description(f'Running {key}') 787 | processed_feature_dict = model_runner.process_features(feature_dict, random_seed=seed) 788 | 789 | # predict 790 | prediction_result, (r, t) = model_runner.predict(processed_feature_dict, random_seed=seed) 791 | outs[key] = parse_results(prediction_result, processed_feature_dict, r=r, t=t, num_res=num_res) 792 | 793 | # cleanup 794 | del processed_feature_dict, prediction_result, r, t 795 | 796 | # report 797 | do_report(key) 798 | pbar.update(n=1) 799 | 800 | # cleanup 801 | del model_runner 802 | 803 | # Find the best model according to the mean pLDDT. 804 | model_rank = list(outs.keys()) 805 | model_rank = [model_rank[i] for i in np.argsort([outs[x][rank_by] for x in model_rank])[::-1]] 806 | 807 | # Write out the prediction 808 | for n,key in enumerate(model_rank): 809 | prefix = f"rank_{n+1}_{key}" 810 | pred_output_path = os.path.join(feature_dict["output_dir"],f'{prefix}_unrelaxed.pdb') 811 | fig = cf.plot_protein(outs[key]["unrelaxed_protein"], Ls=feature_dict["Ls"], dpi=200) 812 | plt.savefig(os.path.join(feature_dict["output_dir"],f'{prefix}.png'), bbox_inches = 'tight') 813 | plt.close(fig) 814 | pdb_lines = protein.to_pdb(outs[key]["unrelaxed_protein"]) 815 | with open(pred_output_path, 'w') as f: 816 | f.write(pdb_lines) 817 | 818 | tmp_pdb_path = os.path.join(feature_dict["output_dir"],f'unranked_{key}_unrelaxed.pdb') 819 | if os.path.isfile(tmp_pdb_path): 820 | os.remove(tmp_pdb_path) 821 | 822 | ############################################################ 823 | if verbose: 824 | print(f"model rank based on {rank_by}") 825 | for n,key in enumerate(model_rank): 826 | print(f"rank_{n+1}_{key} {rank_by}:{outs[key][rank_by]:.2f}") 827 | 828 | return outs, model_rank 829 | --------------------------------------------------------------------------------