├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── Representations_AlphaFold2PredictStructure.ipynb ├── Representations_AlphaFold2_v3.ipynb ├── alphafold ├── __init__.py ├── common │ ├── __init__.py │ ├── confidence.py │ ├── protein.py │ ├── protein_test.py │ ├── residue_constants.py │ ├── residue_constants_test.py │ └── testdata │ │ └── 2rbg.pdb ├── data │ ├── __init__.py │ ├── mmcif_parsing.py │ ├── parsers.py │ ├── pipeline.py │ ├── templates.py │ └── tools │ │ ├── __init__.py │ │ ├── hhblits.py │ │ ├── hhsearch.py │ │ ├── hmmbuild.py │ │ ├── hmmsearch.py │ │ ├── jackhmmer.py │ │ ├── kalign.py │ │ └── utils.py ├── model │ ├── __init__.py │ ├── all_atom.py │ ├── all_atom_test.py │ ├── common_modules.py │ ├── config.py │ ├── data.py │ ├── features.py │ ├── folding.py │ ├── layer_stack.py │ ├── layer_stack_test.py │ ├── lddt.py │ ├── lddt_test.py │ ├── mapping.py │ ├── model.py │ ├── modules.py │ ├── prng.py │ ├── prng_test.py │ ├── quat_affine.py │ ├── quat_affine_test.py │ ├── r3.py │ ├── tf │ │ ├── __init__.py │ │ ├── data_transforms.py │ │ ├── input_pipeline.py │ │ ├── protein_features.py │ │ ├── protein_features_test.py │ │ ├── proteins_dataset.py │ │ ├── shape_helpers.py │ │ ├── shape_helpers_test.py │ │ ├── shape_placeholders.py │ │ └── utils.py │ └── utils.py └── relax │ ├── __init__.py │ ├── amber_minimize.py │ ├── amber_minimize_test.py │ ├── cleanup.py │ ├── cleanup_test.py │ ├── relax.py │ ├── relax_test.py │ ├── testdata │ ├── model_output.pdb │ ├── multiple_disulfides_target.pdb │ ├── with_violations.pdb │ └── with_violations_casp14.pdb │ ├── utils.py │ └── utils_test.py ├── docker ├── Dockerfile ├── openmm.patch ├── requirements.txt └── run_docker.py ├── header.jpg ├── imgs ├── casp14_predictions.gif └── header.jpg ├── notebooks └── AlphaFold.ipynb ├── requirements.txt ├── run_alphafold.py ├── run_alphafold_test.py ├── scripts ├── download_all_data.sh ├── download_alphafold_params.sh ├── download_bfd.sh ├── download_mgnify.sh ├── download_pdb70.sh ├── download_pdb_mmcif.sh ├── download_small_bfd.sh ├── download_uniclust30.sh └── download_uniref90.sh └── setup.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We welcome small patches related to bug fixes and documentation, but we do not 4 | plan to make any major changes to this repository. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | -------------------------------------------------------------------------------- /alphafold/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """An implementation of the inference pipeline of AlphaFold v2.0.""" 15 | -------------------------------------------------------------------------------- /alphafold/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common data types and constants used within Alphafold.""" 15 | -------------------------------------------------------------------------------- /alphafold/common/confidence.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for processing confidence metrics.""" 16 | 17 | from typing import Dict, Optional, Tuple 18 | import numpy as np 19 | import scipy.special 20 | 21 | 22 | def compute_plddt(logits: np.ndarray) -> np.ndarray: 23 | """Computes per-residue pLDDT from logits. 24 | 25 | Args: 26 | logits: [num_res, num_bins] output from the PredictedLDDTHead. 27 | 28 | Returns: 29 | plddt: [num_res] per-residue pLDDT. 30 | """ 31 | num_bins = logits.shape[-1] 32 | bin_width = 1.0 / num_bins 33 | bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width) 34 | probs = scipy.special.softmax(logits, axis=-1) 35 | predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1) 36 | return predicted_lddt_ca * 100 37 | 38 | 39 | def _calculate_bin_centers(breaks: np.ndarray): 40 | """Gets the bin centers from the bin edges. 41 | 42 | Args: 43 | breaks: [num_bins - 1] the error bin edges. 44 | 45 | Returns: 46 | bin_centers: [num_bins] the error bin centers. 47 | """ 48 | step = (breaks[1] - breaks[0]) 49 | 50 | # Add half-step to get the center 51 | bin_centers = breaks + step / 2 52 | # Add a catch-all bin at the end. 53 | bin_centers = np.concatenate([bin_centers, [bin_centers[-1] + step]], 54 | axis=0) 55 | return bin_centers 56 | 57 | 58 | def _calculate_expected_aligned_error( 59 | alignment_confidence_breaks: np.ndarray, 60 | aligned_distance_error_probs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 61 | """Calculates expected aligned distance errors for every pair of residues. 62 | 63 | Args: 64 | alignment_confidence_breaks: [num_bins - 1] the error bin edges. 65 | aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted 66 | probs for each error bin, for each pair of residues. 67 | 68 | Returns: 69 | predicted_aligned_error: [num_res, num_res] the expected aligned distance 70 | error for each pair of residues. 71 | max_predicted_aligned_error: The maximum predicted error possible. 72 | """ 73 | bin_centers = _calculate_bin_centers(alignment_confidence_breaks) 74 | 75 | # Tuple of expected aligned distance error and max possible error. 76 | return (np.sum(aligned_distance_error_probs * bin_centers, axis=-1), 77 | np.asarray(bin_centers[-1])) 78 | 79 | 80 | def compute_predicted_aligned_error( 81 | logits: np.ndarray, 82 | breaks: np.ndarray) -> Dict[str, np.ndarray]: 83 | """Computes aligned confidence metrics from logits. 84 | 85 | Args: 86 | logits: [num_res, num_res, num_bins] the logits output from 87 | PredictedAlignedErrorHead. 88 | breaks: [num_bins - 1] the error bin edges. 89 | 90 | Returns: 91 | aligned_confidence_probs: [num_res, num_res, num_bins] the predicted 92 | aligned error probabilities over bins for each residue pair. 93 | predicted_aligned_error: [num_res, num_res] the expected aligned distance 94 | error for each pair of residues. 95 | max_predicted_aligned_error: The maximum predicted error possible. 96 | """ 97 | aligned_confidence_probs = scipy.special.softmax( 98 | logits, 99 | axis=-1) 100 | predicted_aligned_error, max_predicted_aligned_error = ( 101 | _calculate_expected_aligned_error( 102 | alignment_confidence_breaks=breaks, 103 | aligned_distance_error_probs=aligned_confidence_probs)) 104 | return { 105 | 'aligned_confidence_probs': aligned_confidence_probs, 106 | 'predicted_aligned_error': predicted_aligned_error, 107 | 'max_predicted_aligned_error': max_predicted_aligned_error, 108 | } 109 | 110 | 111 | def predicted_tm_score( 112 | logits: np.ndarray, 113 | breaks: np.ndarray, 114 | residue_weights: Optional[np.ndarray] = None) -> np.ndarray: 115 | """Computes predicted TM alignment score. 116 | 117 | Args: 118 | logits: [num_res, num_res, num_bins] the logits output from 119 | PredictedAlignedErrorHead. 120 | breaks: [num_bins] the error bins. 121 | residue_weights: [num_res] the per residue weights to use for the 122 | expectation. 123 | 124 | Returns: 125 | ptm_score: the predicted TM alignment score. 126 | """ 127 | 128 | # residue_weights has to be in [0, 1], but can be floating-point, i.e. the 129 | # exp. resolved head's probability. 130 | if residue_weights is None: 131 | residue_weights = np.ones(logits.shape[0]) 132 | 133 | bin_centers = _calculate_bin_centers(breaks) 134 | 135 | num_res = np.sum(residue_weights) 136 | # Clip num_res to avoid negative/undefined d0. 137 | clipped_num_res = max(num_res, 19) 138 | 139 | # Compute d_0(num_res) as defined by TM-score, eqn. (5) in 140 | # http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf 141 | # Yang & Skolnick "Scoring function for automated 142 | # assessment of protein structure template quality" 2004 143 | d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8 144 | 145 | # Convert logits to probs 146 | probs = scipy.special.softmax(logits, axis=-1) 147 | 148 | # TM-Score term for every bin 149 | tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0)) 150 | # E_distances tm(distance) 151 | predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1) 152 | 153 | normed_residue_mask = residue_weights / (1e-8 + residue_weights.sum()) 154 | per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1) 155 | return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()]) 156 | -------------------------------------------------------------------------------- /alphafold/common/protein.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Protein data type.""" 16 | import io 17 | from typing import Any, Mapping, Optional 18 | 19 | from Bio.PDB import PDBParser 20 | import dataclasses 21 | import numpy as np 22 | 23 | from alphafold.common import residue_constants 24 | 25 | FeatureDict = Mapping[str, np.ndarray] 26 | ModelOutput = Mapping[str, Any] # Is a nested dict. 27 | 28 | 29 | @dataclasses.dataclass(frozen=True) 30 | class Protein: 31 | """Protein structure representation.""" 32 | 33 | # Cartesian coordinates of atoms in angstroms. The atom types correspond to 34 | # residue_constants.atom_types, i.e. the first three are N, CA, CB. 35 | atom_positions: np.ndarray # [num_res, num_atom_type, 3] 36 | 37 | # Amino-acid type for each residue represented as an integer between 0 and 38 | # 20, where 20 is 'X'. 39 | aatype: np.ndarray # [num_res] 40 | 41 | # Binary float mask to indicate presence of a particular atom. 1.0 if an atom 42 | # is present and 0.0 if not. This should be used for loss masking. 43 | atom_mask: np.ndarray # [num_res, num_atom_type] 44 | 45 | # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. 46 | residue_index: np.ndarray # [num_res] 47 | 48 | # B-factors, or temperature factors, of each residue (in sq. angstroms units), 49 | # representing the displacement of the residue from its ground truth mean 50 | # value. 51 | b_factors: np.ndarray # [num_res, num_atom_type] 52 | 53 | 54 | def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: 55 | """Takes a PDB string and constructs a Protein object. 56 | 57 | WARNING: All non-standard residue types will be converted into UNK. All 58 | non-standard atoms will be ignored. 59 | 60 | Args: 61 | pdb_str: The contents of the pdb file 62 | chain_id: If None, then the pdb file must contain a single chain (which 63 | will be parsed). If chain_id is specified (e.g. A), then only that chain 64 | is parsed. 65 | 66 | Returns: 67 | A new `Protein` parsed from the pdb contents. 68 | """ 69 | pdb_fh = io.StringIO(pdb_str) 70 | parser = PDBParser(QUIET=True) 71 | structure = parser.get_structure('none', pdb_fh) 72 | models = list(structure.get_models()) 73 | if len(models) != 1: 74 | raise ValueError( 75 | f'Only single model PDBs are supported. Found {len(models)} models.') 76 | model = models[0] 77 | 78 | if chain_id is not None: 79 | chain = model[chain_id] 80 | else: 81 | chains = list(model.get_chains()) 82 | if len(chains) != 1: 83 | raise ValueError( 84 | 'Only single chain PDBs are supported when chain_id not specified. ' 85 | f'Found {len(chains)} chains.') 86 | else: 87 | chain = chains[0] 88 | 89 | atom_positions = [] 90 | aatype = [] 91 | atom_mask = [] 92 | residue_index = [] 93 | b_factors = [] 94 | 95 | for res in chain: 96 | if res.id[2] != ' ': 97 | raise ValueError( 98 | f'PDB contains an insertion code at chain {chain.id} and residue ' 99 | f'index {res.id[1]}. These are not supported.') 100 | res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') 101 | restype_idx = residue_constants.restype_order.get( 102 | res_shortname, residue_constants.restype_num) 103 | pos = np.zeros((residue_constants.atom_type_num, 3)) 104 | mask = np.zeros((residue_constants.atom_type_num,)) 105 | res_b_factors = np.zeros((residue_constants.atom_type_num,)) 106 | for atom in res: 107 | if atom.name not in residue_constants.atom_types: 108 | continue 109 | pos[residue_constants.atom_order[atom.name]] = atom.coord 110 | mask[residue_constants.atom_order[atom.name]] = 1. 111 | res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor 112 | if np.sum(mask) < 0.5: 113 | # If no known atom positions are reported for the residue then skip it. 114 | continue 115 | aatype.append(restype_idx) 116 | atom_positions.append(pos) 117 | atom_mask.append(mask) 118 | residue_index.append(res.id[1]) 119 | b_factors.append(res_b_factors) 120 | 121 | return Protein( 122 | atom_positions=np.array(atom_positions), 123 | atom_mask=np.array(atom_mask), 124 | aatype=np.array(aatype), 125 | residue_index=np.array(residue_index), 126 | b_factors=np.array(b_factors)) 127 | 128 | 129 | def to_pdb(prot: Protein) -> str: 130 | """Converts a `Protein` instance to a PDB string. 131 | 132 | Args: 133 | prot: The protein to convert to PDB. 134 | 135 | Returns: 136 | PDB string. 137 | """ 138 | restypes = residue_constants.restypes + ['X'] 139 | res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') 140 | atom_types = residue_constants.atom_types 141 | 142 | pdb_lines = [] 143 | 144 | atom_mask = prot.atom_mask 145 | aatype = prot.aatype 146 | atom_positions = prot.atom_positions 147 | residue_index = prot.residue_index.astype(np.int32) 148 | b_factors = prot.b_factors 149 | 150 | if np.any(aatype > residue_constants.restype_num): 151 | raise ValueError('Invalid aatypes.') 152 | 153 | pdb_lines.append('MODEL 1') 154 | atom_index = 1 155 | chain_id = 'A' 156 | # Add all atom sites. 157 | for i in range(aatype.shape[0]): 158 | res_name_3 = res_1to3(aatype[i]) 159 | for atom_name, pos, mask, b_factor in zip( 160 | atom_types, atom_positions[i], atom_mask[i], b_factors[i]): 161 | if mask < 0.5: 162 | continue 163 | 164 | record_type = 'ATOM' 165 | name = atom_name if len(atom_name) == 4 else f' {atom_name}' 166 | alt_loc = '' 167 | insertion_code = '' 168 | occupancy = 1.00 169 | element = atom_name[0] # Protein supports only C, N, O, S, this works. 170 | charge = '' 171 | # PDB is a columnar format, every space matters here! 172 | atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' 173 | f'{res_name_3:>3} {chain_id:>1}' 174 | f'{residue_index[i]:>4}{insertion_code:>1} ' 175 | f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' 176 | f'{occupancy:>6.2f}{b_factor:>6.2f} ' 177 | f'{element:>2}{charge:>2}') 178 | pdb_lines.append(atom_line) 179 | atom_index += 1 180 | 181 | # Close the chain. 182 | chain_end = 'TER' 183 | chain_termination_line = ( 184 | f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} ' 185 | f'{chain_id:>1}{residue_index[-1]:>4}') 186 | pdb_lines.append(chain_termination_line) 187 | pdb_lines.append('ENDMDL') 188 | 189 | pdb_lines.append('END') 190 | pdb_lines.append('') 191 | return '\n'.join(pdb_lines) 192 | 193 | 194 | def ideal_atom_mask(prot: Protein) -> np.ndarray: 195 | """Computes an ideal atom mask. 196 | 197 | `Protein.atom_mask` typically is defined according to the atoms that are 198 | reported in the PDB. This function computes a mask according to heavy atoms 199 | that should be present in the given seqence of amino acids. 200 | 201 | Args: 202 | prot: `Protein` whose fields are `numpy.ndarray` objects. 203 | 204 | Returns: 205 | An ideal atom mask. 206 | """ 207 | return residue_constants.STANDARD_ATOM_MASK[prot.aatype] 208 | 209 | 210 | def from_prediction(features: FeatureDict, result: ModelOutput, 211 | b_factors: Optional[np.ndarray] = None) -> Protein: 212 | """Assembles a protein from a prediction. 213 | 214 | Args: 215 | features: Dictionary holding model inputs. 216 | result: Dictionary holding model outputs. 217 | b_factors: (Optional) B-factors to use for the protein. 218 | 219 | Returns: 220 | A protein instance. 221 | """ 222 | fold_output = result['structure_module'] 223 | if b_factors is None: 224 | b_factors = np.zeros_like(fold_output['final_atom_mask']) 225 | 226 | return Protein( 227 | aatype=features['aatype'][0], 228 | atom_positions=fold_output['final_atom_positions'], 229 | atom_mask=fold_output['final_atom_mask'], 230 | residue_index=features['residue_index'][0] + 1, 231 | b_factors=b_factors) 232 | -------------------------------------------------------------------------------- /alphafold/common/protein_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for protein.""" 16 | 17 | import os 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import numpy as np 22 | 23 | from alphafold.common import protein 24 | from alphafold.common import residue_constants 25 | # Internal import (7716). 26 | 27 | TEST_DATA_DIR = 'alphafold/common/testdata/' 28 | 29 | 30 | class ProteinTest(parameterized.TestCase): 31 | 32 | def _check_shapes(self, prot, num_res): 33 | """Check that the processed shapes are correct.""" 34 | num_atoms = residue_constants.atom_type_num 35 | self.assertEqual((num_res, num_atoms, 3), prot.atom_positions.shape) 36 | self.assertEqual((num_res,), prot.aatype.shape) 37 | self.assertEqual((num_res, num_atoms), prot.atom_mask.shape) 38 | self.assertEqual((num_res,), prot.residue_index.shape) 39 | self.assertEqual((num_res, num_atoms), prot.b_factors.shape) 40 | 41 | @parameterized.parameters(('2rbg.pdb', 'A', 282), 42 | ('2rbg.pdb', 'B', 282)) 43 | def test_from_pdb_str(self, pdb_file, chain_id, num_res): 44 | pdb_file = os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, 45 | pdb_file) 46 | with open(pdb_file) as f: 47 | pdb_string = f.read() 48 | prot = protein.from_pdb_string(pdb_string, chain_id) 49 | self._check_shapes(prot, num_res) 50 | self.assertGreaterEqual(prot.aatype.min(), 0) 51 | # Allow equal since unknown restypes have index equal to restype_num. 52 | self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num) 53 | 54 | def test_to_pdb(self): 55 | with open( 56 | os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, 57 | '2rbg.pdb')) as f: 58 | pdb_string = f.read() 59 | prot = protein.from_pdb_string(pdb_string, chain_id='A') 60 | pdb_string_reconstr = protein.to_pdb(prot) 61 | prot_reconstr = protein.from_pdb_string(pdb_string_reconstr) 62 | 63 | np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype) 64 | np.testing.assert_array_almost_equal( 65 | prot_reconstr.atom_positions, prot.atom_positions) 66 | np.testing.assert_array_almost_equal( 67 | prot_reconstr.atom_mask, prot.atom_mask) 68 | np.testing.assert_array_equal( 69 | prot_reconstr.residue_index, prot.residue_index) 70 | np.testing.assert_array_almost_equal( 71 | prot_reconstr.b_factors, prot.b_factors) 72 | 73 | def test_ideal_atom_mask(self): 74 | with open( 75 | os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, 76 | '2rbg.pdb')) as f: 77 | pdb_string = f.read() 78 | prot = protein.from_pdb_string(pdb_string, chain_id='A') 79 | ideal_mask = protein.ideal_atom_mask(prot) 80 | non_ideal_residues = set([102] + list(range(127, 285))) 81 | for i, (res, atom_mask) in enumerate( 82 | zip(prot.residue_index, prot.atom_mask)): 83 | if res in non_ideal_residues: 84 | self.assertFalse(np.all(atom_mask == ideal_mask[i]), msg=f'{res}') 85 | else: 86 | self.assertTrue(np.all(atom_mask == ideal_mask[i]), msg=f'{res}') 87 | 88 | 89 | if __name__ == '__main__': 90 | absltest.main() 91 | -------------------------------------------------------------------------------- /alphafold/common/residue_constants_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Test that residue_constants generates correct values.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import numpy as np 20 | 21 | from alphafold.common import residue_constants 22 | 23 | 24 | class ResidueConstantsTest(parameterized.TestCase): 25 | 26 | @parameterized.parameters( 27 | ('ALA', 0), 28 | ('CYS', 1), 29 | ('HIS', 2), 30 | ('MET', 3), 31 | ('LYS', 4), 32 | ('ARG', 4), 33 | ) 34 | def testChiAnglesAtoms(self, residue_name, chi_num): 35 | chi_angles_atoms = residue_constants.chi_angles_atoms[residue_name] 36 | self.assertLen(chi_angles_atoms, chi_num) 37 | for chi_angle_atoms in chi_angles_atoms: 38 | self.assertLen(chi_angle_atoms, 4) 39 | 40 | def testChiGroupsForAtom(self): 41 | for k, chi_groups in residue_constants.chi_groups_for_atom.items(): 42 | res_name, atom_name = k 43 | for chi_group_i, atom_i in chi_groups: 44 | self.assertEqual( 45 | atom_name, 46 | residue_constants.chi_angles_atoms[res_name][chi_group_i][atom_i]) 47 | 48 | @parameterized.parameters( 49 | ('ALA', 5), ('ARG', 11), ('ASN', 8), ('ASP', 8), ('CYS', 6), ('GLN', 9), 50 | ('GLU', 9), ('GLY', 4), ('HIS', 10), ('ILE', 8), ('LEU', 8), ('LYS', 9), 51 | ('MET', 8), ('PHE', 11), ('PRO', 7), ('SER', 6), ('THR', 7), ('TRP', 14), 52 | ('TYR', 12), ('VAL', 7) 53 | ) 54 | def testResidueAtoms(self, atom_name, num_residue_atoms): 55 | residue_atoms = residue_constants.residue_atoms[atom_name] 56 | self.assertLen(residue_atoms, num_residue_atoms) 57 | 58 | def testStandardAtomMask(self): 59 | with self.subTest('Check shape'): 60 | self.assertEqual(residue_constants.STANDARD_ATOM_MASK.shape, (21, 37,)) 61 | 62 | with self.subTest('Check values'): 63 | str_to_row = lambda s: [c == '1' for c in s] # More clear/concise. 64 | np.testing.assert_array_equal( 65 | residue_constants.STANDARD_ATOM_MASK, 66 | np.array([ 67 | # NB This was defined by c+p but looks sane. 68 | str_to_row('11111 '), # ALA 69 | str_to_row('111111 1 1 11 1 '), # ARG 70 | str_to_row('111111 11 '), # ASP 71 | str_to_row('111111 11 '), # ASN 72 | str_to_row('11111 1 '), # CYS 73 | str_to_row('111111 1 11 '), # GLU 74 | str_to_row('111111 1 11 '), # GLN 75 | str_to_row('111 1 '), # GLY 76 | str_to_row('111111 11 1 1 '), # HIS 77 | str_to_row('11111 11 1 '), # ILE 78 | str_to_row('111111 11 '), # LEU 79 | str_to_row('111111 1 1 1 '), # LYS 80 | str_to_row('111111 11 '), # MET 81 | str_to_row('111111 11 11 1 '), # PHE 82 | str_to_row('111111 1 '), # PRO 83 | str_to_row('11111 1 '), # SER 84 | str_to_row('11111 1 1 '), # THR 85 | str_to_row('111111 11 11 1 1 11 '), # TRP 86 | str_to_row('111111 11 11 11 '), # TYR 87 | str_to_row('11111 11 '), # VAL 88 | str_to_row(' '), # UNK 89 | ])) 90 | 91 | with self.subTest('Check row totals'): 92 | # Check each row has the right number of atoms. 93 | for row, restype in enumerate(residue_constants.restypes): # A, R, ... 94 | long_restype = residue_constants.restype_1to3[restype] # ALA, ARG, ... 95 | atoms_names = residue_constants.residue_atoms[ 96 | long_restype] # ['C', 'CA', 'CB', 'N', 'O'], ... 97 | self.assertLen(atoms_names, 98 | residue_constants.STANDARD_ATOM_MASK[row, :].sum(), 99 | long_restype) 100 | 101 | def testAtomTypes(self): 102 | self.assertEqual(residue_constants.atom_type_num, 37) 103 | 104 | self.assertEqual(residue_constants.atom_types[0], 'N') 105 | self.assertEqual(residue_constants.atom_types[1], 'CA') 106 | self.assertEqual(residue_constants.atom_types[2], 'C') 107 | self.assertEqual(residue_constants.atom_types[3], 'CB') 108 | self.assertEqual(residue_constants.atom_types[4], 'O') 109 | 110 | self.assertEqual(residue_constants.atom_order['N'], 0) 111 | self.assertEqual(residue_constants.atom_order['CA'], 1) 112 | self.assertEqual(residue_constants.atom_order['C'], 2) 113 | self.assertEqual(residue_constants.atom_order['CB'], 3) 114 | self.assertEqual(residue_constants.atom_order['O'], 4) 115 | self.assertEqual(residue_constants.atom_type_num, 37) 116 | 117 | def testRestypes(self): 118 | three_letter_restypes = [ 119 | residue_constants.restype_1to3[r] for r in residue_constants.restypes] 120 | for restype, exp_restype in zip( 121 | three_letter_restypes, sorted(residue_constants.restype_1to3.values())): 122 | self.assertEqual(restype, exp_restype) 123 | self.assertEqual(residue_constants.restype_num, 20) 124 | 125 | def testSequenceToOneHotHHBlits(self): 126 | one_hot = residue_constants.sequence_to_onehot( 127 | 'ABCDEFGHIJKLMNOPQRSTUVWXYZ-', residue_constants.HHBLITS_AA_TO_ID) 128 | exp_one_hot = np.array( 129 | [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 130 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 131 | [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 132 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 133 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 134 | [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 135 | [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 136 | [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 137 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 138 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 139 | [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 140 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 141 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 142 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 143 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 144 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 145 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], 146 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], 147 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], 148 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], 149 | [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 150 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], 151 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], 152 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 153 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], 154 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 155 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]) 156 | np.testing.assert_array_equal(one_hot, exp_one_hot) 157 | 158 | def testSequenceToOneHotStandard(self): 159 | one_hot = residue_constants.sequence_to_onehot( 160 | 'ARNDCQEGHILKMFPSTWYV', residue_constants.restype_order) 161 | np.testing.assert_array_equal(one_hot, np.eye(20)) 162 | 163 | def testSequenceToOneHotUnknownMapping(self): 164 | seq = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 165 | expected_out = np.zeros([26, 21]) 166 | for row, position in enumerate( 167 | [0, 20, 4, 3, 6, 13, 7, 8, 9, 20, 11, 10, 12, 2, 20, 14, 5, 1, 15, 16, 168 | 20, 19, 17, 20, 18, 20]): 169 | expected_out[row, position] = 1 170 | aa_types = residue_constants.sequence_to_onehot( 171 | sequence=seq, 172 | mapping=residue_constants.restype_order_with_x, 173 | map_unknown_to_x=True) 174 | self.assertTrue((aa_types == expected_out).all()) 175 | 176 | @parameterized.named_parameters( 177 | ('lowercase', 'aaa'), # Insertions in A3M. 178 | ('gaps', '---'), # Gaps in A3M. 179 | ('dots', '...'), # Gaps in A3M. 180 | ('metadata', '>TEST'), # FASTA metadata line. 181 | ) 182 | def testSequenceToOneHotUnknownMappingError(self, seq): 183 | with self.assertRaises(ValueError): 184 | residue_constants.sequence_to_onehot( 185 | sequence=seq, 186 | mapping=residue_constants.restype_order_with_x, 187 | map_unknown_to_x=True) 188 | 189 | 190 | if __name__ == '__main__': 191 | absltest.main() 192 | -------------------------------------------------------------------------------- /alphafold/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Data pipeline for model features.""" 15 | -------------------------------------------------------------------------------- /alphafold/data/pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for building the input features for the AlphaFold model.""" 16 | 17 | import os 18 | from typing import Mapping, Optional, Sequence 19 | 20 | import numpy as np 21 | 22 | # Internal import (7716). 23 | 24 | from alphafold.common import residue_constants 25 | from alphafold.data import parsers 26 | from alphafold.data import templates 27 | from alphafold.data.tools import hhblits 28 | from alphafold.data.tools import hhsearch 29 | from alphafold.data.tools import jackhmmer 30 | 31 | FeatureDict = Mapping[str, np.ndarray] 32 | 33 | 34 | def make_sequence_features( 35 | sequence: str, description: str, num_res: int) -> FeatureDict: 36 | """Constructs a feature dict of sequence features.""" 37 | features = {} 38 | features['aatype'] = residue_constants.sequence_to_onehot( 39 | sequence=sequence, 40 | mapping=residue_constants.restype_order_with_x, 41 | map_unknown_to_x=True) 42 | features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32) 43 | features['domain_name'] = np.array([description.encode('utf-8')], 44 | dtype=np.object_) 45 | features['residue_index'] = np.array(range(num_res), dtype=np.int32) 46 | features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) 47 | features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_) 48 | return features 49 | 50 | 51 | def make_msa_features( 52 | msas: Sequence[Sequence[str]], 53 | deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict: 54 | """Constructs a feature dict of MSA features.""" 55 | if not msas: 56 | raise ValueError('At least one MSA must be provided.') 57 | 58 | int_msa = [] 59 | deletion_matrix = [] 60 | seen_sequences = set() 61 | for msa_index, msa in enumerate(msas): 62 | if not msa: 63 | raise ValueError(f'MSA {msa_index} must contain at least one sequence.') 64 | for sequence_index, sequence in enumerate(msa): 65 | if sequence in seen_sequences: 66 | continue 67 | seen_sequences.add(sequence) 68 | int_msa.append( 69 | [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) 70 | deletion_matrix.append(deletion_matrices[msa_index][sequence_index]) 71 | 72 | num_res = len(msas[0][0]) 73 | num_alignments = len(int_msa) 74 | features = {} 75 | features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) 76 | features['msa'] = np.array(int_msa, dtype=np.int32) 77 | features['num_alignments'] = np.array( 78 | [num_alignments] * num_res, dtype=np.int32) 79 | return features 80 | 81 | 82 | class DataPipeline: 83 | """Runs the alignment tools and assembles the input features.""" 84 | 85 | def __init__(self, 86 | jackhmmer_binary_path: str, 87 | hhblits_binary_path: str, 88 | hhsearch_binary_path: str, 89 | uniref90_database_path: str, 90 | mgnify_database_path: str, 91 | bfd_database_path: Optional[str], 92 | uniclust30_database_path: Optional[str], 93 | small_bfd_database_path: Optional[str], 94 | pdb70_database_path: str, 95 | template_featurizer: templates.TemplateHitFeaturizer, 96 | use_small_bfd: bool, 97 | mgnify_max_hits: int = 501, 98 | uniref_max_hits: int = 10000): 99 | """Constructs a feature dict for a given FASTA file.""" 100 | self._use_small_bfd = use_small_bfd 101 | self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( 102 | binary_path=jackhmmer_binary_path, 103 | database_path=uniref90_database_path) 104 | if use_small_bfd: 105 | self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer( 106 | binary_path=jackhmmer_binary_path, 107 | database_path=small_bfd_database_path) 108 | else: 109 | self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( 110 | binary_path=hhblits_binary_path, 111 | databases=[bfd_database_path, uniclust30_database_path]) 112 | self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( 113 | binary_path=jackhmmer_binary_path, 114 | database_path=mgnify_database_path) 115 | self.hhsearch_pdb70_runner = hhsearch.HHSearch( 116 | binary_path=hhsearch_binary_path, 117 | databases=[pdb70_database_path]) 118 | self.template_featurizer = template_featurizer 119 | self.mgnify_max_hits = mgnify_max_hits 120 | self.uniref_max_hits = uniref_max_hits 121 | 122 | def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: 123 | """Runs alignment tools on the input sequence and creates features.""" 124 | with open(input_fasta_path) as f: 125 | input_fasta_str = f.read() 126 | input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) 127 | if len(input_seqs) != 1: 128 | raise ValueError( 129 | f'More than one input sequence found in {input_fasta_path}.') 130 | input_sequence = input_seqs[0] 131 | input_description = input_descs[0] 132 | num_res = len(input_sequence) 133 | 134 | jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query( 135 | input_fasta_path)[0] 136 | jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query( 137 | input_fasta_path)[0] 138 | 139 | uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( 140 | jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits) 141 | hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m) 142 | 143 | uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') 144 | with open(uniref90_out_path, 'w') as f: 145 | f.write(jackhmmer_uniref90_result['sto']) 146 | 147 | mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') 148 | with open(mgnify_out_path, 'w') as f: 149 | f.write(jackhmmer_mgnify_result['sto']) 150 | 151 | uniref90_msa, uniref90_deletion_matrix, _ = parsers.parse_stockholm( 152 | jackhmmer_uniref90_result['sto']) 153 | mgnify_msa, mgnify_deletion_matrix, _ = parsers.parse_stockholm( 154 | jackhmmer_mgnify_result['sto']) 155 | hhsearch_hits = parsers.parse_hhr(hhsearch_result) 156 | mgnify_msa = mgnify_msa[:self.mgnify_max_hits] 157 | mgnify_deletion_matrix = mgnify_deletion_matrix[:self.mgnify_max_hits] 158 | 159 | if self._use_small_bfd: 160 | jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query( 161 | input_fasta_path)[0] 162 | 163 | bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.a3m') 164 | with open(bfd_out_path, 'w') as f: 165 | f.write(jackhmmer_small_bfd_result['sto']) 166 | 167 | bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm( 168 | jackhmmer_small_bfd_result['sto']) 169 | else: 170 | hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query( 171 | input_fasta_path) 172 | 173 | bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') 174 | with open(bfd_out_path, 'w') as f: 175 | f.write(hhblits_bfd_uniclust_result['a3m']) 176 | 177 | bfd_msa, bfd_deletion_matrix = parsers.parse_a3m( 178 | hhblits_bfd_uniclust_result['a3m']) 179 | 180 | templates_result = self.template_featurizer.get_templates( 181 | query_sequence=input_sequence, 182 | query_pdb_code=None, 183 | query_release_date=None, 184 | hits=hhsearch_hits) 185 | 186 | sequence_features = make_sequence_features( 187 | sequence=input_sequence, 188 | description=input_description, 189 | num_res=num_res) 190 | 191 | msa_features = make_msa_features( 192 | msas=(uniref90_msa, bfd_msa, mgnify_msa), 193 | deletion_matrices=(uniref90_deletion_matrix, 194 | bfd_deletion_matrix, 195 | mgnify_deletion_matrix)) 196 | 197 | return {**sequence_features, **msa_features, **templates_result.features} 198 | -------------------------------------------------------------------------------- /alphafold/data/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Python wrappers for third party tools.""" 15 | -------------------------------------------------------------------------------- /alphafold/data/tools/hhblits.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library to run HHblits from Python.""" 16 | 17 | import glob 18 | import os 19 | import subprocess 20 | from typing import Any, Mapping, Optional, Sequence 21 | 22 | from absl import logging 23 | from alphafold.data.tools import utils 24 | # Internal import (7716). 25 | 26 | 27 | _HHBLITS_DEFAULT_P = 20 28 | _HHBLITS_DEFAULT_Z = 500 29 | 30 | 31 | class HHBlits: 32 | """Python wrapper of the HHblits binary.""" 33 | 34 | def __init__(self, 35 | *, 36 | binary_path: str, 37 | databases: Sequence[str], 38 | n_cpu: int = 4, 39 | n_iter: int = 3, 40 | e_value: float = 0.001, 41 | maxseq: int = 1_000_000, 42 | realign_max: int = 100_000, 43 | maxfilt: int = 100_000, 44 | min_prefilter_hits: int = 1000, 45 | all_seqs: bool = False, 46 | alt: Optional[int] = None, 47 | p: int = _HHBLITS_DEFAULT_P, 48 | z: int = _HHBLITS_DEFAULT_Z): 49 | """Initializes the Python HHblits wrapper. 50 | 51 | Args: 52 | binary_path: The path to the HHblits executable. 53 | databases: A sequence of HHblits database paths. This should be the 54 | common prefix for the database files (i.e. up to but not including 55 | _hhm.ffindex etc.) 56 | n_cpu: The number of CPUs to give HHblits. 57 | n_iter: The number of HHblits iterations. 58 | e_value: The E-value, see HHblits docs for more details. 59 | maxseq: The maximum number of rows in an input alignment. Note that this 60 | parameter is only supported in HHBlits version 3.1 and higher. 61 | realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. 62 | maxfilt: Max number of hits allowed to pass the 2nd prefilter. 63 | HHblits default: 20000. 64 | min_prefilter_hits: Min number of hits to pass prefilter. 65 | HHblits default: 100. 66 | all_seqs: Return all sequences in the MSA / Do not filter the result MSA. 67 | HHblits default: False. 68 | alt: Show up to this many alternative alignments. 69 | p: Minimum Prob for a hit to be included in the output hhr file. 70 | HHblits default: 20. 71 | z: Hard cap on number of hits reported in the hhr file. 72 | HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. 73 | 74 | Raises: 75 | RuntimeError: If HHblits binary not found within the path. 76 | """ 77 | self.binary_path = binary_path 78 | self.databases = databases 79 | 80 | for database_path in self.databases: 81 | if not glob.glob(database_path + '_*'): 82 | logging.error('Could not find HHBlits database %s', database_path) 83 | raise ValueError(f'Could not find HHBlits database {database_path}') 84 | 85 | self.n_cpu = n_cpu 86 | self.n_iter = n_iter 87 | self.e_value = e_value 88 | self.maxseq = maxseq 89 | self.realign_max = realign_max 90 | self.maxfilt = maxfilt 91 | self.min_prefilter_hits = min_prefilter_hits 92 | self.all_seqs = all_seqs 93 | self.alt = alt 94 | self.p = p 95 | self.z = z 96 | 97 | def query(self, input_fasta_path: str) -> Mapping[str, Any]: 98 | """Queries the database using HHblits.""" 99 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: 100 | a3m_path = os.path.join(query_tmp_dir, 'output.a3m') 101 | 102 | db_cmd = [] 103 | for db_path in self.databases: 104 | db_cmd.append('-d') 105 | db_cmd.append(db_path) 106 | cmd = [ 107 | self.binary_path, 108 | '-i', input_fasta_path, 109 | '-cpu', str(self.n_cpu), 110 | '-oa3m', a3m_path, 111 | '-o', '/dev/null', 112 | '-n', str(self.n_iter), 113 | '-e', str(self.e_value), 114 | '-maxseq', str(self.maxseq), 115 | '-realign_max', str(self.realign_max), 116 | '-maxfilt', str(self.maxfilt), 117 | '-min_prefilter_hits', str(self.min_prefilter_hits)] 118 | if self.all_seqs: 119 | cmd += ['-all'] 120 | if self.alt: 121 | cmd += ['-alt', str(self.alt)] 122 | if self.p != _HHBLITS_DEFAULT_P: 123 | cmd += ['-p', str(self.p)] 124 | if self.z != _HHBLITS_DEFAULT_Z: 125 | cmd += ['-Z', str(self.z)] 126 | cmd += db_cmd 127 | 128 | logging.info('Launching subprocess "%s"', ' '.join(cmd)) 129 | process = subprocess.Popen( 130 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 131 | 132 | with utils.timing('HHblits query'): 133 | stdout, stderr = process.communicate() 134 | retcode = process.wait() 135 | 136 | if retcode: 137 | # Logs have a 15k character limit, so log HHblits error line by line. 138 | logging.error('HHblits failed. HHblits stderr begin:') 139 | for error_line in stderr.decode('utf-8').splitlines(): 140 | if error_line.strip(): 141 | logging.error(error_line.strip()) 142 | logging.error('HHblits stderr end') 143 | raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % ( 144 | stdout.decode('utf-8'), stderr[:500_000].decode('utf-8'))) 145 | 146 | with open(a3m_path) as f: 147 | a3m = f.read() 148 | 149 | raw_output = dict( 150 | a3m=a3m, 151 | output=stdout, 152 | stderr=stderr, 153 | n_iter=self.n_iter, 154 | e_value=self.e_value) 155 | return raw_output 156 | -------------------------------------------------------------------------------- /alphafold/data/tools/hhsearch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library to run HHsearch from Python.""" 16 | 17 | import glob 18 | import os 19 | import subprocess 20 | from typing import Sequence 21 | 22 | from absl import logging 23 | 24 | from alphafold.data.tools import utils 25 | # Internal import (7716). 26 | 27 | 28 | class HHSearch: 29 | """Python wrapper of the HHsearch binary.""" 30 | 31 | def __init__(self, 32 | *, 33 | binary_path: str, 34 | databases: Sequence[str], 35 | maxseq: int = 1_000_000): 36 | """Initializes the Python HHsearch wrapper. 37 | 38 | Args: 39 | binary_path: The path to the HHsearch executable. 40 | databases: A sequence of HHsearch database paths. This should be the 41 | common prefix for the database files (i.e. up to but not including 42 | _hhm.ffindex etc.) 43 | maxseq: The maximum number of rows in an input alignment. Note that this 44 | parameter is only supported in HHBlits version 3.1 and higher. 45 | 46 | Raises: 47 | RuntimeError: If HHsearch binary not found within the path. 48 | """ 49 | self.binary_path = binary_path 50 | self.databases = databases 51 | self.maxseq = maxseq 52 | 53 | for database_path in self.databases: 54 | if not glob.glob(database_path + '_*'): 55 | logging.error('Could not find HHsearch database %s', database_path) 56 | raise ValueError(f'Could not find HHsearch database {database_path}') 57 | 58 | def query(self, a3m: str) -> str: 59 | """Queries the database using HHsearch using a given a3m.""" 60 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: 61 | input_path = os.path.join(query_tmp_dir, 'query.a3m') 62 | hhr_path = os.path.join(query_tmp_dir, 'output.hhr') 63 | with open(input_path, 'w') as f: 64 | f.write(a3m) 65 | 66 | db_cmd = [] 67 | for db_path in self.databases: 68 | db_cmd.append('-d') 69 | db_cmd.append(db_path) 70 | cmd = [self.binary_path, 71 | '-i', input_path, 72 | '-o', hhr_path, 73 | '-maxseq', str(self.maxseq) 74 | ] + db_cmd 75 | 76 | logging.info('Launching subprocess "%s"', ' '.join(cmd)) 77 | process = subprocess.Popen( 78 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 79 | with utils.timing('HHsearch query'): 80 | stdout, stderr = process.communicate() 81 | retcode = process.wait() 82 | 83 | if retcode: 84 | # Stderr is truncated to prevent proto size errors in Beam. 85 | raise RuntimeError( 86 | 'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( 87 | stdout.decode('utf-8'), stderr[:100_000].decode('utf-8'))) 88 | 89 | with open(hhr_path) as f: 90 | hhr = f.read() 91 | return hhr 92 | -------------------------------------------------------------------------------- /alphafold/data/tools/hmmbuild.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A Python wrapper for hmmbuild - construct HMM profiles from MSA.""" 16 | 17 | import os 18 | import re 19 | import subprocess 20 | 21 | from absl import logging 22 | 23 | # Internal import (7716). 24 | 25 | from alphafold.data.tools import utils 26 | 27 | 28 | class Hmmbuild(object): 29 | """Python wrapper of the hmmbuild binary.""" 30 | 31 | def __init__(self, 32 | *, 33 | binary_path: str, 34 | singlemx: bool = False): 35 | """Initializes the Python hmmbuild wrapper. 36 | 37 | Args: 38 | binary_path: The path to the hmmbuild executable. 39 | singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to 40 | just use a common substitution score matrix. 41 | 42 | Raises: 43 | RuntimeError: If hmmbuild binary not found within the path. 44 | """ 45 | self.binary_path = binary_path 46 | self.singlemx = singlemx 47 | 48 | def build_profile_from_sto(self, sto: str, model_construction='fast') -> str: 49 | """Builds a HHM for the aligned sequences given as an A3M string. 50 | 51 | Args: 52 | sto: A string with the aligned sequences in the Stockholm format. 53 | model_construction: Whether to use reference annotation in the msa to 54 | determine consensus columns ('hand') or default ('fast'). 55 | 56 | Returns: 57 | A string with the profile in the HMM format. 58 | 59 | Raises: 60 | RuntimeError: If hmmbuild fails. 61 | """ 62 | return self._build_profile(sto, model_construction=model_construction) 63 | 64 | def build_profile_from_a3m(self, a3m: str) -> str: 65 | """Builds a HHM for the aligned sequences given as an A3M string. 66 | 67 | Args: 68 | a3m: A string with the aligned sequences in the A3M format. 69 | 70 | Returns: 71 | A string with the profile in the HMM format. 72 | 73 | Raises: 74 | RuntimeError: If hmmbuild fails. 75 | """ 76 | lines = [] 77 | for line in a3m.splitlines(): 78 | if not line.startswith('>'): 79 | line = re.sub('[a-z]+', '', line) # Remove inserted residues. 80 | lines.append(line + '\n') 81 | msa = ''.join(lines) 82 | return self._build_profile(msa, model_construction='fast') 83 | 84 | def _build_profile(self, msa: str, model_construction: str = 'fast') -> str: 85 | """Builds a HMM for the aligned sequences given as an MSA string. 86 | 87 | Args: 88 | msa: A string with the aligned sequences, in A3M or STO format. 89 | model_construction: Whether to use reference annotation in the msa to 90 | determine consensus columns ('hand') or default ('fast'). 91 | 92 | Returns: 93 | A string with the profile in the HMM format. 94 | 95 | Raises: 96 | RuntimeError: If hmmbuild fails. 97 | ValueError: If unspecified arguments are provided. 98 | """ 99 | if model_construction not in {'hand', 'fast'}: 100 | raise ValueError(f'Invalid model_construction {model_construction} - only' 101 | 'hand and fast supported.') 102 | 103 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: 104 | input_query = os.path.join(query_tmp_dir, 'query.msa') 105 | output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm') 106 | 107 | with open(input_query, 'w') as f: 108 | f.write(msa) 109 | 110 | cmd = [self.binary_path] 111 | # If adding flags, we have to do so before the output and input: 112 | 113 | if model_construction == 'hand': 114 | cmd.append(f'--{model_construction}') 115 | if self.singlemx: 116 | cmd.append('--singlemx') 117 | cmd.extend([ 118 | '--amino', 119 | output_hmm_path, 120 | input_query, 121 | ]) 122 | 123 | logging.info('Launching subprocess %s', cmd) 124 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, 125 | stderr=subprocess.PIPE) 126 | 127 | with utils.timing('hmmbuild query'): 128 | stdout, stderr = process.communicate() 129 | retcode = process.wait() 130 | logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n', 131 | stdout.decode('utf-8'), stderr.decode('utf-8')) 132 | 133 | if retcode: 134 | raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n' 135 | % (stdout.decode('utf-8'), stderr.decode('utf-8'))) 136 | 137 | with open(output_hmm_path, encoding='utf-8') as f: 138 | hmm = f.read() 139 | 140 | return hmm 141 | -------------------------------------------------------------------------------- /alphafold/data/tools/hmmsearch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A Python wrapper for hmmsearch - search profile against a sequence db.""" 16 | 17 | import os 18 | import subprocess 19 | from typing import Optional, Sequence 20 | 21 | from absl import logging 22 | 23 | # Internal import (7716). 24 | 25 | from alphafold.data.tools import utils 26 | 27 | 28 | class Hmmsearch(object): 29 | """Python wrapper of the hmmsearch binary.""" 30 | 31 | def __init__(self, 32 | *, 33 | binary_path: str, 34 | database_path: str, 35 | flags: Optional[Sequence[str]] = None): 36 | """Initializes the Python hmmsearch wrapper. 37 | 38 | Args: 39 | binary_path: The path to the hmmsearch executable. 40 | database_path: The path to the hmmsearch database (FASTA format). 41 | flags: List of flags to be used by hmmsearch. 42 | 43 | Raises: 44 | RuntimeError: If hmmsearch binary not found within the path. 45 | """ 46 | self.binary_path = binary_path 47 | self.database_path = database_path 48 | self.flags = flags 49 | 50 | if not os.path.exists(self.database_path): 51 | logging.error('Could not find hmmsearch database %s', database_path) 52 | raise ValueError(f'Could not find hmmsearch database {database_path}') 53 | 54 | def query(self, hmm: str) -> str: 55 | """Queries the database using hmmsearch using a given hmm.""" 56 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: 57 | hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') 58 | a3m_out_path = os.path.join(query_tmp_dir, 'output.a3m') 59 | with open(hmm_input_path, 'w') as f: 60 | f.write(hmm) 61 | 62 | cmd = [ 63 | self.binary_path, 64 | '--noali', # Don't include the alignment in stdout. 65 | '--cpu', '8' 66 | ] 67 | # If adding flags, we have to do so before the output and input: 68 | if self.flags: 69 | cmd.extend(self.flags) 70 | cmd.extend([ 71 | '-A', a3m_out_path, 72 | hmm_input_path, 73 | self.database_path, 74 | ]) 75 | 76 | logging.info('Launching sub-process %s', cmd) 77 | process = subprocess.Popen( 78 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 79 | with utils.timing( 80 | f'hmmsearch ({os.path.basename(self.database_path)}) query'): 81 | stdout, stderr = process.communicate() 82 | retcode = process.wait() 83 | 84 | if retcode: 85 | raise RuntimeError( 86 | 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( 87 | stdout.decode('utf-8'), stderr.decode('utf-8'))) 88 | 89 | with open(a3m_out_path) as f: 90 | a3m_out = f.read() 91 | 92 | return a3m_out 93 | -------------------------------------------------------------------------------- /alphafold/data/tools/jackhmmer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library to run Jackhmmer from Python.""" 16 | 17 | from concurrent import futures 18 | import glob 19 | import os 20 | import subprocess 21 | from typing import Any, Callable, Mapping, Optional, Sequence 22 | from urllib import request 23 | 24 | from absl import logging 25 | 26 | from alphafold.data.tools import utils 27 | # Internal import (7716). 28 | 29 | 30 | class Jackhmmer: 31 | """Python wrapper of the Jackhmmer binary.""" 32 | 33 | def __init__(self, 34 | *, 35 | binary_path: str, 36 | database_path: str, 37 | n_cpu: int = 8, 38 | n_iter: int = 1, 39 | e_value: float = 0.0001, 40 | z_value: Optional[int] = None, 41 | get_tblout: bool = False, 42 | filter_f1: float = 0.0005, 43 | filter_f2: float = 0.00005, 44 | filter_f3: float = 0.0000005, 45 | incdom_e: Optional[float] = None, 46 | dom_e: Optional[float] = None, 47 | num_streamed_chunks: Optional[int] = None, 48 | streaming_callback: Optional[Callable[[int], None]] = None): 49 | """Initializes the Python Jackhmmer wrapper. 50 | 51 | Args: 52 | binary_path: The path to the jackhmmer executable. 53 | database_path: The path to the jackhmmer database (FASTA format). 54 | n_cpu: The number of CPUs to give Jackhmmer. 55 | n_iter: The number of Jackhmmer iterations. 56 | e_value: The E-value, see Jackhmmer docs for more details. 57 | z_value: The Z-value, see Jackhmmer docs for more details. 58 | get_tblout: Whether to save tblout string. 59 | filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. 60 | filter_f2: Viterbi pre-filter, set to >1.0 to turn off. 61 | filter_f3: Forward pre-filter, set to >1.0 to turn off. 62 | incdom_e: Domain e-value criteria for inclusion of domains in MSA/next 63 | round. 64 | dom_e: Domain e-value criteria for inclusion in tblout. 65 | num_streamed_chunks: Number of database chunks to stream over. 66 | streaming_callback: Callback function run after each chunk iteration with 67 | the iteration number as argument. 68 | """ 69 | self.binary_path = binary_path 70 | self.database_path = database_path 71 | self.num_streamed_chunks = num_streamed_chunks 72 | 73 | if not os.path.exists(self.database_path) and num_streamed_chunks is None: 74 | logging.error('Could not find Jackhmmer database %s', database_path) 75 | raise ValueError(f'Could not find Jackhmmer database {database_path}') 76 | 77 | self.n_cpu = n_cpu 78 | self.n_iter = n_iter 79 | self.e_value = e_value 80 | self.z_value = z_value 81 | self.filter_f1 = filter_f1 82 | self.filter_f2 = filter_f2 83 | self.filter_f3 = filter_f3 84 | self.incdom_e = incdom_e 85 | self.dom_e = dom_e 86 | self.get_tblout = get_tblout 87 | self.streaming_callback = streaming_callback 88 | 89 | def _query_chunk(self, input_fasta_path: str, database_path: str 90 | ) -> Mapping[str, Any]: 91 | """Queries the database chunk using Jackhmmer.""" 92 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: 93 | sto_path = os.path.join(query_tmp_dir, 'output.sto') 94 | 95 | # The F1/F2/F3 are the expected proportion to pass each of the filtering 96 | # stages (which get progressively more expensive), reducing these 97 | # speeds up the pipeline at the expensive of sensitivity. They are 98 | # currently set very low to make querying Mgnify run in a reasonable 99 | # amount of time. 100 | cmd_flags = [ 101 | # Don't pollute stdout with Jackhmmer output. 102 | '-o', '/dev/null', 103 | '-A', sto_path, 104 | '--noali', 105 | '--F1', str(self.filter_f1), 106 | '--F2', str(self.filter_f2), 107 | '--F3', str(self.filter_f3), 108 | '--incE', str(self.e_value), 109 | # Report only sequences with E-values <= x in per-sequence output. 110 | '-E', str(self.e_value), 111 | '--cpu', str(self.n_cpu), 112 | '-N', str(self.n_iter) 113 | ] 114 | if self.get_tblout: 115 | tblout_path = os.path.join(query_tmp_dir, 'tblout.txt') 116 | cmd_flags.extend(['--tblout', tblout_path]) 117 | 118 | if self.z_value: 119 | cmd_flags.extend(['-Z', str(self.z_value)]) 120 | 121 | if self.dom_e is not None: 122 | cmd_flags.extend(['--domE', str(self.dom_e)]) 123 | 124 | if self.incdom_e is not None: 125 | cmd_flags.extend(['--incdomE', str(self.incdom_e)]) 126 | 127 | cmd = [self.binary_path] + cmd_flags + [input_fasta_path, 128 | database_path] 129 | 130 | logging.info('Launching subprocess "%s"', ' '.join(cmd)) 131 | process = subprocess.Popen( 132 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 133 | with utils.timing( 134 | f'Jackhmmer ({os.path.basename(database_path)}) query'): 135 | _, stderr = process.communicate() 136 | retcode = process.wait() 137 | 138 | if retcode: 139 | raise RuntimeError( 140 | 'Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8')) 141 | 142 | # Get e-values for each target name 143 | tbl = '' 144 | if self.get_tblout: 145 | with open(tblout_path) as f: 146 | tbl = f.read() 147 | 148 | with open(sto_path) as f: 149 | sto = f.read() 150 | 151 | raw_output = dict( 152 | sto=sto, 153 | tbl=tbl, 154 | stderr=stderr, 155 | n_iter=self.n_iter, 156 | e_value=self.e_value) 157 | 158 | return raw_output 159 | 160 | def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: 161 | """Queries the database using Jackhmmer.""" 162 | if self.num_streamed_chunks is None: 163 | return [self._query_chunk(input_fasta_path, self.database_path)] 164 | 165 | db_basename = os.path.basename(self.database_path) 166 | db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' 167 | db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}' 168 | 169 | # Remove existing files to prevent OOM 170 | for f in glob.glob(db_local_chunk('[0-9]*')): 171 | try: 172 | os.remove(f) 173 | except OSError: 174 | print(f'OSError while deleting {f}') 175 | 176 | # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk 177 | with futures.ThreadPoolExecutor(max_workers=2) as executor: 178 | chunked_output = [] 179 | for i in range(1, self.num_streamed_chunks + 1): 180 | # Copy the chunk locally 181 | if i == 1: 182 | future = executor.submit( 183 | request.urlretrieve, db_remote_chunk(i), db_local_chunk(i)) 184 | if i < self.num_streamed_chunks: 185 | next_future = executor.submit( 186 | request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1)) 187 | 188 | # Run Jackhmmer with the chunk 189 | future.result() 190 | chunked_output.append( 191 | self._query_chunk(input_fasta_path, db_local_chunk(i))) 192 | 193 | # Remove the local copy of the chunk 194 | os.remove(db_local_chunk(i)) 195 | future = next_future 196 | if self.streaming_callback: 197 | self.streaming_callback(i) 198 | return chunked_output 199 | -------------------------------------------------------------------------------- /alphafold/data/tools/kalign.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A Python wrapper for Kalign.""" 16 | import os 17 | import subprocess 18 | from typing import Sequence 19 | 20 | from absl import logging 21 | 22 | from alphafold.data.tools import utils 23 | # Internal import (7716). 24 | 25 | 26 | def _to_a3m(sequences: Sequence[str]) -> str: 27 | """Converts sequences to an a3m file.""" 28 | names = ['sequence %d' % i for i in range(1, len(sequences) + 1)] 29 | a3m = [] 30 | for sequence, name in zip(sequences, names): 31 | a3m.append(u'>' + name + u'\n') 32 | a3m.append(sequence + u'\n') 33 | return ''.join(a3m) 34 | 35 | 36 | class Kalign: 37 | """Python wrapper of the Kalign binary.""" 38 | 39 | def __init__(self, *, binary_path: str): 40 | """Initializes the Python Kalign wrapper. 41 | 42 | Args: 43 | binary_path: The path to the Kalign binary. 44 | 45 | Raises: 46 | RuntimeError: If Kalign binary not found within the path. 47 | """ 48 | self.binary_path = binary_path 49 | 50 | def align(self, sequences: Sequence[str]) -> str: 51 | """Aligns the sequences and returns the alignment in A3M string. 52 | 53 | Args: 54 | sequences: A list of query sequence strings. The sequences have to be at 55 | least 6 residues long (Kalign requires this). Note that the order in 56 | which you give the sequences might alter the output slightly as 57 | different alignment tree might get constructed. 58 | 59 | Returns: 60 | A string with the alignment in a3m format. 61 | 62 | Raises: 63 | RuntimeError: If Kalign fails. 64 | ValueError: If any of the sequences is less than 6 residues long. 65 | """ 66 | logging.info('Aligning %d sequences', len(sequences)) 67 | 68 | for s in sequences: 69 | if len(s) < 6: 70 | raise ValueError('Kalign requires all sequences to be at least 6 ' 71 | 'residues long. Got %s (%d residues).' % (s, len(s))) 72 | 73 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: 74 | input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta') 75 | output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') 76 | 77 | with open(input_fasta_path, 'w') as f: 78 | f.write(_to_a3m(sequences)) 79 | 80 | cmd = [ 81 | self.binary_path, 82 | '-i', input_fasta_path, 83 | '-o', output_a3m_path, 84 | '-format', 'fasta', 85 | ] 86 | 87 | logging.info('Launching subprocess "%s"', ' '.join(cmd)) 88 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, 89 | stderr=subprocess.PIPE) 90 | 91 | with utils.timing('Kalign query'): 92 | stdout, stderr = process.communicate() 93 | retcode = process.wait() 94 | logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n', 95 | stdout.decode('utf-8'), stderr.decode('utf-8')) 96 | 97 | if retcode: 98 | raise RuntimeError('Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' 99 | % (stdout.decode('utf-8'), stderr.decode('utf-8'))) 100 | 101 | with open(output_a3m_path) as f: 102 | a3m = f.read() 103 | 104 | return a3m 105 | -------------------------------------------------------------------------------- /alphafold/data/tools/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common utilities for data pipeline tools.""" 15 | import contextlib 16 | import shutil 17 | import tempfile 18 | import time 19 | from typing import Optional 20 | 21 | from absl import logging 22 | 23 | 24 | @contextlib.contextmanager 25 | def tmpdir_manager(base_dir: Optional[str] = None): 26 | """Context manager that deletes a temporary directory on exit.""" 27 | tmpdir = tempfile.mkdtemp(dir=base_dir) 28 | try: 29 | yield tmpdir 30 | finally: 31 | shutil.rmtree(tmpdir, ignore_errors=True) 32 | 33 | 34 | @contextlib.contextmanager 35 | def timing(msg: str): 36 | logging.info('Started %s', msg) 37 | tic = time.time() 38 | yield 39 | toc = time.time() 40 | logging.info('Finished %s in %.3f seconds', msg, toc - tic) 41 | -------------------------------------------------------------------------------- /alphafold/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Alphafold model.""" 15 | -------------------------------------------------------------------------------- /alphafold/model/all_atom_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for all_atom.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import numpy as np 20 | from alphafold.model import all_atom 21 | from alphafold.model import r3 22 | 23 | L1_CLAMP_DISTANCE = 10 24 | 25 | 26 | def get_identity_rigid(shape): 27 | """Returns identity rigid transform.""" 28 | 29 | ones = np.ones(shape) 30 | zeros = np.zeros(shape) 31 | rot = r3.Rots(ones, zeros, zeros, 32 | zeros, ones, zeros, 33 | zeros, zeros, ones) 34 | trans = r3.Vecs(zeros, zeros, zeros) 35 | return r3.Rigids(rot, trans) 36 | 37 | 38 | def get_global_rigid_transform(rot_angle, translation, bcast_dims): 39 | """Returns rigid transform that globally rotates/translates by same amount.""" 40 | 41 | rot_angle = np.asarray(rot_angle) 42 | translation = np.asarray(translation) 43 | if bcast_dims: 44 | for _ in range(bcast_dims): 45 | rot_angle = np.expand_dims(rot_angle, 0) 46 | translation = np.expand_dims(translation, 0) 47 | sin_angle = np.sin(np.deg2rad(rot_angle)) 48 | cos_angle = np.cos(np.deg2rad(rot_angle)) 49 | ones = np.ones_like(sin_angle) 50 | zeros = np.zeros_like(sin_angle) 51 | rot = r3.Rots(ones, zeros, zeros, 52 | zeros, cos_angle, -sin_angle, 53 | zeros, sin_angle, cos_angle) 54 | trans = r3.Vecs(translation[..., 0], translation[..., 1], translation[..., 2]) 55 | return r3.Rigids(rot, trans) 56 | 57 | 58 | class AllAtomTest(parameterized.TestCase, absltest.TestCase): 59 | 60 | @parameterized.named_parameters( 61 | ('identity', 0, [0, 0, 0]), 62 | ('rot_90', 90, [0, 0, 0]), 63 | ('trans_10', 0, [0, 0, 10]), 64 | ('rot_174_trans_1', 174, [1, 1, 1])) 65 | def test_frame_aligned_point_error_perfect_on_global_transform( 66 | self, rot_angle, translation): 67 | """Tests global transform between target and preds gives perfect score.""" 68 | 69 | # pylint: disable=bad-whitespace 70 | target_positions = np.array( 71 | [[ 21.182, 23.095, 19.731], 72 | [ 22.055, 20.919, 17.294], 73 | [ 24.599, 20.005, 15.041], 74 | [ 25.567, 18.214, 12.166], 75 | [ 28.063, 17.082, 10.043], 76 | [ 28.779, 15.569, 6.985], 77 | [ 30.581, 13.815, 4.612], 78 | [ 29.258, 12.193, 2.296]]) 79 | # pylint: enable=bad-whitespace 80 | global_rigid_transform = get_global_rigid_transform( 81 | rot_angle, translation, 1) 82 | 83 | target_positions = r3.vecs_from_tensor(target_positions) 84 | pred_positions = r3.rigids_mul_vecs( 85 | global_rigid_transform, target_positions) 86 | positions_mask = np.ones(target_positions.x.shape[0]) 87 | 88 | target_frames = get_identity_rigid(10) 89 | pred_frames = r3.rigids_mul_rigids(global_rigid_transform, target_frames) 90 | frames_mask = np.ones(10) 91 | 92 | fape = all_atom.frame_aligned_point_error( 93 | pred_frames, target_frames, frames_mask, pred_positions, 94 | target_positions, positions_mask, L1_CLAMP_DISTANCE, 95 | L1_CLAMP_DISTANCE, epsilon=0) 96 | self.assertAlmostEqual(fape, 0.) 97 | 98 | @parameterized.named_parameters( 99 | ('identity', 100 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 101 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 102 | 0.), 103 | ('shift_2.5', 104 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 105 | [[2.5, 0, 0], [7.5, 0, 0], [7.5, 0, 0]], 106 | 0.25), 107 | ('shift_5', 108 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 109 | [[5, 0, 0], [10, 0, 0], [15, 0, 0]], 110 | 0.5), 111 | ('shift_10', 112 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 113 | [[10, 0, 0], [15, 0, 0], [0, 0, 0]], 114 | 1.)) 115 | def test_frame_aligned_point_error_matches_expected( 116 | self, target_positions, pred_positions, expected_alddt): 117 | """Tests score matches expected.""" 118 | 119 | target_frames = get_identity_rigid(2) 120 | pred_frames = target_frames 121 | frames_mask = np.ones(2) 122 | 123 | target_positions = r3.vecs_from_tensor(np.array(target_positions)) 124 | pred_positions = r3.vecs_from_tensor(np.array(pred_positions)) 125 | positions_mask = np.ones(target_positions.x.shape[0]) 126 | 127 | alddt = all_atom.frame_aligned_point_error( 128 | pred_frames, target_frames, frames_mask, pred_positions, 129 | target_positions, positions_mask, L1_CLAMP_DISTANCE, 130 | L1_CLAMP_DISTANCE, epsilon=0) 131 | self.assertAlmostEqual(alddt, expected_alddt) 132 | 133 | 134 | if __name__ == '__main__': 135 | absltest.main() 136 | -------------------------------------------------------------------------------- /alphafold/model/common_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A collection of common Haiku modules for use in protein folding.""" 16 | import haiku as hk 17 | import jax.numpy as jnp 18 | 19 | 20 | class Linear(hk.Module): 21 | """Protein folding specific Linear Module. 22 | 23 | This differs from the standard Haiku Linear in a few ways: 24 | * It supports inputs of arbitrary rank 25 | * Initializers are specified by strings 26 | """ 27 | 28 | def __init__(self, 29 | num_output: int, 30 | initializer: str = 'linear', 31 | use_bias: bool = True, 32 | bias_init: float = 0., 33 | name: str = 'linear'): 34 | """Constructs Linear Module. 35 | 36 | Args: 37 | num_output: number of output channels. 38 | initializer: What initializer to use, should be one of {'linear', 'relu', 39 | 'zeros'} 40 | use_bias: Whether to include trainable bias 41 | bias_init: Value used to initialize bias. 42 | name: name of module, used for name scopes. 43 | """ 44 | 45 | super().__init__(name=name) 46 | self.num_output = num_output 47 | self.initializer = initializer 48 | self.use_bias = use_bias 49 | self.bias_init = bias_init 50 | 51 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: 52 | """Connects Module. 53 | 54 | Args: 55 | inputs: Tensor of shape [..., num_channel] 56 | 57 | Returns: 58 | output of shape [..., num_output] 59 | """ 60 | n_channels = int(inputs.shape[-1]) 61 | 62 | weight_shape = [n_channels, self.num_output] 63 | if self.initializer == 'linear': 64 | weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=1.) 65 | elif self.initializer == 'relu': 66 | weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=2.) 67 | elif self.initializer == 'zeros': 68 | weight_init = hk.initializers.Constant(0.0) 69 | 70 | weights = hk.get_parameter('weights', weight_shape, inputs.dtype, 71 | weight_init) 72 | 73 | # this is equivalent to einsum('...c,cd->...d', inputs, weights) 74 | # but turns out to be slightly faster 75 | inputs = jnp.swapaxes(inputs, -1, -2) 76 | output = jnp.einsum('...cb,cd->...db', inputs, weights) 77 | output = jnp.swapaxes(output, -1, -2) 78 | 79 | if self.use_bias: 80 | bias = hk.get_parameter('bias', [self.num_output], inputs.dtype, 81 | hk.initializers.Constant(self.bias_init)) 82 | output += bias 83 | 84 | return output 85 | -------------------------------------------------------------------------------- /alphafold/model/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Convenience functions for reading data.""" 16 | 17 | import io 18 | import os 19 | from typing import List 20 | 21 | import haiku as hk 22 | import numpy as np 23 | 24 | from alphafold.model import utils 25 | # Internal import (7716). 26 | 27 | 28 | def casp_model_names(data_dir: str) -> List[str]: 29 | params = os.listdir(os.path.join(data_dir, 'params')) 30 | return [os.path.splitext(filename)[0] for filename in params] 31 | 32 | 33 | def get_model_haiku_params(model_name: str, data_dir: str) -> hk.Params: 34 | """Get the Haiku parameters from a model name.""" 35 | 36 | path = os.path.join(data_dir, 'params', f'params_{model_name}.npz') 37 | 38 | with open(path, 'rb') as f: 39 | params = np.load(io.BytesIO(f.read()), allow_pickle=False) 40 | 41 | return utils.flat_params_to_haiku(params) 42 | -------------------------------------------------------------------------------- /alphafold/model/features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Code to generate processed features.""" 16 | import copy 17 | from typing import List, Mapping, Tuple 18 | 19 | import ml_collections 20 | import numpy as np 21 | import tensorflow.compat.v1 as tf 22 | 23 | from alphafold.model.tf import input_pipeline 24 | from alphafold.model.tf import proteins_dataset 25 | 26 | FeatureDict = Mapping[str, np.ndarray] 27 | 28 | 29 | def make_data_config( 30 | config: ml_collections.ConfigDict, 31 | num_res: int, 32 | ) -> Tuple[ml_collections.ConfigDict, List[str]]: 33 | """Makes a data config for the input pipeline.""" 34 | cfg = copy.deepcopy(config.data) 35 | 36 | feature_names = cfg.common.unsupervised_features 37 | if cfg.common.use_templates: 38 | feature_names += cfg.common.template_features 39 | 40 | with cfg.unlocked(): 41 | cfg.eval.crop_size = num_res 42 | 43 | return cfg, feature_names 44 | 45 | 46 | def tf_example_to_features(tf_example: tf.train.Example, 47 | config: ml_collections.ConfigDict, 48 | random_seed: int = 0) -> FeatureDict: 49 | """Converts tf_example to numpy feature dictionary.""" 50 | num_res = int(tf_example.features.feature['seq_length'].int64_list.value[0]) 51 | cfg, feature_names = make_data_config(config, num_res=num_res) 52 | 53 | if 'deletion_matrix_int' in set(tf_example.features.feature): 54 | deletion_matrix_int = ( 55 | tf_example.features.feature['deletion_matrix_int'].int64_list.value) 56 | feat = tf.train.Feature(float_list=tf.train.FloatList( 57 | value=map(float, deletion_matrix_int))) 58 | tf_example.features.feature['deletion_matrix'].CopyFrom(feat) 59 | del tf_example.features.feature['deletion_matrix_int'] 60 | 61 | tf_graph = tf.Graph() 62 | with tf_graph.as_default(), tf.device('/device:CPU:0'): 63 | tf.compat.v1.set_random_seed(random_seed) 64 | tensor_dict = proteins_dataset.create_tensor_dict( 65 | raw_data=tf_example.SerializeToString(), 66 | features=feature_names) 67 | processed_batch = input_pipeline.process_tensors_from_config( 68 | tensor_dict, cfg) 69 | 70 | tf_graph.finalize() 71 | 72 | with tf.Session(graph=tf_graph) as sess: 73 | features = sess.run(processed_batch) 74 | 75 | return {k: v for k, v in features.items() if v.dtype != 'O'} 76 | 77 | 78 | def np_example_to_features(np_example: FeatureDict, 79 | config: ml_collections.ConfigDict, 80 | random_seed: int = 0) -> FeatureDict: 81 | """Preprocesses NumPy feature dict using TF pipeline.""" 82 | np_example = dict(np_example) 83 | num_res = int(np_example['seq_length'][0]) 84 | cfg, feature_names = make_data_config(config, num_res=num_res) 85 | 86 | if 'deletion_matrix_int' in np_example: 87 | np_example['deletion_matrix'] = ( 88 | np_example.pop('deletion_matrix_int').astype(np.float32)) 89 | 90 | tf_graph = tf.Graph() 91 | with tf_graph.as_default(), tf.device('/device:CPU:0'): 92 | tf.compat.v1.set_random_seed(random_seed) 93 | tensor_dict = proteins_dataset.np_to_tensor_dict( 94 | np_example=np_example, features=feature_names) 95 | 96 | processed_batch = input_pipeline.process_tensors_from_config( 97 | tensor_dict, cfg) 98 | 99 | tf_graph.finalize() 100 | 101 | with tf.Session(graph=tf_graph) as sess: 102 | features = sess.run(processed_batch) 103 | 104 | return {k: v for k, v in features.items() if v.dtype != 'O'} 105 | -------------------------------------------------------------------------------- /alphafold/model/lddt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """lDDT protein distance score.""" 16 | import jax.numpy as jnp 17 | 18 | 19 | def lddt(predicted_points, 20 | true_points, 21 | true_points_mask, 22 | cutoff=15., 23 | per_residue=False): 24 | """Measure (approximate) lDDT for a batch of coordinates. 25 | 26 | lDDT reference: 27 | Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local 28 | superposition-free score for comparing protein structures and models using 29 | distance difference tests. Bioinformatics 29, 2722–2728 (2013). 30 | 31 | lDDT is a measure of the difference between the true distance matrix and the 32 | distance matrix of the predicted points. The difference is computed only on 33 | points closer than cutoff *in the true structure*. 34 | 35 | This function does not compute the exact lDDT value that the original paper 36 | describes because it does not include terms for physical feasibility 37 | (e.g. bond length violations). Therefore this is only an approximate 38 | lDDT score. 39 | 40 | Args: 41 | predicted_points: (batch, length, 3) array of predicted 3D points 42 | true_points: (batch, length, 3) array of true 3D points 43 | true_points_mask: (batch, length, 1) binary-valued float array. This mask 44 | should be 1 for points that exist in the true points. 45 | cutoff: Maximum distance for a pair of points to be included 46 | per_residue: If true, return score for each residue. Note that the overall 47 | lDDT is not exactly the mean of the per_residue lDDT's because some 48 | residues have more contacts than others. 49 | 50 | Returns: 51 | An (approximate, see above) lDDT score in the range 0-1. 52 | """ 53 | 54 | assert len(predicted_points.shape) == 3 55 | assert predicted_points.shape[-1] == 3 56 | assert true_points_mask.shape[-1] == 1 57 | assert len(true_points_mask.shape) == 3 58 | 59 | # Compute true and predicted distance matrices. 60 | dmat_true = jnp.sqrt(1e-10 + jnp.sum( 61 | (true_points[:, :, None] - true_points[:, None, :])**2, axis=-1)) 62 | 63 | dmat_predicted = jnp.sqrt(1e-10 + jnp.sum( 64 | (predicted_points[:, :, None] - 65 | predicted_points[:, None, :])**2, axis=-1)) 66 | 67 | dists_to_score = ( 68 | (dmat_true < cutoff).astype(jnp.float32) * true_points_mask * 69 | jnp.transpose(true_points_mask, [0, 2, 1]) * 70 | (1. - jnp.eye(dmat_true.shape[1])) # Exclude self-interaction. 71 | ) 72 | 73 | # Shift unscored distances to be far away. 74 | dist_l1 = jnp.abs(dmat_true - dmat_predicted) 75 | 76 | # True lDDT uses a number of fixed bins. 77 | # We ignore the physical plausibility correction to lDDT, though. 78 | score = 0.25 * ((dist_l1 < 0.5).astype(jnp.float32) + 79 | (dist_l1 < 1.0).astype(jnp.float32) + 80 | (dist_l1 < 2.0).astype(jnp.float32) + 81 | (dist_l1 < 4.0).astype(jnp.float32)) 82 | 83 | # Normalize over the appropriate axes. 84 | reduce_axes = (-1,) if per_residue else (-2, -1) 85 | norm = 1. / (1e-10 + jnp.sum(dists_to_score, axis=reduce_axes)) 86 | score = norm * (1e-10 + jnp.sum(dists_to_score * score, axis=reduce_axes)) 87 | 88 | return score 89 | -------------------------------------------------------------------------------- /alphafold/model/lddt_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for lddt.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | import numpy as np 20 | from alphafold.model import lddt 21 | 22 | 23 | class LddtTest(parameterized.TestCase, absltest.TestCase): 24 | 25 | @parameterized.named_parameters( 26 | ('same', 27 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 28 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 29 | [1, 1, 1]), 30 | ('all_shifted', 31 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 32 | [[-1, 0, 0], [4, 0, 0], [9, 0, 0]], 33 | [1, 1, 1]), 34 | ('all_rotated', 35 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 36 | [[0, 0, 0], [0, 5, 0], [0, 10, 0]], 37 | [1, 1, 1]), 38 | ('half_a_dist', 39 | [[0, 0, 0], [5, 0, 0]], 40 | [[0, 0, 0], [5.5-1e-5, 0, 0]], 41 | [1, 1]), 42 | ('one_a_dist', 43 | [[0, 0, 0], [5, 0, 0]], 44 | [[0, 0, 0], [6-1e-5, 0, 0]], 45 | [0.75, 0.75]), 46 | ('two_a_dist', 47 | [[0, 0, 0], [5, 0, 0]], 48 | [[0, 0, 0], [7-1e-5, 0, 0]], 49 | [0.5, 0.5]), 50 | ('four_a_dist', 51 | [[0, 0, 0], [5, 0, 0]], 52 | [[0, 0, 0], [9-1e-5, 0, 0]], 53 | [0.25, 0.25],), 54 | ('five_a_dist', 55 | [[0, 0, 0], [16-1e-5, 0, 0]], 56 | [[0, 0, 0], [11, 0, 0]], 57 | [0, 0]), 58 | ('no_pairs', 59 | [[0, 0, 0], [20, 0, 0]], 60 | [[0, 0, 0], [25-1e-5, 0, 0]], 61 | [1, 1]), 62 | ) 63 | def test_lddt( 64 | self, predicted_pos, true_pos, exp_lddt): 65 | predicted_pos = np.array([predicted_pos], dtype=np.float32) 66 | true_points_mask = np.array([[[1]] * len(true_pos)], dtype=np.float32) 67 | true_pos = np.array([true_pos], dtype=np.float32) 68 | cutoff = 15.0 69 | per_residue = True 70 | 71 | result = lddt.lddt( 72 | predicted_pos, true_pos, true_points_mask, cutoff, 73 | per_residue) 74 | 75 | np.testing.assert_almost_equal(result, [exp_lddt], decimal=4) 76 | 77 | 78 | if __name__ == '__main__': 79 | absltest.main() 80 | -------------------------------------------------------------------------------- /alphafold/model/mapping.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Specialized mapping functions.""" 16 | 17 | import functools 18 | 19 | from typing import Any, Callable, Optional, Sequence, Union 20 | 21 | import haiku as hk 22 | import jax 23 | import jax.numpy as jnp 24 | 25 | 26 | PYTREE = Any 27 | PYTREE_JAX_ARRAY = Any 28 | 29 | partial = functools.partial 30 | PROXY = object() 31 | 32 | 33 | def _maybe_slice(array, i, slice_size, axis): 34 | if axis is PROXY: 35 | return array 36 | else: 37 | return jax.lax.dynamic_slice_in_dim( 38 | array, i, slice_size=slice_size, axis=axis) 39 | 40 | 41 | def _maybe_get_size(array, axis): 42 | if axis == PROXY: 43 | return -1 44 | else: 45 | return array.shape[axis] 46 | 47 | 48 | def _expand_axes(axes, values, name='sharded_apply'): 49 | values_tree_def = jax.tree_flatten(values)[1] 50 | flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes) 51 | # Replace None's with PROXY 52 | flat_axes = [PROXY if x is None else x for x in flat_axes] 53 | return jax.tree_unflatten(values_tree_def, flat_axes) 54 | 55 | 56 | def sharded_map( 57 | fun: Callable[..., PYTREE_JAX_ARRAY], 58 | shard_size: Union[int, None] = 1, 59 | in_axes: Union[int, PYTREE] = 0, 60 | out_axes: Union[int, PYTREE] = 0) -> Callable[..., PYTREE_JAX_ARRAY]: 61 | """Sharded vmap. 62 | 63 | Maps `fun` over axes, in a way similar to vmap, but does so in shards of 64 | `shard_size`. This allows a smooth trade-off between memory usage 65 | (as in a plain map) vs higher throughput (as in a vmap). 66 | 67 | Args: 68 | fun: Function to apply smap transform to. 69 | shard_size: Integer denoting shard size. 70 | in_axes: Either integer or pytree describing which axis to map over for each 71 | input to `fun`, None denotes broadcasting. 72 | out_axes: integer or pytree denoting to what axis in the output the mapped 73 | over axis maps. 74 | 75 | Returns: 76 | function with smap applied. 77 | """ 78 | vmapped_fun = hk.vmap(fun, in_axes, out_axes) 79 | return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes) 80 | 81 | 82 | def sharded_apply( 83 | fun: Callable[..., PYTREE_JAX_ARRAY], # pylint: disable=g-bare-generic 84 | shard_size: Union[int, None] = 1, 85 | in_axes: Union[int, PYTREE] = 0, 86 | out_axes: Union[int, PYTREE] = 0, 87 | new_out_axes: bool = False) -> Callable[..., PYTREE_JAX_ARRAY]: 88 | """Sharded apply. 89 | 90 | Applies `fun` over shards to axes, in a way similar to vmap, 91 | but does so in shards of `shard_size`. Shards are stacked after. 92 | This allows a smooth trade-off between 93 | memory usage (as in a plain map) vs higher throughput (as in a vmap). 94 | 95 | Args: 96 | fun: Function to apply smap transform to. 97 | shard_size: Integer denoting shard size. 98 | in_axes: Either integer or pytree describing which axis to map over for each 99 | input to `fun`, None denotes broadcasting. 100 | out_axes: integer or pytree denoting to what axis in the output the mapped 101 | over axis maps. 102 | new_out_axes: whether to stack outputs on new axes. This assumes that the 103 | output sizes for each shard (including the possible remainder shard) are 104 | the same. 105 | 106 | Returns: 107 | function with smap applied. 108 | """ 109 | docstr = ('Mapped version of {fun}. Takes similar arguments to {fun} ' 110 | 'but with additional array axes over which {fun} is mapped.') 111 | if new_out_axes: 112 | raise NotImplementedError('New output axes not yet implemented.') 113 | 114 | # shard size None denotes no sharding 115 | if shard_size is None: 116 | return fun 117 | 118 | @jax.util.wraps(fun, docstr=docstr) 119 | def mapped_fn(*args): 120 | # Expand in axes and Determine Loop range 121 | in_axes_ = _expand_axes(in_axes, args) 122 | 123 | in_sizes = jax.tree_multimap(_maybe_get_size, args, in_axes_) 124 | flat_sizes = jax.tree_flatten(in_sizes)[0] 125 | in_size = max(flat_sizes) 126 | assert all(i in {in_size, -1} for i in flat_sizes) 127 | 128 | num_extra_shards = (in_size - 1) // shard_size 129 | 130 | # Fix Up if necessary 131 | last_shard_size = in_size % shard_size 132 | last_shard_size = shard_size if last_shard_size == 0 else last_shard_size 133 | 134 | def apply_fun_to_slice(slice_start, slice_size): 135 | input_slice = jax.tree_multimap( 136 | lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis 137 | ), args, in_axes_) 138 | return fun(*input_slice) 139 | 140 | remainder_shape_dtype = hk.eval_shape( 141 | partial(apply_fun_to_slice, 0, last_shard_size)) 142 | out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype) 143 | out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype) 144 | out_axes_ = _expand_axes(out_axes, remainder_shape_dtype) 145 | 146 | if num_extra_shards > 0: 147 | regular_shard_shape_dtype = hk.eval_shape( 148 | partial(apply_fun_to_slice, 0, shard_size)) 149 | shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype) 150 | 151 | def make_output_shape(axis, shard_shape, remainder_shape): 152 | return shard_shape[:axis] + ( 153 | shard_shape[axis] * num_extra_shards + 154 | remainder_shape[axis],) + shard_shape[axis + 1:] 155 | 156 | out_shapes = jax.tree_multimap(make_output_shape, out_axes_, shard_shapes, 157 | out_shapes) 158 | 159 | # Calls dynamic Update slice with different argument order 160 | # This is here since tree_multimap only works with positional arguments 161 | def dynamic_update_slice_in_dim(full_array, update, axis, i): 162 | return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis) 163 | 164 | def compute_shard(outputs, slice_start, slice_size): 165 | slice_out = apply_fun_to_slice(slice_start, slice_size) 166 | update_slice = partial( 167 | dynamic_update_slice_in_dim, i=slice_start) 168 | return jax.tree_multimap(update_slice, outputs, slice_out, out_axes_) 169 | 170 | def scan_iteration(outputs, i): 171 | new_outputs = compute_shard(outputs, i, shard_size) 172 | return new_outputs, () 173 | 174 | slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size) 175 | 176 | def allocate_buffer(dtype, shape): 177 | return jnp.zeros(shape, dtype=dtype) 178 | 179 | outputs = jax.tree_multimap(allocate_buffer, out_dtypes, out_shapes) 180 | 181 | if slice_starts.shape[0] > 0: 182 | outputs, _ = hk.scan(scan_iteration, outputs, slice_starts) 183 | 184 | if last_shard_size != shard_size: 185 | remainder_start = in_size - last_shard_size 186 | outputs = compute_shard(outputs, remainder_start, last_shard_size) 187 | 188 | return outputs 189 | 190 | return mapped_fn 191 | 192 | 193 | def inference_subbatch( 194 | module: Callable[..., PYTREE_JAX_ARRAY], 195 | subbatch_size: int, 196 | batched_args: Sequence[PYTREE_JAX_ARRAY], 197 | nonbatched_args: Sequence[PYTREE_JAX_ARRAY], 198 | low_memory: bool = True, 199 | input_subbatch_dim: int = 0, 200 | output_subbatch_dim: Optional[int] = None) -> PYTREE_JAX_ARRAY: 201 | """Run through subbatches (like batch apply but with split and concat).""" 202 | assert len(batched_args) > 0 # pylint: disable=g-explicit-length-test 203 | 204 | if not low_memory: 205 | args = list(batched_args) + list(nonbatched_args) 206 | return module(*args) 207 | 208 | if output_subbatch_dim is None: 209 | output_subbatch_dim = input_subbatch_dim 210 | 211 | def run_module(*batched_args): 212 | args = list(batched_args) + list(nonbatched_args) 213 | return module(*args) 214 | sharded_module = sharded_apply(run_module, 215 | shard_size=subbatch_size, 216 | in_axes=input_subbatch_dim, 217 | out_axes=output_subbatch_dim) 218 | return sharded_module(*batched_args) 219 | -------------------------------------------------------------------------------- /alphafold/model/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Code for constructing the model.""" 16 | from typing import Any, Mapping, Optional, Union 17 | 18 | from absl import logging 19 | import haiku as hk 20 | import jax 21 | import ml_collections 22 | import numpy as np 23 | import tensorflow.compat.v1 as tf 24 | import tree 25 | 26 | from alphafold.common import confidence 27 | from alphafold.model import features 28 | from alphafold.model import modules 29 | 30 | 31 | def get_confidence_metrics( 32 | prediction_result: Mapping[str, Any]) -> Mapping[str, Any]: 33 | """Post processes prediction_result to get confidence metrics.""" 34 | 35 | confidence_metrics = {} 36 | confidence_metrics['plddt'] = confidence.compute_plddt( 37 | prediction_result['predicted_lddt']['logits']) 38 | if 'predicted_aligned_error' in prediction_result: 39 | confidence_metrics.update(confidence.compute_predicted_aligned_error( 40 | prediction_result['predicted_aligned_error']['logits'], 41 | prediction_result['predicted_aligned_error']['breaks'])) 42 | confidence_metrics['ptm'] = confidence.predicted_tm_score( 43 | prediction_result['predicted_aligned_error']['logits'], 44 | prediction_result['predicted_aligned_error']['breaks']) 45 | 46 | return confidence_metrics 47 | 48 | 49 | class RunModel: 50 | """Container for JAX model.""" 51 | 52 | def __init__(self, 53 | config: ml_collections.ConfigDict, 54 | params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None): 55 | self.config = config 56 | self.params = params 57 | 58 | def _forward_fn(batch): 59 | model = modules.AlphaFold(self.config.model) 60 | return model( 61 | batch, 62 | is_training=False, 63 | compute_loss=False, 64 | ensemble_representations=True) 65 | 66 | self.apply = jax.jit(hk.transform(_forward_fn).apply) 67 | self.init = jax.jit(hk.transform(_forward_fn).init) 68 | 69 | def init_params(self, feat: features.FeatureDict, random_seed: int = 0): 70 | """Initializes the model parameters. 71 | 72 | If none were provided when this class was instantiated then the parameters 73 | are randomly initialized. 74 | 75 | Args: 76 | feat: A dictionary of NumPy feature arrays as output by 77 | RunModel.process_features. 78 | random_seed: A random seed to use to initialize the parameters if none 79 | were set when this class was initialized. 80 | """ 81 | if not self.params: 82 | # Init params randomly. 83 | rng = jax.random.PRNGKey(random_seed) 84 | self.params = hk.data_structures.to_mutable_dict( 85 | self.init(rng, feat)) 86 | logging.warning('Initialized parameters randomly') 87 | 88 | def process_features( 89 | self, 90 | raw_features: Union[tf.train.Example, features.FeatureDict], 91 | random_seed: int) -> features.FeatureDict: 92 | """Processes features to prepare for feeding them into the model. 93 | 94 | Args: 95 | raw_features: The output of the data pipeline either as a dict of NumPy 96 | arrays or as a tf.train.Example. 97 | random_seed: The random seed to use when processing the features. 98 | 99 | Returns: 100 | A dict of NumPy feature arrays suitable for feeding into the model. 101 | """ 102 | if isinstance(raw_features, dict): 103 | return features.np_example_to_features( 104 | np_example=raw_features, 105 | config=self.config, 106 | random_seed=random_seed) 107 | else: 108 | return features.tf_example_to_features( 109 | tf_example=raw_features, 110 | config=self.config, 111 | random_seed=random_seed) 112 | 113 | def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct: 114 | self.init_params(feat) 115 | logging.info('Running eval_shape with shape(feat) = %s', 116 | tree.map_structure(lambda x: x.shape, feat)) 117 | shape = jax.eval_shape(self.apply, self.params, jax.random.PRNGKey(0), feat) 118 | logging.info('Output shape was %s', shape) 119 | return shape 120 | 121 | def predict(self, feat: features.FeatureDict) -> Mapping[str, Any]: 122 | """Makes a prediction by inferencing the model on the provided features. 123 | 124 | Args: 125 | feat: A dictionary of NumPy feature arrays as output by 126 | RunModel.process_features. 127 | 128 | Returns: 129 | A dictionary of model outputs. 130 | """ 131 | self.init_params(feat) 132 | logging.info('Running predict with shape(feat) = %s', 133 | tree.map_structure(lambda x: x.shape, feat)) 134 | result = self.apply(self.params, jax.random.PRNGKey(0), feat) 135 | # This block is to ensure benchmark timings are accurate. Some blocking is 136 | # already happening when computing get_confidence_metrics, and this ensures 137 | # all outputs are blocked on. 138 | jax.tree_map(lambda x: x.block_until_ready(), result) 139 | result.update(get_confidence_metrics(result)) 140 | logging.info('Output shape was %s', 141 | tree.map_structure(lambda x: x.shape, result)) 142 | return result 143 | 144 | -------------------------------------------------------------------------------- /alphafold/model/prng.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A collection of utilities surrounding PRNG usage in protein folding.""" 16 | 17 | import haiku as hk 18 | import jax 19 | 20 | 21 | def safe_dropout(*, tensor, safe_key, rate, is_deterministic, is_training): 22 | if is_training and rate != 0.0 and not is_deterministic: 23 | return hk.dropout(safe_key.get(), rate, tensor) 24 | else: 25 | return tensor 26 | 27 | 28 | class SafeKey: 29 | """Safety wrapper for PRNG keys.""" 30 | 31 | def __init__(self, key): 32 | self._key = key 33 | self._used = False 34 | 35 | def _assert_not_used(self): 36 | if self._used: 37 | raise RuntimeError('Random key has been used previously.') 38 | 39 | def get(self): 40 | self._assert_not_used() 41 | self._used = True 42 | return self._key 43 | 44 | def split(self, num_keys=2): 45 | self._assert_not_used() 46 | self._used = True 47 | new_keys = jax.random.split(self._key, num_keys) 48 | return jax.tree_map(SafeKey, tuple(new_keys)) 49 | 50 | def duplicate(self, num_keys=2): 51 | self._assert_not_used() 52 | self._used = True 53 | return tuple(SafeKey(self._key) for _ in range(num_keys)) 54 | 55 | 56 | def _safe_key_flatten(safe_key): 57 | # Flatten transfers "ownership" to the tree 58 | return (safe_key._key,), safe_key._used # pylint: disable=protected-access 59 | 60 | 61 | def _safe_key_unflatten(aux_data, children): 62 | ret = SafeKey(children[0]) 63 | ret._used = aux_data # pylint: disable=protected-access 64 | return ret 65 | 66 | 67 | jax.tree_util.register_pytree_node( 68 | SafeKey, _safe_key_flatten, _safe_key_unflatten) 69 | 70 | -------------------------------------------------------------------------------- /alphafold/model/prng_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for prng.""" 16 | 17 | from absl.testing import absltest 18 | import jax 19 | 20 | from alphafold.model import prng 21 | 22 | 23 | class PrngTest(absltest.TestCase): 24 | 25 | def test_key_reuse(self): 26 | 27 | init_key = jax.random.PRNGKey(42) 28 | safe_key = prng.SafeKey(init_key) 29 | _, safe_key = safe_key.split() 30 | 31 | raw_key = safe_key.get() 32 | 33 | self.assertNotEqual(raw_key[0], init_key[0]) 34 | self.assertNotEqual(raw_key[1], init_key[1]) 35 | 36 | with self.assertRaises(RuntimeError): 37 | safe_key.get() 38 | 39 | with self.assertRaises(RuntimeError): 40 | safe_key.split() 41 | 42 | with self.assertRaises(RuntimeError): 43 | safe_key.duplicate() 44 | 45 | 46 | if __name__ == '__main__': 47 | absltest.main() 48 | -------------------------------------------------------------------------------- /alphafold/model/quat_affine_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for quat_affine.""" 16 | 17 | from absl import logging 18 | from absl.testing import absltest 19 | import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | from alphafold.model import quat_affine 23 | 24 | VERBOSE = False 25 | np.set_printoptions(precision=3, suppress=True) 26 | 27 | r2t = quat_affine.rot_list_to_tensor 28 | v2t = quat_affine.vec_list_to_tensor 29 | 30 | q2r = lambda q: r2t(quat_affine.quat_to_rot(q)) 31 | 32 | 33 | class QuatAffineTest(absltest.TestCase): 34 | 35 | def _assert_check(self, to_check, tol=1e-5): 36 | for k, (correct, generated) in to_check.items(): 37 | if VERBOSE: 38 | logging.info(k) 39 | logging.info('Correct %s', correct) 40 | logging.info('Predicted %s', generated) 41 | self.assertLess(np.max(np.abs(correct - generated)), tol) 42 | 43 | def test_conversion(self): 44 | quat = jnp.array([-2., 5., -1., 4.]) 45 | 46 | rotation = jnp.array([ 47 | [0.26087, 0.130435, 0.956522], 48 | [-0.565217, -0.782609, 0.26087], 49 | [0.782609, -0.608696, -0.130435]]) 50 | 51 | translation = jnp.array([1., -3., 4.]) 52 | point = jnp.array([0.7, 3.2, -2.9]) 53 | 54 | a = quat_affine.QuatAffine(quat, translation, unstack_inputs=True) 55 | true_new_point = jnp.matmul(rotation, point[:, None])[:, 0] + translation 56 | 57 | self._assert_check({ 58 | 'rot': (rotation, r2t(a.rotation)), 59 | 'trans': (translation, v2t(a.translation)), 60 | 'point': (true_new_point, 61 | v2t(a.apply_to_point(jnp.moveaxis(point, -1, 0)))), 62 | # Because of the double cover, we must be careful and compare rotations 63 | 'quat': (q2r(a.quaternion), 64 | q2r(quat_affine.rot_to_quat(a.rotation))), 65 | 66 | }) 67 | 68 | def test_double_cover(self): 69 | """Test that -q is the same rotation as q.""" 70 | rng = jax.random.PRNGKey(42) 71 | keys = jax.random.split(rng) 72 | q = jax.random.normal(keys[0], (2, 4)) 73 | trans = jax.random.normal(keys[1], (2, 3)) 74 | a1 = quat_affine.QuatAffine(q, trans, unstack_inputs=True) 75 | a2 = quat_affine.QuatAffine(-q, trans, unstack_inputs=True) 76 | 77 | self._assert_check({ 78 | 'rot': (r2t(a1.rotation), 79 | r2t(a2.rotation)), 80 | 'trans': (v2t(a1.translation), 81 | v2t(a2.translation)), 82 | }) 83 | 84 | def test_homomorphism(self): 85 | rng = jax.random.PRNGKey(42) 86 | keys = jax.random.split(rng, 4) 87 | vec_q1 = jax.random.normal(keys[0], (2, 3)) 88 | 89 | q1 = jnp.concatenate([ 90 | jnp.ones_like(vec_q1)[:, :1], 91 | vec_q1], axis=-1) 92 | 93 | q2 = jax.random.normal(keys[1], (2, 4)) 94 | t1 = jax.random.normal(keys[2], (2, 3)) 95 | t2 = jax.random.normal(keys[3], (2, 3)) 96 | 97 | a1 = quat_affine.QuatAffine(q1, t1, unstack_inputs=True) 98 | a2 = quat_affine.QuatAffine(q2, t2, unstack_inputs=True) 99 | a21 = a2.pre_compose(jnp.concatenate([vec_q1, t1], axis=-1)) 100 | 101 | rng, key = jax.random.split(rng) 102 | x = jax.random.normal(key, (2, 3)) 103 | new_x = a21.apply_to_point(jnp.moveaxis(x, -1, 0)) 104 | new_x_apply2 = a2.apply_to_point(a1.apply_to_point(jnp.moveaxis(x, -1, 0))) 105 | 106 | self._assert_check({ 107 | 'quat': (q2r(quat_affine.quat_multiply(a2.quaternion, a1.quaternion)), 108 | q2r(a21.quaternion)), 109 | 'rot': (jnp.matmul(r2t(a2.rotation), r2t(a1.rotation)), 110 | r2t(a21.rotation)), 111 | 'point': (v2t(new_x_apply2), 112 | v2t(new_x)), 113 | 'inverse': (x, v2t(a21.invert_point(new_x))), 114 | }) 115 | 116 | def test_batching(self): 117 | """Test that affine applies batchwise.""" 118 | rng = jax.random.PRNGKey(42) 119 | keys = jax.random.split(rng, 3) 120 | q = jax.random.uniform(keys[0], (5, 2, 4)) 121 | t = jax.random.uniform(keys[1], (2, 3)) 122 | x = jax.random.uniform(keys[2], (5, 1, 3)) 123 | 124 | a = quat_affine.QuatAffine(q, t, unstack_inputs=True) 125 | y = v2t(a.apply_to_point(jnp.moveaxis(x, -1, 0))) 126 | 127 | y_list = [] 128 | for i in range(5): 129 | for j in range(2): 130 | a_local = quat_affine.QuatAffine(q[i, j], t[j], 131 | unstack_inputs=True) 132 | y_local = v2t(a_local.apply_to_point(jnp.moveaxis(x[i, 0], -1, 0))) 133 | y_list.append(y_local) 134 | y_combine = jnp.reshape(jnp.stack(y_list, axis=0), (5, 2, 3)) 135 | 136 | self._assert_check({ 137 | 'batch': (y_combine, y), 138 | 'quat': (q2r(a.quaternion), 139 | q2r(quat_affine.rot_to_quat(a.rotation))), 140 | }) 141 | 142 | def assertAllClose(self, a, b, rtol=1e-06, atol=1e-06): 143 | self.assertTrue(np.allclose(a, b, rtol=rtol, atol=atol)) 144 | 145 | def assertAllEqual(self, a, b): 146 | self.assertTrue(np.all(np.array(a) == np.array(b))) 147 | 148 | 149 | if __name__ == '__main__': 150 | absltest.main() 151 | -------------------------------------------------------------------------------- /alphafold/model/tf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Alphafold model TensorFlow code.""" 15 | -------------------------------------------------------------------------------- /alphafold/model/tf/input_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Feature pre-processing input pipeline for AlphaFold.""" 16 | import tensorflow.compat.v1 as tf 17 | import tree 18 | 19 | from alphafold.model.tf import data_transforms 20 | from alphafold.model.tf import shape_placeholders 21 | 22 | # Pylint gets confused by the curry1 decorator because it changes the number 23 | # of arguments to the function. 24 | # pylint:disable=no-value-for-parameter 25 | 26 | 27 | NUM_RES = shape_placeholders.NUM_RES 28 | NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ 29 | NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ 30 | NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES 31 | 32 | 33 | def nonensembled_map_fns(data_config): 34 | """Input pipeline functions which are not ensembled.""" 35 | common_cfg = data_config.common 36 | 37 | map_fns = [ 38 | data_transforms.correct_msa_restypes, 39 | data_transforms.add_distillation_flag(False), 40 | data_transforms.cast_64bit_ints, 41 | data_transforms.squeeze_features, 42 | # Keep to not disrupt RNG. 43 | data_transforms.randomly_replace_msa_with_unknown(0.0), 44 | data_transforms.make_seq_mask, 45 | data_transforms.make_msa_mask, 46 | # Compute the HHblits profile if it's not set. This has to be run before 47 | # sampling the MSA. 48 | data_transforms.make_hhblits_profile, 49 | data_transforms.make_random_crop_to_size_seed, 50 | ] 51 | if common_cfg.use_templates: 52 | map_fns.extend([ 53 | data_transforms.fix_templates_aatype, 54 | data_transforms.make_template_mask, 55 | data_transforms.make_pseudo_beta('template_') 56 | ]) 57 | map_fns.extend([ 58 | data_transforms.make_atom14_masks, 59 | ]) 60 | 61 | return map_fns 62 | 63 | 64 | def ensembled_map_fns(data_config): 65 | """Input pipeline functions that can be ensembled and averaged.""" 66 | common_cfg = data_config.common 67 | eval_cfg = data_config.eval 68 | 69 | map_fns = [] 70 | 71 | if common_cfg.reduce_msa_clusters_by_max_templates: 72 | pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates 73 | else: 74 | pad_msa_clusters = eval_cfg.max_msa_clusters 75 | 76 | max_msa_clusters = pad_msa_clusters 77 | max_extra_msa = common_cfg.max_extra_msa 78 | 79 | map_fns.append( 80 | data_transforms.sample_msa( 81 | max_msa_clusters, 82 | keep_extra=True)) 83 | 84 | if 'masked_msa' in common_cfg: 85 | # Masked MSA should come *before* MSA clustering so that 86 | # the clustering and full MSA profile do not leak information about 87 | # the masked locations and secret corrupted locations. 88 | map_fns.append( 89 | data_transforms.make_masked_msa(common_cfg.masked_msa, 90 | eval_cfg.masked_msa_replace_fraction)) 91 | 92 | if common_cfg.msa_cluster_features: 93 | map_fns.append(data_transforms.nearest_neighbor_clusters()) 94 | map_fns.append(data_transforms.summarize_clusters()) 95 | 96 | # Crop after creating the cluster profiles. 97 | if max_extra_msa: 98 | map_fns.append(data_transforms.crop_extra_msa(max_extra_msa)) 99 | else: 100 | map_fns.append(data_transforms.delete_extra_msa) 101 | 102 | map_fns.append(data_transforms.make_msa_feat()) 103 | 104 | crop_feats = dict(eval_cfg.feat) 105 | 106 | if eval_cfg.fixed_size: 107 | map_fns.append(data_transforms.select_feat(list(crop_feats))) 108 | map_fns.append(data_transforms.random_crop_to_size( 109 | eval_cfg.crop_size, 110 | eval_cfg.max_templates, 111 | crop_feats, 112 | eval_cfg.subsample_templates)) 113 | map_fns.append(data_transforms.make_fixed_size( 114 | crop_feats, 115 | pad_msa_clusters, 116 | common_cfg.max_extra_msa, 117 | eval_cfg.crop_size, 118 | eval_cfg.max_templates)) 119 | else: 120 | map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates)) 121 | 122 | return map_fns 123 | 124 | 125 | def process_tensors_from_config(tensors, data_config): 126 | """Apply filters and maps to an existing dataset, based on the config.""" 127 | 128 | def wrap_ensemble_fn(data, i): 129 | """Function to be mapped over the ensemble dimension.""" 130 | d = data.copy() 131 | fns = ensembled_map_fns(data_config) 132 | fn = compose(fns) 133 | d['ensemble_index'] = i 134 | return fn(d) 135 | 136 | eval_cfg = data_config.eval 137 | tensors = compose( 138 | nonensembled_map_fns( 139 | data_config))( 140 | tensors) 141 | 142 | tensors_0 = wrap_ensemble_fn(tensors, tf.constant(0)) 143 | num_ensemble = eval_cfg.num_ensemble 144 | if data_config.common.resample_msa_in_recycling: 145 | # Separate batch per ensembling & recycling step. 146 | num_ensemble *= data_config.common.num_recycle + 1 147 | 148 | if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1: 149 | fn_output_signature = tree.map_structure( 150 | tf.TensorSpec.from_tensor, tensors_0) 151 | tensors = tf.map_fn( 152 | lambda x: wrap_ensemble_fn(tensors, x), 153 | tf.range(num_ensemble), 154 | parallel_iterations=1, 155 | fn_output_signature=fn_output_signature) 156 | else: 157 | tensors = tree.map_structure(lambda x: x[None], 158 | tensors_0) 159 | return tensors 160 | 161 | 162 | @data_transforms.curry1 163 | def compose(x, fs): 164 | for f in fs: 165 | x = f(x) 166 | return x 167 | -------------------------------------------------------------------------------- /alphafold/model/tf/protein_features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Contains descriptions of various protein features.""" 16 | import enum 17 | from typing import Dict, Optional, Sequence, Tuple, Union 18 | 19 | import tensorflow.compat.v1 as tf 20 | 21 | from alphafold.common import residue_constants 22 | 23 | # Type aliases. 24 | FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]] 25 | 26 | 27 | class FeatureType(enum.Enum): 28 | ZERO_DIM = 0 # Shape [x] 29 | ONE_DIM = 1 # Shape [num_res, x] 30 | TWO_DIM = 2 # Shape [num_res, num_res, x] 31 | MSA = 3 # Shape [msa_length, num_res, x] 32 | 33 | 34 | # Placeholder values that will be replaced with their true value at runtime. 35 | NUM_RES = "num residues placeholder" 36 | NUM_SEQ = "length msa placeholder" 37 | NUM_TEMPLATES = "num templates placeholder" 38 | # Sizes of the protein features, NUM_RES and NUM_SEQ are allowed as placeholders 39 | # to be replaced with the number of residues and the number of sequences in the 40 | # multiple sequence alignment, respectively. 41 | 42 | 43 | FEATURES = { 44 | #### Static features of a protein sequence #### 45 | "aatype": (tf.float32, [NUM_RES, 21]), 46 | "between_segment_residues": (tf.int64, [NUM_RES, 1]), 47 | "deletion_matrix": (tf.float32, [NUM_SEQ, NUM_RES, 1]), 48 | "domain_name": (tf.string, [1]), 49 | "msa": (tf.int64, [NUM_SEQ, NUM_RES, 1]), 50 | "num_alignments": (tf.int64, [NUM_RES, 1]), 51 | "residue_index": (tf.int64, [NUM_RES, 1]), 52 | "seq_length": (tf.int64, [NUM_RES, 1]), 53 | "sequence": (tf.string, [1]), 54 | "all_atom_positions": (tf.float32, 55 | [NUM_RES, residue_constants.atom_type_num, 3]), 56 | "all_atom_mask": (tf.int64, [NUM_RES, residue_constants.atom_type_num]), 57 | "resolution": (tf.float32, [1]), 58 | "template_domain_names": (tf.string, [NUM_TEMPLATES]), 59 | "template_sum_probs": (tf.float32, [NUM_TEMPLATES, 1]), 60 | "template_aatype": (tf.float32, [NUM_TEMPLATES, NUM_RES, 22]), 61 | "template_all_atom_positions": (tf.float32, [ 62 | NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3 63 | ]), 64 | "template_all_atom_masks": (tf.float32, [ 65 | NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1 66 | ]), 67 | } 68 | 69 | FEATURE_TYPES = {k: v[0] for k, v in FEATURES.items()} 70 | FEATURE_SIZES = {k: v[1] for k, v in FEATURES.items()} 71 | 72 | 73 | def register_feature(name: str, 74 | type_: tf.dtypes.DType, 75 | shape_: Tuple[Union[str, int]]): 76 | """Register extra features used in custom datasets.""" 77 | FEATURES[name] = (type_, shape_) 78 | FEATURE_TYPES[name] = type_ 79 | FEATURE_SIZES[name] = shape_ 80 | 81 | 82 | def shape(feature_name: str, 83 | num_residues: int, 84 | msa_length: int, 85 | num_templates: Optional[int] = None, 86 | features: Optional[FeaturesMetadata] = None): 87 | """Get the shape for the given feature name. 88 | 89 | This is near identical to _get_tf_shape_no_placeholders() but with 2 90 | differences: 91 | * This method does not calculate a single placeholder from the total number of 92 | elements (eg given and size := 12, this won't deduce NUM_RES 93 | must be 4) 94 | * This method will work with tensors 95 | 96 | Args: 97 | feature_name: String identifier for the feature. If the feature name ends 98 | with "_unnormalized", theis suffix is stripped off. 99 | num_residues: The number of residues in the current domain - some elements 100 | of the shape can be dynamic and will be replaced by this value. 101 | msa_length: The number of sequences in the multiple sequence alignment, some 102 | elements of the shape can be dynamic and will be replaced by this value. 103 | If the number of alignments is unknown / not read, please pass None for 104 | msa_length. 105 | num_templates (optional): The number of templates in this tfexample. 106 | features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES. 107 | 108 | Returns: 109 | List of ints representation the tensor size. 110 | 111 | Raises: 112 | ValueError: If a feature is requested but no concrete placeholder value is 113 | given. 114 | """ 115 | features = features or FEATURES 116 | if feature_name.endswith("_unnormalized"): 117 | feature_name = feature_name[:-13] 118 | 119 | unused_dtype, raw_sizes = features[feature_name] 120 | replacements = {NUM_RES: num_residues, 121 | NUM_SEQ: msa_length} 122 | 123 | if num_templates is not None: 124 | replacements[NUM_TEMPLATES] = num_templates 125 | 126 | sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes] 127 | for dimension in sizes: 128 | if isinstance(dimension, str): 129 | raise ValueError("Could not parse %s (shape: %s) with values: %s" % ( 130 | feature_name, raw_sizes, replacements)) 131 | return sizes 132 | 133 | -------------------------------------------------------------------------------- /alphafold/model/tf/protein_features_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for protein_features.""" 16 | import uuid 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import tensorflow.compat.v1 as tf 21 | 22 | from alphafold.model.tf import protein_features 23 | 24 | 25 | def _random_bytes(): 26 | return str(uuid.uuid4()).encode('utf-8') 27 | 28 | 29 | class FeaturesTest(parameterized.TestCase, tf.test.TestCase): 30 | 31 | def testFeatureNames(self): 32 | self.assertEqual(len(protein_features.FEATURE_SIZES), 33 | len(protein_features.FEATURE_TYPES)) 34 | sorted_size_names = sorted(protein_features.FEATURE_SIZES.keys()) 35 | sorted_type_names = sorted(protein_features.FEATURE_TYPES.keys()) 36 | for i, size_name in enumerate(sorted_size_names): 37 | self.assertEqual(size_name, sorted_type_names[i]) 38 | 39 | def testReplacement(self): 40 | for name in protein_features.FEATURE_SIZES.keys(): 41 | sizes = protein_features.shape(name, 42 | num_residues=12, 43 | msa_length=24, 44 | num_templates=3) 45 | for x in sizes: 46 | self.assertEqual(type(x), int) 47 | self.assertGreater(x, 0) 48 | 49 | 50 | if __name__ == '__main__': 51 | tf.disable_v2_behavior() 52 | absltest.main() 53 | -------------------------------------------------------------------------------- /alphafold/model/tf/proteins_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Datasets consisting of proteins.""" 16 | from typing import Dict, Mapping, Optional, Sequence 17 | 18 | import numpy as np 19 | import tensorflow.compat.v1 as tf 20 | 21 | from alphafold.model.tf import protein_features 22 | 23 | TensorDict = Dict[str, tf.Tensor] 24 | 25 | 26 | def parse_tfexample( 27 | raw_data: bytes, 28 | features: protein_features.FeaturesMetadata, 29 | key: Optional[str] = None) -> Dict[str, tf.train.Feature]: 30 | """Read a single TF Example proto and return a subset of its features. 31 | 32 | Args: 33 | raw_data: A serialized tf.Example proto. 34 | features: A dictionary of features, mapping string feature names to a tuple 35 | (dtype, shape). This dictionary should be a subset of 36 | protein_features.FEATURES (or the dictionary itself for all features). 37 | key: Optional string with the SSTable key of that tf.Example. This will be 38 | added into features as a 'key' but only if requested in features. 39 | 40 | Returns: 41 | A dictionary of features mapping feature names to features. Only the given 42 | features are returned, all other ones are filtered out. 43 | """ 44 | feature_map = { 45 | k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True) 46 | for k, v in features.items() 47 | } 48 | parsed_features = tf.io.parse_single_example(raw_data, feature_map) 49 | reshaped_features = parse_reshape_logic(parsed_features, features, key=key) 50 | 51 | return reshaped_features 52 | 53 | 54 | def _first(tensor: tf.Tensor) -> tf.Tensor: 55 | """Returns the 1st element - the input can be a tensor or a scalar.""" 56 | return tf.reshape(tensor, shape=(-1,))[0] 57 | 58 | 59 | def parse_reshape_logic( 60 | parsed_features: TensorDict, 61 | features: protein_features.FeaturesMetadata, 62 | key: Optional[str] = None) -> TensorDict: 63 | """Transforms parsed serial features to the correct shape.""" 64 | # Find out what is the number of sequences and the number of alignments. 65 | num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32) 66 | 67 | if "num_alignments" in parsed_features: 68 | num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32) 69 | else: 70 | num_msa = 0 71 | 72 | if "template_domain_names" in parsed_features: 73 | num_templates = tf.cast( 74 | tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32) 75 | else: 76 | num_templates = 0 77 | 78 | if key is not None and "key" in features: 79 | parsed_features["key"] = [key] # Expand dims from () to (1,). 80 | 81 | # Reshape the tensors according to the sequence length and num alignments. 82 | for k, v in parsed_features.items(): 83 | new_shape = protein_features.shape( 84 | feature_name=k, 85 | num_residues=num_residues, 86 | msa_length=num_msa, 87 | num_templates=num_templates, 88 | features=features) 89 | new_shape_size = tf.constant(1, dtype=tf.int32) 90 | for dim in new_shape: 91 | new_shape_size *= tf.cast(dim, tf.int32) 92 | 93 | assert_equal = tf.assert_equal( 94 | tf.size(v), new_shape_size, 95 | name="assert_%s_shape_correct" % k, 96 | message="The size of feature %s (%s) could not be reshaped " 97 | "into %s" % (k, tf.size(v), new_shape)) 98 | if "template" not in k: 99 | # Make sure the feature we are reshaping is not empty. 100 | assert_non_empty = tf.assert_greater( 101 | tf.size(v), 0, name="assert_%s_non_empty" % k, 102 | message="The feature %s is not set in the tf.Example. Either do not " 103 | "request the feature or use a tf.Example that has the " 104 | "feature set." % k) 105 | with tf.control_dependencies([assert_non_empty, assert_equal]): 106 | parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) 107 | else: 108 | with tf.control_dependencies([assert_equal]): 109 | parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) 110 | 111 | return parsed_features 112 | 113 | 114 | def _make_features_metadata( 115 | feature_names: Sequence[str]) -> protein_features.FeaturesMetadata: 116 | """Makes a feature name to type and shape mapping from a list of names.""" 117 | # Make sure these features are always read. 118 | required_features = ["aatype", "sequence", "seq_length"] 119 | feature_names = list(set(feature_names) | set(required_features)) 120 | 121 | features_metadata = {name: protein_features.FEATURES[name] 122 | for name in feature_names} 123 | return features_metadata 124 | 125 | 126 | def create_tensor_dict( 127 | raw_data: bytes, 128 | features: Sequence[str], 129 | key: Optional[str] = None, 130 | ) -> TensorDict: 131 | """Creates a dictionary of tensor features. 132 | 133 | Args: 134 | raw_data: A serialized tf.Example proto. 135 | features: A list of strings of feature names to be returned in the dataset. 136 | key: Optional string with the SSTable key of that tf.Example. This will be 137 | added into features as a 'key' but only if requested in features. 138 | 139 | Returns: 140 | A dictionary of features mapping feature names to features. Only the given 141 | features are returned, all other ones are filtered out. 142 | """ 143 | features_metadata = _make_features_metadata(features) 144 | return parse_tfexample(raw_data, features_metadata, key) 145 | 146 | 147 | def np_to_tensor_dict( 148 | np_example: Mapping[str, np.ndarray], 149 | features: Sequence[str], 150 | ) -> TensorDict: 151 | """Creates dict of tensors from a dict of NumPy arrays. 152 | 153 | Args: 154 | np_example: A dict of NumPy feature arrays. 155 | features: A list of strings of feature names to be returned in the dataset. 156 | 157 | Returns: 158 | A dictionary of features mapping feature names to features. Only the given 159 | features are returned, all other ones are filtered out. 160 | """ 161 | features_metadata = _make_features_metadata(features) 162 | tensor_dict = {k: tf.constant(v) for k, v in np_example.items() 163 | if k in features_metadata} 164 | 165 | # Ensures shapes are as expected. Needed for setting size of empty features 166 | # e.g. when no template hits were found. 167 | tensor_dict = parse_reshape_logic(tensor_dict, features_metadata) 168 | return tensor_dict 169 | -------------------------------------------------------------------------------- /alphafold/model/tf/shape_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for dealing with shapes of TensorFlow tensors.""" 16 | import tensorflow.compat.v1 as tf 17 | 18 | 19 | def shape_list(x): 20 | """Return list of dimensions of a tensor, statically where possible. 21 | 22 | Like `x.shape.as_list()` but with tensors instead of `None`s. 23 | 24 | Args: 25 | x: A tensor. 26 | Returns: 27 | A list with length equal to the rank of the tensor. The n-th element of the 28 | list is an integer when that dimension is statically known otherwise it is 29 | the n-th element of `tf.shape(x)`. 30 | """ 31 | x = tf.convert_to_tensor(x) 32 | 33 | # If unknown rank, return dynamic shape 34 | if x.get_shape().dims is None: 35 | return tf.shape(x) 36 | 37 | static = x.get_shape().as_list() 38 | shape = tf.shape(x) 39 | 40 | ret = [] 41 | for i in range(len(static)): 42 | dim = static[i] 43 | if dim is None: 44 | dim = shape[i] 45 | ret.append(dim) 46 | return ret 47 | 48 | -------------------------------------------------------------------------------- /alphafold/model/tf/shape_helpers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for shape_helpers.""" 16 | 17 | import numpy as np 18 | import tensorflow.compat.v1 as tf 19 | 20 | from alphafold.model.tf import shape_helpers 21 | 22 | 23 | class ShapeTest(tf.test.TestCase): 24 | 25 | def test_shape_list(self): 26 | """Test that shape_list can allow for reshaping to dynamic shapes.""" 27 | a = tf.zeros([10, 4, 4, 2]) 28 | p = tf.placeholder(tf.float32, shape=[None, None, 1, 4, 4]) 29 | shape_dyn = shape_helpers.shape_list(p)[:2] + [4, 4] 30 | 31 | b = tf.reshape(a, shape_dyn) 32 | with self.session() as sess: 33 | out = sess.run(b, feed_dict={p: np.ones((20, 1, 1, 4, 4))}) 34 | 35 | self.assertAllEqual(out.shape, (20, 1, 4, 4)) 36 | 37 | 38 | if __name__ == '__main__': 39 | tf.disable_v2_behavior() 40 | tf.test.main() 41 | -------------------------------------------------------------------------------- /alphafold/model/tf/shape_placeholders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Placeholder values for run-time varying dimension sizes.""" 16 | 17 | NUM_RES = 'num residues placeholder' 18 | NUM_MSA_SEQ = 'msa placeholder' 19 | NUM_EXTRA_SEQ = 'extra msa placeholder' 20 | NUM_TEMPLATES = 'num templates placeholder' 21 | -------------------------------------------------------------------------------- /alphafold/model/tf/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Shared utilities for various components.""" 16 | import tensorflow.compat.v1 as tf 17 | 18 | 19 | def tf_combine_mask(*masks): 20 | """Take the intersection of float-valued masks.""" 21 | ret = 1 22 | for m in masks: 23 | ret *= m 24 | return ret 25 | 26 | 27 | class SeedMaker(object): 28 | """Return unique seeds.""" 29 | 30 | def __init__(self, initial_seed=0): 31 | self.next_seed = initial_seed 32 | 33 | def __call__(self): 34 | i = self.next_seed 35 | self.next_seed += 1 36 | return i 37 | 38 | seed_maker = SeedMaker() 39 | 40 | 41 | def make_random_seed(): 42 | return tf.random.uniform([2], 43 | tf.int32.min, 44 | tf.int32.max, 45 | tf.int32, 46 | seed=seed_maker()) 47 | 48 | -------------------------------------------------------------------------------- /alphafold/model/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A collection of JAX utility functions for use in protein folding.""" 16 | 17 | import collections 18 | import numbers 19 | from typing import Mapping 20 | 21 | import haiku as hk 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | 27 | def final_init(config): 28 | if config.zero_init: 29 | return 'zeros' 30 | else: 31 | return 'linear' 32 | 33 | 34 | def batched_gather(params, indices, axis=0, batch_dims=0): 35 | """Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`.""" 36 | take_fn = lambda p, i: jnp.take(p, i, axis=axis) 37 | for _ in range(batch_dims): 38 | take_fn = jax.vmap(take_fn) 39 | return take_fn(params, indices) 40 | 41 | 42 | def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10): 43 | """Masked mean.""" 44 | if drop_mask_channel: 45 | mask = mask[..., 0] 46 | 47 | mask_shape = mask.shape 48 | value_shape = value.shape 49 | 50 | assert len(mask_shape) == len(value_shape) 51 | 52 | if isinstance(axis, numbers.Integral): 53 | axis = [axis] 54 | elif axis is None: 55 | axis = list(range(len(mask_shape))) 56 | assert isinstance(axis, collections.Iterable), ( 57 | 'axis needs to be either an iterable, integer or "None"') 58 | 59 | broadcast_factor = 1. 60 | for axis_ in axis: 61 | value_size = value_shape[axis_] 62 | mask_size = mask_shape[axis_] 63 | if mask_size == 1: 64 | broadcast_factor *= value_size 65 | else: 66 | assert mask_size == value_size 67 | 68 | return (jnp.sum(mask * value, axis=axis) / 69 | (jnp.sum(mask, axis=axis) * broadcast_factor + eps)) 70 | 71 | 72 | def flat_params_to_haiku(params: Mapping[str, np.ndarray]) -> hk.Params: 73 | """Convert a dictionary of NumPy arrays to Haiku parameters.""" 74 | hk_params = {} 75 | for path, array in params.items(): 76 | scope, name = path.split('//') 77 | if scope not in hk_params: 78 | hk_params[scope] = {} 79 | hk_params[scope][name] = jnp.array(array) 80 | 81 | return hk_params 82 | -------------------------------------------------------------------------------- /alphafold/relax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Amber relaxation.""" 15 | -------------------------------------------------------------------------------- /alphafold/relax/amber_minimize_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for amber_minimize.""" 16 | import os 17 | 18 | from absl.testing import absltest 19 | import numpy as np 20 | 21 | from alphafold.common import protein 22 | from alphafold.relax import amber_minimize 23 | # Internal import (7716). 24 | 25 | 26 | def _load_test_protein(data_path): 27 | pdb_path = os.path.join(absltest.get_default_test_srcdir(), data_path) 28 | with open(pdb_path, 'r') as f: 29 | return protein.from_pdb_string(f.read()) 30 | 31 | 32 | class AmberMinimizeTest(absltest.TestCase): 33 | 34 | def test_multiple_disulfides_target(self): 35 | prot = _load_test_protein( 36 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb' 37 | ) 38 | ret = amber_minimize.run_pipeline(prot, max_iterations=10, max_attempts=1, 39 | stiffness=10.) 40 | self.assertIn('opt_time', ret) 41 | self.assertIn('min_attempts', ret) 42 | 43 | def test_raises_invalid_protein_assertion(self): 44 | prot = _load_test_protein( 45 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb' 46 | ) 47 | prot.atom_mask[4, :] = 0 48 | with self.assertRaisesRegex( 49 | ValueError, 50 | 'Amber minimization can only be performed on proteins with well-defined' 51 | ' residues. This protein contains at least one residue with no atoms.'): 52 | amber_minimize.run_pipeline(prot, max_iterations=10, 53 | stiffness=1., 54 | max_attempts=1) 55 | 56 | def test_iterative_relax(self): 57 | prot = _load_test_protein( 58 | 'alphafold/relax/testdata/with_violations.pdb' 59 | ) 60 | violations = amber_minimize.get_violation_metrics(prot) 61 | self.assertGreater(violations['num_residue_violations'], 0) 62 | out = amber_minimize.run_pipeline( 63 | prot=prot, max_outer_iterations=10, stiffness=10.) 64 | self.assertLess(out['efinal'], out['einit']) 65 | self.assertEqual(0, out['num_residue_violations']) 66 | 67 | def test_find_violations(self): 68 | prot = _load_test_protein( 69 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb' 70 | ) 71 | viols, _ = amber_minimize.find_violations(prot) 72 | 73 | expected_between_residues_connection_mask = np.zeros((191,), np.float32) 74 | for residue in (42, 43, 59, 60, 135, 136): 75 | expected_between_residues_connection_mask[residue] = 1.0 76 | 77 | expected_clash_indices = np.array([ 78 | [8, 4], 79 | [8, 5], 80 | [13, 3], 81 | [14, 1], 82 | [14, 4], 83 | [26, 4], 84 | [26, 5], 85 | [31, 8], 86 | [31, 10], 87 | [39, 0], 88 | [39, 1], 89 | [39, 2], 90 | [39, 3], 91 | [39, 4], 92 | [42, 5], 93 | [42, 6], 94 | [42, 7], 95 | [42, 8], 96 | [47, 7], 97 | [47, 8], 98 | [47, 9], 99 | [47, 10], 100 | [64, 4], 101 | [85, 5], 102 | [102, 4], 103 | [102, 5], 104 | [109, 13], 105 | [111, 5], 106 | [118, 6], 107 | [118, 7], 108 | [118, 8], 109 | [124, 4], 110 | [124, 5], 111 | [131, 5], 112 | [139, 7], 113 | [147, 4], 114 | [152, 7]], dtype=np.int32) 115 | expected_between_residues_clash_mask = np.zeros([191, 14]) 116 | expected_between_residues_clash_mask[expected_clash_indices[:, 0], 117 | expected_clash_indices[:, 1]] += 1 118 | expected_per_atom_violations = np.zeros([191, 14]) 119 | np.testing.assert_array_equal( 120 | viols['between_residues']['connections_per_residue_violation_mask'], 121 | expected_between_residues_connection_mask) 122 | np.testing.assert_array_equal( 123 | viols['between_residues']['clashes_per_atom_clash_mask'], 124 | expected_between_residues_clash_mask) 125 | np.testing.assert_array_equal( 126 | viols['within_residues']['per_atom_violations'], 127 | expected_per_atom_violations) 128 | 129 | 130 | if __name__ == '__main__': 131 | absltest.main() 132 | -------------------------------------------------------------------------------- /alphafold/relax/cleanup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations. 16 | 17 | fix_pdb uses a third-party tool. We also support fixing some additional edge 18 | cases like removing chains of length one (see clean_structure). 19 | """ 20 | import io 21 | 22 | import pdbfixer 23 | from simtk.openmm import app 24 | from simtk.openmm.app import element 25 | 26 | 27 | def fix_pdb(pdbfile, alterations_info): 28 | """Apply pdbfixer to the contents of a PDB file; return a PDB string result. 29 | 30 | 1) Replaces nonstandard residues. 31 | 2) Removes heterogens (non protein residues) including water. 32 | 3) Adds missing residues and missing atoms within existing residues. 33 | 4) Adds hydrogens assuming pH=7.0. 34 | 5) KeepIds is currently true, so the fixer must keep the existing chain and 35 | residue identifiers. This will fail for some files in wider PDB that have 36 | invalid IDs. 37 | 38 | Args: 39 | pdbfile: Input PDB file handle. 40 | alterations_info: A dict that will store details of changes made. 41 | 42 | Returns: 43 | A PDB string representing the fixed structure. 44 | """ 45 | fixer = pdbfixer.PDBFixer(pdbfile=pdbfile) 46 | fixer.findNonstandardResidues() 47 | alterations_info['nonstandard_residues'] = fixer.nonstandardResidues 48 | fixer.replaceNonstandardResidues() 49 | _remove_heterogens(fixer, alterations_info, keep_water=False) 50 | fixer.findMissingResidues() 51 | alterations_info['missing_residues'] = fixer.missingResidues 52 | fixer.findMissingAtoms() 53 | alterations_info['missing_heavy_atoms'] = fixer.missingAtoms 54 | alterations_info['missing_terminals'] = fixer.missingTerminals 55 | fixer.addMissingAtoms(seed=0) 56 | fixer.addMissingHydrogens() 57 | out_handle = io.StringIO() 58 | app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle, 59 | keepIds=True) 60 | return out_handle.getvalue() 61 | 62 | 63 | def clean_structure(pdb_structure, alterations_info): 64 | """Applies additional fixes to an OpenMM structure, to handle edge cases. 65 | 66 | Args: 67 | pdb_structure: An OpenMM structure to modify and fix. 68 | alterations_info: A dict that will store details of changes made. 69 | """ 70 | _replace_met_se(pdb_structure, alterations_info) 71 | _remove_chains_of_length_one(pdb_structure, alterations_info) 72 | 73 | 74 | def _remove_heterogens(fixer, alterations_info, keep_water): 75 | """Removes the residues that Pdbfixer considers to be heterogens. 76 | 77 | Args: 78 | fixer: A Pdbfixer instance. 79 | alterations_info: A dict that will store details of changes made. 80 | keep_water: If True, water (HOH) is not considered to be a heterogen. 81 | """ 82 | initial_resnames = set() 83 | for chain in fixer.topology.chains(): 84 | for residue in chain.residues(): 85 | initial_resnames.add(residue.name) 86 | fixer.removeHeterogens(keepWater=keep_water) 87 | final_resnames = set() 88 | for chain in fixer.topology.chains(): 89 | for residue in chain.residues(): 90 | final_resnames.add(residue.name) 91 | alterations_info['removed_heterogens'] = ( 92 | initial_resnames.difference(final_resnames)) 93 | 94 | 95 | def _replace_met_se(pdb_structure, alterations_info): 96 | """Replace the Se in any MET residues that were not marked as modified.""" 97 | modified_met_residues = [] 98 | for res in pdb_structure.iter_residues(): 99 | name = res.get_name_with_spaces().strip() 100 | if name == 'MET': 101 | s_atom = res.get_atom('SD') 102 | if s_atom.element_symbol == 'Se': 103 | s_atom.element_symbol = 'S' 104 | s_atom.element = element.get_by_symbol('S') 105 | modified_met_residues.append(s_atom.residue_number) 106 | alterations_info['Se_in_MET'] = modified_met_residues 107 | 108 | 109 | def _remove_chains_of_length_one(pdb_structure, alterations_info): 110 | """Removes chains that correspond to a single amino acid. 111 | 112 | A single amino acid in a chain is both N and C terminus. There is no force 113 | template for this case. 114 | 115 | Args: 116 | pdb_structure: An OpenMM pdb_structure to modify and fix. 117 | alterations_info: A dict that will store details of changes made. 118 | """ 119 | removed_chains = {} 120 | for model in pdb_structure.iter_models(): 121 | valid_chains = [c for c in model.iter_chains() if len(c) > 1] 122 | invalid_chain_ids = [c.chain_id for c in model.iter_chains() if len(c) <= 1] 123 | model.chains = valid_chains 124 | for chain_id in invalid_chain_ids: 125 | model.chains_by_id.pop(chain_id) 126 | removed_chains[model.number] = invalid_chain_ids 127 | alterations_info['removed_chains'] = removed_chains 128 | -------------------------------------------------------------------------------- /alphafold/relax/cleanup_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for relax.cleanup.""" 16 | import io 17 | 18 | from absl.testing import absltest 19 | from simtk.openmm.app.internal import pdbstructure 20 | 21 | from alphafold.relax import cleanup 22 | 23 | 24 | def _pdb_to_structure(pdb_str): 25 | handle = io.StringIO(pdb_str) 26 | return pdbstructure.PdbStructure(handle) 27 | 28 | 29 | def _lines_to_structure(pdb_lines): 30 | return _pdb_to_structure('\n'.join(pdb_lines)) 31 | 32 | 33 | class CleanupTest(absltest.TestCase): 34 | 35 | def test_missing_residues(self): 36 | pdb_lines = ['SEQRES 1 C 3 CYS GLY LEU', 37 | 'ATOM 1 N CYS C 1 -12.262 20.115 60.959 1.00 ' 38 | '19.08 N', 39 | 'ATOM 2 CA CYS C 1 -11.065 20.934 60.773 1.00 ' 40 | '17.23 C', 41 | 'ATOM 3 C CYS C 1 -10.002 20.742 61.844 1.00 ' 42 | '15.38 C', 43 | 'ATOM 4 O CYS C 1 -10.284 20.225 62.929 1.00 ' 44 | '16.04 O', 45 | 'ATOM 5 N LEU C 3 -7.688 18.700 62.045 1.00 ' 46 | '14.75 N', 47 | 'ATOM 6 CA LEU C 3 -7.256 17.320 62.234 1.00 ' 48 | '16.81 C', 49 | 'ATOM 7 C LEU C 3 -6.380 16.864 61.070 1.00 ' 50 | '16.95 C', 51 | 'ATOM 8 O LEU C 3 -6.551 17.332 59.947 1.00 ' 52 | '16.97 O'] 53 | input_handle = io.StringIO('\n'.join(pdb_lines)) 54 | alterations = {} 55 | result = cleanup.fix_pdb(input_handle, alterations) 56 | structure = _pdb_to_structure(result) 57 | residue_names = [r.get_name() for r in structure.iter_residues()] 58 | self.assertCountEqual(residue_names, ['CYS', 'GLY', 'LEU']) 59 | self.assertCountEqual(alterations['missing_residues'].values(), [['GLY']]) 60 | 61 | def test_missing_atoms(self): 62 | pdb_lines = ['SEQRES 1 A 1 PRO', 63 | 'ATOM 1 CA PRO A 1 1.000 1.000 1.000 1.00 ' 64 | ' 0.00 C'] 65 | input_handle = io.StringIO('\n'.join(pdb_lines)) 66 | alterations = {} 67 | result = cleanup.fix_pdb(input_handle, alterations) 68 | structure = _pdb_to_structure(result) 69 | atom_names = [a.get_name() for a in structure.iter_atoms()] 70 | self.assertCountEqual(atom_names, ['N', 'CD', 'HD2', 'HD3', 'CG', 'HG2', 71 | 'HG3', 'CB', 'HB2', 'HB3', 'CA', 'HA', 72 | 'C', 'O', 'H2', 'H3', 'OXT']) 73 | missing_atoms_by_residue = list(alterations['missing_heavy_atoms'].values()) 74 | self.assertLen(missing_atoms_by_residue, 1) 75 | atoms_added = [a.name for a in missing_atoms_by_residue[0]] 76 | self.assertCountEqual(atoms_added, ['N', 'CD', 'CG', 'CB', 'C', 'O']) 77 | missing_terminals_by_residue = alterations['missing_terminals'] 78 | self.assertLen(missing_terminals_by_residue, 1) 79 | has_missing_terminal = [r.name for r in missing_terminals_by_residue.keys()] 80 | self.assertCountEqual(has_missing_terminal, ['PRO']) 81 | self.assertCountEqual([t for t in missing_terminals_by_residue.values()], 82 | [['OXT']]) 83 | 84 | def test_remove_heterogens(self): 85 | pdb_lines = ['SEQRES 1 A 1 GLY', 86 | 'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 ' 87 | ' 0.00 C', 88 | 'ATOM 2 O HOH A 2 0.000 0.000 0.000 1.00 ' 89 | ' 0.00 O'] 90 | input_handle = io.StringIO('\n'.join(pdb_lines)) 91 | alterations = {} 92 | result = cleanup.fix_pdb(input_handle, alterations) 93 | structure = _pdb_to_structure(result) 94 | self.assertCountEqual([res.get_name() for res in structure.iter_residues()], 95 | ['GLY']) 96 | self.assertEqual(alterations['removed_heterogens'], set(['HOH'])) 97 | 98 | def test_fix_nonstandard_residues(self): 99 | pdb_lines = ['SEQRES 1 A 1 DAL', 100 | 'ATOM 1 CA DAL A 1 0.000 0.000 0.000 1.00 ' 101 | ' 0.00 C'] 102 | input_handle = io.StringIO('\n'.join(pdb_lines)) 103 | alterations = {} 104 | result = cleanup.fix_pdb(input_handle, alterations) 105 | structure = _pdb_to_structure(result) 106 | residue_names = [res.get_name() for res in structure.iter_residues()] 107 | self.assertCountEqual(residue_names, ['ALA']) 108 | self.assertLen(alterations['nonstandard_residues'], 1) 109 | original_res, new_name = alterations['nonstandard_residues'][0] 110 | self.assertEqual(original_res.id, '1') 111 | self.assertEqual(new_name, 'ALA') 112 | 113 | def test_replace_met_se(self): 114 | pdb_lines = ['SEQRES 1 A 1 MET', 115 | 'ATOM 1 SD MET A 1 0.000 0.000 0.000 1.00 ' 116 | ' 0.00 Se'] 117 | structure = _lines_to_structure(pdb_lines) 118 | alterations = {} 119 | cleanup._replace_met_se(structure, alterations) 120 | sd = [a for a in structure.iter_atoms() if a.get_name() == 'SD'] 121 | self.assertLen(sd, 1) 122 | self.assertEqual(sd[0].element_symbol, 'S') 123 | self.assertCountEqual(alterations['Se_in_MET'], [sd[0].residue_number]) 124 | 125 | def test_remove_chains_of_length_one(self): 126 | pdb_lines = ['SEQRES 1 A 1 GLY', 127 | 'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 ' 128 | ' 0.00 C'] 129 | structure = _lines_to_structure(pdb_lines) 130 | alterations = {} 131 | cleanup._remove_chains_of_length_one(structure, alterations) 132 | chains = list(structure.iter_chains()) 133 | self.assertEmpty(chains) 134 | self.assertCountEqual(alterations['removed_chains'].values(), [['A']]) 135 | 136 | 137 | if __name__ == '__main__': 138 | absltest.main() 139 | -------------------------------------------------------------------------------- /alphafold/relax/relax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Amber relaxation.""" 16 | from typing import Any, Dict, Sequence, Tuple 17 | 18 | import numpy as np 19 | 20 | from alphafold.common import protein 21 | from alphafold.relax import amber_minimize 22 | from alphafold.relax import utils 23 | 24 | 25 | class AmberRelaxation(object): 26 | """Amber relaxation.""" 27 | 28 | def __init__(self, 29 | *, 30 | max_iterations: int, 31 | tolerance: float, 32 | stiffness: float, 33 | exclude_residues: Sequence[int], 34 | max_outer_iterations: int): 35 | """Initialize Amber Relaxer. 36 | 37 | Args: 38 | max_iterations: Maximum number of L-BFGS iterations. 0 means no max. 39 | tolerance: kcal/mol, the energy tolerance of L-BFGS. 40 | stiffness: kcal/mol A**2, spring constant of heavy atom restraining 41 | potential. 42 | exclude_residues: Residues to exclude from per-atom restraining. 43 | Zero-indexed. 44 | max_outer_iterations: Maximum number of violation-informed relax 45 | iterations. A value of 1 will run the non-iterative procedure used in 46 | CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes 47 | as soon as there are no violations, hence in most cases this causes no 48 | slowdown. In the worst case we do 20 outer iterations. 49 | """ 50 | 51 | self._max_iterations = max_iterations 52 | self._tolerance = tolerance 53 | self._stiffness = stiffness 54 | self._exclude_residues = exclude_residues 55 | self._max_outer_iterations = max_outer_iterations 56 | 57 | def process(self, *, 58 | prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]: 59 | """Runs Amber relax on a prediction, adds hydrogens, returns PDB string.""" 60 | out = amber_minimize.run_pipeline( 61 | prot=prot, max_iterations=self._max_iterations, 62 | tolerance=self._tolerance, stiffness=self._stiffness, 63 | exclude_residues=self._exclude_residues, 64 | max_outer_iterations=self._max_outer_iterations) 65 | min_pos = out['pos'] 66 | start_pos = out['posinit'] 67 | rmsd = np.sqrt(np.sum((start_pos - min_pos)**2) / start_pos.shape[0]) 68 | debug_data = { 69 | 'initial_energy': out['einit'], 70 | 'final_energy': out['efinal'], 71 | 'attempts': out['min_attempts'], 72 | 'rmsd': rmsd 73 | } 74 | pdb_str = amber_minimize.clean_protein(prot) 75 | min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos) 76 | min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors) 77 | utils.assert_equal_nonterminal_atom_types( 78 | protein.from_pdb_string(min_pdb).atom_mask, 79 | prot.atom_mask) 80 | violations = out['structural_violations'][ 81 | 'total_per_residue_violations_mask'] 82 | return min_pdb, debug_data, violations 83 | 84 | -------------------------------------------------------------------------------- /alphafold/relax/relax_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for relax.""" 16 | import os 17 | 18 | from absl.testing import absltest 19 | import numpy as np 20 | from alphafold.common import protein 21 | from alphafold.relax import relax 22 | # Internal import (7716). 23 | 24 | 25 | class RunAmberRelaxTest(absltest.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | self.test_dir = os.path.join( 30 | absltest.get_default_test_srcdir(), 31 | 'alphafold/relax/testdata/') 32 | self.test_config = { 33 | 'max_iterations': 1, 34 | 'tolerance': 2.39, 35 | 'stiffness': 10.0, 36 | 'exclude_residues': [], 37 | 'max_outer_iterations': 1} 38 | 39 | def test_process(self): 40 | amber_relax = relax.AmberRelaxation(**self.test_config) 41 | 42 | with open(os.path.join(self.test_dir, 'model_output.pdb')) as f: 43 | test_prot = protein.from_pdb_string(f.read()) 44 | pdb_min, debug_info, num_violations = amber_relax.process(prot=test_prot) 45 | 46 | self.assertCountEqual(debug_info.keys(), 47 | set({'initial_energy', 'final_energy', 48 | 'attempts', 'rmsd'})) 49 | self.assertLess(debug_info['final_energy'], debug_info['initial_energy']) 50 | self.assertGreater(debug_info['rmsd'], 0) 51 | 52 | prot_min = protein.from_pdb_string(pdb_min) 53 | # Most protein properties should be unchanged. 54 | np.testing.assert_almost_equal(test_prot.aatype, prot_min.aatype) 55 | np.testing.assert_almost_equal(test_prot.residue_index, 56 | prot_min.residue_index) 57 | # Atom mask and bfactors identical except for terminal OXT of last residue. 58 | np.testing.assert_almost_equal(test_prot.atom_mask[:-1, :], 59 | prot_min.atom_mask[:-1, :]) 60 | np.testing.assert_almost_equal(test_prot.b_factors[:-1, :], 61 | prot_min.b_factors[:-1, :]) 62 | np.testing.assert_almost_equal(test_prot.atom_mask[:, :-1], 63 | prot_min.atom_mask[:, :-1]) 64 | np.testing.assert_almost_equal(test_prot.b_factors[:, :-1], 65 | prot_min.b_factors[:, :-1]) 66 | # There are no residues with violations. 67 | np.testing.assert_equal(num_violations, np.zeros_like(num_violations)) 68 | 69 | def test_unresolved_violations(self): 70 | amber_relax = relax.AmberRelaxation(**self.test_config) 71 | with open(os.path.join(self.test_dir, 72 | 'with_violations_casp14.pdb')) as f: 73 | test_prot = protein.from_pdb_string(f.read()) 74 | _, _, num_violations = amber_relax.process(prot=test_prot) 75 | exp_num_violations = np.array( 76 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 77 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 78 | 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 79 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 80 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81 | 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 82 | 0, 0, 0, 0]) 83 | # Check no violations were added. Can't check exactly due to stochasticity. 84 | self.assertTrue(np.all(num_violations <= exp_num_violations)) 85 | 86 | 87 | if __name__ == '__main__': 88 | absltest.main() 89 | -------------------------------------------------------------------------------- /alphafold/relax/testdata/model_output.pdb: -------------------------------------------------------------------------------- 1 | ATOM 1 C MET A 1 1.921 -46.152 7.786 1.00 4.39 C 2 | ATOM 2 CA MET A 1 1.631 -46.829 9.131 1.00 4.39 C 3 | ATOM 3 CB MET A 1 2.759 -47.768 9.578 1.00 4.39 C 4 | ATOM 4 CE MET A 1 3.466 -49.770 13.198 1.00 4.39 C 5 | ATOM 5 CG MET A 1 2.581 -48.221 11.034 1.00 4.39 C 6 | ATOM 6 H MET A 1 0.234 -48.249 8.549 1.00 4.39 H 7 | ATOM 7 H2 MET A 1 -0.424 -46.789 8.952 1.00 4.39 H 8 | ATOM 8 H3 MET A 1 0.111 -47.796 10.118 1.00 4.39 H 9 | ATOM 9 HA MET A 1 1.628 -46.009 9.849 1.00 4.39 H 10 | ATOM 10 HB2 MET A 1 3.701 -47.225 9.500 1.00 4.39 H 11 | ATOM 11 HB3 MET A 1 2.807 -48.640 8.926 1.00 4.39 H 12 | ATOM 12 HE1 MET A 1 2.747 -50.537 12.910 1.00 4.39 H 13 | ATOM 13 HE2 MET A 1 4.296 -50.241 13.725 1.00 4.39 H 14 | ATOM 14 HE3 MET A 1 2.988 -49.052 13.864 1.00 4.39 H 15 | ATOM 15 HG2 MET A 1 1.791 -48.971 11.083 1.00 4.39 H 16 | ATOM 16 HG3 MET A 1 2.295 -47.368 11.650 1.00 4.39 H 17 | ATOM 17 N MET A 1 0.291 -47.464 9.182 1.00 4.39 N 18 | ATOM 18 O MET A 1 2.091 -44.945 7.799 1.00 4.39 O 19 | ATOM 19 SD MET A 1 4.096 -48.921 11.725 1.00 4.39 S 20 | ATOM 20 C LYS A 2 1.366 -45.033 4.898 1.00 2.92 C 21 | ATOM 21 CA LYS A 2 2.235 -46.242 5.308 1.00 2.92 C 22 | ATOM 22 CB LYS A 2 2.206 -47.314 4.196 1.00 2.92 C 23 | ATOM 23 CD LYS A 2 3.331 -49.342 3.134 1.00 2.92 C 24 | ATOM 24 CE LYS A 2 4.434 -50.403 3.293 1.00 2.92 C 25 | ATOM 25 CG LYS A 2 3.294 -48.395 4.349 1.00 2.92 C 26 | ATOM 26 H LYS A 2 1.832 -47.853 6.656 1.00 2.92 H 27 | ATOM 27 HA LYS A 2 3.248 -45.841 5.355 1.00 2.92 H 28 | ATOM 28 HB2 LYS A 2 1.223 -47.785 4.167 1.00 2.92 H 29 | ATOM 29 HB3 LYS A 2 2.363 -46.812 3.241 1.00 2.92 H 30 | ATOM 30 HD2 LYS A 2 3.524 -48.754 2.237 1.00 2.92 H 31 | ATOM 31 HD3 LYS A 2 2.364 -49.833 3.031 1.00 2.92 H 32 | ATOM 32 HE2 LYS A 2 5.383 -49.891 3.455 1.00 2.92 H 33 | ATOM 33 HE3 LYS A 2 4.225 -51.000 4.180 1.00 2.92 H 34 | ATOM 34 HG2 LYS A 2 3.102 -48.977 5.250 1.00 2.92 H 35 | ATOM 35 HG3 LYS A 2 4.264 -47.909 4.446 1.00 2.92 H 36 | ATOM 36 HZ1 LYS A 2 4.763 -50.747 1.274 1.00 2.92 H 37 | ATOM 37 HZ2 LYS A 2 3.681 -51.785 1.931 1.00 2.92 H 38 | ATOM 38 HZ3 LYS A 2 5.280 -51.965 2.224 1.00 2.92 H 39 | ATOM 39 N LYS A 2 1.907 -46.846 6.629 1.00 2.92 N 40 | ATOM 40 NZ LYS A 2 4.542 -51.286 2.100 1.00 2.92 N 41 | ATOM 41 O LYS A 2 1.882 -44.093 4.312 1.00 2.92 O 42 | ATOM 42 C PHE A 3 -0.511 -42.597 5.624 1.00 4.39 C 43 | ATOM 43 CA PHE A 3 -0.853 -43.933 4.929 1.00 4.39 C 44 | ATOM 44 CB PHE A 3 -2.271 -44.408 5.285 1.00 4.39 C 45 | ATOM 45 CD1 PHE A 3 -3.760 -43.542 3.432 1.00 4.39 C 46 | ATOM 46 CD2 PHE A 3 -4.050 -42.638 5.675 1.00 4.39 C 47 | ATOM 47 CE1 PHE A 3 -4.797 -42.715 2.965 1.00 4.39 C 48 | ATOM 48 CE2 PHE A 3 -5.091 -41.818 5.207 1.00 4.39 C 49 | ATOM 49 CG PHE A 3 -3.382 -43.505 4.788 1.00 4.39 C 50 | ATOM 50 CZ PHE A 3 -5.463 -41.853 3.853 1.00 4.39 C 51 | ATOM 51 H PHE A 3 -0.311 -45.868 5.655 1.00 4.39 H 52 | ATOM 52 HA PHE A 3 -0.817 -43.746 3.856 1.00 4.39 H 53 | ATOM 53 HB2 PHE A 3 -2.353 -44.512 6.367 1.00 4.39 H 54 | ATOM 54 HB3 PHE A 3 -2.432 -45.393 4.848 1.00 4.39 H 55 | ATOM 55 HD1 PHE A 3 -3.255 -44.198 2.739 1.00 4.39 H 56 | ATOM 56 HD2 PHE A 3 -3.768 -42.590 6.716 1.00 4.39 H 57 | ATOM 57 HE1 PHE A 3 -5.083 -42.735 1.923 1.00 4.39 H 58 | ATOM 58 HE2 PHE A 3 -5.604 -41.151 5.885 1.00 4.39 H 59 | ATOM 59 HZ PHE A 3 -6.257 -41.215 3.493 1.00 4.39 H 60 | ATOM 60 N PHE A 3 0.079 -45.027 5.253 1.00 4.39 N 61 | ATOM 61 O PHE A 3 -0.633 -41.541 5.014 1.00 4.39 O 62 | ATOM 62 C LEU A 4 1.598 -40.732 7.042 1.00 4.39 C 63 | ATOM 63 CA LEU A 4 0.367 -41.437 7.633 1.00 4.39 C 64 | ATOM 64 CB LEU A 4 0.628 -41.823 9.104 1.00 4.39 C 65 | ATOM 65 CD1 LEU A 4 -0.319 -42.778 11.228 1.00 4.39 C 66 | ATOM 66 CD2 LEU A 4 -1.300 -40.694 10.309 1.00 4.39 C 67 | ATOM 67 CG LEU A 4 -0.650 -42.027 9.937 1.00 4.39 C 68 | ATOM 68 H LEU A 4 0.163 -43.538 7.292 1.00 4.39 H 69 | ATOM 69 HA LEU A 4 -0.445 -40.712 7.588 1.00 4.39 H 70 | ATOM 70 HB2 LEU A 4 1.213 -41.034 9.576 1.00 4.39 H 71 | ATOM 71 HB3 LEU A 4 1.235 -42.728 9.127 1.00 4.39 H 72 | ATOM 72 HD11 LEU A 4 0.380 -42.191 11.824 1.00 4.39 H 73 | ATOM 73 HD12 LEU A 4 0.127 -43.747 11.002 1.00 4.39 H 74 | ATOM 74 HD13 LEU A 4 -1.230 -42.927 11.808 1.00 4.39 H 75 | ATOM 75 HD21 LEU A 4 -0.606 -40.080 10.883 1.00 4.39 H 76 | ATOM 76 HD22 LEU A 4 -2.193 -40.869 10.909 1.00 4.39 H 77 | ATOM 77 HD23 LEU A 4 -1.593 -40.147 9.413 1.00 4.39 H 78 | ATOM 78 HG LEU A 4 -1.359 -42.630 9.370 1.00 4.39 H 79 | ATOM 79 N LEU A 4 -0.012 -42.638 6.869 1.00 4.39 N 80 | ATOM 80 O LEU A 4 1.655 -39.508 7.028 1.00 4.39 O 81 | ATOM 81 C VAL A 5 3.372 -40.190 4.573 1.00 4.39 C 82 | ATOM 82 CA VAL A 5 3.752 -40.956 5.845 1.00 4.39 C 83 | ATOM 83 CB VAL A 5 4.757 -42.083 5.528 1.00 4.39 C 84 | ATOM 84 CG1 VAL A 5 6.019 -41.568 4.827 1.00 4.39 C 85 | ATOM 85 CG2 VAL A 5 5.199 -42.807 6.810 1.00 4.39 C 86 | ATOM 86 H VAL A 5 2.440 -42.503 6.548 1.00 4.39 H 87 | ATOM 87 HA VAL A 5 4.234 -40.242 6.512 1.00 4.39 H 88 | ATOM 88 HB VAL A 5 4.279 -42.813 4.875 1.00 4.39 H 89 | ATOM 89 HG11 VAL A 5 6.494 -40.795 5.431 1.00 4.39 H 90 | ATOM 90 HG12 VAL A 5 5.770 -41.145 3.853 1.00 4.39 H 91 | ATOM 91 HG13 VAL A 5 6.725 -42.383 4.670 1.00 4.39 H 92 | ATOM 92 HG21 VAL A 5 4.347 -43.283 7.297 1.00 4.39 H 93 | ATOM 93 HG22 VAL A 5 5.933 -43.575 6.568 1.00 4.39 H 94 | ATOM 94 HG23 VAL A 5 5.651 -42.093 7.498 1.00 4.39 H 95 | ATOM 95 N VAL A 5 2.554 -41.501 6.509 1.00 4.39 N 96 | ATOM 96 O VAL A 5 3.937 -39.138 4.297 1.00 4.39 O 97 | TER 96 VAL A 5 98 | END 99 | -------------------------------------------------------------------------------- /alphafold/relax/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utils for minimization.""" 16 | import io 17 | 18 | from Bio import PDB 19 | import numpy as np 20 | from simtk.openmm import app as openmm_app 21 | from simtk.openmm.app.internal.pdbstructure import PdbStructure 22 | 23 | from alphafold.common import residue_constants 24 | 25 | 26 | def overwrite_pdb_coordinates(pdb_str: str, pos) -> str: 27 | pdb_file = io.StringIO(pdb_str) 28 | structure = PdbStructure(pdb_file) 29 | topology = openmm_app.PDBFile(structure).getTopology() 30 | with io.StringIO() as f: 31 | openmm_app.PDBFile.writeFile(topology, pos, f) 32 | return f.getvalue() 33 | 34 | 35 | def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str: 36 | """Overwrites the B-factors in pdb_str with contents of bfactors array. 37 | 38 | Args: 39 | pdb_str: An input PDB string. 40 | bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the 41 | B-factors are per residue; i.e. that the nonzero entries are identical in 42 | [0, i, :]. 43 | 44 | Returns: 45 | A new PDB string with the B-factors replaced. 46 | """ 47 | if bfactors.shape[-1] != residue_constants.atom_type_num: 48 | raise ValueError( 49 | f'Invalid final dimension size for bfactors: {bfactors.shape[-1]}.') 50 | 51 | parser = PDB.PDBParser(QUIET=True) 52 | handle = io.StringIO(pdb_str) 53 | structure = parser.get_structure('', handle) 54 | 55 | curr_resid = ('', '', '') 56 | idx = -1 57 | for atom in structure.get_atoms(): 58 | atom_resid = atom.parent.get_id() 59 | if atom_resid != curr_resid: 60 | idx += 1 61 | if idx >= bfactors.shape[0]: 62 | raise ValueError('Index into bfactors exceeds number of residues. ' 63 | 'B-factors shape: {shape}, idx: {idx}.') 64 | curr_resid = atom_resid 65 | atom.bfactor = bfactors[idx, residue_constants.atom_order['CA']] 66 | 67 | new_pdb = io.StringIO() 68 | pdb_io = PDB.PDBIO() 69 | pdb_io.set_structure(structure) 70 | pdb_io.save(new_pdb) 71 | return new_pdb.getvalue() 72 | 73 | 74 | def assert_equal_nonterminal_atom_types( 75 | atom_mask: np.ndarray, ref_atom_mask: np.ndarray): 76 | """Checks that pre- and post-minimized proteins have same atom set.""" 77 | # Ignore any terminal OXT atoms which may have been added by minimization. 78 | oxt = residue_constants.atom_order['OXT'] 79 | no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool) 80 | no_oxt_mask[..., oxt] = False 81 | np.testing.assert_almost_equal(ref_atom_mask[no_oxt_mask], 82 | atom_mask[no_oxt_mask]) 83 | 84 | -------------------------------------------------------------------------------- /alphafold/relax/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for utils.""" 16 | 17 | import os 18 | 19 | from absl.testing import absltest 20 | import numpy as np 21 | 22 | from alphafold.common import protein 23 | from alphafold.relax import utils 24 | # Internal import (7716). 25 | 26 | 27 | class UtilsTest(absltest.TestCase): 28 | 29 | def test_overwrite_b_factors(self): 30 | testdir = os.path.join( 31 | absltest.get_default_test_srcdir(), 32 | 'alphafold/relax/testdata/' 33 | 'multiple_disulfides_target.pdb') 34 | with open(testdir) as f: 35 | test_pdb = f.read() 36 | n_residues = 191 37 | bfactors = np.stack([np.arange(0, n_residues)] * 37, axis=-1) 38 | 39 | output_pdb = utils.overwrite_b_factors(test_pdb, bfactors) 40 | 41 | # Check that the atom lines are unchanged apart from the B-factors. 42 | atom_lines_original = [l for l in test_pdb.split('\n') if l[:4] == ('ATOM')] 43 | atom_lines_new = [l for l in output_pdb.split('\n') if l[:4] == ('ATOM')] 44 | for line_original, line_new in zip(atom_lines_original, atom_lines_new): 45 | self.assertEqual(line_original[:60].strip(), line_new[:60].strip()) 46 | self.assertEqual(line_original[66:].strip(), line_new[66:].strip()) 47 | 48 | # Check B-factors are correctly set for all atoms present. 49 | as_protein = protein.from_pdb_string(output_pdb) 50 | np.testing.assert_almost_equal( 51 | np.where(as_protein.atom_mask > 0, as_protein.b_factors, 0), 52 | np.where(as_protein.atom_mask > 0, bfactors, 0)) 53 | 54 | 55 | if __name__ == '__main__': 56 | absltest.main() 57 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ARG CUDA=11.0 16 | FROM nvidia/cuda:${CUDA}-base 17 | # FROM directive resets ARGS, so we specify again (the value is retained if 18 | # previously set). 19 | ARG CUDA 20 | 21 | # Use bash to support string substitution. 22 | SHELL ["/bin/bash", "-c"] 23 | 24 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ 25 | build-essential \ 26 | cmake \ 27 | cuda-command-line-tools-${CUDA/./-} \ 28 | git \ 29 | hmmer \ 30 | kalign \ 31 | tzdata \ 32 | wget \ 33 | && rm -rf /var/lib/apt/lists/* 34 | 35 | # Compile HHsuite from source. 36 | RUN git clone --branch v3.3.0 https://github.com/soedinglab/hh-suite.git /tmp/hh-suite \ 37 | && mkdir /tmp/hh-suite/build 38 | WORKDIR /tmp/hh-suite/build 39 | RUN cmake -DCMAKE_INSTALL_PREFIX=/opt/hhsuite .. \ 40 | && make -j 4 && make install \ 41 | && ln -s /opt/hhsuite/bin/* /usr/bin \ 42 | && rm -rf /tmp/hh-suite 43 | 44 | # Install Miniconda package manger. 45 | RUN wget -q -P /tmp \ 46 | https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ 47 | && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \ 48 | && rm /tmp/Miniconda3-latest-Linux-x86_64.sh 49 | 50 | # Install conda packages. 51 | ENV PATH="/opt/conda/bin:$PATH" 52 | RUN conda update -qy conda \ 53 | && conda install -y -c conda-forge \ 54 | openmm=7.5.1 \ 55 | cudatoolkit==${CUDA}.3 \ 56 | pdbfixer \ 57 | pip \ 58 | python=3.7 59 | 60 | COPY . /app/alphafold 61 | RUN wget -q -P /app/alphafold/alphafold/common/ \ 62 | https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt 63 | 64 | # Install pip packages. 65 | RUN pip3 install --upgrade pip \ 66 | && pip3 install -r /app/alphafold/requirements.txt \ 67 | && pip3 install --upgrade jax jaxlib==0.1.69+cuda${CUDA/./} -f \ 68 | https://storage.googleapis.com/jax-releases/jax_releases.html 69 | 70 | # Apply OpenMM patch. 71 | WORKDIR /opt/conda/lib/python3.7/site-packages 72 | RUN patch -p0 < /app/alphafold/docker/openmm.patch 73 | 74 | # We need to run `ldconfig` first to ensure GPUs are visible, due to some quirk 75 | # with Debian. See https://github.com/NVIDIA/nvidia-docker/issues/1399 for 76 | # details. 77 | # ENTRYPOINT does not support easily running multiple commands, so instead we 78 | # write a shell script to wrap them up. 79 | WORKDIR /app/alphafold 80 | RUN echo $'#!/bin/bash\n\ 81 | ldconfig\n\ 82 | python /app/alphafold/run_alphafold.py "$@"' > /app/run_alphafold.sh \ 83 | && chmod +x /app/run_alphafold.sh 84 | ENTRYPOINT ["/app/run_alphafold.sh"] 85 | -------------------------------------------------------------------------------- /docker/openmm.patch: -------------------------------------------------------------------------------- 1 | Index: simtk/openmm/app/topology.py 2 | =================================================================== 3 | --- simtk.orig/openmm/app/topology.py 4 | +++ simtk/openmm/app/topology.py 5 | @@ -356,19 +356,35 @@ 6 | def isCyx(res): 7 | names = [atom.name for atom in res._atoms] 8 | return 'SG' in names and 'HG' not in names 9 | + # This function is used to prevent multiple di-sulfide bonds from being 10 | + # assigned to a given atom. This is a DeepMind modification. 11 | + def isDisulfideBonded(atom): 12 | + for b in self._bonds: 13 | + if (atom in b and b[0].name == 'SG' and 14 | + b[1].name == 'SG'): 15 | + return True 16 | + 17 | + return False 18 | 19 | cyx = [res for res in self.residues() if res.name == 'CYS' and isCyx(res)] 20 | atomNames = [[atom.name for atom in res._atoms] for res in cyx] 21 | for i in range(len(cyx)): 22 | sg1 = cyx[i]._atoms[atomNames[i].index('SG')] 23 | pos1 = positions[sg1.index] 24 | + candidate_distance, candidate_atom = 0.3*nanometers, None 25 | for j in range(i): 26 | sg2 = cyx[j]._atoms[atomNames[j].index('SG')] 27 | pos2 = positions[sg2.index] 28 | delta = [x-y for (x,y) in zip(pos1, pos2)] 29 | distance = sqrt(delta[0]*delta[0] + delta[1]*delta[1] + delta[2]*delta[2]) 30 | - if distance < 0.3*nanometers: 31 | - self.addBond(sg1, sg2) 32 | + if distance < candidate_distance and not isDisulfideBonded(sg2): 33 | + candidate_distance = distance 34 | + candidate_atom = sg2 35 | + # Assign bond to closest pair. 36 | + if candidate_atom: 37 | + self.addBond(sg1, candidate_atom) 38 | + 39 | + 40 | 41 | class Chain(object): 42 | """A Chain object represents a chain within a Topology.""" 43 | -------------------------------------------------------------------------------- /docker/requirements.txt: -------------------------------------------------------------------------------- 1 | # Dependencies necessary to execute run_docker.py 2 | absl-py==0.13.0 3 | docker==5.0.0 4 | -------------------------------------------------------------------------------- /docker/run_docker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Docker launch script for Alphafold docker image.""" 16 | 17 | import os 18 | import signal 19 | from typing import Tuple 20 | 21 | from absl import app 22 | from absl import flags 23 | from absl import logging 24 | import docker 25 | from docker import types 26 | 27 | 28 | #### USER CONFIGURATION #### 29 | 30 | # Set to target of scripts/download_all_databases.sh 31 | DOWNLOAD_DIR = 'SET ME' 32 | 33 | # Name of the AlphaFold Docker image. 34 | docker_image_name = 'alphafold' 35 | 36 | # Path to a directory that will store the results. 37 | output_dir = '/tmp/alphafold' 38 | 39 | # Names of models to use. 40 | model_names = [ 41 | 'model_1', 42 | 'model_2', 43 | 'model_3', 44 | 'model_4', 45 | 'model_5', 46 | ] 47 | 48 | # You can individually override the following paths if you have placed the 49 | # data in locations other than the DOWNLOAD_DIR. 50 | 51 | # Path to directory of supporting data, contains 'params' dir. 52 | data_dir = DOWNLOAD_DIR 53 | 54 | # Path to the Uniref90 database for use by JackHMMER. 55 | uniref90_database_path = os.path.join( 56 | DOWNLOAD_DIR, 'uniref90', 'uniref90.fasta') 57 | 58 | # Path to the MGnify database for use by JackHMMER. 59 | mgnify_database_path = os.path.join( 60 | DOWNLOAD_DIR, 'mgnify', 'mgy_clusters_2018_08.fa') 61 | 62 | # Path to the BFD database for use by HHblits. 63 | bfd_database_path = os.path.join( 64 | DOWNLOAD_DIR, 'bfd', 65 | 'bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt') 66 | 67 | # Path to the Small BFD database for use by JackHMMER. 68 | small_bfd_database_path = os.path.join( 69 | DOWNLOAD_DIR, 'small_bfd', 'bfd-first_non_consensus_sequences.fasta') 70 | 71 | # Path to the Uniclust30 database for use by HHblits. 72 | uniclust30_database_path = os.path.join( 73 | DOWNLOAD_DIR, 'uniclust30', 'uniclust30_2018_08', 'uniclust30_2018_08') 74 | 75 | # Path to the PDB70 database for use by HHsearch. 76 | pdb70_database_path = os.path.join(DOWNLOAD_DIR, 'pdb70', 'pdb70') 77 | 78 | # Path to a directory with template mmCIF structures, each named .cif') 79 | template_mmcif_dir = os.path.join(DOWNLOAD_DIR, 'pdb_mmcif', 'mmcif_files') 80 | 81 | # Path to a file mapping obsolete PDB IDs to their replacements. 82 | obsolete_pdbs_path = os.path.join(DOWNLOAD_DIR, 'pdb_mmcif', 'obsolete.dat') 83 | 84 | #### END OF USER CONFIGURATION #### 85 | 86 | 87 | flags.DEFINE_bool('use_gpu', True, 'Enable NVIDIA runtime to run with GPUs.') 88 | flags.DEFINE_string('gpu_devices', 'all', 'Comma separated list of devices to ' 89 | 'pass to NVIDIA_VISIBLE_DEVICES.') 90 | flags.DEFINE_list('fasta_paths', None, 'Paths to FASTA files, each containing ' 91 | 'one sequence. Paths should be separated by commas. ' 92 | 'All FASTA paths must have a unique basename as the ' 93 | 'basename is used to name the output directories for ' 94 | 'each prediction.') 95 | flags.DEFINE_string('max_template_date', None, 'Maximum template release date ' 96 | 'to consider (ISO-8601 format - i.e. YYYY-MM-DD). ' 97 | 'Important if folding historical test sets.') 98 | flags.DEFINE_enum('preset', 'full_dbs', 99 | ['reduced_dbs', 'full_dbs', 'casp14'], 100 | 'Choose preset model configuration - no ensembling and ' 101 | 'smaller genetic database config (reduced_dbs), no ' 102 | 'ensembling and full genetic database config (full_dbs) or ' 103 | 'full genetic database config and 8 model ensemblings ' 104 | '(casp14).') 105 | flags.DEFINE_boolean('benchmark', False, 'Run multiple JAX model evaluations ' 106 | 'to obtain a timing that excludes the compilation time, ' 107 | 'which should be more indicative of the time required for ' 108 | 'inferencing many proteins.') 109 | 110 | FLAGS = flags.FLAGS 111 | 112 | _ROOT_MOUNT_DIRECTORY = '/mnt/' 113 | 114 | 115 | def _create_mount(mount_name: str, path: str) -> Tuple[types.Mount, str]: 116 | path = os.path.abspath(path) 117 | source_path = os.path.dirname(path) 118 | target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, mount_name) 119 | logging.info('Mounting %s -> %s', source_path, target_path) 120 | mount = types.Mount(target_path, source_path, type='bind', read_only=True) 121 | return mount, os.path.join(target_path, os.path.basename(path)) 122 | 123 | 124 | def main(argv): 125 | if len(argv) > 1: 126 | raise app.UsageError('Too many command-line arguments.') 127 | 128 | mounts = [] 129 | command_args = [] 130 | 131 | # Mount each fasta path as a unique target directory. 132 | target_fasta_paths = [] 133 | for i, fasta_path in enumerate(FLAGS.fasta_paths): 134 | mount, target_path = _create_mount(f'fasta_path_{i}', fasta_path) 135 | mounts.append(mount) 136 | target_fasta_paths.append(target_path) 137 | command_args.append(f'--fasta_paths={",".join(target_fasta_paths)}') 138 | 139 | database_paths = [ 140 | ('uniref90_database_path', uniref90_database_path), 141 | ('mgnify_database_path', mgnify_database_path), 142 | ('pdb70_database_path', pdb70_database_path), 143 | ('data_dir', data_dir), 144 | ('template_mmcif_dir', template_mmcif_dir), 145 | ('obsolete_pdbs_path', obsolete_pdbs_path), 146 | ] 147 | if FLAGS.preset == 'reduced_dbs': 148 | database_paths.append(('small_bfd_database_path', small_bfd_database_path)) 149 | else: 150 | database_paths.extend([ 151 | ('uniclust30_database_path', uniclust30_database_path), 152 | ('bfd_database_path', bfd_database_path), 153 | ]) 154 | for name, path in database_paths: 155 | if path: 156 | mount, target_path = _create_mount(name, path) 157 | mounts.append(mount) 158 | command_args.append(f'--{name}={target_path}') 159 | 160 | output_target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, 'output') 161 | mounts.append(types.Mount(output_target_path, output_dir, type='bind')) 162 | 163 | command_args.extend([ 164 | f'--output_dir={output_target_path}', 165 | f'--model_names={",".join(model_names)}', 166 | f'--max_template_date={FLAGS.max_template_date}', 167 | f'--preset={FLAGS.preset}', 168 | f'--benchmark={FLAGS.benchmark}', 169 | '--logtostderr', 170 | ]) 171 | 172 | client = docker.from_env() 173 | container = client.containers.run( 174 | image=docker_image_name, 175 | command=command_args, 176 | runtime='nvidia' if FLAGS.use_gpu else None, 177 | remove=True, 178 | detach=True, 179 | mounts=mounts, 180 | environment={ 181 | 'NVIDIA_VISIBLE_DEVICES': FLAGS.gpu_devices, 182 | # The following flags allow us to make predictions on proteins that 183 | # would typically be too long to fit into GPU memory. 184 | 'TF_FORCE_UNIFIED_MEMORY': '1', 185 | 'XLA_PYTHON_CLIENT_MEM_FRACTION': '4.0', 186 | }) 187 | 188 | # Add signal handler to ensure CTRL+C also stops the running container. 189 | signal.signal(signal.SIGINT, 190 | lambda unused_sig, unused_frame: container.kill()) 191 | 192 | for line in container.logs(stream=True): 193 | logging.info(line.strip().decode('utf-8')) 194 | 195 | 196 | if __name__ == '__main__': 197 | flags.mark_flags_as_required([ 198 | 'fasta_paths', 199 | 'max_template_date', 200 | ]) 201 | app.run(main) 202 | -------------------------------------------------------------------------------- /header.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinformatics/alphafold/18ddb85e42ab6363fe1d86dab403306366725bb2/header.jpg -------------------------------------------------------------------------------- /imgs/casp14_predictions.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinformatics/alphafold/18ddb85e42ab6363fe1d86dab403306366725bb2/imgs/casp14_predictions.gif -------------------------------------------------------------------------------- /imgs/header.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinformatics/alphafold/18ddb85e42ab6363fe1d86dab403306366725bb2/imgs/header.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | biopython==1.79 3 | chex==0.0.7 4 | dm-haiku==0.0.4 5 | dm-tree==0.1.6 6 | docker==5.0.0 7 | immutabledict==2.0.0 8 | jax==0.2.14 9 | ml-collections==0.1.0 10 | numpy==1.19.5 11 | scipy==1.7.0 12 | tensorflow==2.5.0 13 | -------------------------------------------------------------------------------- /run_alphafold_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for run_alphafold.""" 16 | 17 | import os 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import mock 22 | import numpy as np 23 | 24 | import run_alphafold 25 | # Internal import (7716). 26 | 27 | 28 | class RunAlphafoldTest(parameterized.TestCase): 29 | 30 | def test_end_to_end(self): 31 | 32 | data_pipeline_mock = mock.Mock() 33 | model_runner_mock = mock.Mock() 34 | amber_relaxer_mock = mock.Mock() 35 | 36 | data_pipeline_mock.process.return_value = {} 37 | model_runner_mock.process_features.return_value = { 38 | 'aatype': np.zeros((12, 10), dtype=np.int32), 39 | 'residue_index': np.tile(np.arange(10, dtype=np.int32)[None], (12, 1)), 40 | } 41 | model_runner_mock.predict.return_value = { 42 | 'structure_module': { 43 | 'final_atom_positions': np.zeros((10, 37, 3)), 44 | 'final_atom_mask': np.ones((10, 37)), 45 | }, 46 | 'predicted_lddt': { 47 | 'logits': np.ones((10, 50)), 48 | }, 49 | 'plddt': np.zeros(10), 50 | 'ptm': np.array(0.), 51 | 'aligned_confidence_probs': np.zeros((10, 10, 50)), 52 | 'predicted_aligned_error': np.zeros((10, 10)), 53 | 'max_predicted_aligned_error': np.array(0.), 54 | } 55 | amber_relaxer_mock.process.return_value = ('RELAXED', None, None) 56 | 57 | fasta_path = os.path.join(absltest.get_default_test_tmpdir(), 58 | 'target.fasta') 59 | with open(fasta_path, 'wt') as f: 60 | f.write('>A\nAAAAAAAAAAAAA') 61 | fasta_name = 'test' 62 | 63 | out_dir = absltest.get_default_test_tmpdir() 64 | 65 | run_alphafold.predict_structure( 66 | fasta_path=fasta_path, 67 | fasta_name=fasta_name, 68 | output_dir_base=out_dir, 69 | data_pipeline=data_pipeline_mock, 70 | model_runners={'model1': model_runner_mock}, 71 | amber_relaxer=amber_relaxer_mock, 72 | benchmark=False, 73 | random_seed=0) 74 | 75 | 76 | if __name__ == '__main__': 77 | absltest.main() 78 | -------------------------------------------------------------------------------- /scripts/download_all_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips all required data for AlphaFold. 18 | # 19 | # Usage: bash download_all_data.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | DOWNLOAD_MODE="${2:-full_dbs}" # Default mode to full_dbs. 34 | if [[ "${DOWNLOAD_MODE}" != full_dbs && "${DOWNLOAD_MODE}" != reduced_dbs ]] 35 | then 36 | echo "DOWNLOAD_MODE ${DOWNLOAD_MODE} not recognized." 37 | exit 1 38 | fi 39 | 40 | SCRIPT_DIR="$(dirname "$(realpath "$0")")" 41 | 42 | echo "Downloading AlphaFold parameters..." 43 | bash "${SCRIPT_DIR}/download_alphafold_params.sh" "${DOWNLOAD_DIR}" 44 | 45 | if [[ "${DOWNLOAD_MODE}" = full_dbs ]] ; then 46 | echo "Downloading BFD..." 47 | bash "${SCRIPT_DIR}/download_bfd.sh" "${DOWNLOAD_DIR}" 48 | else 49 | echo "Downloading Small BFD..." 50 | bash "${SCRIPT_DIR}/download_small_bfd.sh" "${DOWNLOAD_DIR}" 51 | fi 52 | 53 | echo "Downloading MGnify..." 54 | bash "${SCRIPT_DIR}/download_mgnify.sh" "${DOWNLOAD_DIR}" 55 | 56 | echo "Downloading PDB70..." 57 | bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}" 58 | 59 | echo "Downloading PDB mmCIF files..." 60 | bash "${SCRIPT_DIR}/download_pdb_mmcif.sh" "${DOWNLOAD_DIR}" 61 | 62 | echo "Downloading Uniclust30..." 63 | bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}" 64 | 65 | echo "Downloading Uniref90..." 66 | bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}" 67 | 68 | echo "All data downloaded." 69 | -------------------------------------------------------------------------------- /scripts/download_alphafold_params.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the AlphaFold parameters. 18 | # 19 | # Usage: bash download_alphafold_params.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}/params" 34 | SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar" 35 | BASENAME=$(basename "${SOURCE_URL}") 36 | 37 | mkdir --parents "${ROOT_DIR}" 38 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" 39 | tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \ 40 | --directory="${ROOT_DIR}" --preserve-permissions 41 | rm "${ROOT_DIR}/${BASENAME}" 42 | -------------------------------------------------------------------------------- /scripts/download_bfd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the BFD database for AlphaFold. 18 | # 19 | # Usage: bash download_bfd.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}/bfd" 34 | # Mirror of: 35 | # https://bfd.mmseqs.com/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz. 36 | SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz" 37 | BASENAME=$(basename "${SOURCE_URL}") 38 | 39 | mkdir --parents "${ROOT_DIR}" 40 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" 41 | tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \ 42 | --directory="${ROOT_DIR}" 43 | rm "${ROOT_DIR}/${BASENAME}" 44 | -------------------------------------------------------------------------------- /scripts/download_mgnify.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the MGnify database for AlphaFold. 18 | # 19 | # Usage: bash download_mgnify.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}/mgnify" 34 | # Mirror of: 35 | # ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2018_12/mgy_clusters.fa.gz 36 | SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/mgy_clusters_2018_12.fa.gz" 37 | BASENAME=$(basename "${SOURCE_URL}") 38 | 39 | mkdir --parents "${ROOT_DIR}" 40 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" 41 | pushd "${ROOT_DIR}" 42 | gunzip "${ROOT_DIR}/${BASENAME}" 43 | popd 44 | -------------------------------------------------------------------------------- /scripts/download_pdb70.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the PDB70 database for AlphaFold. 18 | # 19 | # Usage: bash download_pdb70.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}/pdb70" 34 | SOURCE_URL="http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/old-releases/pdb70_from_mmcif_200401.tar.gz" 35 | BASENAME=$(basename "${SOURCE_URL}") 36 | 37 | mkdir --parents "${ROOT_DIR}" 38 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" 39 | tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \ 40 | --directory="${ROOT_DIR}" 41 | rm "${ROOT_DIR}/${BASENAME}" 42 | -------------------------------------------------------------------------------- /scripts/download_pdb_mmcif.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads, unzips and flattens the PDB database for AlphaFold. 18 | # 19 | # Usage: bash download_pdb_mmcif.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | if ! command -v rsync &> /dev/null ; then 33 | echo "Error: rsync could not be found. Please install rsync." 34 | exit 1 35 | fi 36 | 37 | DOWNLOAD_DIR="$1" 38 | ROOT_DIR="${DOWNLOAD_DIR}/pdb_mmcif" 39 | RAW_DIR="${ROOT_DIR}/raw" 40 | MMCIF_DIR="${ROOT_DIR}/mmcif_files" 41 | 42 | echo "Running rsync to fetch all mmCIF files (note that the rsync progress estimate might be inaccurate)..." 43 | mkdir --parents "${RAW_DIR}" 44 | rsync --recursive --links --perms --times --compress --info=progress2 --delete --port=33444 \ 45 | rsync.rcsb.org::ftp_data/structures/divided/mmCIF/ \ 46 | "${RAW_DIR}" 47 | 48 | echo "Unzipping all mmCIF files..." 49 | find "${RAW_DIR}/" -type f -iname "*.gz" -exec gunzip {} + 50 | 51 | echo "Flattening all mmCIF files..." 52 | mkdir --parents "${MMCIF_DIR}" 53 | find "${RAW_DIR}" -type d -empty -delete # Delete empty directories. 54 | for subdir in "${RAW_DIR}"/*; do 55 | mv "${subdir}/"*.cif "${MMCIF_DIR}" 56 | done 57 | 58 | # Delete empty download directory structure. 59 | find "${RAW_DIR}" -type d -empty -delete 60 | 61 | aria2c "ftp://ftp.wwpdb.org/pub/pdb/data/status/obsolete.dat" --dir="${ROOT_DIR}" 62 | -------------------------------------------------------------------------------- /scripts/download_small_bfd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the Small BFD database for AlphaFold. 18 | # 19 | # Usage: bash download_small_bfd.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}/small_bfd" 34 | SOURCE_URL="https://storage.googleapis.com/alphafold-databases/reduced_dbs/bfd-first_non_consensus_sequences.fasta.gz" 35 | BASENAME=$(basename "${SOURCE_URL}") 36 | 37 | mkdir --parents "${ROOT_DIR}" 38 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" 39 | pushd "${ROOT_DIR}" 40 | gunzip "${ROOT_DIR}/${BASENAME}" 41 | popd 42 | -------------------------------------------------------------------------------- /scripts/download_uniclust30.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the Uniclust30 database for AlphaFold. 18 | # 19 | # Usage: bash download_uniclust30.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}/uniclust30" 34 | # Mirror of: 35 | # http://wwwuser.gwdg.de/~compbiol/uniclust/2018_08/uniclust30_2018_08_hhsuite.tar.gz 36 | SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/uniclust30_2018_08_hhsuite.tar.gz" 37 | BASENAME=$(basename "${SOURCE_URL}") 38 | 39 | mkdir --parents "${ROOT_DIR}" 40 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" 41 | tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \ 42 | --directory="${ROOT_DIR}" 43 | rm "${ROOT_DIR}/${BASENAME}" 44 | -------------------------------------------------------------------------------- /scripts/download_uniref90.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the UniRef90 database for AlphaFold. 18 | # 19 | # Usage: bash download_uniref90.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}/uniref90" 34 | SOURCE_URL="ftp://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref90/uniref90.fasta.gz" 35 | BASENAME=$(basename "${SOURCE_URL}") 36 | 37 | mkdir --parents "${ROOT_DIR}" 38 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" 39 | pushd "${ROOT_DIR}" 40 | gunzip "${ROOT_DIR}/${BASENAME}" 41 | popd 42 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Install script for setuptools.""" 15 | 16 | from setuptools import find_packages 17 | from setuptools import setup 18 | 19 | setup( 20 | name='alphafold', 21 | version='2.0.0', 22 | description='An implementation of the inference pipeline of AlphaFold v2.0.' 23 | 'This is a completely new model that was entered as AlphaFold2 in CASP14 ' 24 | 'and published in Nature.', 25 | author='DeepMind', 26 | author_email='alphafold@deepmind.com', 27 | license='Apache License, Version 2.0', 28 | url='https://github.com/deepmind/alphafold', 29 | packages=find_packages(), 30 | install_requires=[ 31 | 'absl-py', 32 | 'biopython', 33 | 'chex', 34 | 'dm-haiku', 35 | 'dm-tree', 36 | 'docker', 37 | 'immutabledict', 38 | 'jax', 39 | 'ml-collections', 40 | 'numpy', 41 | 'scipy', 42 | 'tensorflow', 43 | ], 44 | tests_require=['mock'], 45 | classifiers=[ 46 | 'Development Status :: 5 - Production/Stable', 47 | 'Intended Audience :: Science/Research', 48 | 'License :: OSI Approved :: Apache Software License', 49 | 'Operating System :: POSIX :: Linux', 50 | 'Programming Language :: Python :: 3.6', 51 | 'Programming Language :: Python :: 3.7', 52 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 53 | ], 54 | ) 55 | --------------------------------------------------------------------------------