├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── Representations_AlphaFold2PredictStructure.ipynb
├── Representations_AlphaFold2_v3.ipynb
├── alphafold
├── __init__.py
├── common
│ ├── __init__.py
│ ├── confidence.py
│ ├── protein.py
│ ├── protein_test.py
│ ├── residue_constants.py
│ ├── residue_constants_test.py
│ └── testdata
│ │ └── 2rbg.pdb
├── data
│ ├── __init__.py
│ ├── mmcif_parsing.py
│ ├── parsers.py
│ ├── pipeline.py
│ ├── templates.py
│ └── tools
│ │ ├── __init__.py
│ │ ├── hhblits.py
│ │ ├── hhsearch.py
│ │ ├── hmmbuild.py
│ │ ├── hmmsearch.py
│ │ ├── jackhmmer.py
│ │ ├── kalign.py
│ │ └── utils.py
├── model
│ ├── __init__.py
│ ├── all_atom.py
│ ├── all_atom_test.py
│ ├── common_modules.py
│ ├── config.py
│ ├── data.py
│ ├── features.py
│ ├── folding.py
│ ├── layer_stack.py
│ ├── layer_stack_test.py
│ ├── lddt.py
│ ├── lddt_test.py
│ ├── mapping.py
│ ├── model.py
│ ├── modules.py
│ ├── prng.py
│ ├── prng_test.py
│ ├── quat_affine.py
│ ├── quat_affine_test.py
│ ├── r3.py
│ ├── tf
│ │ ├── __init__.py
│ │ ├── data_transforms.py
│ │ ├── input_pipeline.py
│ │ ├── protein_features.py
│ │ ├── protein_features_test.py
│ │ ├── proteins_dataset.py
│ │ ├── shape_helpers.py
│ │ ├── shape_helpers_test.py
│ │ ├── shape_placeholders.py
│ │ └── utils.py
│ └── utils.py
└── relax
│ ├── __init__.py
│ ├── amber_minimize.py
│ ├── amber_minimize_test.py
│ ├── cleanup.py
│ ├── cleanup_test.py
│ ├── relax.py
│ ├── relax_test.py
│ ├── testdata
│ ├── model_output.pdb
│ ├── multiple_disulfides_target.pdb
│ ├── with_violations.pdb
│ └── with_violations_casp14.pdb
│ ├── utils.py
│ └── utils_test.py
├── docker
├── Dockerfile
├── openmm.patch
├── requirements.txt
└── run_docker.py
├── header.jpg
├── imgs
├── casp14_predictions.gif
└── header.jpg
├── notebooks
└── AlphaFold.ipynb
├── requirements.txt
├── run_alphafold.py
├── run_alphafold_test.py
├── scripts
├── download_all_data.sh
├── download_alphafold_params.sh
├── download_bfd.sh
├── download_mgnify.sh
├── download_pdb70.sh
├── download_pdb_mmcif.sh
├── download_small_bfd.sh
├── download_uniclust30.sh
└── download_uniref90.sh
└── setup.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We welcome small patches related to bug fixes and documentation, but we do not
4 | plan to make any major changes to this repository.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution,
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
--------------------------------------------------------------------------------
/alphafold/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """An implementation of the inference pipeline of AlphaFold v2.0."""
15 |
--------------------------------------------------------------------------------
/alphafold/common/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Common data types and constants used within Alphafold."""
15 |
--------------------------------------------------------------------------------
/alphafold/common/confidence.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Functions for processing confidence metrics."""
16 |
17 | from typing import Dict, Optional, Tuple
18 | import numpy as np
19 | import scipy.special
20 |
21 |
22 | def compute_plddt(logits: np.ndarray) -> np.ndarray:
23 | """Computes per-residue pLDDT from logits.
24 |
25 | Args:
26 | logits: [num_res, num_bins] output from the PredictedLDDTHead.
27 |
28 | Returns:
29 | plddt: [num_res] per-residue pLDDT.
30 | """
31 | num_bins = logits.shape[-1]
32 | bin_width = 1.0 / num_bins
33 | bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width)
34 | probs = scipy.special.softmax(logits, axis=-1)
35 | predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1)
36 | return predicted_lddt_ca * 100
37 |
38 |
39 | def _calculate_bin_centers(breaks: np.ndarray):
40 | """Gets the bin centers from the bin edges.
41 |
42 | Args:
43 | breaks: [num_bins - 1] the error bin edges.
44 |
45 | Returns:
46 | bin_centers: [num_bins] the error bin centers.
47 | """
48 | step = (breaks[1] - breaks[0])
49 |
50 | # Add half-step to get the center
51 | bin_centers = breaks + step / 2
52 | # Add a catch-all bin at the end.
53 | bin_centers = np.concatenate([bin_centers, [bin_centers[-1] + step]],
54 | axis=0)
55 | return bin_centers
56 |
57 |
58 | def _calculate_expected_aligned_error(
59 | alignment_confidence_breaks: np.ndarray,
60 | aligned_distance_error_probs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
61 | """Calculates expected aligned distance errors for every pair of residues.
62 |
63 | Args:
64 | alignment_confidence_breaks: [num_bins - 1] the error bin edges.
65 | aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted
66 | probs for each error bin, for each pair of residues.
67 |
68 | Returns:
69 | predicted_aligned_error: [num_res, num_res] the expected aligned distance
70 | error for each pair of residues.
71 | max_predicted_aligned_error: The maximum predicted error possible.
72 | """
73 | bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
74 |
75 | # Tuple of expected aligned distance error and max possible error.
76 | return (np.sum(aligned_distance_error_probs * bin_centers, axis=-1),
77 | np.asarray(bin_centers[-1]))
78 |
79 |
80 | def compute_predicted_aligned_error(
81 | logits: np.ndarray,
82 | breaks: np.ndarray) -> Dict[str, np.ndarray]:
83 | """Computes aligned confidence metrics from logits.
84 |
85 | Args:
86 | logits: [num_res, num_res, num_bins] the logits output from
87 | PredictedAlignedErrorHead.
88 | breaks: [num_bins - 1] the error bin edges.
89 |
90 | Returns:
91 | aligned_confidence_probs: [num_res, num_res, num_bins] the predicted
92 | aligned error probabilities over bins for each residue pair.
93 | predicted_aligned_error: [num_res, num_res] the expected aligned distance
94 | error for each pair of residues.
95 | max_predicted_aligned_error: The maximum predicted error possible.
96 | """
97 | aligned_confidence_probs = scipy.special.softmax(
98 | logits,
99 | axis=-1)
100 | predicted_aligned_error, max_predicted_aligned_error = (
101 | _calculate_expected_aligned_error(
102 | alignment_confidence_breaks=breaks,
103 | aligned_distance_error_probs=aligned_confidence_probs))
104 | return {
105 | 'aligned_confidence_probs': aligned_confidence_probs,
106 | 'predicted_aligned_error': predicted_aligned_error,
107 | 'max_predicted_aligned_error': max_predicted_aligned_error,
108 | }
109 |
110 |
111 | def predicted_tm_score(
112 | logits: np.ndarray,
113 | breaks: np.ndarray,
114 | residue_weights: Optional[np.ndarray] = None) -> np.ndarray:
115 | """Computes predicted TM alignment score.
116 |
117 | Args:
118 | logits: [num_res, num_res, num_bins] the logits output from
119 | PredictedAlignedErrorHead.
120 | breaks: [num_bins] the error bins.
121 | residue_weights: [num_res] the per residue weights to use for the
122 | expectation.
123 |
124 | Returns:
125 | ptm_score: the predicted TM alignment score.
126 | """
127 |
128 | # residue_weights has to be in [0, 1], but can be floating-point, i.e. the
129 | # exp. resolved head's probability.
130 | if residue_weights is None:
131 | residue_weights = np.ones(logits.shape[0])
132 |
133 | bin_centers = _calculate_bin_centers(breaks)
134 |
135 | num_res = np.sum(residue_weights)
136 | # Clip num_res to avoid negative/undefined d0.
137 | clipped_num_res = max(num_res, 19)
138 |
139 | # Compute d_0(num_res) as defined by TM-score, eqn. (5) in
140 | # http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
141 | # Yang & Skolnick "Scoring function for automated
142 | # assessment of protein structure template quality" 2004
143 | d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8
144 |
145 | # Convert logits to probs
146 | probs = scipy.special.softmax(logits, axis=-1)
147 |
148 | # TM-Score term for every bin
149 | tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0))
150 | # E_distances tm(distance)
151 | predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1)
152 |
153 | normed_residue_mask = residue_weights / (1e-8 + residue_weights.sum())
154 | per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1)
155 | return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()])
156 |
--------------------------------------------------------------------------------
/alphafold/common/protein.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Protein data type."""
16 | import io
17 | from typing import Any, Mapping, Optional
18 |
19 | from Bio.PDB import PDBParser
20 | import dataclasses
21 | import numpy as np
22 |
23 | from alphafold.common import residue_constants
24 |
25 | FeatureDict = Mapping[str, np.ndarray]
26 | ModelOutput = Mapping[str, Any] # Is a nested dict.
27 |
28 |
29 | @dataclasses.dataclass(frozen=True)
30 | class Protein:
31 | """Protein structure representation."""
32 |
33 | # Cartesian coordinates of atoms in angstroms. The atom types correspond to
34 | # residue_constants.atom_types, i.e. the first three are N, CA, CB.
35 | atom_positions: np.ndarray # [num_res, num_atom_type, 3]
36 |
37 | # Amino-acid type for each residue represented as an integer between 0 and
38 | # 20, where 20 is 'X'.
39 | aatype: np.ndarray # [num_res]
40 |
41 | # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
42 | # is present and 0.0 if not. This should be used for loss masking.
43 | atom_mask: np.ndarray # [num_res, num_atom_type]
44 |
45 | # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
46 | residue_index: np.ndarray # [num_res]
47 |
48 | # B-factors, or temperature factors, of each residue (in sq. angstroms units),
49 | # representing the displacement of the residue from its ground truth mean
50 | # value.
51 | b_factors: np.ndarray # [num_res, num_atom_type]
52 |
53 |
54 | def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
55 | """Takes a PDB string and constructs a Protein object.
56 |
57 | WARNING: All non-standard residue types will be converted into UNK. All
58 | non-standard atoms will be ignored.
59 |
60 | Args:
61 | pdb_str: The contents of the pdb file
62 | chain_id: If None, then the pdb file must contain a single chain (which
63 | will be parsed). If chain_id is specified (e.g. A), then only that chain
64 | is parsed.
65 |
66 | Returns:
67 | A new `Protein` parsed from the pdb contents.
68 | """
69 | pdb_fh = io.StringIO(pdb_str)
70 | parser = PDBParser(QUIET=True)
71 | structure = parser.get_structure('none', pdb_fh)
72 | models = list(structure.get_models())
73 | if len(models) != 1:
74 | raise ValueError(
75 | f'Only single model PDBs are supported. Found {len(models)} models.')
76 | model = models[0]
77 |
78 | if chain_id is not None:
79 | chain = model[chain_id]
80 | else:
81 | chains = list(model.get_chains())
82 | if len(chains) != 1:
83 | raise ValueError(
84 | 'Only single chain PDBs are supported when chain_id not specified. '
85 | f'Found {len(chains)} chains.')
86 | else:
87 | chain = chains[0]
88 |
89 | atom_positions = []
90 | aatype = []
91 | atom_mask = []
92 | residue_index = []
93 | b_factors = []
94 |
95 | for res in chain:
96 | if res.id[2] != ' ':
97 | raise ValueError(
98 | f'PDB contains an insertion code at chain {chain.id} and residue '
99 | f'index {res.id[1]}. These are not supported.')
100 | res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
101 | restype_idx = residue_constants.restype_order.get(
102 | res_shortname, residue_constants.restype_num)
103 | pos = np.zeros((residue_constants.atom_type_num, 3))
104 | mask = np.zeros((residue_constants.atom_type_num,))
105 | res_b_factors = np.zeros((residue_constants.atom_type_num,))
106 | for atom in res:
107 | if atom.name not in residue_constants.atom_types:
108 | continue
109 | pos[residue_constants.atom_order[atom.name]] = atom.coord
110 | mask[residue_constants.atom_order[atom.name]] = 1.
111 | res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
112 | if np.sum(mask) < 0.5:
113 | # If no known atom positions are reported for the residue then skip it.
114 | continue
115 | aatype.append(restype_idx)
116 | atom_positions.append(pos)
117 | atom_mask.append(mask)
118 | residue_index.append(res.id[1])
119 | b_factors.append(res_b_factors)
120 |
121 | return Protein(
122 | atom_positions=np.array(atom_positions),
123 | atom_mask=np.array(atom_mask),
124 | aatype=np.array(aatype),
125 | residue_index=np.array(residue_index),
126 | b_factors=np.array(b_factors))
127 |
128 |
129 | def to_pdb(prot: Protein) -> str:
130 | """Converts a `Protein` instance to a PDB string.
131 |
132 | Args:
133 | prot: The protein to convert to PDB.
134 |
135 | Returns:
136 | PDB string.
137 | """
138 | restypes = residue_constants.restypes + ['X']
139 | res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
140 | atom_types = residue_constants.atom_types
141 |
142 | pdb_lines = []
143 |
144 | atom_mask = prot.atom_mask
145 | aatype = prot.aatype
146 | atom_positions = prot.atom_positions
147 | residue_index = prot.residue_index.astype(np.int32)
148 | b_factors = prot.b_factors
149 |
150 | if np.any(aatype > residue_constants.restype_num):
151 | raise ValueError('Invalid aatypes.')
152 |
153 | pdb_lines.append('MODEL 1')
154 | atom_index = 1
155 | chain_id = 'A'
156 | # Add all atom sites.
157 | for i in range(aatype.shape[0]):
158 | res_name_3 = res_1to3(aatype[i])
159 | for atom_name, pos, mask, b_factor in zip(
160 | atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
161 | if mask < 0.5:
162 | continue
163 |
164 | record_type = 'ATOM'
165 | name = atom_name if len(atom_name) == 4 else f' {atom_name}'
166 | alt_loc = ''
167 | insertion_code = ''
168 | occupancy = 1.00
169 | element = atom_name[0] # Protein supports only C, N, O, S, this works.
170 | charge = ''
171 | # PDB is a columnar format, every space matters here!
172 | atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
173 | f'{res_name_3:>3} {chain_id:>1}'
174 | f'{residue_index[i]:>4}{insertion_code:>1} '
175 | f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
176 | f'{occupancy:>6.2f}{b_factor:>6.2f} '
177 | f'{element:>2}{charge:>2}')
178 | pdb_lines.append(atom_line)
179 | atom_index += 1
180 |
181 | # Close the chain.
182 | chain_end = 'TER'
183 | chain_termination_line = (
184 | f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} '
185 | f'{chain_id:>1}{residue_index[-1]:>4}')
186 | pdb_lines.append(chain_termination_line)
187 | pdb_lines.append('ENDMDL')
188 |
189 | pdb_lines.append('END')
190 | pdb_lines.append('')
191 | return '\n'.join(pdb_lines)
192 |
193 |
194 | def ideal_atom_mask(prot: Protein) -> np.ndarray:
195 | """Computes an ideal atom mask.
196 |
197 | `Protein.atom_mask` typically is defined according to the atoms that are
198 | reported in the PDB. This function computes a mask according to heavy atoms
199 | that should be present in the given seqence of amino acids.
200 |
201 | Args:
202 | prot: `Protein` whose fields are `numpy.ndarray` objects.
203 |
204 | Returns:
205 | An ideal atom mask.
206 | """
207 | return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
208 |
209 |
210 | def from_prediction(features: FeatureDict, result: ModelOutput,
211 | b_factors: Optional[np.ndarray] = None) -> Protein:
212 | """Assembles a protein from a prediction.
213 |
214 | Args:
215 | features: Dictionary holding model inputs.
216 | result: Dictionary holding model outputs.
217 | b_factors: (Optional) B-factors to use for the protein.
218 |
219 | Returns:
220 | A protein instance.
221 | """
222 | fold_output = result['structure_module']
223 | if b_factors is None:
224 | b_factors = np.zeros_like(fold_output['final_atom_mask'])
225 |
226 | return Protein(
227 | aatype=features['aatype'][0],
228 | atom_positions=fold_output['final_atom_positions'],
229 | atom_mask=fold_output['final_atom_mask'],
230 | residue_index=features['residue_index'][0] + 1,
231 | b_factors=b_factors)
232 |
--------------------------------------------------------------------------------
/alphafold/common/protein_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for protein."""
16 |
17 | import os
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | import numpy as np
22 |
23 | from alphafold.common import protein
24 | from alphafold.common import residue_constants
25 | # Internal import (7716).
26 |
27 | TEST_DATA_DIR = 'alphafold/common/testdata/'
28 |
29 |
30 | class ProteinTest(parameterized.TestCase):
31 |
32 | def _check_shapes(self, prot, num_res):
33 | """Check that the processed shapes are correct."""
34 | num_atoms = residue_constants.atom_type_num
35 | self.assertEqual((num_res, num_atoms, 3), prot.atom_positions.shape)
36 | self.assertEqual((num_res,), prot.aatype.shape)
37 | self.assertEqual((num_res, num_atoms), prot.atom_mask.shape)
38 | self.assertEqual((num_res,), prot.residue_index.shape)
39 | self.assertEqual((num_res, num_atoms), prot.b_factors.shape)
40 |
41 | @parameterized.parameters(('2rbg.pdb', 'A', 282),
42 | ('2rbg.pdb', 'B', 282))
43 | def test_from_pdb_str(self, pdb_file, chain_id, num_res):
44 | pdb_file = os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
45 | pdb_file)
46 | with open(pdb_file) as f:
47 | pdb_string = f.read()
48 | prot = protein.from_pdb_string(pdb_string, chain_id)
49 | self._check_shapes(prot, num_res)
50 | self.assertGreaterEqual(prot.aatype.min(), 0)
51 | # Allow equal since unknown restypes have index equal to restype_num.
52 | self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num)
53 |
54 | def test_to_pdb(self):
55 | with open(
56 | os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
57 | '2rbg.pdb')) as f:
58 | pdb_string = f.read()
59 | prot = protein.from_pdb_string(pdb_string, chain_id='A')
60 | pdb_string_reconstr = protein.to_pdb(prot)
61 | prot_reconstr = protein.from_pdb_string(pdb_string_reconstr)
62 |
63 | np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype)
64 | np.testing.assert_array_almost_equal(
65 | prot_reconstr.atom_positions, prot.atom_positions)
66 | np.testing.assert_array_almost_equal(
67 | prot_reconstr.atom_mask, prot.atom_mask)
68 | np.testing.assert_array_equal(
69 | prot_reconstr.residue_index, prot.residue_index)
70 | np.testing.assert_array_almost_equal(
71 | prot_reconstr.b_factors, prot.b_factors)
72 |
73 | def test_ideal_atom_mask(self):
74 | with open(
75 | os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
76 | '2rbg.pdb')) as f:
77 | pdb_string = f.read()
78 | prot = protein.from_pdb_string(pdb_string, chain_id='A')
79 | ideal_mask = protein.ideal_atom_mask(prot)
80 | non_ideal_residues = set([102] + list(range(127, 285)))
81 | for i, (res, atom_mask) in enumerate(
82 | zip(prot.residue_index, prot.atom_mask)):
83 | if res in non_ideal_residues:
84 | self.assertFalse(np.all(atom_mask == ideal_mask[i]), msg=f'{res}')
85 | else:
86 | self.assertTrue(np.all(atom_mask == ideal_mask[i]), msg=f'{res}')
87 |
88 |
89 | if __name__ == '__main__':
90 | absltest.main()
91 |
--------------------------------------------------------------------------------
/alphafold/common/residue_constants_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Test that residue_constants generates correct values."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import numpy as np
20 |
21 | from alphafold.common import residue_constants
22 |
23 |
24 | class ResidueConstantsTest(parameterized.TestCase):
25 |
26 | @parameterized.parameters(
27 | ('ALA', 0),
28 | ('CYS', 1),
29 | ('HIS', 2),
30 | ('MET', 3),
31 | ('LYS', 4),
32 | ('ARG', 4),
33 | )
34 | def testChiAnglesAtoms(self, residue_name, chi_num):
35 | chi_angles_atoms = residue_constants.chi_angles_atoms[residue_name]
36 | self.assertLen(chi_angles_atoms, chi_num)
37 | for chi_angle_atoms in chi_angles_atoms:
38 | self.assertLen(chi_angle_atoms, 4)
39 |
40 | def testChiGroupsForAtom(self):
41 | for k, chi_groups in residue_constants.chi_groups_for_atom.items():
42 | res_name, atom_name = k
43 | for chi_group_i, atom_i in chi_groups:
44 | self.assertEqual(
45 | atom_name,
46 | residue_constants.chi_angles_atoms[res_name][chi_group_i][atom_i])
47 |
48 | @parameterized.parameters(
49 | ('ALA', 5), ('ARG', 11), ('ASN', 8), ('ASP', 8), ('CYS', 6), ('GLN', 9),
50 | ('GLU', 9), ('GLY', 4), ('HIS', 10), ('ILE', 8), ('LEU', 8), ('LYS', 9),
51 | ('MET', 8), ('PHE', 11), ('PRO', 7), ('SER', 6), ('THR', 7), ('TRP', 14),
52 | ('TYR', 12), ('VAL', 7)
53 | )
54 | def testResidueAtoms(self, atom_name, num_residue_atoms):
55 | residue_atoms = residue_constants.residue_atoms[atom_name]
56 | self.assertLen(residue_atoms, num_residue_atoms)
57 |
58 | def testStandardAtomMask(self):
59 | with self.subTest('Check shape'):
60 | self.assertEqual(residue_constants.STANDARD_ATOM_MASK.shape, (21, 37,))
61 |
62 | with self.subTest('Check values'):
63 | str_to_row = lambda s: [c == '1' for c in s] # More clear/concise.
64 | np.testing.assert_array_equal(
65 | residue_constants.STANDARD_ATOM_MASK,
66 | np.array([
67 | # NB This was defined by c+p but looks sane.
68 | str_to_row('11111 '), # ALA
69 | str_to_row('111111 1 1 11 1 '), # ARG
70 | str_to_row('111111 11 '), # ASP
71 | str_to_row('111111 11 '), # ASN
72 | str_to_row('11111 1 '), # CYS
73 | str_to_row('111111 1 11 '), # GLU
74 | str_to_row('111111 1 11 '), # GLN
75 | str_to_row('111 1 '), # GLY
76 | str_to_row('111111 11 1 1 '), # HIS
77 | str_to_row('11111 11 1 '), # ILE
78 | str_to_row('111111 11 '), # LEU
79 | str_to_row('111111 1 1 1 '), # LYS
80 | str_to_row('111111 11 '), # MET
81 | str_to_row('111111 11 11 1 '), # PHE
82 | str_to_row('111111 1 '), # PRO
83 | str_to_row('11111 1 '), # SER
84 | str_to_row('11111 1 1 '), # THR
85 | str_to_row('111111 11 11 1 1 11 '), # TRP
86 | str_to_row('111111 11 11 11 '), # TYR
87 | str_to_row('11111 11 '), # VAL
88 | str_to_row(' '), # UNK
89 | ]))
90 |
91 | with self.subTest('Check row totals'):
92 | # Check each row has the right number of atoms.
93 | for row, restype in enumerate(residue_constants.restypes): # A, R, ...
94 | long_restype = residue_constants.restype_1to3[restype] # ALA, ARG, ...
95 | atoms_names = residue_constants.residue_atoms[
96 | long_restype] # ['C', 'CA', 'CB', 'N', 'O'], ...
97 | self.assertLen(atoms_names,
98 | residue_constants.STANDARD_ATOM_MASK[row, :].sum(),
99 | long_restype)
100 |
101 | def testAtomTypes(self):
102 | self.assertEqual(residue_constants.atom_type_num, 37)
103 |
104 | self.assertEqual(residue_constants.atom_types[0], 'N')
105 | self.assertEqual(residue_constants.atom_types[1], 'CA')
106 | self.assertEqual(residue_constants.atom_types[2], 'C')
107 | self.assertEqual(residue_constants.atom_types[3], 'CB')
108 | self.assertEqual(residue_constants.atom_types[4], 'O')
109 |
110 | self.assertEqual(residue_constants.atom_order['N'], 0)
111 | self.assertEqual(residue_constants.atom_order['CA'], 1)
112 | self.assertEqual(residue_constants.atom_order['C'], 2)
113 | self.assertEqual(residue_constants.atom_order['CB'], 3)
114 | self.assertEqual(residue_constants.atom_order['O'], 4)
115 | self.assertEqual(residue_constants.atom_type_num, 37)
116 |
117 | def testRestypes(self):
118 | three_letter_restypes = [
119 | residue_constants.restype_1to3[r] for r in residue_constants.restypes]
120 | for restype, exp_restype in zip(
121 | three_letter_restypes, sorted(residue_constants.restype_1to3.values())):
122 | self.assertEqual(restype, exp_restype)
123 | self.assertEqual(residue_constants.restype_num, 20)
124 |
125 | def testSequenceToOneHotHHBlits(self):
126 | one_hot = residue_constants.sequence_to_onehot(
127 | 'ABCDEFGHIJKLMNOPQRSTUVWXYZ-', residue_constants.HHBLITS_AA_TO_ID)
128 | exp_one_hot = np.array(
129 | [[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
130 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
131 | [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
132 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
133 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
134 | [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
135 | [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
136 | [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
137 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
138 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
139 | [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
140 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
141 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
142 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
143 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
144 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
145 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
146 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
147 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
148 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
149 | [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
150 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
151 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
152 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
153 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
154 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
155 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])
156 | np.testing.assert_array_equal(one_hot, exp_one_hot)
157 |
158 | def testSequenceToOneHotStandard(self):
159 | one_hot = residue_constants.sequence_to_onehot(
160 | 'ARNDCQEGHILKMFPSTWYV', residue_constants.restype_order)
161 | np.testing.assert_array_equal(one_hot, np.eye(20))
162 |
163 | def testSequenceToOneHotUnknownMapping(self):
164 | seq = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
165 | expected_out = np.zeros([26, 21])
166 | for row, position in enumerate(
167 | [0, 20, 4, 3, 6, 13, 7, 8, 9, 20, 11, 10, 12, 2, 20, 14, 5, 1, 15, 16,
168 | 20, 19, 17, 20, 18, 20]):
169 | expected_out[row, position] = 1
170 | aa_types = residue_constants.sequence_to_onehot(
171 | sequence=seq,
172 | mapping=residue_constants.restype_order_with_x,
173 | map_unknown_to_x=True)
174 | self.assertTrue((aa_types == expected_out).all())
175 |
176 | @parameterized.named_parameters(
177 | ('lowercase', 'aaa'), # Insertions in A3M.
178 | ('gaps', '---'), # Gaps in A3M.
179 | ('dots', '...'), # Gaps in A3M.
180 | ('metadata', '>TEST'), # FASTA metadata line.
181 | )
182 | def testSequenceToOneHotUnknownMappingError(self, seq):
183 | with self.assertRaises(ValueError):
184 | residue_constants.sequence_to_onehot(
185 | sequence=seq,
186 | mapping=residue_constants.restype_order_with_x,
187 | map_unknown_to_x=True)
188 |
189 |
190 | if __name__ == '__main__':
191 | absltest.main()
192 |
--------------------------------------------------------------------------------
/alphafold/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Data pipeline for model features."""
15 |
--------------------------------------------------------------------------------
/alphafold/data/pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Functions for building the input features for the AlphaFold model."""
16 |
17 | import os
18 | from typing import Mapping, Optional, Sequence
19 |
20 | import numpy as np
21 |
22 | # Internal import (7716).
23 |
24 | from alphafold.common import residue_constants
25 | from alphafold.data import parsers
26 | from alphafold.data import templates
27 | from alphafold.data.tools import hhblits
28 | from alphafold.data.tools import hhsearch
29 | from alphafold.data.tools import jackhmmer
30 |
31 | FeatureDict = Mapping[str, np.ndarray]
32 |
33 |
34 | def make_sequence_features(
35 | sequence: str, description: str, num_res: int) -> FeatureDict:
36 | """Constructs a feature dict of sequence features."""
37 | features = {}
38 | features['aatype'] = residue_constants.sequence_to_onehot(
39 | sequence=sequence,
40 | mapping=residue_constants.restype_order_with_x,
41 | map_unknown_to_x=True)
42 | features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32)
43 | features['domain_name'] = np.array([description.encode('utf-8')],
44 | dtype=np.object_)
45 | features['residue_index'] = np.array(range(num_res), dtype=np.int32)
46 | features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
47 | features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_)
48 | return features
49 |
50 |
51 | def make_msa_features(
52 | msas: Sequence[Sequence[str]],
53 | deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict:
54 | """Constructs a feature dict of MSA features."""
55 | if not msas:
56 | raise ValueError('At least one MSA must be provided.')
57 |
58 | int_msa = []
59 | deletion_matrix = []
60 | seen_sequences = set()
61 | for msa_index, msa in enumerate(msas):
62 | if not msa:
63 | raise ValueError(f'MSA {msa_index} must contain at least one sequence.')
64 | for sequence_index, sequence in enumerate(msa):
65 | if sequence in seen_sequences:
66 | continue
67 | seen_sequences.add(sequence)
68 | int_msa.append(
69 | [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
70 | deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
71 |
72 | num_res = len(msas[0][0])
73 | num_alignments = len(int_msa)
74 | features = {}
75 | features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
76 | features['msa'] = np.array(int_msa, dtype=np.int32)
77 | features['num_alignments'] = np.array(
78 | [num_alignments] * num_res, dtype=np.int32)
79 | return features
80 |
81 |
82 | class DataPipeline:
83 | """Runs the alignment tools and assembles the input features."""
84 |
85 | def __init__(self,
86 | jackhmmer_binary_path: str,
87 | hhblits_binary_path: str,
88 | hhsearch_binary_path: str,
89 | uniref90_database_path: str,
90 | mgnify_database_path: str,
91 | bfd_database_path: Optional[str],
92 | uniclust30_database_path: Optional[str],
93 | small_bfd_database_path: Optional[str],
94 | pdb70_database_path: str,
95 | template_featurizer: templates.TemplateHitFeaturizer,
96 | use_small_bfd: bool,
97 | mgnify_max_hits: int = 501,
98 | uniref_max_hits: int = 10000):
99 | """Constructs a feature dict for a given FASTA file."""
100 | self._use_small_bfd = use_small_bfd
101 | self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
102 | binary_path=jackhmmer_binary_path,
103 | database_path=uniref90_database_path)
104 | if use_small_bfd:
105 | self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
106 | binary_path=jackhmmer_binary_path,
107 | database_path=small_bfd_database_path)
108 | else:
109 | self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
110 | binary_path=hhblits_binary_path,
111 | databases=[bfd_database_path, uniclust30_database_path])
112 | self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
113 | binary_path=jackhmmer_binary_path,
114 | database_path=mgnify_database_path)
115 | self.hhsearch_pdb70_runner = hhsearch.HHSearch(
116 | binary_path=hhsearch_binary_path,
117 | databases=[pdb70_database_path])
118 | self.template_featurizer = template_featurizer
119 | self.mgnify_max_hits = mgnify_max_hits
120 | self.uniref_max_hits = uniref_max_hits
121 |
122 | def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
123 | """Runs alignment tools on the input sequence and creates features."""
124 | with open(input_fasta_path) as f:
125 | input_fasta_str = f.read()
126 | input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
127 | if len(input_seqs) != 1:
128 | raise ValueError(
129 | f'More than one input sequence found in {input_fasta_path}.')
130 | input_sequence = input_seqs[0]
131 | input_description = input_descs[0]
132 | num_res = len(input_sequence)
133 |
134 | jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
135 | input_fasta_path)[0]
136 | jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
137 | input_fasta_path)[0]
138 |
139 | uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
140 | jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits)
141 | hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
142 |
143 | uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
144 | with open(uniref90_out_path, 'w') as f:
145 | f.write(jackhmmer_uniref90_result['sto'])
146 |
147 | mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
148 | with open(mgnify_out_path, 'w') as f:
149 | f.write(jackhmmer_mgnify_result['sto'])
150 |
151 | uniref90_msa, uniref90_deletion_matrix, _ = parsers.parse_stockholm(
152 | jackhmmer_uniref90_result['sto'])
153 | mgnify_msa, mgnify_deletion_matrix, _ = parsers.parse_stockholm(
154 | jackhmmer_mgnify_result['sto'])
155 | hhsearch_hits = parsers.parse_hhr(hhsearch_result)
156 | mgnify_msa = mgnify_msa[:self.mgnify_max_hits]
157 | mgnify_deletion_matrix = mgnify_deletion_matrix[:self.mgnify_max_hits]
158 |
159 | if self._use_small_bfd:
160 | jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
161 | input_fasta_path)[0]
162 |
163 | bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.a3m')
164 | with open(bfd_out_path, 'w') as f:
165 | f.write(jackhmmer_small_bfd_result['sto'])
166 |
167 | bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
168 | jackhmmer_small_bfd_result['sto'])
169 | else:
170 | hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(
171 | input_fasta_path)
172 |
173 | bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
174 | with open(bfd_out_path, 'w') as f:
175 | f.write(hhblits_bfd_uniclust_result['a3m'])
176 |
177 | bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(
178 | hhblits_bfd_uniclust_result['a3m'])
179 |
180 | templates_result = self.template_featurizer.get_templates(
181 | query_sequence=input_sequence,
182 | query_pdb_code=None,
183 | query_release_date=None,
184 | hits=hhsearch_hits)
185 |
186 | sequence_features = make_sequence_features(
187 | sequence=input_sequence,
188 | description=input_description,
189 | num_res=num_res)
190 |
191 | msa_features = make_msa_features(
192 | msas=(uniref90_msa, bfd_msa, mgnify_msa),
193 | deletion_matrices=(uniref90_deletion_matrix,
194 | bfd_deletion_matrix,
195 | mgnify_deletion_matrix))
196 |
197 | return {**sequence_features, **msa_features, **templates_result.features}
198 |
--------------------------------------------------------------------------------
/alphafold/data/tools/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Python wrappers for third party tools."""
15 |
--------------------------------------------------------------------------------
/alphafold/data/tools/hhblits.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library to run HHblits from Python."""
16 |
17 | import glob
18 | import os
19 | import subprocess
20 | from typing import Any, Mapping, Optional, Sequence
21 |
22 | from absl import logging
23 | from alphafold.data.tools import utils
24 | # Internal import (7716).
25 |
26 |
27 | _HHBLITS_DEFAULT_P = 20
28 | _HHBLITS_DEFAULT_Z = 500
29 |
30 |
31 | class HHBlits:
32 | """Python wrapper of the HHblits binary."""
33 |
34 | def __init__(self,
35 | *,
36 | binary_path: str,
37 | databases: Sequence[str],
38 | n_cpu: int = 4,
39 | n_iter: int = 3,
40 | e_value: float = 0.001,
41 | maxseq: int = 1_000_000,
42 | realign_max: int = 100_000,
43 | maxfilt: int = 100_000,
44 | min_prefilter_hits: int = 1000,
45 | all_seqs: bool = False,
46 | alt: Optional[int] = None,
47 | p: int = _HHBLITS_DEFAULT_P,
48 | z: int = _HHBLITS_DEFAULT_Z):
49 | """Initializes the Python HHblits wrapper.
50 |
51 | Args:
52 | binary_path: The path to the HHblits executable.
53 | databases: A sequence of HHblits database paths. This should be the
54 | common prefix for the database files (i.e. up to but not including
55 | _hhm.ffindex etc.)
56 | n_cpu: The number of CPUs to give HHblits.
57 | n_iter: The number of HHblits iterations.
58 | e_value: The E-value, see HHblits docs for more details.
59 | maxseq: The maximum number of rows in an input alignment. Note that this
60 | parameter is only supported in HHBlits version 3.1 and higher.
61 | realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
62 | maxfilt: Max number of hits allowed to pass the 2nd prefilter.
63 | HHblits default: 20000.
64 | min_prefilter_hits: Min number of hits to pass prefilter.
65 | HHblits default: 100.
66 | all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
67 | HHblits default: False.
68 | alt: Show up to this many alternative alignments.
69 | p: Minimum Prob for a hit to be included in the output hhr file.
70 | HHblits default: 20.
71 | z: Hard cap on number of hits reported in the hhr file.
72 | HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
73 |
74 | Raises:
75 | RuntimeError: If HHblits binary not found within the path.
76 | """
77 | self.binary_path = binary_path
78 | self.databases = databases
79 |
80 | for database_path in self.databases:
81 | if not glob.glob(database_path + '_*'):
82 | logging.error('Could not find HHBlits database %s', database_path)
83 | raise ValueError(f'Could not find HHBlits database {database_path}')
84 |
85 | self.n_cpu = n_cpu
86 | self.n_iter = n_iter
87 | self.e_value = e_value
88 | self.maxseq = maxseq
89 | self.realign_max = realign_max
90 | self.maxfilt = maxfilt
91 | self.min_prefilter_hits = min_prefilter_hits
92 | self.all_seqs = all_seqs
93 | self.alt = alt
94 | self.p = p
95 | self.z = z
96 |
97 | def query(self, input_fasta_path: str) -> Mapping[str, Any]:
98 | """Queries the database using HHblits."""
99 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
100 | a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
101 |
102 | db_cmd = []
103 | for db_path in self.databases:
104 | db_cmd.append('-d')
105 | db_cmd.append(db_path)
106 | cmd = [
107 | self.binary_path,
108 | '-i', input_fasta_path,
109 | '-cpu', str(self.n_cpu),
110 | '-oa3m', a3m_path,
111 | '-o', '/dev/null',
112 | '-n', str(self.n_iter),
113 | '-e', str(self.e_value),
114 | '-maxseq', str(self.maxseq),
115 | '-realign_max', str(self.realign_max),
116 | '-maxfilt', str(self.maxfilt),
117 | '-min_prefilter_hits', str(self.min_prefilter_hits)]
118 | if self.all_seqs:
119 | cmd += ['-all']
120 | if self.alt:
121 | cmd += ['-alt', str(self.alt)]
122 | if self.p != _HHBLITS_DEFAULT_P:
123 | cmd += ['-p', str(self.p)]
124 | if self.z != _HHBLITS_DEFAULT_Z:
125 | cmd += ['-Z', str(self.z)]
126 | cmd += db_cmd
127 |
128 | logging.info('Launching subprocess "%s"', ' '.join(cmd))
129 | process = subprocess.Popen(
130 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
131 |
132 | with utils.timing('HHblits query'):
133 | stdout, stderr = process.communicate()
134 | retcode = process.wait()
135 |
136 | if retcode:
137 | # Logs have a 15k character limit, so log HHblits error line by line.
138 | logging.error('HHblits failed. HHblits stderr begin:')
139 | for error_line in stderr.decode('utf-8').splitlines():
140 | if error_line.strip():
141 | logging.error(error_line.strip())
142 | logging.error('HHblits stderr end')
143 | raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % (
144 | stdout.decode('utf-8'), stderr[:500_000].decode('utf-8')))
145 |
146 | with open(a3m_path) as f:
147 | a3m = f.read()
148 |
149 | raw_output = dict(
150 | a3m=a3m,
151 | output=stdout,
152 | stderr=stderr,
153 | n_iter=self.n_iter,
154 | e_value=self.e_value)
155 | return raw_output
156 |
--------------------------------------------------------------------------------
/alphafold/data/tools/hhsearch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library to run HHsearch from Python."""
16 |
17 | import glob
18 | import os
19 | import subprocess
20 | from typing import Sequence
21 |
22 | from absl import logging
23 |
24 | from alphafold.data.tools import utils
25 | # Internal import (7716).
26 |
27 |
28 | class HHSearch:
29 | """Python wrapper of the HHsearch binary."""
30 |
31 | def __init__(self,
32 | *,
33 | binary_path: str,
34 | databases: Sequence[str],
35 | maxseq: int = 1_000_000):
36 | """Initializes the Python HHsearch wrapper.
37 |
38 | Args:
39 | binary_path: The path to the HHsearch executable.
40 | databases: A sequence of HHsearch database paths. This should be the
41 | common prefix for the database files (i.e. up to but not including
42 | _hhm.ffindex etc.)
43 | maxseq: The maximum number of rows in an input alignment. Note that this
44 | parameter is only supported in HHBlits version 3.1 and higher.
45 |
46 | Raises:
47 | RuntimeError: If HHsearch binary not found within the path.
48 | """
49 | self.binary_path = binary_path
50 | self.databases = databases
51 | self.maxseq = maxseq
52 |
53 | for database_path in self.databases:
54 | if not glob.glob(database_path + '_*'):
55 | logging.error('Could not find HHsearch database %s', database_path)
56 | raise ValueError(f'Could not find HHsearch database {database_path}')
57 |
58 | def query(self, a3m: str) -> str:
59 | """Queries the database using HHsearch using a given a3m."""
60 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
61 | input_path = os.path.join(query_tmp_dir, 'query.a3m')
62 | hhr_path = os.path.join(query_tmp_dir, 'output.hhr')
63 | with open(input_path, 'w') as f:
64 | f.write(a3m)
65 |
66 | db_cmd = []
67 | for db_path in self.databases:
68 | db_cmd.append('-d')
69 | db_cmd.append(db_path)
70 | cmd = [self.binary_path,
71 | '-i', input_path,
72 | '-o', hhr_path,
73 | '-maxseq', str(self.maxseq)
74 | ] + db_cmd
75 |
76 | logging.info('Launching subprocess "%s"', ' '.join(cmd))
77 | process = subprocess.Popen(
78 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
79 | with utils.timing('HHsearch query'):
80 | stdout, stderr = process.communicate()
81 | retcode = process.wait()
82 |
83 | if retcode:
84 | # Stderr is truncated to prevent proto size errors in Beam.
85 | raise RuntimeError(
86 | 'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
87 | stdout.decode('utf-8'), stderr[:100_000].decode('utf-8')))
88 |
89 | with open(hhr_path) as f:
90 | hhr = f.read()
91 | return hhr
92 |
--------------------------------------------------------------------------------
/alphafold/data/tools/hmmbuild.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
16 |
17 | import os
18 | import re
19 | import subprocess
20 |
21 | from absl import logging
22 |
23 | # Internal import (7716).
24 |
25 | from alphafold.data.tools import utils
26 |
27 |
28 | class Hmmbuild(object):
29 | """Python wrapper of the hmmbuild binary."""
30 |
31 | def __init__(self,
32 | *,
33 | binary_path: str,
34 | singlemx: bool = False):
35 | """Initializes the Python hmmbuild wrapper.
36 |
37 | Args:
38 | binary_path: The path to the hmmbuild executable.
39 | singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
40 | just use a common substitution score matrix.
41 |
42 | Raises:
43 | RuntimeError: If hmmbuild binary not found within the path.
44 | """
45 | self.binary_path = binary_path
46 | self.singlemx = singlemx
47 |
48 | def build_profile_from_sto(self, sto: str, model_construction='fast') -> str:
49 | """Builds a HHM for the aligned sequences given as an A3M string.
50 |
51 | Args:
52 | sto: A string with the aligned sequences in the Stockholm format.
53 | model_construction: Whether to use reference annotation in the msa to
54 | determine consensus columns ('hand') or default ('fast').
55 |
56 | Returns:
57 | A string with the profile in the HMM format.
58 |
59 | Raises:
60 | RuntimeError: If hmmbuild fails.
61 | """
62 | return self._build_profile(sto, model_construction=model_construction)
63 |
64 | def build_profile_from_a3m(self, a3m: str) -> str:
65 | """Builds a HHM for the aligned sequences given as an A3M string.
66 |
67 | Args:
68 | a3m: A string with the aligned sequences in the A3M format.
69 |
70 | Returns:
71 | A string with the profile in the HMM format.
72 |
73 | Raises:
74 | RuntimeError: If hmmbuild fails.
75 | """
76 | lines = []
77 | for line in a3m.splitlines():
78 | if not line.startswith('>'):
79 | line = re.sub('[a-z]+', '', line) # Remove inserted residues.
80 | lines.append(line + '\n')
81 | msa = ''.join(lines)
82 | return self._build_profile(msa, model_construction='fast')
83 |
84 | def _build_profile(self, msa: str, model_construction: str = 'fast') -> str:
85 | """Builds a HMM for the aligned sequences given as an MSA string.
86 |
87 | Args:
88 | msa: A string with the aligned sequences, in A3M or STO format.
89 | model_construction: Whether to use reference annotation in the msa to
90 | determine consensus columns ('hand') or default ('fast').
91 |
92 | Returns:
93 | A string with the profile in the HMM format.
94 |
95 | Raises:
96 | RuntimeError: If hmmbuild fails.
97 | ValueError: If unspecified arguments are provided.
98 | """
99 | if model_construction not in {'hand', 'fast'}:
100 | raise ValueError(f'Invalid model_construction {model_construction} - only'
101 | 'hand and fast supported.')
102 |
103 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
104 | input_query = os.path.join(query_tmp_dir, 'query.msa')
105 | output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm')
106 |
107 | with open(input_query, 'w') as f:
108 | f.write(msa)
109 |
110 | cmd = [self.binary_path]
111 | # If adding flags, we have to do so before the output and input:
112 |
113 | if model_construction == 'hand':
114 | cmd.append(f'--{model_construction}')
115 | if self.singlemx:
116 | cmd.append('--singlemx')
117 | cmd.extend([
118 | '--amino',
119 | output_hmm_path,
120 | input_query,
121 | ])
122 |
123 | logging.info('Launching subprocess %s', cmd)
124 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
125 | stderr=subprocess.PIPE)
126 |
127 | with utils.timing('hmmbuild query'):
128 | stdout, stderr = process.communicate()
129 | retcode = process.wait()
130 | logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n',
131 | stdout.decode('utf-8'), stderr.decode('utf-8'))
132 |
133 | if retcode:
134 | raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n'
135 | % (stdout.decode('utf-8'), stderr.decode('utf-8')))
136 |
137 | with open(output_hmm_path, encoding='utf-8') as f:
138 | hmm = f.read()
139 |
140 | return hmm
141 |
--------------------------------------------------------------------------------
/alphafold/data/tools/hmmsearch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A Python wrapper for hmmsearch - search profile against a sequence db."""
16 |
17 | import os
18 | import subprocess
19 | from typing import Optional, Sequence
20 |
21 | from absl import logging
22 |
23 | # Internal import (7716).
24 |
25 | from alphafold.data.tools import utils
26 |
27 |
28 | class Hmmsearch(object):
29 | """Python wrapper of the hmmsearch binary."""
30 |
31 | def __init__(self,
32 | *,
33 | binary_path: str,
34 | database_path: str,
35 | flags: Optional[Sequence[str]] = None):
36 | """Initializes the Python hmmsearch wrapper.
37 |
38 | Args:
39 | binary_path: The path to the hmmsearch executable.
40 | database_path: The path to the hmmsearch database (FASTA format).
41 | flags: List of flags to be used by hmmsearch.
42 |
43 | Raises:
44 | RuntimeError: If hmmsearch binary not found within the path.
45 | """
46 | self.binary_path = binary_path
47 | self.database_path = database_path
48 | self.flags = flags
49 |
50 | if not os.path.exists(self.database_path):
51 | logging.error('Could not find hmmsearch database %s', database_path)
52 | raise ValueError(f'Could not find hmmsearch database {database_path}')
53 |
54 | def query(self, hmm: str) -> str:
55 | """Queries the database using hmmsearch using a given hmm."""
56 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
57 | hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
58 | a3m_out_path = os.path.join(query_tmp_dir, 'output.a3m')
59 | with open(hmm_input_path, 'w') as f:
60 | f.write(hmm)
61 |
62 | cmd = [
63 | self.binary_path,
64 | '--noali', # Don't include the alignment in stdout.
65 | '--cpu', '8'
66 | ]
67 | # If adding flags, we have to do so before the output and input:
68 | if self.flags:
69 | cmd.extend(self.flags)
70 | cmd.extend([
71 | '-A', a3m_out_path,
72 | hmm_input_path,
73 | self.database_path,
74 | ])
75 |
76 | logging.info('Launching sub-process %s', cmd)
77 | process = subprocess.Popen(
78 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
79 | with utils.timing(
80 | f'hmmsearch ({os.path.basename(self.database_path)}) query'):
81 | stdout, stderr = process.communicate()
82 | retcode = process.wait()
83 |
84 | if retcode:
85 | raise RuntimeError(
86 | 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
87 | stdout.decode('utf-8'), stderr.decode('utf-8')))
88 |
89 | with open(a3m_out_path) as f:
90 | a3m_out = f.read()
91 |
92 | return a3m_out
93 |
--------------------------------------------------------------------------------
/alphafold/data/tools/jackhmmer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library to run Jackhmmer from Python."""
16 |
17 | from concurrent import futures
18 | import glob
19 | import os
20 | import subprocess
21 | from typing import Any, Callable, Mapping, Optional, Sequence
22 | from urllib import request
23 |
24 | from absl import logging
25 |
26 | from alphafold.data.tools import utils
27 | # Internal import (7716).
28 |
29 |
30 | class Jackhmmer:
31 | """Python wrapper of the Jackhmmer binary."""
32 |
33 | def __init__(self,
34 | *,
35 | binary_path: str,
36 | database_path: str,
37 | n_cpu: int = 8,
38 | n_iter: int = 1,
39 | e_value: float = 0.0001,
40 | z_value: Optional[int] = None,
41 | get_tblout: bool = False,
42 | filter_f1: float = 0.0005,
43 | filter_f2: float = 0.00005,
44 | filter_f3: float = 0.0000005,
45 | incdom_e: Optional[float] = None,
46 | dom_e: Optional[float] = None,
47 | num_streamed_chunks: Optional[int] = None,
48 | streaming_callback: Optional[Callable[[int], None]] = None):
49 | """Initializes the Python Jackhmmer wrapper.
50 |
51 | Args:
52 | binary_path: The path to the jackhmmer executable.
53 | database_path: The path to the jackhmmer database (FASTA format).
54 | n_cpu: The number of CPUs to give Jackhmmer.
55 | n_iter: The number of Jackhmmer iterations.
56 | e_value: The E-value, see Jackhmmer docs for more details.
57 | z_value: The Z-value, see Jackhmmer docs for more details.
58 | get_tblout: Whether to save tblout string.
59 | filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
60 | filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
61 | filter_f3: Forward pre-filter, set to >1.0 to turn off.
62 | incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
63 | round.
64 | dom_e: Domain e-value criteria for inclusion in tblout.
65 | num_streamed_chunks: Number of database chunks to stream over.
66 | streaming_callback: Callback function run after each chunk iteration with
67 | the iteration number as argument.
68 | """
69 | self.binary_path = binary_path
70 | self.database_path = database_path
71 | self.num_streamed_chunks = num_streamed_chunks
72 |
73 | if not os.path.exists(self.database_path) and num_streamed_chunks is None:
74 | logging.error('Could not find Jackhmmer database %s', database_path)
75 | raise ValueError(f'Could not find Jackhmmer database {database_path}')
76 |
77 | self.n_cpu = n_cpu
78 | self.n_iter = n_iter
79 | self.e_value = e_value
80 | self.z_value = z_value
81 | self.filter_f1 = filter_f1
82 | self.filter_f2 = filter_f2
83 | self.filter_f3 = filter_f3
84 | self.incdom_e = incdom_e
85 | self.dom_e = dom_e
86 | self.get_tblout = get_tblout
87 | self.streaming_callback = streaming_callback
88 |
89 | def _query_chunk(self, input_fasta_path: str, database_path: str
90 | ) -> Mapping[str, Any]:
91 | """Queries the database chunk using Jackhmmer."""
92 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
93 | sto_path = os.path.join(query_tmp_dir, 'output.sto')
94 |
95 | # The F1/F2/F3 are the expected proportion to pass each of the filtering
96 | # stages (which get progressively more expensive), reducing these
97 | # speeds up the pipeline at the expensive of sensitivity. They are
98 | # currently set very low to make querying Mgnify run in a reasonable
99 | # amount of time.
100 | cmd_flags = [
101 | # Don't pollute stdout with Jackhmmer output.
102 | '-o', '/dev/null',
103 | '-A', sto_path,
104 | '--noali',
105 | '--F1', str(self.filter_f1),
106 | '--F2', str(self.filter_f2),
107 | '--F3', str(self.filter_f3),
108 | '--incE', str(self.e_value),
109 | # Report only sequences with E-values <= x in per-sequence output.
110 | '-E', str(self.e_value),
111 | '--cpu', str(self.n_cpu),
112 | '-N', str(self.n_iter)
113 | ]
114 | if self.get_tblout:
115 | tblout_path = os.path.join(query_tmp_dir, 'tblout.txt')
116 | cmd_flags.extend(['--tblout', tblout_path])
117 |
118 | if self.z_value:
119 | cmd_flags.extend(['-Z', str(self.z_value)])
120 |
121 | if self.dom_e is not None:
122 | cmd_flags.extend(['--domE', str(self.dom_e)])
123 |
124 | if self.incdom_e is not None:
125 | cmd_flags.extend(['--incdomE', str(self.incdom_e)])
126 |
127 | cmd = [self.binary_path] + cmd_flags + [input_fasta_path,
128 | database_path]
129 |
130 | logging.info('Launching subprocess "%s"', ' '.join(cmd))
131 | process = subprocess.Popen(
132 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
133 | with utils.timing(
134 | f'Jackhmmer ({os.path.basename(database_path)}) query'):
135 | _, stderr = process.communicate()
136 | retcode = process.wait()
137 |
138 | if retcode:
139 | raise RuntimeError(
140 | 'Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8'))
141 |
142 | # Get e-values for each target name
143 | tbl = ''
144 | if self.get_tblout:
145 | with open(tblout_path) as f:
146 | tbl = f.read()
147 |
148 | with open(sto_path) as f:
149 | sto = f.read()
150 |
151 | raw_output = dict(
152 | sto=sto,
153 | tbl=tbl,
154 | stderr=stderr,
155 | n_iter=self.n_iter,
156 | e_value=self.e_value)
157 |
158 | return raw_output
159 |
160 | def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
161 | """Queries the database using Jackhmmer."""
162 | if self.num_streamed_chunks is None:
163 | return [self._query_chunk(input_fasta_path, self.database_path)]
164 |
165 | db_basename = os.path.basename(self.database_path)
166 | db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}'
167 | db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}'
168 |
169 | # Remove existing files to prevent OOM
170 | for f in glob.glob(db_local_chunk('[0-9]*')):
171 | try:
172 | os.remove(f)
173 | except OSError:
174 | print(f'OSError while deleting {f}')
175 |
176 | # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
177 | with futures.ThreadPoolExecutor(max_workers=2) as executor:
178 | chunked_output = []
179 | for i in range(1, self.num_streamed_chunks + 1):
180 | # Copy the chunk locally
181 | if i == 1:
182 | future = executor.submit(
183 | request.urlretrieve, db_remote_chunk(i), db_local_chunk(i))
184 | if i < self.num_streamed_chunks:
185 | next_future = executor.submit(
186 | request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1))
187 |
188 | # Run Jackhmmer with the chunk
189 | future.result()
190 | chunked_output.append(
191 | self._query_chunk(input_fasta_path, db_local_chunk(i)))
192 |
193 | # Remove the local copy of the chunk
194 | os.remove(db_local_chunk(i))
195 | future = next_future
196 | if self.streaming_callback:
197 | self.streaming_callback(i)
198 | return chunked_output
199 |
--------------------------------------------------------------------------------
/alphafold/data/tools/kalign.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A Python wrapper for Kalign."""
16 | import os
17 | import subprocess
18 | from typing import Sequence
19 |
20 | from absl import logging
21 |
22 | from alphafold.data.tools import utils
23 | # Internal import (7716).
24 |
25 |
26 | def _to_a3m(sequences: Sequence[str]) -> str:
27 | """Converts sequences to an a3m file."""
28 | names = ['sequence %d' % i for i in range(1, len(sequences) + 1)]
29 | a3m = []
30 | for sequence, name in zip(sequences, names):
31 | a3m.append(u'>' + name + u'\n')
32 | a3m.append(sequence + u'\n')
33 | return ''.join(a3m)
34 |
35 |
36 | class Kalign:
37 | """Python wrapper of the Kalign binary."""
38 |
39 | def __init__(self, *, binary_path: str):
40 | """Initializes the Python Kalign wrapper.
41 |
42 | Args:
43 | binary_path: The path to the Kalign binary.
44 |
45 | Raises:
46 | RuntimeError: If Kalign binary not found within the path.
47 | """
48 | self.binary_path = binary_path
49 |
50 | def align(self, sequences: Sequence[str]) -> str:
51 | """Aligns the sequences and returns the alignment in A3M string.
52 |
53 | Args:
54 | sequences: A list of query sequence strings. The sequences have to be at
55 | least 6 residues long (Kalign requires this). Note that the order in
56 | which you give the sequences might alter the output slightly as
57 | different alignment tree might get constructed.
58 |
59 | Returns:
60 | A string with the alignment in a3m format.
61 |
62 | Raises:
63 | RuntimeError: If Kalign fails.
64 | ValueError: If any of the sequences is less than 6 residues long.
65 | """
66 | logging.info('Aligning %d sequences', len(sequences))
67 |
68 | for s in sequences:
69 | if len(s) < 6:
70 | raise ValueError('Kalign requires all sequences to be at least 6 '
71 | 'residues long. Got %s (%d residues).' % (s, len(s)))
72 |
73 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
74 | input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta')
75 | output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
76 |
77 | with open(input_fasta_path, 'w') as f:
78 | f.write(_to_a3m(sequences))
79 |
80 | cmd = [
81 | self.binary_path,
82 | '-i', input_fasta_path,
83 | '-o', output_a3m_path,
84 | '-format', 'fasta',
85 | ]
86 |
87 | logging.info('Launching subprocess "%s"', ' '.join(cmd))
88 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
89 | stderr=subprocess.PIPE)
90 |
91 | with utils.timing('Kalign query'):
92 | stdout, stderr = process.communicate()
93 | retcode = process.wait()
94 | logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n',
95 | stdout.decode('utf-8'), stderr.decode('utf-8'))
96 |
97 | if retcode:
98 | raise RuntimeError('Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n'
99 | % (stdout.decode('utf-8'), stderr.decode('utf-8')))
100 |
101 | with open(output_a3m_path) as f:
102 | a3m = f.read()
103 |
104 | return a3m
105 |
--------------------------------------------------------------------------------
/alphafold/data/tools/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Common utilities for data pipeline tools."""
15 | import contextlib
16 | import shutil
17 | import tempfile
18 | import time
19 | from typing import Optional
20 |
21 | from absl import logging
22 |
23 |
24 | @contextlib.contextmanager
25 | def tmpdir_manager(base_dir: Optional[str] = None):
26 | """Context manager that deletes a temporary directory on exit."""
27 | tmpdir = tempfile.mkdtemp(dir=base_dir)
28 | try:
29 | yield tmpdir
30 | finally:
31 | shutil.rmtree(tmpdir, ignore_errors=True)
32 |
33 |
34 | @contextlib.contextmanager
35 | def timing(msg: str):
36 | logging.info('Started %s', msg)
37 | tic = time.time()
38 | yield
39 | toc = time.time()
40 | logging.info('Finished %s in %.3f seconds', msg, toc - tic)
41 |
--------------------------------------------------------------------------------
/alphafold/model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Alphafold model."""
15 |
--------------------------------------------------------------------------------
/alphafold/model/all_atom_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for all_atom."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import numpy as np
20 | from alphafold.model import all_atom
21 | from alphafold.model import r3
22 |
23 | L1_CLAMP_DISTANCE = 10
24 |
25 |
26 | def get_identity_rigid(shape):
27 | """Returns identity rigid transform."""
28 |
29 | ones = np.ones(shape)
30 | zeros = np.zeros(shape)
31 | rot = r3.Rots(ones, zeros, zeros,
32 | zeros, ones, zeros,
33 | zeros, zeros, ones)
34 | trans = r3.Vecs(zeros, zeros, zeros)
35 | return r3.Rigids(rot, trans)
36 |
37 |
38 | def get_global_rigid_transform(rot_angle, translation, bcast_dims):
39 | """Returns rigid transform that globally rotates/translates by same amount."""
40 |
41 | rot_angle = np.asarray(rot_angle)
42 | translation = np.asarray(translation)
43 | if bcast_dims:
44 | for _ in range(bcast_dims):
45 | rot_angle = np.expand_dims(rot_angle, 0)
46 | translation = np.expand_dims(translation, 0)
47 | sin_angle = np.sin(np.deg2rad(rot_angle))
48 | cos_angle = np.cos(np.deg2rad(rot_angle))
49 | ones = np.ones_like(sin_angle)
50 | zeros = np.zeros_like(sin_angle)
51 | rot = r3.Rots(ones, zeros, zeros,
52 | zeros, cos_angle, -sin_angle,
53 | zeros, sin_angle, cos_angle)
54 | trans = r3.Vecs(translation[..., 0], translation[..., 1], translation[..., 2])
55 | return r3.Rigids(rot, trans)
56 |
57 |
58 | class AllAtomTest(parameterized.TestCase, absltest.TestCase):
59 |
60 | @parameterized.named_parameters(
61 | ('identity', 0, [0, 0, 0]),
62 | ('rot_90', 90, [0, 0, 0]),
63 | ('trans_10', 0, [0, 0, 10]),
64 | ('rot_174_trans_1', 174, [1, 1, 1]))
65 | def test_frame_aligned_point_error_perfect_on_global_transform(
66 | self, rot_angle, translation):
67 | """Tests global transform between target and preds gives perfect score."""
68 |
69 | # pylint: disable=bad-whitespace
70 | target_positions = np.array(
71 | [[ 21.182, 23.095, 19.731],
72 | [ 22.055, 20.919, 17.294],
73 | [ 24.599, 20.005, 15.041],
74 | [ 25.567, 18.214, 12.166],
75 | [ 28.063, 17.082, 10.043],
76 | [ 28.779, 15.569, 6.985],
77 | [ 30.581, 13.815, 4.612],
78 | [ 29.258, 12.193, 2.296]])
79 | # pylint: enable=bad-whitespace
80 | global_rigid_transform = get_global_rigid_transform(
81 | rot_angle, translation, 1)
82 |
83 | target_positions = r3.vecs_from_tensor(target_positions)
84 | pred_positions = r3.rigids_mul_vecs(
85 | global_rigid_transform, target_positions)
86 | positions_mask = np.ones(target_positions.x.shape[0])
87 |
88 | target_frames = get_identity_rigid(10)
89 | pred_frames = r3.rigids_mul_rigids(global_rigid_transform, target_frames)
90 | frames_mask = np.ones(10)
91 |
92 | fape = all_atom.frame_aligned_point_error(
93 | pred_frames, target_frames, frames_mask, pred_positions,
94 | target_positions, positions_mask, L1_CLAMP_DISTANCE,
95 | L1_CLAMP_DISTANCE, epsilon=0)
96 | self.assertAlmostEqual(fape, 0.)
97 |
98 | @parameterized.named_parameters(
99 | ('identity',
100 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
101 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
102 | 0.),
103 | ('shift_2.5',
104 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
105 | [[2.5, 0, 0], [7.5, 0, 0], [7.5, 0, 0]],
106 | 0.25),
107 | ('shift_5',
108 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
109 | [[5, 0, 0], [10, 0, 0], [15, 0, 0]],
110 | 0.5),
111 | ('shift_10',
112 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
113 | [[10, 0, 0], [15, 0, 0], [0, 0, 0]],
114 | 1.))
115 | def test_frame_aligned_point_error_matches_expected(
116 | self, target_positions, pred_positions, expected_alddt):
117 | """Tests score matches expected."""
118 |
119 | target_frames = get_identity_rigid(2)
120 | pred_frames = target_frames
121 | frames_mask = np.ones(2)
122 |
123 | target_positions = r3.vecs_from_tensor(np.array(target_positions))
124 | pred_positions = r3.vecs_from_tensor(np.array(pred_positions))
125 | positions_mask = np.ones(target_positions.x.shape[0])
126 |
127 | alddt = all_atom.frame_aligned_point_error(
128 | pred_frames, target_frames, frames_mask, pred_positions,
129 | target_positions, positions_mask, L1_CLAMP_DISTANCE,
130 | L1_CLAMP_DISTANCE, epsilon=0)
131 | self.assertAlmostEqual(alddt, expected_alddt)
132 |
133 |
134 | if __name__ == '__main__':
135 | absltest.main()
136 |
--------------------------------------------------------------------------------
/alphafold/model/common_modules.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A collection of common Haiku modules for use in protein folding."""
16 | import haiku as hk
17 | import jax.numpy as jnp
18 |
19 |
20 | class Linear(hk.Module):
21 | """Protein folding specific Linear Module.
22 |
23 | This differs from the standard Haiku Linear in a few ways:
24 | * It supports inputs of arbitrary rank
25 | * Initializers are specified by strings
26 | """
27 |
28 | def __init__(self,
29 | num_output: int,
30 | initializer: str = 'linear',
31 | use_bias: bool = True,
32 | bias_init: float = 0.,
33 | name: str = 'linear'):
34 | """Constructs Linear Module.
35 |
36 | Args:
37 | num_output: number of output channels.
38 | initializer: What initializer to use, should be one of {'linear', 'relu',
39 | 'zeros'}
40 | use_bias: Whether to include trainable bias
41 | bias_init: Value used to initialize bias.
42 | name: name of module, used for name scopes.
43 | """
44 |
45 | super().__init__(name=name)
46 | self.num_output = num_output
47 | self.initializer = initializer
48 | self.use_bias = use_bias
49 | self.bias_init = bias_init
50 |
51 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
52 | """Connects Module.
53 |
54 | Args:
55 | inputs: Tensor of shape [..., num_channel]
56 |
57 | Returns:
58 | output of shape [..., num_output]
59 | """
60 | n_channels = int(inputs.shape[-1])
61 |
62 | weight_shape = [n_channels, self.num_output]
63 | if self.initializer == 'linear':
64 | weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=1.)
65 | elif self.initializer == 'relu':
66 | weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=2.)
67 | elif self.initializer == 'zeros':
68 | weight_init = hk.initializers.Constant(0.0)
69 |
70 | weights = hk.get_parameter('weights', weight_shape, inputs.dtype,
71 | weight_init)
72 |
73 | # this is equivalent to einsum('...c,cd->...d', inputs, weights)
74 | # but turns out to be slightly faster
75 | inputs = jnp.swapaxes(inputs, -1, -2)
76 | output = jnp.einsum('...cb,cd->...db', inputs, weights)
77 | output = jnp.swapaxes(output, -1, -2)
78 |
79 | if self.use_bias:
80 | bias = hk.get_parameter('bias', [self.num_output], inputs.dtype,
81 | hk.initializers.Constant(self.bias_init))
82 | output += bias
83 |
84 | return output
85 |
--------------------------------------------------------------------------------
/alphafold/model/data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Convenience functions for reading data."""
16 |
17 | import io
18 | import os
19 | from typing import List
20 |
21 | import haiku as hk
22 | import numpy as np
23 |
24 | from alphafold.model import utils
25 | # Internal import (7716).
26 |
27 |
28 | def casp_model_names(data_dir: str) -> List[str]:
29 | params = os.listdir(os.path.join(data_dir, 'params'))
30 | return [os.path.splitext(filename)[0] for filename in params]
31 |
32 |
33 | def get_model_haiku_params(model_name: str, data_dir: str) -> hk.Params:
34 | """Get the Haiku parameters from a model name."""
35 |
36 | path = os.path.join(data_dir, 'params', f'params_{model_name}.npz')
37 |
38 | with open(path, 'rb') as f:
39 | params = np.load(io.BytesIO(f.read()), allow_pickle=False)
40 |
41 | return utils.flat_params_to_haiku(params)
42 |
--------------------------------------------------------------------------------
/alphafold/model/features.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Code to generate processed features."""
16 | import copy
17 | from typing import List, Mapping, Tuple
18 |
19 | import ml_collections
20 | import numpy as np
21 | import tensorflow.compat.v1 as tf
22 |
23 | from alphafold.model.tf import input_pipeline
24 | from alphafold.model.tf import proteins_dataset
25 |
26 | FeatureDict = Mapping[str, np.ndarray]
27 |
28 |
29 | def make_data_config(
30 | config: ml_collections.ConfigDict,
31 | num_res: int,
32 | ) -> Tuple[ml_collections.ConfigDict, List[str]]:
33 | """Makes a data config for the input pipeline."""
34 | cfg = copy.deepcopy(config.data)
35 |
36 | feature_names = cfg.common.unsupervised_features
37 | if cfg.common.use_templates:
38 | feature_names += cfg.common.template_features
39 |
40 | with cfg.unlocked():
41 | cfg.eval.crop_size = num_res
42 |
43 | return cfg, feature_names
44 |
45 |
46 | def tf_example_to_features(tf_example: tf.train.Example,
47 | config: ml_collections.ConfigDict,
48 | random_seed: int = 0) -> FeatureDict:
49 | """Converts tf_example to numpy feature dictionary."""
50 | num_res = int(tf_example.features.feature['seq_length'].int64_list.value[0])
51 | cfg, feature_names = make_data_config(config, num_res=num_res)
52 |
53 | if 'deletion_matrix_int' in set(tf_example.features.feature):
54 | deletion_matrix_int = (
55 | tf_example.features.feature['deletion_matrix_int'].int64_list.value)
56 | feat = tf.train.Feature(float_list=tf.train.FloatList(
57 | value=map(float, deletion_matrix_int)))
58 | tf_example.features.feature['deletion_matrix'].CopyFrom(feat)
59 | del tf_example.features.feature['deletion_matrix_int']
60 |
61 | tf_graph = tf.Graph()
62 | with tf_graph.as_default(), tf.device('/device:CPU:0'):
63 | tf.compat.v1.set_random_seed(random_seed)
64 | tensor_dict = proteins_dataset.create_tensor_dict(
65 | raw_data=tf_example.SerializeToString(),
66 | features=feature_names)
67 | processed_batch = input_pipeline.process_tensors_from_config(
68 | tensor_dict, cfg)
69 |
70 | tf_graph.finalize()
71 |
72 | with tf.Session(graph=tf_graph) as sess:
73 | features = sess.run(processed_batch)
74 |
75 | return {k: v for k, v in features.items() if v.dtype != 'O'}
76 |
77 |
78 | def np_example_to_features(np_example: FeatureDict,
79 | config: ml_collections.ConfigDict,
80 | random_seed: int = 0) -> FeatureDict:
81 | """Preprocesses NumPy feature dict using TF pipeline."""
82 | np_example = dict(np_example)
83 | num_res = int(np_example['seq_length'][0])
84 | cfg, feature_names = make_data_config(config, num_res=num_res)
85 |
86 | if 'deletion_matrix_int' in np_example:
87 | np_example['deletion_matrix'] = (
88 | np_example.pop('deletion_matrix_int').astype(np.float32))
89 |
90 | tf_graph = tf.Graph()
91 | with tf_graph.as_default(), tf.device('/device:CPU:0'):
92 | tf.compat.v1.set_random_seed(random_seed)
93 | tensor_dict = proteins_dataset.np_to_tensor_dict(
94 | np_example=np_example, features=feature_names)
95 |
96 | processed_batch = input_pipeline.process_tensors_from_config(
97 | tensor_dict, cfg)
98 |
99 | tf_graph.finalize()
100 |
101 | with tf.Session(graph=tf_graph) as sess:
102 | features = sess.run(processed_batch)
103 |
104 | return {k: v for k, v in features.items() if v.dtype != 'O'}
105 |
--------------------------------------------------------------------------------
/alphafold/model/lddt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """lDDT protein distance score."""
16 | import jax.numpy as jnp
17 |
18 |
19 | def lddt(predicted_points,
20 | true_points,
21 | true_points_mask,
22 | cutoff=15.,
23 | per_residue=False):
24 | """Measure (approximate) lDDT for a batch of coordinates.
25 |
26 | lDDT reference:
27 | Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local
28 | superposition-free score for comparing protein structures and models using
29 | distance difference tests. Bioinformatics 29, 2722–2728 (2013).
30 |
31 | lDDT is a measure of the difference between the true distance matrix and the
32 | distance matrix of the predicted points. The difference is computed only on
33 | points closer than cutoff *in the true structure*.
34 |
35 | This function does not compute the exact lDDT value that the original paper
36 | describes because it does not include terms for physical feasibility
37 | (e.g. bond length violations). Therefore this is only an approximate
38 | lDDT score.
39 |
40 | Args:
41 | predicted_points: (batch, length, 3) array of predicted 3D points
42 | true_points: (batch, length, 3) array of true 3D points
43 | true_points_mask: (batch, length, 1) binary-valued float array. This mask
44 | should be 1 for points that exist in the true points.
45 | cutoff: Maximum distance for a pair of points to be included
46 | per_residue: If true, return score for each residue. Note that the overall
47 | lDDT is not exactly the mean of the per_residue lDDT's because some
48 | residues have more contacts than others.
49 |
50 | Returns:
51 | An (approximate, see above) lDDT score in the range 0-1.
52 | """
53 |
54 | assert len(predicted_points.shape) == 3
55 | assert predicted_points.shape[-1] == 3
56 | assert true_points_mask.shape[-1] == 1
57 | assert len(true_points_mask.shape) == 3
58 |
59 | # Compute true and predicted distance matrices.
60 | dmat_true = jnp.sqrt(1e-10 + jnp.sum(
61 | (true_points[:, :, None] - true_points[:, None, :])**2, axis=-1))
62 |
63 | dmat_predicted = jnp.sqrt(1e-10 + jnp.sum(
64 | (predicted_points[:, :, None] -
65 | predicted_points[:, None, :])**2, axis=-1))
66 |
67 | dists_to_score = (
68 | (dmat_true < cutoff).astype(jnp.float32) * true_points_mask *
69 | jnp.transpose(true_points_mask, [0, 2, 1]) *
70 | (1. - jnp.eye(dmat_true.shape[1])) # Exclude self-interaction.
71 | )
72 |
73 | # Shift unscored distances to be far away.
74 | dist_l1 = jnp.abs(dmat_true - dmat_predicted)
75 |
76 | # True lDDT uses a number of fixed bins.
77 | # We ignore the physical plausibility correction to lDDT, though.
78 | score = 0.25 * ((dist_l1 < 0.5).astype(jnp.float32) +
79 | (dist_l1 < 1.0).astype(jnp.float32) +
80 | (dist_l1 < 2.0).astype(jnp.float32) +
81 | (dist_l1 < 4.0).astype(jnp.float32))
82 |
83 | # Normalize over the appropriate axes.
84 | reduce_axes = (-1,) if per_residue else (-2, -1)
85 | norm = 1. / (1e-10 + jnp.sum(dists_to_score, axis=reduce_axes))
86 | score = norm * (1e-10 + jnp.sum(dists_to_score * score, axis=reduce_axes))
87 |
88 | return score
89 |
--------------------------------------------------------------------------------
/alphafold/model/lddt_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for lddt."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | import numpy as np
20 | from alphafold.model import lddt
21 |
22 |
23 | class LddtTest(parameterized.TestCase, absltest.TestCase):
24 |
25 | @parameterized.named_parameters(
26 | ('same',
27 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
28 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
29 | [1, 1, 1]),
30 | ('all_shifted',
31 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
32 | [[-1, 0, 0], [4, 0, 0], [9, 0, 0]],
33 | [1, 1, 1]),
34 | ('all_rotated',
35 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
36 | [[0, 0, 0], [0, 5, 0], [0, 10, 0]],
37 | [1, 1, 1]),
38 | ('half_a_dist',
39 | [[0, 0, 0], [5, 0, 0]],
40 | [[0, 0, 0], [5.5-1e-5, 0, 0]],
41 | [1, 1]),
42 | ('one_a_dist',
43 | [[0, 0, 0], [5, 0, 0]],
44 | [[0, 0, 0], [6-1e-5, 0, 0]],
45 | [0.75, 0.75]),
46 | ('two_a_dist',
47 | [[0, 0, 0], [5, 0, 0]],
48 | [[0, 0, 0], [7-1e-5, 0, 0]],
49 | [0.5, 0.5]),
50 | ('four_a_dist',
51 | [[0, 0, 0], [5, 0, 0]],
52 | [[0, 0, 0], [9-1e-5, 0, 0]],
53 | [0.25, 0.25],),
54 | ('five_a_dist',
55 | [[0, 0, 0], [16-1e-5, 0, 0]],
56 | [[0, 0, 0], [11, 0, 0]],
57 | [0, 0]),
58 | ('no_pairs',
59 | [[0, 0, 0], [20, 0, 0]],
60 | [[0, 0, 0], [25-1e-5, 0, 0]],
61 | [1, 1]),
62 | )
63 | def test_lddt(
64 | self, predicted_pos, true_pos, exp_lddt):
65 | predicted_pos = np.array([predicted_pos], dtype=np.float32)
66 | true_points_mask = np.array([[[1]] * len(true_pos)], dtype=np.float32)
67 | true_pos = np.array([true_pos], dtype=np.float32)
68 | cutoff = 15.0
69 | per_residue = True
70 |
71 | result = lddt.lddt(
72 | predicted_pos, true_pos, true_points_mask, cutoff,
73 | per_residue)
74 |
75 | np.testing.assert_almost_equal(result, [exp_lddt], decimal=4)
76 |
77 |
78 | if __name__ == '__main__':
79 | absltest.main()
80 |
--------------------------------------------------------------------------------
/alphafold/model/mapping.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Specialized mapping functions."""
16 |
17 | import functools
18 |
19 | from typing import Any, Callable, Optional, Sequence, Union
20 |
21 | import haiku as hk
22 | import jax
23 | import jax.numpy as jnp
24 |
25 |
26 | PYTREE = Any
27 | PYTREE_JAX_ARRAY = Any
28 |
29 | partial = functools.partial
30 | PROXY = object()
31 |
32 |
33 | def _maybe_slice(array, i, slice_size, axis):
34 | if axis is PROXY:
35 | return array
36 | else:
37 | return jax.lax.dynamic_slice_in_dim(
38 | array, i, slice_size=slice_size, axis=axis)
39 |
40 |
41 | def _maybe_get_size(array, axis):
42 | if axis == PROXY:
43 | return -1
44 | else:
45 | return array.shape[axis]
46 |
47 |
48 | def _expand_axes(axes, values, name='sharded_apply'):
49 | values_tree_def = jax.tree_flatten(values)[1]
50 | flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes)
51 | # Replace None's with PROXY
52 | flat_axes = [PROXY if x is None else x for x in flat_axes]
53 | return jax.tree_unflatten(values_tree_def, flat_axes)
54 |
55 |
56 | def sharded_map(
57 | fun: Callable[..., PYTREE_JAX_ARRAY],
58 | shard_size: Union[int, None] = 1,
59 | in_axes: Union[int, PYTREE] = 0,
60 | out_axes: Union[int, PYTREE] = 0) -> Callable[..., PYTREE_JAX_ARRAY]:
61 | """Sharded vmap.
62 |
63 | Maps `fun` over axes, in a way similar to vmap, but does so in shards of
64 | `shard_size`. This allows a smooth trade-off between memory usage
65 | (as in a plain map) vs higher throughput (as in a vmap).
66 |
67 | Args:
68 | fun: Function to apply smap transform to.
69 | shard_size: Integer denoting shard size.
70 | in_axes: Either integer or pytree describing which axis to map over for each
71 | input to `fun`, None denotes broadcasting.
72 | out_axes: integer or pytree denoting to what axis in the output the mapped
73 | over axis maps.
74 |
75 | Returns:
76 | function with smap applied.
77 | """
78 | vmapped_fun = hk.vmap(fun, in_axes, out_axes)
79 | return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes)
80 |
81 |
82 | def sharded_apply(
83 | fun: Callable[..., PYTREE_JAX_ARRAY], # pylint: disable=g-bare-generic
84 | shard_size: Union[int, None] = 1,
85 | in_axes: Union[int, PYTREE] = 0,
86 | out_axes: Union[int, PYTREE] = 0,
87 | new_out_axes: bool = False) -> Callable[..., PYTREE_JAX_ARRAY]:
88 | """Sharded apply.
89 |
90 | Applies `fun` over shards to axes, in a way similar to vmap,
91 | but does so in shards of `shard_size`. Shards are stacked after.
92 | This allows a smooth trade-off between
93 | memory usage (as in a plain map) vs higher throughput (as in a vmap).
94 |
95 | Args:
96 | fun: Function to apply smap transform to.
97 | shard_size: Integer denoting shard size.
98 | in_axes: Either integer or pytree describing which axis to map over for each
99 | input to `fun`, None denotes broadcasting.
100 | out_axes: integer or pytree denoting to what axis in the output the mapped
101 | over axis maps.
102 | new_out_axes: whether to stack outputs on new axes. This assumes that the
103 | output sizes for each shard (including the possible remainder shard) are
104 | the same.
105 |
106 | Returns:
107 | function with smap applied.
108 | """
109 | docstr = ('Mapped version of {fun}. Takes similar arguments to {fun} '
110 | 'but with additional array axes over which {fun} is mapped.')
111 | if new_out_axes:
112 | raise NotImplementedError('New output axes not yet implemented.')
113 |
114 | # shard size None denotes no sharding
115 | if shard_size is None:
116 | return fun
117 |
118 | @jax.util.wraps(fun, docstr=docstr)
119 | def mapped_fn(*args):
120 | # Expand in axes and Determine Loop range
121 | in_axes_ = _expand_axes(in_axes, args)
122 |
123 | in_sizes = jax.tree_multimap(_maybe_get_size, args, in_axes_)
124 | flat_sizes = jax.tree_flatten(in_sizes)[0]
125 | in_size = max(flat_sizes)
126 | assert all(i in {in_size, -1} for i in flat_sizes)
127 |
128 | num_extra_shards = (in_size - 1) // shard_size
129 |
130 | # Fix Up if necessary
131 | last_shard_size = in_size % shard_size
132 | last_shard_size = shard_size if last_shard_size == 0 else last_shard_size
133 |
134 | def apply_fun_to_slice(slice_start, slice_size):
135 | input_slice = jax.tree_multimap(
136 | lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis
137 | ), args, in_axes_)
138 | return fun(*input_slice)
139 |
140 | remainder_shape_dtype = hk.eval_shape(
141 | partial(apply_fun_to_slice, 0, last_shard_size))
142 | out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype)
143 | out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype)
144 | out_axes_ = _expand_axes(out_axes, remainder_shape_dtype)
145 |
146 | if num_extra_shards > 0:
147 | regular_shard_shape_dtype = hk.eval_shape(
148 | partial(apply_fun_to_slice, 0, shard_size))
149 | shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype)
150 |
151 | def make_output_shape(axis, shard_shape, remainder_shape):
152 | return shard_shape[:axis] + (
153 | shard_shape[axis] * num_extra_shards +
154 | remainder_shape[axis],) + shard_shape[axis + 1:]
155 |
156 | out_shapes = jax.tree_multimap(make_output_shape, out_axes_, shard_shapes,
157 | out_shapes)
158 |
159 | # Calls dynamic Update slice with different argument order
160 | # This is here since tree_multimap only works with positional arguments
161 | def dynamic_update_slice_in_dim(full_array, update, axis, i):
162 | return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis)
163 |
164 | def compute_shard(outputs, slice_start, slice_size):
165 | slice_out = apply_fun_to_slice(slice_start, slice_size)
166 | update_slice = partial(
167 | dynamic_update_slice_in_dim, i=slice_start)
168 | return jax.tree_multimap(update_slice, outputs, slice_out, out_axes_)
169 |
170 | def scan_iteration(outputs, i):
171 | new_outputs = compute_shard(outputs, i, shard_size)
172 | return new_outputs, ()
173 |
174 | slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size)
175 |
176 | def allocate_buffer(dtype, shape):
177 | return jnp.zeros(shape, dtype=dtype)
178 |
179 | outputs = jax.tree_multimap(allocate_buffer, out_dtypes, out_shapes)
180 |
181 | if slice_starts.shape[0] > 0:
182 | outputs, _ = hk.scan(scan_iteration, outputs, slice_starts)
183 |
184 | if last_shard_size != shard_size:
185 | remainder_start = in_size - last_shard_size
186 | outputs = compute_shard(outputs, remainder_start, last_shard_size)
187 |
188 | return outputs
189 |
190 | return mapped_fn
191 |
192 |
193 | def inference_subbatch(
194 | module: Callable[..., PYTREE_JAX_ARRAY],
195 | subbatch_size: int,
196 | batched_args: Sequence[PYTREE_JAX_ARRAY],
197 | nonbatched_args: Sequence[PYTREE_JAX_ARRAY],
198 | low_memory: bool = True,
199 | input_subbatch_dim: int = 0,
200 | output_subbatch_dim: Optional[int] = None) -> PYTREE_JAX_ARRAY:
201 | """Run through subbatches (like batch apply but with split and concat)."""
202 | assert len(batched_args) > 0 # pylint: disable=g-explicit-length-test
203 |
204 | if not low_memory:
205 | args = list(batched_args) + list(nonbatched_args)
206 | return module(*args)
207 |
208 | if output_subbatch_dim is None:
209 | output_subbatch_dim = input_subbatch_dim
210 |
211 | def run_module(*batched_args):
212 | args = list(batched_args) + list(nonbatched_args)
213 | return module(*args)
214 | sharded_module = sharded_apply(run_module,
215 | shard_size=subbatch_size,
216 | in_axes=input_subbatch_dim,
217 | out_axes=output_subbatch_dim)
218 | return sharded_module(*batched_args)
219 |
--------------------------------------------------------------------------------
/alphafold/model/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Code for constructing the model."""
16 | from typing import Any, Mapping, Optional, Union
17 |
18 | from absl import logging
19 | import haiku as hk
20 | import jax
21 | import ml_collections
22 | import numpy as np
23 | import tensorflow.compat.v1 as tf
24 | import tree
25 |
26 | from alphafold.common import confidence
27 | from alphafold.model import features
28 | from alphafold.model import modules
29 |
30 |
31 | def get_confidence_metrics(
32 | prediction_result: Mapping[str, Any]) -> Mapping[str, Any]:
33 | """Post processes prediction_result to get confidence metrics."""
34 |
35 | confidence_metrics = {}
36 | confidence_metrics['plddt'] = confidence.compute_plddt(
37 | prediction_result['predicted_lddt']['logits'])
38 | if 'predicted_aligned_error' in prediction_result:
39 | confidence_metrics.update(confidence.compute_predicted_aligned_error(
40 | prediction_result['predicted_aligned_error']['logits'],
41 | prediction_result['predicted_aligned_error']['breaks']))
42 | confidence_metrics['ptm'] = confidence.predicted_tm_score(
43 | prediction_result['predicted_aligned_error']['logits'],
44 | prediction_result['predicted_aligned_error']['breaks'])
45 |
46 | return confidence_metrics
47 |
48 |
49 | class RunModel:
50 | """Container for JAX model."""
51 |
52 | def __init__(self,
53 | config: ml_collections.ConfigDict,
54 | params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None):
55 | self.config = config
56 | self.params = params
57 |
58 | def _forward_fn(batch):
59 | model = modules.AlphaFold(self.config.model)
60 | return model(
61 | batch,
62 | is_training=False,
63 | compute_loss=False,
64 | ensemble_representations=True)
65 |
66 | self.apply = jax.jit(hk.transform(_forward_fn).apply)
67 | self.init = jax.jit(hk.transform(_forward_fn).init)
68 |
69 | def init_params(self, feat: features.FeatureDict, random_seed: int = 0):
70 | """Initializes the model parameters.
71 |
72 | If none were provided when this class was instantiated then the parameters
73 | are randomly initialized.
74 |
75 | Args:
76 | feat: A dictionary of NumPy feature arrays as output by
77 | RunModel.process_features.
78 | random_seed: A random seed to use to initialize the parameters if none
79 | were set when this class was initialized.
80 | """
81 | if not self.params:
82 | # Init params randomly.
83 | rng = jax.random.PRNGKey(random_seed)
84 | self.params = hk.data_structures.to_mutable_dict(
85 | self.init(rng, feat))
86 | logging.warning('Initialized parameters randomly')
87 |
88 | def process_features(
89 | self,
90 | raw_features: Union[tf.train.Example, features.FeatureDict],
91 | random_seed: int) -> features.FeatureDict:
92 | """Processes features to prepare for feeding them into the model.
93 |
94 | Args:
95 | raw_features: The output of the data pipeline either as a dict of NumPy
96 | arrays or as a tf.train.Example.
97 | random_seed: The random seed to use when processing the features.
98 |
99 | Returns:
100 | A dict of NumPy feature arrays suitable for feeding into the model.
101 | """
102 | if isinstance(raw_features, dict):
103 | return features.np_example_to_features(
104 | np_example=raw_features,
105 | config=self.config,
106 | random_seed=random_seed)
107 | else:
108 | return features.tf_example_to_features(
109 | tf_example=raw_features,
110 | config=self.config,
111 | random_seed=random_seed)
112 |
113 | def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct:
114 | self.init_params(feat)
115 | logging.info('Running eval_shape with shape(feat) = %s',
116 | tree.map_structure(lambda x: x.shape, feat))
117 | shape = jax.eval_shape(self.apply, self.params, jax.random.PRNGKey(0), feat)
118 | logging.info('Output shape was %s', shape)
119 | return shape
120 |
121 | def predict(self, feat: features.FeatureDict) -> Mapping[str, Any]:
122 | """Makes a prediction by inferencing the model on the provided features.
123 |
124 | Args:
125 | feat: A dictionary of NumPy feature arrays as output by
126 | RunModel.process_features.
127 |
128 | Returns:
129 | A dictionary of model outputs.
130 | """
131 | self.init_params(feat)
132 | logging.info('Running predict with shape(feat) = %s',
133 | tree.map_structure(lambda x: x.shape, feat))
134 | result = self.apply(self.params, jax.random.PRNGKey(0), feat)
135 | # This block is to ensure benchmark timings are accurate. Some blocking is
136 | # already happening when computing get_confidence_metrics, and this ensures
137 | # all outputs are blocked on.
138 | jax.tree_map(lambda x: x.block_until_ready(), result)
139 | result.update(get_confidence_metrics(result))
140 | logging.info('Output shape was %s',
141 | tree.map_structure(lambda x: x.shape, result))
142 | return result
143 |
144 |
--------------------------------------------------------------------------------
/alphafold/model/prng.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A collection of utilities surrounding PRNG usage in protein folding."""
16 |
17 | import haiku as hk
18 | import jax
19 |
20 |
21 | def safe_dropout(*, tensor, safe_key, rate, is_deterministic, is_training):
22 | if is_training and rate != 0.0 and not is_deterministic:
23 | return hk.dropout(safe_key.get(), rate, tensor)
24 | else:
25 | return tensor
26 |
27 |
28 | class SafeKey:
29 | """Safety wrapper for PRNG keys."""
30 |
31 | def __init__(self, key):
32 | self._key = key
33 | self._used = False
34 |
35 | def _assert_not_used(self):
36 | if self._used:
37 | raise RuntimeError('Random key has been used previously.')
38 |
39 | def get(self):
40 | self._assert_not_used()
41 | self._used = True
42 | return self._key
43 |
44 | def split(self, num_keys=2):
45 | self._assert_not_used()
46 | self._used = True
47 | new_keys = jax.random.split(self._key, num_keys)
48 | return jax.tree_map(SafeKey, tuple(new_keys))
49 |
50 | def duplicate(self, num_keys=2):
51 | self._assert_not_used()
52 | self._used = True
53 | return tuple(SafeKey(self._key) for _ in range(num_keys))
54 |
55 |
56 | def _safe_key_flatten(safe_key):
57 | # Flatten transfers "ownership" to the tree
58 | return (safe_key._key,), safe_key._used # pylint: disable=protected-access
59 |
60 |
61 | def _safe_key_unflatten(aux_data, children):
62 | ret = SafeKey(children[0])
63 | ret._used = aux_data # pylint: disable=protected-access
64 | return ret
65 |
66 |
67 | jax.tree_util.register_pytree_node(
68 | SafeKey, _safe_key_flatten, _safe_key_unflatten)
69 |
70 |
--------------------------------------------------------------------------------
/alphafold/model/prng_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for prng."""
16 |
17 | from absl.testing import absltest
18 | import jax
19 |
20 | from alphafold.model import prng
21 |
22 |
23 | class PrngTest(absltest.TestCase):
24 |
25 | def test_key_reuse(self):
26 |
27 | init_key = jax.random.PRNGKey(42)
28 | safe_key = prng.SafeKey(init_key)
29 | _, safe_key = safe_key.split()
30 |
31 | raw_key = safe_key.get()
32 |
33 | self.assertNotEqual(raw_key[0], init_key[0])
34 | self.assertNotEqual(raw_key[1], init_key[1])
35 |
36 | with self.assertRaises(RuntimeError):
37 | safe_key.get()
38 |
39 | with self.assertRaises(RuntimeError):
40 | safe_key.split()
41 |
42 | with self.assertRaises(RuntimeError):
43 | safe_key.duplicate()
44 |
45 |
46 | if __name__ == '__main__':
47 | absltest.main()
48 |
--------------------------------------------------------------------------------
/alphafold/model/quat_affine_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for quat_affine."""
16 |
17 | from absl import logging
18 | from absl.testing import absltest
19 | import jax
20 | import jax.numpy as jnp
21 | import numpy as np
22 | from alphafold.model import quat_affine
23 |
24 | VERBOSE = False
25 | np.set_printoptions(precision=3, suppress=True)
26 |
27 | r2t = quat_affine.rot_list_to_tensor
28 | v2t = quat_affine.vec_list_to_tensor
29 |
30 | q2r = lambda q: r2t(quat_affine.quat_to_rot(q))
31 |
32 |
33 | class QuatAffineTest(absltest.TestCase):
34 |
35 | def _assert_check(self, to_check, tol=1e-5):
36 | for k, (correct, generated) in to_check.items():
37 | if VERBOSE:
38 | logging.info(k)
39 | logging.info('Correct %s', correct)
40 | logging.info('Predicted %s', generated)
41 | self.assertLess(np.max(np.abs(correct - generated)), tol)
42 |
43 | def test_conversion(self):
44 | quat = jnp.array([-2., 5., -1., 4.])
45 |
46 | rotation = jnp.array([
47 | [0.26087, 0.130435, 0.956522],
48 | [-0.565217, -0.782609, 0.26087],
49 | [0.782609, -0.608696, -0.130435]])
50 |
51 | translation = jnp.array([1., -3., 4.])
52 | point = jnp.array([0.7, 3.2, -2.9])
53 |
54 | a = quat_affine.QuatAffine(quat, translation, unstack_inputs=True)
55 | true_new_point = jnp.matmul(rotation, point[:, None])[:, 0] + translation
56 |
57 | self._assert_check({
58 | 'rot': (rotation, r2t(a.rotation)),
59 | 'trans': (translation, v2t(a.translation)),
60 | 'point': (true_new_point,
61 | v2t(a.apply_to_point(jnp.moveaxis(point, -1, 0)))),
62 | # Because of the double cover, we must be careful and compare rotations
63 | 'quat': (q2r(a.quaternion),
64 | q2r(quat_affine.rot_to_quat(a.rotation))),
65 |
66 | })
67 |
68 | def test_double_cover(self):
69 | """Test that -q is the same rotation as q."""
70 | rng = jax.random.PRNGKey(42)
71 | keys = jax.random.split(rng)
72 | q = jax.random.normal(keys[0], (2, 4))
73 | trans = jax.random.normal(keys[1], (2, 3))
74 | a1 = quat_affine.QuatAffine(q, trans, unstack_inputs=True)
75 | a2 = quat_affine.QuatAffine(-q, trans, unstack_inputs=True)
76 |
77 | self._assert_check({
78 | 'rot': (r2t(a1.rotation),
79 | r2t(a2.rotation)),
80 | 'trans': (v2t(a1.translation),
81 | v2t(a2.translation)),
82 | })
83 |
84 | def test_homomorphism(self):
85 | rng = jax.random.PRNGKey(42)
86 | keys = jax.random.split(rng, 4)
87 | vec_q1 = jax.random.normal(keys[0], (2, 3))
88 |
89 | q1 = jnp.concatenate([
90 | jnp.ones_like(vec_q1)[:, :1],
91 | vec_q1], axis=-1)
92 |
93 | q2 = jax.random.normal(keys[1], (2, 4))
94 | t1 = jax.random.normal(keys[2], (2, 3))
95 | t2 = jax.random.normal(keys[3], (2, 3))
96 |
97 | a1 = quat_affine.QuatAffine(q1, t1, unstack_inputs=True)
98 | a2 = quat_affine.QuatAffine(q2, t2, unstack_inputs=True)
99 | a21 = a2.pre_compose(jnp.concatenate([vec_q1, t1], axis=-1))
100 |
101 | rng, key = jax.random.split(rng)
102 | x = jax.random.normal(key, (2, 3))
103 | new_x = a21.apply_to_point(jnp.moveaxis(x, -1, 0))
104 | new_x_apply2 = a2.apply_to_point(a1.apply_to_point(jnp.moveaxis(x, -1, 0)))
105 |
106 | self._assert_check({
107 | 'quat': (q2r(quat_affine.quat_multiply(a2.quaternion, a1.quaternion)),
108 | q2r(a21.quaternion)),
109 | 'rot': (jnp.matmul(r2t(a2.rotation), r2t(a1.rotation)),
110 | r2t(a21.rotation)),
111 | 'point': (v2t(new_x_apply2),
112 | v2t(new_x)),
113 | 'inverse': (x, v2t(a21.invert_point(new_x))),
114 | })
115 |
116 | def test_batching(self):
117 | """Test that affine applies batchwise."""
118 | rng = jax.random.PRNGKey(42)
119 | keys = jax.random.split(rng, 3)
120 | q = jax.random.uniform(keys[0], (5, 2, 4))
121 | t = jax.random.uniform(keys[1], (2, 3))
122 | x = jax.random.uniform(keys[2], (5, 1, 3))
123 |
124 | a = quat_affine.QuatAffine(q, t, unstack_inputs=True)
125 | y = v2t(a.apply_to_point(jnp.moveaxis(x, -1, 0)))
126 |
127 | y_list = []
128 | for i in range(5):
129 | for j in range(2):
130 | a_local = quat_affine.QuatAffine(q[i, j], t[j],
131 | unstack_inputs=True)
132 | y_local = v2t(a_local.apply_to_point(jnp.moveaxis(x[i, 0], -1, 0)))
133 | y_list.append(y_local)
134 | y_combine = jnp.reshape(jnp.stack(y_list, axis=0), (5, 2, 3))
135 |
136 | self._assert_check({
137 | 'batch': (y_combine, y),
138 | 'quat': (q2r(a.quaternion),
139 | q2r(quat_affine.rot_to_quat(a.rotation))),
140 | })
141 |
142 | def assertAllClose(self, a, b, rtol=1e-06, atol=1e-06):
143 | self.assertTrue(np.allclose(a, b, rtol=rtol, atol=atol))
144 |
145 | def assertAllEqual(self, a, b):
146 | self.assertTrue(np.all(np.array(a) == np.array(b)))
147 |
148 |
149 | if __name__ == '__main__':
150 | absltest.main()
151 |
--------------------------------------------------------------------------------
/alphafold/model/tf/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Alphafold model TensorFlow code."""
15 |
--------------------------------------------------------------------------------
/alphafold/model/tf/input_pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Feature pre-processing input pipeline for AlphaFold."""
16 | import tensorflow.compat.v1 as tf
17 | import tree
18 |
19 | from alphafold.model.tf import data_transforms
20 | from alphafold.model.tf import shape_placeholders
21 |
22 | # Pylint gets confused by the curry1 decorator because it changes the number
23 | # of arguments to the function.
24 | # pylint:disable=no-value-for-parameter
25 |
26 |
27 | NUM_RES = shape_placeholders.NUM_RES
28 | NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ
29 | NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ
30 | NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
31 |
32 |
33 | def nonensembled_map_fns(data_config):
34 | """Input pipeline functions which are not ensembled."""
35 | common_cfg = data_config.common
36 |
37 | map_fns = [
38 | data_transforms.correct_msa_restypes,
39 | data_transforms.add_distillation_flag(False),
40 | data_transforms.cast_64bit_ints,
41 | data_transforms.squeeze_features,
42 | # Keep to not disrupt RNG.
43 | data_transforms.randomly_replace_msa_with_unknown(0.0),
44 | data_transforms.make_seq_mask,
45 | data_transforms.make_msa_mask,
46 | # Compute the HHblits profile if it's not set. This has to be run before
47 | # sampling the MSA.
48 | data_transforms.make_hhblits_profile,
49 | data_transforms.make_random_crop_to_size_seed,
50 | ]
51 | if common_cfg.use_templates:
52 | map_fns.extend([
53 | data_transforms.fix_templates_aatype,
54 | data_transforms.make_template_mask,
55 | data_transforms.make_pseudo_beta('template_')
56 | ])
57 | map_fns.extend([
58 | data_transforms.make_atom14_masks,
59 | ])
60 |
61 | return map_fns
62 |
63 |
64 | def ensembled_map_fns(data_config):
65 | """Input pipeline functions that can be ensembled and averaged."""
66 | common_cfg = data_config.common
67 | eval_cfg = data_config.eval
68 |
69 | map_fns = []
70 |
71 | if common_cfg.reduce_msa_clusters_by_max_templates:
72 | pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates
73 | else:
74 | pad_msa_clusters = eval_cfg.max_msa_clusters
75 |
76 | max_msa_clusters = pad_msa_clusters
77 | max_extra_msa = common_cfg.max_extra_msa
78 |
79 | map_fns.append(
80 | data_transforms.sample_msa(
81 | max_msa_clusters,
82 | keep_extra=True))
83 |
84 | if 'masked_msa' in common_cfg:
85 | # Masked MSA should come *before* MSA clustering so that
86 | # the clustering and full MSA profile do not leak information about
87 | # the masked locations and secret corrupted locations.
88 | map_fns.append(
89 | data_transforms.make_masked_msa(common_cfg.masked_msa,
90 | eval_cfg.masked_msa_replace_fraction))
91 |
92 | if common_cfg.msa_cluster_features:
93 | map_fns.append(data_transforms.nearest_neighbor_clusters())
94 | map_fns.append(data_transforms.summarize_clusters())
95 |
96 | # Crop after creating the cluster profiles.
97 | if max_extra_msa:
98 | map_fns.append(data_transforms.crop_extra_msa(max_extra_msa))
99 | else:
100 | map_fns.append(data_transforms.delete_extra_msa)
101 |
102 | map_fns.append(data_transforms.make_msa_feat())
103 |
104 | crop_feats = dict(eval_cfg.feat)
105 |
106 | if eval_cfg.fixed_size:
107 | map_fns.append(data_transforms.select_feat(list(crop_feats)))
108 | map_fns.append(data_transforms.random_crop_to_size(
109 | eval_cfg.crop_size,
110 | eval_cfg.max_templates,
111 | crop_feats,
112 | eval_cfg.subsample_templates))
113 | map_fns.append(data_transforms.make_fixed_size(
114 | crop_feats,
115 | pad_msa_clusters,
116 | common_cfg.max_extra_msa,
117 | eval_cfg.crop_size,
118 | eval_cfg.max_templates))
119 | else:
120 | map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates))
121 |
122 | return map_fns
123 |
124 |
125 | def process_tensors_from_config(tensors, data_config):
126 | """Apply filters and maps to an existing dataset, based on the config."""
127 |
128 | def wrap_ensemble_fn(data, i):
129 | """Function to be mapped over the ensemble dimension."""
130 | d = data.copy()
131 | fns = ensembled_map_fns(data_config)
132 | fn = compose(fns)
133 | d['ensemble_index'] = i
134 | return fn(d)
135 |
136 | eval_cfg = data_config.eval
137 | tensors = compose(
138 | nonensembled_map_fns(
139 | data_config))(
140 | tensors)
141 |
142 | tensors_0 = wrap_ensemble_fn(tensors, tf.constant(0))
143 | num_ensemble = eval_cfg.num_ensemble
144 | if data_config.common.resample_msa_in_recycling:
145 | # Separate batch per ensembling & recycling step.
146 | num_ensemble *= data_config.common.num_recycle + 1
147 |
148 | if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1:
149 | fn_output_signature = tree.map_structure(
150 | tf.TensorSpec.from_tensor, tensors_0)
151 | tensors = tf.map_fn(
152 | lambda x: wrap_ensemble_fn(tensors, x),
153 | tf.range(num_ensemble),
154 | parallel_iterations=1,
155 | fn_output_signature=fn_output_signature)
156 | else:
157 | tensors = tree.map_structure(lambda x: x[None],
158 | tensors_0)
159 | return tensors
160 |
161 |
162 | @data_transforms.curry1
163 | def compose(x, fs):
164 | for f in fs:
165 | x = f(x)
166 | return x
167 |
--------------------------------------------------------------------------------
/alphafold/model/tf/protein_features.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Contains descriptions of various protein features."""
16 | import enum
17 | from typing import Dict, Optional, Sequence, Tuple, Union
18 |
19 | import tensorflow.compat.v1 as tf
20 |
21 | from alphafold.common import residue_constants
22 |
23 | # Type aliases.
24 | FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]]
25 |
26 |
27 | class FeatureType(enum.Enum):
28 | ZERO_DIM = 0 # Shape [x]
29 | ONE_DIM = 1 # Shape [num_res, x]
30 | TWO_DIM = 2 # Shape [num_res, num_res, x]
31 | MSA = 3 # Shape [msa_length, num_res, x]
32 |
33 |
34 | # Placeholder values that will be replaced with their true value at runtime.
35 | NUM_RES = "num residues placeholder"
36 | NUM_SEQ = "length msa placeholder"
37 | NUM_TEMPLATES = "num templates placeholder"
38 | # Sizes of the protein features, NUM_RES and NUM_SEQ are allowed as placeholders
39 | # to be replaced with the number of residues and the number of sequences in the
40 | # multiple sequence alignment, respectively.
41 |
42 |
43 | FEATURES = {
44 | #### Static features of a protein sequence ####
45 | "aatype": (tf.float32, [NUM_RES, 21]),
46 | "between_segment_residues": (tf.int64, [NUM_RES, 1]),
47 | "deletion_matrix": (tf.float32, [NUM_SEQ, NUM_RES, 1]),
48 | "domain_name": (tf.string, [1]),
49 | "msa": (tf.int64, [NUM_SEQ, NUM_RES, 1]),
50 | "num_alignments": (tf.int64, [NUM_RES, 1]),
51 | "residue_index": (tf.int64, [NUM_RES, 1]),
52 | "seq_length": (tf.int64, [NUM_RES, 1]),
53 | "sequence": (tf.string, [1]),
54 | "all_atom_positions": (tf.float32,
55 | [NUM_RES, residue_constants.atom_type_num, 3]),
56 | "all_atom_mask": (tf.int64, [NUM_RES, residue_constants.atom_type_num]),
57 | "resolution": (tf.float32, [1]),
58 | "template_domain_names": (tf.string, [NUM_TEMPLATES]),
59 | "template_sum_probs": (tf.float32, [NUM_TEMPLATES, 1]),
60 | "template_aatype": (tf.float32, [NUM_TEMPLATES, NUM_RES, 22]),
61 | "template_all_atom_positions": (tf.float32, [
62 | NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3
63 | ]),
64 | "template_all_atom_masks": (tf.float32, [
65 | NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1
66 | ]),
67 | }
68 |
69 | FEATURE_TYPES = {k: v[0] for k, v in FEATURES.items()}
70 | FEATURE_SIZES = {k: v[1] for k, v in FEATURES.items()}
71 |
72 |
73 | def register_feature(name: str,
74 | type_: tf.dtypes.DType,
75 | shape_: Tuple[Union[str, int]]):
76 | """Register extra features used in custom datasets."""
77 | FEATURES[name] = (type_, shape_)
78 | FEATURE_TYPES[name] = type_
79 | FEATURE_SIZES[name] = shape_
80 |
81 |
82 | def shape(feature_name: str,
83 | num_residues: int,
84 | msa_length: int,
85 | num_templates: Optional[int] = None,
86 | features: Optional[FeaturesMetadata] = None):
87 | """Get the shape for the given feature name.
88 |
89 | This is near identical to _get_tf_shape_no_placeholders() but with 2
90 | differences:
91 | * This method does not calculate a single placeholder from the total number of
92 | elements (eg given and size := 12, this won't deduce NUM_RES
93 | must be 4)
94 | * This method will work with tensors
95 |
96 | Args:
97 | feature_name: String identifier for the feature. If the feature name ends
98 | with "_unnormalized", theis suffix is stripped off.
99 | num_residues: The number of residues in the current domain - some elements
100 | of the shape can be dynamic and will be replaced by this value.
101 | msa_length: The number of sequences in the multiple sequence alignment, some
102 | elements of the shape can be dynamic and will be replaced by this value.
103 | If the number of alignments is unknown / not read, please pass None for
104 | msa_length.
105 | num_templates (optional): The number of templates in this tfexample.
106 | features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES.
107 |
108 | Returns:
109 | List of ints representation the tensor size.
110 |
111 | Raises:
112 | ValueError: If a feature is requested but no concrete placeholder value is
113 | given.
114 | """
115 | features = features or FEATURES
116 | if feature_name.endswith("_unnormalized"):
117 | feature_name = feature_name[:-13]
118 |
119 | unused_dtype, raw_sizes = features[feature_name]
120 | replacements = {NUM_RES: num_residues,
121 | NUM_SEQ: msa_length}
122 |
123 | if num_templates is not None:
124 | replacements[NUM_TEMPLATES] = num_templates
125 |
126 | sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes]
127 | for dimension in sizes:
128 | if isinstance(dimension, str):
129 | raise ValueError("Could not parse %s (shape: %s) with values: %s" % (
130 | feature_name, raw_sizes, replacements))
131 | return sizes
132 |
133 |
--------------------------------------------------------------------------------
/alphafold/model/tf/protein_features_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for protein_features."""
16 | import uuid
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | import tensorflow.compat.v1 as tf
21 |
22 | from alphafold.model.tf import protein_features
23 |
24 |
25 | def _random_bytes():
26 | return str(uuid.uuid4()).encode('utf-8')
27 |
28 |
29 | class FeaturesTest(parameterized.TestCase, tf.test.TestCase):
30 |
31 | def testFeatureNames(self):
32 | self.assertEqual(len(protein_features.FEATURE_SIZES),
33 | len(protein_features.FEATURE_TYPES))
34 | sorted_size_names = sorted(protein_features.FEATURE_SIZES.keys())
35 | sorted_type_names = sorted(protein_features.FEATURE_TYPES.keys())
36 | for i, size_name in enumerate(sorted_size_names):
37 | self.assertEqual(size_name, sorted_type_names[i])
38 |
39 | def testReplacement(self):
40 | for name in protein_features.FEATURE_SIZES.keys():
41 | sizes = protein_features.shape(name,
42 | num_residues=12,
43 | msa_length=24,
44 | num_templates=3)
45 | for x in sizes:
46 | self.assertEqual(type(x), int)
47 | self.assertGreater(x, 0)
48 |
49 |
50 | if __name__ == '__main__':
51 | tf.disable_v2_behavior()
52 | absltest.main()
53 |
--------------------------------------------------------------------------------
/alphafold/model/tf/proteins_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Datasets consisting of proteins."""
16 | from typing import Dict, Mapping, Optional, Sequence
17 |
18 | import numpy as np
19 | import tensorflow.compat.v1 as tf
20 |
21 | from alphafold.model.tf import protein_features
22 |
23 | TensorDict = Dict[str, tf.Tensor]
24 |
25 |
26 | def parse_tfexample(
27 | raw_data: bytes,
28 | features: protein_features.FeaturesMetadata,
29 | key: Optional[str] = None) -> Dict[str, tf.train.Feature]:
30 | """Read a single TF Example proto and return a subset of its features.
31 |
32 | Args:
33 | raw_data: A serialized tf.Example proto.
34 | features: A dictionary of features, mapping string feature names to a tuple
35 | (dtype, shape). This dictionary should be a subset of
36 | protein_features.FEATURES (or the dictionary itself for all features).
37 | key: Optional string with the SSTable key of that tf.Example. This will be
38 | added into features as a 'key' but only if requested in features.
39 |
40 | Returns:
41 | A dictionary of features mapping feature names to features. Only the given
42 | features are returned, all other ones are filtered out.
43 | """
44 | feature_map = {
45 | k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True)
46 | for k, v in features.items()
47 | }
48 | parsed_features = tf.io.parse_single_example(raw_data, feature_map)
49 | reshaped_features = parse_reshape_logic(parsed_features, features, key=key)
50 |
51 | return reshaped_features
52 |
53 |
54 | def _first(tensor: tf.Tensor) -> tf.Tensor:
55 | """Returns the 1st element - the input can be a tensor or a scalar."""
56 | return tf.reshape(tensor, shape=(-1,))[0]
57 |
58 |
59 | def parse_reshape_logic(
60 | parsed_features: TensorDict,
61 | features: protein_features.FeaturesMetadata,
62 | key: Optional[str] = None) -> TensorDict:
63 | """Transforms parsed serial features to the correct shape."""
64 | # Find out what is the number of sequences and the number of alignments.
65 | num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32)
66 |
67 | if "num_alignments" in parsed_features:
68 | num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32)
69 | else:
70 | num_msa = 0
71 |
72 | if "template_domain_names" in parsed_features:
73 | num_templates = tf.cast(
74 | tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32)
75 | else:
76 | num_templates = 0
77 |
78 | if key is not None and "key" in features:
79 | parsed_features["key"] = [key] # Expand dims from () to (1,).
80 |
81 | # Reshape the tensors according to the sequence length and num alignments.
82 | for k, v in parsed_features.items():
83 | new_shape = protein_features.shape(
84 | feature_name=k,
85 | num_residues=num_residues,
86 | msa_length=num_msa,
87 | num_templates=num_templates,
88 | features=features)
89 | new_shape_size = tf.constant(1, dtype=tf.int32)
90 | for dim in new_shape:
91 | new_shape_size *= tf.cast(dim, tf.int32)
92 |
93 | assert_equal = tf.assert_equal(
94 | tf.size(v), new_shape_size,
95 | name="assert_%s_shape_correct" % k,
96 | message="The size of feature %s (%s) could not be reshaped "
97 | "into %s" % (k, tf.size(v), new_shape))
98 | if "template" not in k:
99 | # Make sure the feature we are reshaping is not empty.
100 | assert_non_empty = tf.assert_greater(
101 | tf.size(v), 0, name="assert_%s_non_empty" % k,
102 | message="The feature %s is not set in the tf.Example. Either do not "
103 | "request the feature or use a tf.Example that has the "
104 | "feature set." % k)
105 | with tf.control_dependencies([assert_non_empty, assert_equal]):
106 | parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)
107 | else:
108 | with tf.control_dependencies([assert_equal]):
109 | parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)
110 |
111 | return parsed_features
112 |
113 |
114 | def _make_features_metadata(
115 | feature_names: Sequence[str]) -> protein_features.FeaturesMetadata:
116 | """Makes a feature name to type and shape mapping from a list of names."""
117 | # Make sure these features are always read.
118 | required_features = ["aatype", "sequence", "seq_length"]
119 | feature_names = list(set(feature_names) | set(required_features))
120 |
121 | features_metadata = {name: protein_features.FEATURES[name]
122 | for name in feature_names}
123 | return features_metadata
124 |
125 |
126 | def create_tensor_dict(
127 | raw_data: bytes,
128 | features: Sequence[str],
129 | key: Optional[str] = None,
130 | ) -> TensorDict:
131 | """Creates a dictionary of tensor features.
132 |
133 | Args:
134 | raw_data: A serialized tf.Example proto.
135 | features: A list of strings of feature names to be returned in the dataset.
136 | key: Optional string with the SSTable key of that tf.Example. This will be
137 | added into features as a 'key' but only if requested in features.
138 |
139 | Returns:
140 | A dictionary of features mapping feature names to features. Only the given
141 | features are returned, all other ones are filtered out.
142 | """
143 | features_metadata = _make_features_metadata(features)
144 | return parse_tfexample(raw_data, features_metadata, key)
145 |
146 |
147 | def np_to_tensor_dict(
148 | np_example: Mapping[str, np.ndarray],
149 | features: Sequence[str],
150 | ) -> TensorDict:
151 | """Creates dict of tensors from a dict of NumPy arrays.
152 |
153 | Args:
154 | np_example: A dict of NumPy feature arrays.
155 | features: A list of strings of feature names to be returned in the dataset.
156 |
157 | Returns:
158 | A dictionary of features mapping feature names to features. Only the given
159 | features are returned, all other ones are filtered out.
160 | """
161 | features_metadata = _make_features_metadata(features)
162 | tensor_dict = {k: tf.constant(v) for k, v in np_example.items()
163 | if k in features_metadata}
164 |
165 | # Ensures shapes are as expected. Needed for setting size of empty features
166 | # e.g. when no template hits were found.
167 | tensor_dict = parse_reshape_logic(tensor_dict, features_metadata)
168 | return tensor_dict
169 |
--------------------------------------------------------------------------------
/alphafold/model/tf/shape_helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for dealing with shapes of TensorFlow tensors."""
16 | import tensorflow.compat.v1 as tf
17 |
18 |
19 | def shape_list(x):
20 | """Return list of dimensions of a tensor, statically where possible.
21 |
22 | Like `x.shape.as_list()` but with tensors instead of `None`s.
23 |
24 | Args:
25 | x: A tensor.
26 | Returns:
27 | A list with length equal to the rank of the tensor. The n-th element of the
28 | list is an integer when that dimension is statically known otherwise it is
29 | the n-th element of `tf.shape(x)`.
30 | """
31 | x = tf.convert_to_tensor(x)
32 |
33 | # If unknown rank, return dynamic shape
34 | if x.get_shape().dims is None:
35 | return tf.shape(x)
36 |
37 | static = x.get_shape().as_list()
38 | shape = tf.shape(x)
39 |
40 | ret = []
41 | for i in range(len(static)):
42 | dim = static[i]
43 | if dim is None:
44 | dim = shape[i]
45 | ret.append(dim)
46 | return ret
47 |
48 |
--------------------------------------------------------------------------------
/alphafold/model/tf/shape_helpers_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for shape_helpers."""
16 |
17 | import numpy as np
18 | import tensorflow.compat.v1 as tf
19 |
20 | from alphafold.model.tf import shape_helpers
21 |
22 |
23 | class ShapeTest(tf.test.TestCase):
24 |
25 | def test_shape_list(self):
26 | """Test that shape_list can allow for reshaping to dynamic shapes."""
27 | a = tf.zeros([10, 4, 4, 2])
28 | p = tf.placeholder(tf.float32, shape=[None, None, 1, 4, 4])
29 | shape_dyn = shape_helpers.shape_list(p)[:2] + [4, 4]
30 |
31 | b = tf.reshape(a, shape_dyn)
32 | with self.session() as sess:
33 | out = sess.run(b, feed_dict={p: np.ones((20, 1, 1, 4, 4))})
34 |
35 | self.assertAllEqual(out.shape, (20, 1, 4, 4))
36 |
37 |
38 | if __name__ == '__main__':
39 | tf.disable_v2_behavior()
40 | tf.test.main()
41 |
--------------------------------------------------------------------------------
/alphafold/model/tf/shape_placeholders.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Placeholder values for run-time varying dimension sizes."""
16 |
17 | NUM_RES = 'num residues placeholder'
18 | NUM_MSA_SEQ = 'msa placeholder'
19 | NUM_EXTRA_SEQ = 'extra msa placeholder'
20 | NUM_TEMPLATES = 'num templates placeholder'
21 |
--------------------------------------------------------------------------------
/alphafold/model/tf/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Shared utilities for various components."""
16 | import tensorflow.compat.v1 as tf
17 |
18 |
19 | def tf_combine_mask(*masks):
20 | """Take the intersection of float-valued masks."""
21 | ret = 1
22 | for m in masks:
23 | ret *= m
24 | return ret
25 |
26 |
27 | class SeedMaker(object):
28 | """Return unique seeds."""
29 |
30 | def __init__(self, initial_seed=0):
31 | self.next_seed = initial_seed
32 |
33 | def __call__(self):
34 | i = self.next_seed
35 | self.next_seed += 1
36 | return i
37 |
38 | seed_maker = SeedMaker()
39 |
40 |
41 | def make_random_seed():
42 | return tf.random.uniform([2],
43 | tf.int32.min,
44 | tf.int32.max,
45 | tf.int32,
46 | seed=seed_maker())
47 |
48 |
--------------------------------------------------------------------------------
/alphafold/model/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A collection of JAX utility functions for use in protein folding."""
16 |
17 | import collections
18 | import numbers
19 | from typing import Mapping
20 |
21 | import haiku as hk
22 | import jax
23 | import jax.numpy as jnp
24 | import numpy as np
25 |
26 |
27 | def final_init(config):
28 | if config.zero_init:
29 | return 'zeros'
30 | else:
31 | return 'linear'
32 |
33 |
34 | def batched_gather(params, indices, axis=0, batch_dims=0):
35 | """Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`."""
36 | take_fn = lambda p, i: jnp.take(p, i, axis=axis)
37 | for _ in range(batch_dims):
38 | take_fn = jax.vmap(take_fn)
39 | return take_fn(params, indices)
40 |
41 |
42 | def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10):
43 | """Masked mean."""
44 | if drop_mask_channel:
45 | mask = mask[..., 0]
46 |
47 | mask_shape = mask.shape
48 | value_shape = value.shape
49 |
50 | assert len(mask_shape) == len(value_shape)
51 |
52 | if isinstance(axis, numbers.Integral):
53 | axis = [axis]
54 | elif axis is None:
55 | axis = list(range(len(mask_shape)))
56 | assert isinstance(axis, collections.Iterable), (
57 | 'axis needs to be either an iterable, integer or "None"')
58 |
59 | broadcast_factor = 1.
60 | for axis_ in axis:
61 | value_size = value_shape[axis_]
62 | mask_size = mask_shape[axis_]
63 | if mask_size == 1:
64 | broadcast_factor *= value_size
65 | else:
66 | assert mask_size == value_size
67 |
68 | return (jnp.sum(mask * value, axis=axis) /
69 | (jnp.sum(mask, axis=axis) * broadcast_factor + eps))
70 |
71 |
72 | def flat_params_to_haiku(params: Mapping[str, np.ndarray]) -> hk.Params:
73 | """Convert a dictionary of NumPy arrays to Haiku parameters."""
74 | hk_params = {}
75 | for path, array in params.items():
76 | scope, name = path.split('//')
77 | if scope not in hk_params:
78 | hk_params[scope] = {}
79 | hk_params[scope][name] = jnp.array(array)
80 |
81 | return hk_params
82 |
--------------------------------------------------------------------------------
/alphafold/relax/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Amber relaxation."""
15 |
--------------------------------------------------------------------------------
/alphafold/relax/amber_minimize_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for amber_minimize."""
16 | import os
17 |
18 | from absl.testing import absltest
19 | import numpy as np
20 |
21 | from alphafold.common import protein
22 | from alphafold.relax import amber_minimize
23 | # Internal import (7716).
24 |
25 |
26 | def _load_test_protein(data_path):
27 | pdb_path = os.path.join(absltest.get_default_test_srcdir(), data_path)
28 | with open(pdb_path, 'r') as f:
29 | return protein.from_pdb_string(f.read())
30 |
31 |
32 | class AmberMinimizeTest(absltest.TestCase):
33 |
34 | def test_multiple_disulfides_target(self):
35 | prot = _load_test_protein(
36 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb'
37 | )
38 | ret = amber_minimize.run_pipeline(prot, max_iterations=10, max_attempts=1,
39 | stiffness=10.)
40 | self.assertIn('opt_time', ret)
41 | self.assertIn('min_attempts', ret)
42 |
43 | def test_raises_invalid_protein_assertion(self):
44 | prot = _load_test_protein(
45 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb'
46 | )
47 | prot.atom_mask[4, :] = 0
48 | with self.assertRaisesRegex(
49 | ValueError,
50 | 'Amber minimization can only be performed on proteins with well-defined'
51 | ' residues. This protein contains at least one residue with no atoms.'):
52 | amber_minimize.run_pipeline(prot, max_iterations=10,
53 | stiffness=1.,
54 | max_attempts=1)
55 |
56 | def test_iterative_relax(self):
57 | prot = _load_test_protein(
58 | 'alphafold/relax/testdata/with_violations.pdb'
59 | )
60 | violations = amber_minimize.get_violation_metrics(prot)
61 | self.assertGreater(violations['num_residue_violations'], 0)
62 | out = amber_minimize.run_pipeline(
63 | prot=prot, max_outer_iterations=10, stiffness=10.)
64 | self.assertLess(out['efinal'], out['einit'])
65 | self.assertEqual(0, out['num_residue_violations'])
66 |
67 | def test_find_violations(self):
68 | prot = _load_test_protein(
69 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb'
70 | )
71 | viols, _ = amber_minimize.find_violations(prot)
72 |
73 | expected_between_residues_connection_mask = np.zeros((191,), np.float32)
74 | for residue in (42, 43, 59, 60, 135, 136):
75 | expected_between_residues_connection_mask[residue] = 1.0
76 |
77 | expected_clash_indices = np.array([
78 | [8, 4],
79 | [8, 5],
80 | [13, 3],
81 | [14, 1],
82 | [14, 4],
83 | [26, 4],
84 | [26, 5],
85 | [31, 8],
86 | [31, 10],
87 | [39, 0],
88 | [39, 1],
89 | [39, 2],
90 | [39, 3],
91 | [39, 4],
92 | [42, 5],
93 | [42, 6],
94 | [42, 7],
95 | [42, 8],
96 | [47, 7],
97 | [47, 8],
98 | [47, 9],
99 | [47, 10],
100 | [64, 4],
101 | [85, 5],
102 | [102, 4],
103 | [102, 5],
104 | [109, 13],
105 | [111, 5],
106 | [118, 6],
107 | [118, 7],
108 | [118, 8],
109 | [124, 4],
110 | [124, 5],
111 | [131, 5],
112 | [139, 7],
113 | [147, 4],
114 | [152, 7]], dtype=np.int32)
115 | expected_between_residues_clash_mask = np.zeros([191, 14])
116 | expected_between_residues_clash_mask[expected_clash_indices[:, 0],
117 | expected_clash_indices[:, 1]] += 1
118 | expected_per_atom_violations = np.zeros([191, 14])
119 | np.testing.assert_array_equal(
120 | viols['between_residues']['connections_per_residue_violation_mask'],
121 | expected_between_residues_connection_mask)
122 | np.testing.assert_array_equal(
123 | viols['between_residues']['clashes_per_atom_clash_mask'],
124 | expected_between_residues_clash_mask)
125 | np.testing.assert_array_equal(
126 | viols['within_residues']['per_atom_violations'],
127 | expected_per_atom_violations)
128 |
129 |
130 | if __name__ == '__main__':
131 | absltest.main()
132 |
--------------------------------------------------------------------------------
/alphafold/relax/cleanup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations.
16 |
17 | fix_pdb uses a third-party tool. We also support fixing some additional edge
18 | cases like removing chains of length one (see clean_structure).
19 | """
20 | import io
21 |
22 | import pdbfixer
23 | from simtk.openmm import app
24 | from simtk.openmm.app import element
25 |
26 |
27 | def fix_pdb(pdbfile, alterations_info):
28 | """Apply pdbfixer to the contents of a PDB file; return a PDB string result.
29 |
30 | 1) Replaces nonstandard residues.
31 | 2) Removes heterogens (non protein residues) including water.
32 | 3) Adds missing residues and missing atoms within existing residues.
33 | 4) Adds hydrogens assuming pH=7.0.
34 | 5) KeepIds is currently true, so the fixer must keep the existing chain and
35 | residue identifiers. This will fail for some files in wider PDB that have
36 | invalid IDs.
37 |
38 | Args:
39 | pdbfile: Input PDB file handle.
40 | alterations_info: A dict that will store details of changes made.
41 |
42 | Returns:
43 | A PDB string representing the fixed structure.
44 | """
45 | fixer = pdbfixer.PDBFixer(pdbfile=pdbfile)
46 | fixer.findNonstandardResidues()
47 | alterations_info['nonstandard_residues'] = fixer.nonstandardResidues
48 | fixer.replaceNonstandardResidues()
49 | _remove_heterogens(fixer, alterations_info, keep_water=False)
50 | fixer.findMissingResidues()
51 | alterations_info['missing_residues'] = fixer.missingResidues
52 | fixer.findMissingAtoms()
53 | alterations_info['missing_heavy_atoms'] = fixer.missingAtoms
54 | alterations_info['missing_terminals'] = fixer.missingTerminals
55 | fixer.addMissingAtoms(seed=0)
56 | fixer.addMissingHydrogens()
57 | out_handle = io.StringIO()
58 | app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle,
59 | keepIds=True)
60 | return out_handle.getvalue()
61 |
62 |
63 | def clean_structure(pdb_structure, alterations_info):
64 | """Applies additional fixes to an OpenMM structure, to handle edge cases.
65 |
66 | Args:
67 | pdb_structure: An OpenMM structure to modify and fix.
68 | alterations_info: A dict that will store details of changes made.
69 | """
70 | _replace_met_se(pdb_structure, alterations_info)
71 | _remove_chains_of_length_one(pdb_structure, alterations_info)
72 |
73 |
74 | def _remove_heterogens(fixer, alterations_info, keep_water):
75 | """Removes the residues that Pdbfixer considers to be heterogens.
76 |
77 | Args:
78 | fixer: A Pdbfixer instance.
79 | alterations_info: A dict that will store details of changes made.
80 | keep_water: If True, water (HOH) is not considered to be a heterogen.
81 | """
82 | initial_resnames = set()
83 | for chain in fixer.topology.chains():
84 | for residue in chain.residues():
85 | initial_resnames.add(residue.name)
86 | fixer.removeHeterogens(keepWater=keep_water)
87 | final_resnames = set()
88 | for chain in fixer.topology.chains():
89 | for residue in chain.residues():
90 | final_resnames.add(residue.name)
91 | alterations_info['removed_heterogens'] = (
92 | initial_resnames.difference(final_resnames))
93 |
94 |
95 | def _replace_met_se(pdb_structure, alterations_info):
96 | """Replace the Se in any MET residues that were not marked as modified."""
97 | modified_met_residues = []
98 | for res in pdb_structure.iter_residues():
99 | name = res.get_name_with_spaces().strip()
100 | if name == 'MET':
101 | s_atom = res.get_atom('SD')
102 | if s_atom.element_symbol == 'Se':
103 | s_atom.element_symbol = 'S'
104 | s_atom.element = element.get_by_symbol('S')
105 | modified_met_residues.append(s_atom.residue_number)
106 | alterations_info['Se_in_MET'] = modified_met_residues
107 |
108 |
109 | def _remove_chains_of_length_one(pdb_structure, alterations_info):
110 | """Removes chains that correspond to a single amino acid.
111 |
112 | A single amino acid in a chain is both N and C terminus. There is no force
113 | template for this case.
114 |
115 | Args:
116 | pdb_structure: An OpenMM pdb_structure to modify and fix.
117 | alterations_info: A dict that will store details of changes made.
118 | """
119 | removed_chains = {}
120 | for model in pdb_structure.iter_models():
121 | valid_chains = [c for c in model.iter_chains() if len(c) > 1]
122 | invalid_chain_ids = [c.chain_id for c in model.iter_chains() if len(c) <= 1]
123 | model.chains = valid_chains
124 | for chain_id in invalid_chain_ids:
125 | model.chains_by_id.pop(chain_id)
126 | removed_chains[model.number] = invalid_chain_ids
127 | alterations_info['removed_chains'] = removed_chains
128 |
--------------------------------------------------------------------------------
/alphafold/relax/cleanup_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for relax.cleanup."""
16 | import io
17 |
18 | from absl.testing import absltest
19 | from simtk.openmm.app.internal import pdbstructure
20 |
21 | from alphafold.relax import cleanup
22 |
23 |
24 | def _pdb_to_structure(pdb_str):
25 | handle = io.StringIO(pdb_str)
26 | return pdbstructure.PdbStructure(handle)
27 |
28 |
29 | def _lines_to_structure(pdb_lines):
30 | return _pdb_to_structure('\n'.join(pdb_lines))
31 |
32 |
33 | class CleanupTest(absltest.TestCase):
34 |
35 | def test_missing_residues(self):
36 | pdb_lines = ['SEQRES 1 C 3 CYS GLY LEU',
37 | 'ATOM 1 N CYS C 1 -12.262 20.115 60.959 1.00 '
38 | '19.08 N',
39 | 'ATOM 2 CA CYS C 1 -11.065 20.934 60.773 1.00 '
40 | '17.23 C',
41 | 'ATOM 3 C CYS C 1 -10.002 20.742 61.844 1.00 '
42 | '15.38 C',
43 | 'ATOM 4 O CYS C 1 -10.284 20.225 62.929 1.00 '
44 | '16.04 O',
45 | 'ATOM 5 N LEU C 3 -7.688 18.700 62.045 1.00 '
46 | '14.75 N',
47 | 'ATOM 6 CA LEU C 3 -7.256 17.320 62.234 1.00 '
48 | '16.81 C',
49 | 'ATOM 7 C LEU C 3 -6.380 16.864 61.070 1.00 '
50 | '16.95 C',
51 | 'ATOM 8 O LEU C 3 -6.551 17.332 59.947 1.00 '
52 | '16.97 O']
53 | input_handle = io.StringIO('\n'.join(pdb_lines))
54 | alterations = {}
55 | result = cleanup.fix_pdb(input_handle, alterations)
56 | structure = _pdb_to_structure(result)
57 | residue_names = [r.get_name() for r in structure.iter_residues()]
58 | self.assertCountEqual(residue_names, ['CYS', 'GLY', 'LEU'])
59 | self.assertCountEqual(alterations['missing_residues'].values(), [['GLY']])
60 |
61 | def test_missing_atoms(self):
62 | pdb_lines = ['SEQRES 1 A 1 PRO',
63 | 'ATOM 1 CA PRO A 1 1.000 1.000 1.000 1.00 '
64 | ' 0.00 C']
65 | input_handle = io.StringIO('\n'.join(pdb_lines))
66 | alterations = {}
67 | result = cleanup.fix_pdb(input_handle, alterations)
68 | structure = _pdb_to_structure(result)
69 | atom_names = [a.get_name() for a in structure.iter_atoms()]
70 | self.assertCountEqual(atom_names, ['N', 'CD', 'HD2', 'HD3', 'CG', 'HG2',
71 | 'HG3', 'CB', 'HB2', 'HB3', 'CA', 'HA',
72 | 'C', 'O', 'H2', 'H3', 'OXT'])
73 | missing_atoms_by_residue = list(alterations['missing_heavy_atoms'].values())
74 | self.assertLen(missing_atoms_by_residue, 1)
75 | atoms_added = [a.name for a in missing_atoms_by_residue[0]]
76 | self.assertCountEqual(atoms_added, ['N', 'CD', 'CG', 'CB', 'C', 'O'])
77 | missing_terminals_by_residue = alterations['missing_terminals']
78 | self.assertLen(missing_terminals_by_residue, 1)
79 | has_missing_terminal = [r.name for r in missing_terminals_by_residue.keys()]
80 | self.assertCountEqual(has_missing_terminal, ['PRO'])
81 | self.assertCountEqual([t for t in missing_terminals_by_residue.values()],
82 | [['OXT']])
83 |
84 | def test_remove_heterogens(self):
85 | pdb_lines = ['SEQRES 1 A 1 GLY',
86 | 'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 '
87 | ' 0.00 C',
88 | 'ATOM 2 O HOH A 2 0.000 0.000 0.000 1.00 '
89 | ' 0.00 O']
90 | input_handle = io.StringIO('\n'.join(pdb_lines))
91 | alterations = {}
92 | result = cleanup.fix_pdb(input_handle, alterations)
93 | structure = _pdb_to_structure(result)
94 | self.assertCountEqual([res.get_name() for res in structure.iter_residues()],
95 | ['GLY'])
96 | self.assertEqual(alterations['removed_heterogens'], set(['HOH']))
97 |
98 | def test_fix_nonstandard_residues(self):
99 | pdb_lines = ['SEQRES 1 A 1 DAL',
100 | 'ATOM 1 CA DAL A 1 0.000 0.000 0.000 1.00 '
101 | ' 0.00 C']
102 | input_handle = io.StringIO('\n'.join(pdb_lines))
103 | alterations = {}
104 | result = cleanup.fix_pdb(input_handle, alterations)
105 | structure = _pdb_to_structure(result)
106 | residue_names = [res.get_name() for res in structure.iter_residues()]
107 | self.assertCountEqual(residue_names, ['ALA'])
108 | self.assertLen(alterations['nonstandard_residues'], 1)
109 | original_res, new_name = alterations['nonstandard_residues'][0]
110 | self.assertEqual(original_res.id, '1')
111 | self.assertEqual(new_name, 'ALA')
112 |
113 | def test_replace_met_se(self):
114 | pdb_lines = ['SEQRES 1 A 1 MET',
115 | 'ATOM 1 SD MET A 1 0.000 0.000 0.000 1.00 '
116 | ' 0.00 Se']
117 | structure = _lines_to_structure(pdb_lines)
118 | alterations = {}
119 | cleanup._replace_met_se(structure, alterations)
120 | sd = [a for a in structure.iter_atoms() if a.get_name() == 'SD']
121 | self.assertLen(sd, 1)
122 | self.assertEqual(sd[0].element_symbol, 'S')
123 | self.assertCountEqual(alterations['Se_in_MET'], [sd[0].residue_number])
124 |
125 | def test_remove_chains_of_length_one(self):
126 | pdb_lines = ['SEQRES 1 A 1 GLY',
127 | 'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 '
128 | ' 0.00 C']
129 | structure = _lines_to_structure(pdb_lines)
130 | alterations = {}
131 | cleanup._remove_chains_of_length_one(structure, alterations)
132 | chains = list(structure.iter_chains())
133 | self.assertEmpty(chains)
134 | self.assertCountEqual(alterations['removed_chains'].values(), [['A']])
135 |
136 |
137 | if __name__ == '__main__':
138 | absltest.main()
139 |
--------------------------------------------------------------------------------
/alphafold/relax/relax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Amber relaxation."""
16 | from typing import Any, Dict, Sequence, Tuple
17 |
18 | import numpy as np
19 |
20 | from alphafold.common import protein
21 | from alphafold.relax import amber_minimize
22 | from alphafold.relax import utils
23 |
24 |
25 | class AmberRelaxation(object):
26 | """Amber relaxation."""
27 |
28 | def __init__(self,
29 | *,
30 | max_iterations: int,
31 | tolerance: float,
32 | stiffness: float,
33 | exclude_residues: Sequence[int],
34 | max_outer_iterations: int):
35 | """Initialize Amber Relaxer.
36 |
37 | Args:
38 | max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
39 | tolerance: kcal/mol, the energy tolerance of L-BFGS.
40 | stiffness: kcal/mol A**2, spring constant of heavy atom restraining
41 | potential.
42 | exclude_residues: Residues to exclude from per-atom restraining.
43 | Zero-indexed.
44 | max_outer_iterations: Maximum number of violation-informed relax
45 | iterations. A value of 1 will run the non-iterative procedure used in
46 | CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
47 | as soon as there are no violations, hence in most cases this causes no
48 | slowdown. In the worst case we do 20 outer iterations.
49 | """
50 |
51 | self._max_iterations = max_iterations
52 | self._tolerance = tolerance
53 | self._stiffness = stiffness
54 | self._exclude_residues = exclude_residues
55 | self._max_outer_iterations = max_outer_iterations
56 |
57 | def process(self, *,
58 | prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]:
59 | """Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
60 | out = amber_minimize.run_pipeline(
61 | prot=prot, max_iterations=self._max_iterations,
62 | tolerance=self._tolerance, stiffness=self._stiffness,
63 | exclude_residues=self._exclude_residues,
64 | max_outer_iterations=self._max_outer_iterations)
65 | min_pos = out['pos']
66 | start_pos = out['posinit']
67 | rmsd = np.sqrt(np.sum((start_pos - min_pos)**2) / start_pos.shape[0])
68 | debug_data = {
69 | 'initial_energy': out['einit'],
70 | 'final_energy': out['efinal'],
71 | 'attempts': out['min_attempts'],
72 | 'rmsd': rmsd
73 | }
74 | pdb_str = amber_minimize.clean_protein(prot)
75 | min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
76 | min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
77 | utils.assert_equal_nonterminal_atom_types(
78 | protein.from_pdb_string(min_pdb).atom_mask,
79 | prot.atom_mask)
80 | violations = out['structural_violations'][
81 | 'total_per_residue_violations_mask']
82 | return min_pdb, debug_data, violations
83 |
84 |
--------------------------------------------------------------------------------
/alphafold/relax/relax_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for relax."""
16 | import os
17 |
18 | from absl.testing import absltest
19 | import numpy as np
20 | from alphafold.common import protein
21 | from alphafold.relax import relax
22 | # Internal import (7716).
23 |
24 |
25 | class RunAmberRelaxTest(absltest.TestCase):
26 |
27 | def setUp(self):
28 | super().setUp()
29 | self.test_dir = os.path.join(
30 | absltest.get_default_test_srcdir(),
31 | 'alphafold/relax/testdata/')
32 | self.test_config = {
33 | 'max_iterations': 1,
34 | 'tolerance': 2.39,
35 | 'stiffness': 10.0,
36 | 'exclude_residues': [],
37 | 'max_outer_iterations': 1}
38 |
39 | def test_process(self):
40 | amber_relax = relax.AmberRelaxation(**self.test_config)
41 |
42 | with open(os.path.join(self.test_dir, 'model_output.pdb')) as f:
43 | test_prot = protein.from_pdb_string(f.read())
44 | pdb_min, debug_info, num_violations = amber_relax.process(prot=test_prot)
45 |
46 | self.assertCountEqual(debug_info.keys(),
47 | set({'initial_energy', 'final_energy',
48 | 'attempts', 'rmsd'}))
49 | self.assertLess(debug_info['final_energy'], debug_info['initial_energy'])
50 | self.assertGreater(debug_info['rmsd'], 0)
51 |
52 | prot_min = protein.from_pdb_string(pdb_min)
53 | # Most protein properties should be unchanged.
54 | np.testing.assert_almost_equal(test_prot.aatype, prot_min.aatype)
55 | np.testing.assert_almost_equal(test_prot.residue_index,
56 | prot_min.residue_index)
57 | # Atom mask and bfactors identical except for terminal OXT of last residue.
58 | np.testing.assert_almost_equal(test_prot.atom_mask[:-1, :],
59 | prot_min.atom_mask[:-1, :])
60 | np.testing.assert_almost_equal(test_prot.b_factors[:-1, :],
61 | prot_min.b_factors[:-1, :])
62 | np.testing.assert_almost_equal(test_prot.atom_mask[:, :-1],
63 | prot_min.atom_mask[:, :-1])
64 | np.testing.assert_almost_equal(test_prot.b_factors[:, :-1],
65 | prot_min.b_factors[:, :-1])
66 | # There are no residues with violations.
67 | np.testing.assert_equal(num_violations, np.zeros_like(num_violations))
68 |
69 | def test_unresolved_violations(self):
70 | amber_relax = relax.AmberRelaxation(**self.test_config)
71 | with open(os.path.join(self.test_dir,
72 | 'with_violations_casp14.pdb')) as f:
73 | test_prot = protein.from_pdb_string(f.read())
74 | _, _, num_violations = amber_relax.process(prot=test_prot)
75 | exp_num_violations = np.array(
76 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
77 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,
78 | 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
79 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
80 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
81 | 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
82 | 0, 0, 0, 0])
83 | # Check no violations were added. Can't check exactly due to stochasticity.
84 | self.assertTrue(np.all(num_violations <= exp_num_violations))
85 |
86 |
87 | if __name__ == '__main__':
88 | absltest.main()
89 |
--------------------------------------------------------------------------------
/alphafold/relax/testdata/model_output.pdb:
--------------------------------------------------------------------------------
1 | ATOM 1 C MET A 1 1.921 -46.152 7.786 1.00 4.39 C
2 | ATOM 2 CA MET A 1 1.631 -46.829 9.131 1.00 4.39 C
3 | ATOM 3 CB MET A 1 2.759 -47.768 9.578 1.00 4.39 C
4 | ATOM 4 CE MET A 1 3.466 -49.770 13.198 1.00 4.39 C
5 | ATOM 5 CG MET A 1 2.581 -48.221 11.034 1.00 4.39 C
6 | ATOM 6 H MET A 1 0.234 -48.249 8.549 1.00 4.39 H
7 | ATOM 7 H2 MET A 1 -0.424 -46.789 8.952 1.00 4.39 H
8 | ATOM 8 H3 MET A 1 0.111 -47.796 10.118 1.00 4.39 H
9 | ATOM 9 HA MET A 1 1.628 -46.009 9.849 1.00 4.39 H
10 | ATOM 10 HB2 MET A 1 3.701 -47.225 9.500 1.00 4.39 H
11 | ATOM 11 HB3 MET A 1 2.807 -48.640 8.926 1.00 4.39 H
12 | ATOM 12 HE1 MET A 1 2.747 -50.537 12.910 1.00 4.39 H
13 | ATOM 13 HE2 MET A 1 4.296 -50.241 13.725 1.00 4.39 H
14 | ATOM 14 HE3 MET A 1 2.988 -49.052 13.864 1.00 4.39 H
15 | ATOM 15 HG2 MET A 1 1.791 -48.971 11.083 1.00 4.39 H
16 | ATOM 16 HG3 MET A 1 2.295 -47.368 11.650 1.00 4.39 H
17 | ATOM 17 N MET A 1 0.291 -47.464 9.182 1.00 4.39 N
18 | ATOM 18 O MET A 1 2.091 -44.945 7.799 1.00 4.39 O
19 | ATOM 19 SD MET A 1 4.096 -48.921 11.725 1.00 4.39 S
20 | ATOM 20 C LYS A 2 1.366 -45.033 4.898 1.00 2.92 C
21 | ATOM 21 CA LYS A 2 2.235 -46.242 5.308 1.00 2.92 C
22 | ATOM 22 CB LYS A 2 2.206 -47.314 4.196 1.00 2.92 C
23 | ATOM 23 CD LYS A 2 3.331 -49.342 3.134 1.00 2.92 C
24 | ATOM 24 CE LYS A 2 4.434 -50.403 3.293 1.00 2.92 C
25 | ATOM 25 CG LYS A 2 3.294 -48.395 4.349 1.00 2.92 C
26 | ATOM 26 H LYS A 2 1.832 -47.853 6.656 1.00 2.92 H
27 | ATOM 27 HA LYS A 2 3.248 -45.841 5.355 1.00 2.92 H
28 | ATOM 28 HB2 LYS A 2 1.223 -47.785 4.167 1.00 2.92 H
29 | ATOM 29 HB3 LYS A 2 2.363 -46.812 3.241 1.00 2.92 H
30 | ATOM 30 HD2 LYS A 2 3.524 -48.754 2.237 1.00 2.92 H
31 | ATOM 31 HD3 LYS A 2 2.364 -49.833 3.031 1.00 2.92 H
32 | ATOM 32 HE2 LYS A 2 5.383 -49.891 3.455 1.00 2.92 H
33 | ATOM 33 HE3 LYS A 2 4.225 -51.000 4.180 1.00 2.92 H
34 | ATOM 34 HG2 LYS A 2 3.102 -48.977 5.250 1.00 2.92 H
35 | ATOM 35 HG3 LYS A 2 4.264 -47.909 4.446 1.00 2.92 H
36 | ATOM 36 HZ1 LYS A 2 4.763 -50.747 1.274 1.00 2.92 H
37 | ATOM 37 HZ2 LYS A 2 3.681 -51.785 1.931 1.00 2.92 H
38 | ATOM 38 HZ3 LYS A 2 5.280 -51.965 2.224 1.00 2.92 H
39 | ATOM 39 N LYS A 2 1.907 -46.846 6.629 1.00 2.92 N
40 | ATOM 40 NZ LYS A 2 4.542 -51.286 2.100 1.00 2.92 N
41 | ATOM 41 O LYS A 2 1.882 -44.093 4.312 1.00 2.92 O
42 | ATOM 42 C PHE A 3 -0.511 -42.597 5.624 1.00 4.39 C
43 | ATOM 43 CA PHE A 3 -0.853 -43.933 4.929 1.00 4.39 C
44 | ATOM 44 CB PHE A 3 -2.271 -44.408 5.285 1.00 4.39 C
45 | ATOM 45 CD1 PHE A 3 -3.760 -43.542 3.432 1.00 4.39 C
46 | ATOM 46 CD2 PHE A 3 -4.050 -42.638 5.675 1.00 4.39 C
47 | ATOM 47 CE1 PHE A 3 -4.797 -42.715 2.965 1.00 4.39 C
48 | ATOM 48 CE2 PHE A 3 -5.091 -41.818 5.207 1.00 4.39 C
49 | ATOM 49 CG PHE A 3 -3.382 -43.505 4.788 1.00 4.39 C
50 | ATOM 50 CZ PHE A 3 -5.463 -41.853 3.853 1.00 4.39 C
51 | ATOM 51 H PHE A 3 -0.311 -45.868 5.655 1.00 4.39 H
52 | ATOM 52 HA PHE A 3 -0.817 -43.746 3.856 1.00 4.39 H
53 | ATOM 53 HB2 PHE A 3 -2.353 -44.512 6.367 1.00 4.39 H
54 | ATOM 54 HB3 PHE A 3 -2.432 -45.393 4.848 1.00 4.39 H
55 | ATOM 55 HD1 PHE A 3 -3.255 -44.198 2.739 1.00 4.39 H
56 | ATOM 56 HD2 PHE A 3 -3.768 -42.590 6.716 1.00 4.39 H
57 | ATOM 57 HE1 PHE A 3 -5.083 -42.735 1.923 1.00 4.39 H
58 | ATOM 58 HE2 PHE A 3 -5.604 -41.151 5.885 1.00 4.39 H
59 | ATOM 59 HZ PHE A 3 -6.257 -41.215 3.493 1.00 4.39 H
60 | ATOM 60 N PHE A 3 0.079 -45.027 5.253 1.00 4.39 N
61 | ATOM 61 O PHE A 3 -0.633 -41.541 5.014 1.00 4.39 O
62 | ATOM 62 C LEU A 4 1.598 -40.732 7.042 1.00 4.39 C
63 | ATOM 63 CA LEU A 4 0.367 -41.437 7.633 1.00 4.39 C
64 | ATOM 64 CB LEU A 4 0.628 -41.823 9.104 1.00 4.39 C
65 | ATOM 65 CD1 LEU A 4 -0.319 -42.778 11.228 1.00 4.39 C
66 | ATOM 66 CD2 LEU A 4 -1.300 -40.694 10.309 1.00 4.39 C
67 | ATOM 67 CG LEU A 4 -0.650 -42.027 9.937 1.00 4.39 C
68 | ATOM 68 H LEU A 4 0.163 -43.538 7.292 1.00 4.39 H
69 | ATOM 69 HA LEU A 4 -0.445 -40.712 7.588 1.00 4.39 H
70 | ATOM 70 HB2 LEU A 4 1.213 -41.034 9.576 1.00 4.39 H
71 | ATOM 71 HB3 LEU A 4 1.235 -42.728 9.127 1.00 4.39 H
72 | ATOM 72 HD11 LEU A 4 0.380 -42.191 11.824 1.00 4.39 H
73 | ATOM 73 HD12 LEU A 4 0.127 -43.747 11.002 1.00 4.39 H
74 | ATOM 74 HD13 LEU A 4 -1.230 -42.927 11.808 1.00 4.39 H
75 | ATOM 75 HD21 LEU A 4 -0.606 -40.080 10.883 1.00 4.39 H
76 | ATOM 76 HD22 LEU A 4 -2.193 -40.869 10.909 1.00 4.39 H
77 | ATOM 77 HD23 LEU A 4 -1.593 -40.147 9.413 1.00 4.39 H
78 | ATOM 78 HG LEU A 4 -1.359 -42.630 9.370 1.00 4.39 H
79 | ATOM 79 N LEU A 4 -0.012 -42.638 6.869 1.00 4.39 N
80 | ATOM 80 O LEU A 4 1.655 -39.508 7.028 1.00 4.39 O
81 | ATOM 81 C VAL A 5 3.372 -40.190 4.573 1.00 4.39 C
82 | ATOM 82 CA VAL A 5 3.752 -40.956 5.845 1.00 4.39 C
83 | ATOM 83 CB VAL A 5 4.757 -42.083 5.528 1.00 4.39 C
84 | ATOM 84 CG1 VAL A 5 6.019 -41.568 4.827 1.00 4.39 C
85 | ATOM 85 CG2 VAL A 5 5.199 -42.807 6.810 1.00 4.39 C
86 | ATOM 86 H VAL A 5 2.440 -42.503 6.548 1.00 4.39 H
87 | ATOM 87 HA VAL A 5 4.234 -40.242 6.512 1.00 4.39 H
88 | ATOM 88 HB VAL A 5 4.279 -42.813 4.875 1.00 4.39 H
89 | ATOM 89 HG11 VAL A 5 6.494 -40.795 5.431 1.00 4.39 H
90 | ATOM 90 HG12 VAL A 5 5.770 -41.145 3.853 1.00 4.39 H
91 | ATOM 91 HG13 VAL A 5 6.725 -42.383 4.670 1.00 4.39 H
92 | ATOM 92 HG21 VAL A 5 4.347 -43.283 7.297 1.00 4.39 H
93 | ATOM 93 HG22 VAL A 5 5.933 -43.575 6.568 1.00 4.39 H
94 | ATOM 94 HG23 VAL A 5 5.651 -42.093 7.498 1.00 4.39 H
95 | ATOM 95 N VAL A 5 2.554 -41.501 6.509 1.00 4.39 N
96 | ATOM 96 O VAL A 5 3.937 -39.138 4.297 1.00 4.39 O
97 | TER 96 VAL A 5
98 | END
99 |
--------------------------------------------------------------------------------
/alphafold/relax/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utils for minimization."""
16 | import io
17 |
18 | from Bio import PDB
19 | import numpy as np
20 | from simtk.openmm import app as openmm_app
21 | from simtk.openmm.app.internal.pdbstructure import PdbStructure
22 |
23 | from alphafold.common import residue_constants
24 |
25 |
26 | def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
27 | pdb_file = io.StringIO(pdb_str)
28 | structure = PdbStructure(pdb_file)
29 | topology = openmm_app.PDBFile(structure).getTopology()
30 | with io.StringIO() as f:
31 | openmm_app.PDBFile.writeFile(topology, pos, f)
32 | return f.getvalue()
33 |
34 |
35 | def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
36 | """Overwrites the B-factors in pdb_str with contents of bfactors array.
37 |
38 | Args:
39 | pdb_str: An input PDB string.
40 | bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
41 | B-factors are per residue; i.e. that the nonzero entries are identical in
42 | [0, i, :].
43 |
44 | Returns:
45 | A new PDB string with the B-factors replaced.
46 | """
47 | if bfactors.shape[-1] != residue_constants.atom_type_num:
48 | raise ValueError(
49 | f'Invalid final dimension size for bfactors: {bfactors.shape[-1]}.')
50 |
51 | parser = PDB.PDBParser(QUIET=True)
52 | handle = io.StringIO(pdb_str)
53 | structure = parser.get_structure('', handle)
54 |
55 | curr_resid = ('', '', '')
56 | idx = -1
57 | for atom in structure.get_atoms():
58 | atom_resid = atom.parent.get_id()
59 | if atom_resid != curr_resid:
60 | idx += 1
61 | if idx >= bfactors.shape[0]:
62 | raise ValueError('Index into bfactors exceeds number of residues. '
63 | 'B-factors shape: {shape}, idx: {idx}.')
64 | curr_resid = atom_resid
65 | atom.bfactor = bfactors[idx, residue_constants.atom_order['CA']]
66 |
67 | new_pdb = io.StringIO()
68 | pdb_io = PDB.PDBIO()
69 | pdb_io.set_structure(structure)
70 | pdb_io.save(new_pdb)
71 | return new_pdb.getvalue()
72 |
73 |
74 | def assert_equal_nonterminal_atom_types(
75 | atom_mask: np.ndarray, ref_atom_mask: np.ndarray):
76 | """Checks that pre- and post-minimized proteins have same atom set."""
77 | # Ignore any terminal OXT atoms which may have been added by minimization.
78 | oxt = residue_constants.atom_order['OXT']
79 | no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
80 | no_oxt_mask[..., oxt] = False
81 | np.testing.assert_almost_equal(ref_atom_mask[no_oxt_mask],
82 | atom_mask[no_oxt_mask])
83 |
84 |
--------------------------------------------------------------------------------
/alphafold/relax/utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for utils."""
16 |
17 | import os
18 |
19 | from absl.testing import absltest
20 | import numpy as np
21 |
22 | from alphafold.common import protein
23 | from alphafold.relax import utils
24 | # Internal import (7716).
25 |
26 |
27 | class UtilsTest(absltest.TestCase):
28 |
29 | def test_overwrite_b_factors(self):
30 | testdir = os.path.join(
31 | absltest.get_default_test_srcdir(),
32 | 'alphafold/relax/testdata/'
33 | 'multiple_disulfides_target.pdb')
34 | with open(testdir) as f:
35 | test_pdb = f.read()
36 | n_residues = 191
37 | bfactors = np.stack([np.arange(0, n_residues)] * 37, axis=-1)
38 |
39 | output_pdb = utils.overwrite_b_factors(test_pdb, bfactors)
40 |
41 | # Check that the atom lines are unchanged apart from the B-factors.
42 | atom_lines_original = [l for l in test_pdb.split('\n') if l[:4] == ('ATOM')]
43 | atom_lines_new = [l for l in output_pdb.split('\n') if l[:4] == ('ATOM')]
44 | for line_original, line_new in zip(atom_lines_original, atom_lines_new):
45 | self.assertEqual(line_original[:60].strip(), line_new[:60].strip())
46 | self.assertEqual(line_original[66:].strip(), line_new[66:].strip())
47 |
48 | # Check B-factors are correctly set for all atoms present.
49 | as_protein = protein.from_pdb_string(output_pdb)
50 | np.testing.assert_almost_equal(
51 | np.where(as_protein.atom_mask > 0, as_protein.b_factors, 0),
52 | np.where(as_protein.atom_mask > 0, bfactors, 0))
53 |
54 |
55 | if __name__ == '__main__':
56 | absltest.main()
57 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | ARG CUDA=11.0
16 | FROM nvidia/cuda:${CUDA}-base
17 | # FROM directive resets ARGS, so we specify again (the value is retained if
18 | # previously set).
19 | ARG CUDA
20 |
21 | # Use bash to support string substitution.
22 | SHELL ["/bin/bash", "-c"]
23 |
24 | RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
25 | build-essential \
26 | cmake \
27 | cuda-command-line-tools-${CUDA/./-} \
28 | git \
29 | hmmer \
30 | kalign \
31 | tzdata \
32 | wget \
33 | && rm -rf /var/lib/apt/lists/*
34 |
35 | # Compile HHsuite from source.
36 | RUN git clone --branch v3.3.0 https://github.com/soedinglab/hh-suite.git /tmp/hh-suite \
37 | && mkdir /tmp/hh-suite/build
38 | WORKDIR /tmp/hh-suite/build
39 | RUN cmake -DCMAKE_INSTALL_PREFIX=/opt/hhsuite .. \
40 | && make -j 4 && make install \
41 | && ln -s /opt/hhsuite/bin/* /usr/bin \
42 | && rm -rf /tmp/hh-suite
43 |
44 | # Install Miniconda package manger.
45 | RUN wget -q -P /tmp \
46 | https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
47 | && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \
48 | && rm /tmp/Miniconda3-latest-Linux-x86_64.sh
49 |
50 | # Install conda packages.
51 | ENV PATH="/opt/conda/bin:$PATH"
52 | RUN conda update -qy conda \
53 | && conda install -y -c conda-forge \
54 | openmm=7.5.1 \
55 | cudatoolkit==${CUDA}.3 \
56 | pdbfixer \
57 | pip \
58 | python=3.7
59 |
60 | COPY . /app/alphafold
61 | RUN wget -q -P /app/alphafold/alphafold/common/ \
62 | https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt
63 |
64 | # Install pip packages.
65 | RUN pip3 install --upgrade pip \
66 | && pip3 install -r /app/alphafold/requirements.txt \
67 | && pip3 install --upgrade jax jaxlib==0.1.69+cuda${CUDA/./} -f \
68 | https://storage.googleapis.com/jax-releases/jax_releases.html
69 |
70 | # Apply OpenMM patch.
71 | WORKDIR /opt/conda/lib/python3.7/site-packages
72 | RUN patch -p0 < /app/alphafold/docker/openmm.patch
73 |
74 | # We need to run `ldconfig` first to ensure GPUs are visible, due to some quirk
75 | # with Debian. See https://github.com/NVIDIA/nvidia-docker/issues/1399 for
76 | # details.
77 | # ENTRYPOINT does not support easily running multiple commands, so instead we
78 | # write a shell script to wrap them up.
79 | WORKDIR /app/alphafold
80 | RUN echo $'#!/bin/bash\n\
81 | ldconfig\n\
82 | python /app/alphafold/run_alphafold.py "$@"' > /app/run_alphafold.sh \
83 | && chmod +x /app/run_alphafold.sh
84 | ENTRYPOINT ["/app/run_alphafold.sh"]
85 |
--------------------------------------------------------------------------------
/docker/openmm.patch:
--------------------------------------------------------------------------------
1 | Index: simtk/openmm/app/topology.py
2 | ===================================================================
3 | --- simtk.orig/openmm/app/topology.py
4 | +++ simtk/openmm/app/topology.py
5 | @@ -356,19 +356,35 @@
6 | def isCyx(res):
7 | names = [atom.name for atom in res._atoms]
8 | return 'SG' in names and 'HG' not in names
9 | + # This function is used to prevent multiple di-sulfide bonds from being
10 | + # assigned to a given atom. This is a DeepMind modification.
11 | + def isDisulfideBonded(atom):
12 | + for b in self._bonds:
13 | + if (atom in b and b[0].name == 'SG' and
14 | + b[1].name == 'SG'):
15 | + return True
16 | +
17 | + return False
18 |
19 | cyx = [res for res in self.residues() if res.name == 'CYS' and isCyx(res)]
20 | atomNames = [[atom.name for atom in res._atoms] for res in cyx]
21 | for i in range(len(cyx)):
22 | sg1 = cyx[i]._atoms[atomNames[i].index('SG')]
23 | pos1 = positions[sg1.index]
24 | + candidate_distance, candidate_atom = 0.3*nanometers, None
25 | for j in range(i):
26 | sg2 = cyx[j]._atoms[atomNames[j].index('SG')]
27 | pos2 = positions[sg2.index]
28 | delta = [x-y for (x,y) in zip(pos1, pos2)]
29 | distance = sqrt(delta[0]*delta[0] + delta[1]*delta[1] + delta[2]*delta[2])
30 | - if distance < 0.3*nanometers:
31 | - self.addBond(sg1, sg2)
32 | + if distance < candidate_distance and not isDisulfideBonded(sg2):
33 | + candidate_distance = distance
34 | + candidate_atom = sg2
35 | + # Assign bond to closest pair.
36 | + if candidate_atom:
37 | + self.addBond(sg1, candidate_atom)
38 | +
39 | +
40 |
41 | class Chain(object):
42 | """A Chain object represents a chain within a Topology."""
43 |
--------------------------------------------------------------------------------
/docker/requirements.txt:
--------------------------------------------------------------------------------
1 | # Dependencies necessary to execute run_docker.py
2 | absl-py==0.13.0
3 | docker==5.0.0
4 |
--------------------------------------------------------------------------------
/docker/run_docker.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Docker launch script for Alphafold docker image."""
16 |
17 | import os
18 | import signal
19 | from typing import Tuple
20 |
21 | from absl import app
22 | from absl import flags
23 | from absl import logging
24 | import docker
25 | from docker import types
26 |
27 |
28 | #### USER CONFIGURATION ####
29 |
30 | # Set to target of scripts/download_all_databases.sh
31 | DOWNLOAD_DIR = 'SET ME'
32 |
33 | # Name of the AlphaFold Docker image.
34 | docker_image_name = 'alphafold'
35 |
36 | # Path to a directory that will store the results.
37 | output_dir = '/tmp/alphafold'
38 |
39 | # Names of models to use.
40 | model_names = [
41 | 'model_1',
42 | 'model_2',
43 | 'model_3',
44 | 'model_4',
45 | 'model_5',
46 | ]
47 |
48 | # You can individually override the following paths if you have placed the
49 | # data in locations other than the DOWNLOAD_DIR.
50 |
51 | # Path to directory of supporting data, contains 'params' dir.
52 | data_dir = DOWNLOAD_DIR
53 |
54 | # Path to the Uniref90 database for use by JackHMMER.
55 | uniref90_database_path = os.path.join(
56 | DOWNLOAD_DIR, 'uniref90', 'uniref90.fasta')
57 |
58 | # Path to the MGnify database for use by JackHMMER.
59 | mgnify_database_path = os.path.join(
60 | DOWNLOAD_DIR, 'mgnify', 'mgy_clusters_2018_08.fa')
61 |
62 | # Path to the BFD database for use by HHblits.
63 | bfd_database_path = os.path.join(
64 | DOWNLOAD_DIR, 'bfd',
65 | 'bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt')
66 |
67 | # Path to the Small BFD database for use by JackHMMER.
68 | small_bfd_database_path = os.path.join(
69 | DOWNLOAD_DIR, 'small_bfd', 'bfd-first_non_consensus_sequences.fasta')
70 |
71 | # Path to the Uniclust30 database for use by HHblits.
72 | uniclust30_database_path = os.path.join(
73 | DOWNLOAD_DIR, 'uniclust30', 'uniclust30_2018_08', 'uniclust30_2018_08')
74 |
75 | # Path to the PDB70 database for use by HHsearch.
76 | pdb70_database_path = os.path.join(DOWNLOAD_DIR, 'pdb70', 'pdb70')
77 |
78 | # Path to a directory with template mmCIF structures, each named .cif')
79 | template_mmcif_dir = os.path.join(DOWNLOAD_DIR, 'pdb_mmcif', 'mmcif_files')
80 |
81 | # Path to a file mapping obsolete PDB IDs to their replacements.
82 | obsolete_pdbs_path = os.path.join(DOWNLOAD_DIR, 'pdb_mmcif', 'obsolete.dat')
83 |
84 | #### END OF USER CONFIGURATION ####
85 |
86 |
87 | flags.DEFINE_bool('use_gpu', True, 'Enable NVIDIA runtime to run with GPUs.')
88 | flags.DEFINE_string('gpu_devices', 'all', 'Comma separated list of devices to '
89 | 'pass to NVIDIA_VISIBLE_DEVICES.')
90 | flags.DEFINE_list('fasta_paths', None, 'Paths to FASTA files, each containing '
91 | 'one sequence. Paths should be separated by commas. '
92 | 'All FASTA paths must have a unique basename as the '
93 | 'basename is used to name the output directories for '
94 | 'each prediction.')
95 | flags.DEFINE_string('max_template_date', None, 'Maximum template release date '
96 | 'to consider (ISO-8601 format - i.e. YYYY-MM-DD). '
97 | 'Important if folding historical test sets.')
98 | flags.DEFINE_enum('preset', 'full_dbs',
99 | ['reduced_dbs', 'full_dbs', 'casp14'],
100 | 'Choose preset model configuration - no ensembling and '
101 | 'smaller genetic database config (reduced_dbs), no '
102 | 'ensembling and full genetic database config (full_dbs) or '
103 | 'full genetic database config and 8 model ensemblings '
104 | '(casp14).')
105 | flags.DEFINE_boolean('benchmark', False, 'Run multiple JAX model evaluations '
106 | 'to obtain a timing that excludes the compilation time, '
107 | 'which should be more indicative of the time required for '
108 | 'inferencing many proteins.')
109 |
110 | FLAGS = flags.FLAGS
111 |
112 | _ROOT_MOUNT_DIRECTORY = '/mnt/'
113 |
114 |
115 | def _create_mount(mount_name: str, path: str) -> Tuple[types.Mount, str]:
116 | path = os.path.abspath(path)
117 | source_path = os.path.dirname(path)
118 | target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, mount_name)
119 | logging.info('Mounting %s -> %s', source_path, target_path)
120 | mount = types.Mount(target_path, source_path, type='bind', read_only=True)
121 | return mount, os.path.join(target_path, os.path.basename(path))
122 |
123 |
124 | def main(argv):
125 | if len(argv) > 1:
126 | raise app.UsageError('Too many command-line arguments.')
127 |
128 | mounts = []
129 | command_args = []
130 |
131 | # Mount each fasta path as a unique target directory.
132 | target_fasta_paths = []
133 | for i, fasta_path in enumerate(FLAGS.fasta_paths):
134 | mount, target_path = _create_mount(f'fasta_path_{i}', fasta_path)
135 | mounts.append(mount)
136 | target_fasta_paths.append(target_path)
137 | command_args.append(f'--fasta_paths={",".join(target_fasta_paths)}')
138 |
139 | database_paths = [
140 | ('uniref90_database_path', uniref90_database_path),
141 | ('mgnify_database_path', mgnify_database_path),
142 | ('pdb70_database_path', pdb70_database_path),
143 | ('data_dir', data_dir),
144 | ('template_mmcif_dir', template_mmcif_dir),
145 | ('obsolete_pdbs_path', obsolete_pdbs_path),
146 | ]
147 | if FLAGS.preset == 'reduced_dbs':
148 | database_paths.append(('small_bfd_database_path', small_bfd_database_path))
149 | else:
150 | database_paths.extend([
151 | ('uniclust30_database_path', uniclust30_database_path),
152 | ('bfd_database_path', bfd_database_path),
153 | ])
154 | for name, path in database_paths:
155 | if path:
156 | mount, target_path = _create_mount(name, path)
157 | mounts.append(mount)
158 | command_args.append(f'--{name}={target_path}')
159 |
160 | output_target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, 'output')
161 | mounts.append(types.Mount(output_target_path, output_dir, type='bind'))
162 |
163 | command_args.extend([
164 | f'--output_dir={output_target_path}',
165 | f'--model_names={",".join(model_names)}',
166 | f'--max_template_date={FLAGS.max_template_date}',
167 | f'--preset={FLAGS.preset}',
168 | f'--benchmark={FLAGS.benchmark}',
169 | '--logtostderr',
170 | ])
171 |
172 | client = docker.from_env()
173 | container = client.containers.run(
174 | image=docker_image_name,
175 | command=command_args,
176 | runtime='nvidia' if FLAGS.use_gpu else None,
177 | remove=True,
178 | detach=True,
179 | mounts=mounts,
180 | environment={
181 | 'NVIDIA_VISIBLE_DEVICES': FLAGS.gpu_devices,
182 | # The following flags allow us to make predictions on proteins that
183 | # would typically be too long to fit into GPU memory.
184 | 'TF_FORCE_UNIFIED_MEMORY': '1',
185 | 'XLA_PYTHON_CLIENT_MEM_FRACTION': '4.0',
186 | })
187 |
188 | # Add signal handler to ensure CTRL+C also stops the running container.
189 | signal.signal(signal.SIGINT,
190 | lambda unused_sig, unused_frame: container.kill())
191 |
192 | for line in container.logs(stream=True):
193 | logging.info(line.strip().decode('utf-8'))
194 |
195 |
196 | if __name__ == '__main__':
197 | flags.mark_flags_as_required([
198 | 'fasta_paths',
199 | 'max_template_date',
200 | ])
201 | app.run(main)
202 |
--------------------------------------------------------------------------------
/header.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xinformatics/alphafold/18ddb85e42ab6363fe1d86dab403306366725bb2/header.jpg
--------------------------------------------------------------------------------
/imgs/casp14_predictions.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xinformatics/alphafold/18ddb85e42ab6363fe1d86dab403306366725bb2/imgs/casp14_predictions.gif
--------------------------------------------------------------------------------
/imgs/header.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xinformatics/alphafold/18ddb85e42ab6363fe1d86dab403306366725bb2/imgs/header.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.13.0
2 | biopython==1.79
3 | chex==0.0.7
4 | dm-haiku==0.0.4
5 | dm-tree==0.1.6
6 | docker==5.0.0
7 | immutabledict==2.0.0
8 | jax==0.2.14
9 | ml-collections==0.1.0
10 | numpy==1.19.5
11 | scipy==1.7.0
12 | tensorflow==2.5.0
13 |
--------------------------------------------------------------------------------
/run_alphafold_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for run_alphafold."""
16 |
17 | import os
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | import mock
22 | import numpy as np
23 |
24 | import run_alphafold
25 | # Internal import (7716).
26 |
27 |
28 | class RunAlphafoldTest(parameterized.TestCase):
29 |
30 | def test_end_to_end(self):
31 |
32 | data_pipeline_mock = mock.Mock()
33 | model_runner_mock = mock.Mock()
34 | amber_relaxer_mock = mock.Mock()
35 |
36 | data_pipeline_mock.process.return_value = {}
37 | model_runner_mock.process_features.return_value = {
38 | 'aatype': np.zeros((12, 10), dtype=np.int32),
39 | 'residue_index': np.tile(np.arange(10, dtype=np.int32)[None], (12, 1)),
40 | }
41 | model_runner_mock.predict.return_value = {
42 | 'structure_module': {
43 | 'final_atom_positions': np.zeros((10, 37, 3)),
44 | 'final_atom_mask': np.ones((10, 37)),
45 | },
46 | 'predicted_lddt': {
47 | 'logits': np.ones((10, 50)),
48 | },
49 | 'plddt': np.zeros(10),
50 | 'ptm': np.array(0.),
51 | 'aligned_confidence_probs': np.zeros((10, 10, 50)),
52 | 'predicted_aligned_error': np.zeros((10, 10)),
53 | 'max_predicted_aligned_error': np.array(0.),
54 | }
55 | amber_relaxer_mock.process.return_value = ('RELAXED', None, None)
56 |
57 | fasta_path = os.path.join(absltest.get_default_test_tmpdir(),
58 | 'target.fasta')
59 | with open(fasta_path, 'wt') as f:
60 | f.write('>A\nAAAAAAAAAAAAA')
61 | fasta_name = 'test'
62 |
63 | out_dir = absltest.get_default_test_tmpdir()
64 |
65 | run_alphafold.predict_structure(
66 | fasta_path=fasta_path,
67 | fasta_name=fasta_name,
68 | output_dir_base=out_dir,
69 | data_pipeline=data_pipeline_mock,
70 | model_runners={'model1': model_runner_mock},
71 | amber_relaxer=amber_relaxer_mock,
72 | benchmark=False,
73 | random_seed=0)
74 |
75 |
76 | if __name__ == '__main__':
77 | absltest.main()
78 |
--------------------------------------------------------------------------------
/scripts/download_all_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Copyright 2021 DeepMind Technologies Limited
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 | # Downloads and unzips all required data for AlphaFold.
18 | #
19 | # Usage: bash download_all_data.sh /path/to/download/directory
20 | set -e
21 |
22 | if [[ $# -eq 0 ]]; then
23 | echo "Error: download directory must be provided as an input argument."
24 | exit 1
25 | fi
26 |
27 | if ! command -v aria2c &> /dev/null ; then
28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
29 | exit 1
30 | fi
31 |
32 | DOWNLOAD_DIR="$1"
33 | DOWNLOAD_MODE="${2:-full_dbs}" # Default mode to full_dbs.
34 | if [[ "${DOWNLOAD_MODE}" != full_dbs && "${DOWNLOAD_MODE}" != reduced_dbs ]]
35 | then
36 | echo "DOWNLOAD_MODE ${DOWNLOAD_MODE} not recognized."
37 | exit 1
38 | fi
39 |
40 | SCRIPT_DIR="$(dirname "$(realpath "$0")")"
41 |
42 | echo "Downloading AlphaFold parameters..."
43 | bash "${SCRIPT_DIR}/download_alphafold_params.sh" "${DOWNLOAD_DIR}"
44 |
45 | if [[ "${DOWNLOAD_MODE}" = full_dbs ]] ; then
46 | echo "Downloading BFD..."
47 | bash "${SCRIPT_DIR}/download_bfd.sh" "${DOWNLOAD_DIR}"
48 | else
49 | echo "Downloading Small BFD..."
50 | bash "${SCRIPT_DIR}/download_small_bfd.sh" "${DOWNLOAD_DIR}"
51 | fi
52 |
53 | echo "Downloading MGnify..."
54 | bash "${SCRIPT_DIR}/download_mgnify.sh" "${DOWNLOAD_DIR}"
55 |
56 | echo "Downloading PDB70..."
57 | bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}"
58 |
59 | echo "Downloading PDB mmCIF files..."
60 | bash "${SCRIPT_DIR}/download_pdb_mmcif.sh" "${DOWNLOAD_DIR}"
61 |
62 | echo "Downloading Uniclust30..."
63 | bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}"
64 |
65 | echo "Downloading Uniref90..."
66 | bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}"
67 |
68 | echo "All data downloaded."
69 |
--------------------------------------------------------------------------------
/scripts/download_alphafold_params.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Copyright 2021 DeepMind Technologies Limited
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 | # Downloads and unzips the AlphaFold parameters.
18 | #
19 | # Usage: bash download_alphafold_params.sh /path/to/download/directory
20 | set -e
21 |
22 | if [[ $# -eq 0 ]]; then
23 | echo "Error: download directory must be provided as an input argument."
24 | exit 1
25 | fi
26 |
27 | if ! command -v aria2c &> /dev/null ; then
28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
29 | exit 1
30 | fi
31 |
32 | DOWNLOAD_DIR="$1"
33 | ROOT_DIR="${DOWNLOAD_DIR}/params"
34 | SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2021-07-14.tar"
35 | BASENAME=$(basename "${SOURCE_URL}")
36 |
37 | mkdir --parents "${ROOT_DIR}"
38 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
39 | tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \
40 | --directory="${ROOT_DIR}" --preserve-permissions
41 | rm "${ROOT_DIR}/${BASENAME}"
42 |
--------------------------------------------------------------------------------
/scripts/download_bfd.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Copyright 2021 DeepMind Technologies Limited
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 | # Downloads and unzips the BFD database for AlphaFold.
18 | #
19 | # Usage: bash download_bfd.sh /path/to/download/directory
20 | set -e
21 |
22 | if [[ $# -eq 0 ]]; then
23 | echo "Error: download directory must be provided as an input argument."
24 | exit 1
25 | fi
26 |
27 | if ! command -v aria2c &> /dev/null ; then
28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
29 | exit 1
30 | fi
31 |
32 | DOWNLOAD_DIR="$1"
33 | ROOT_DIR="${DOWNLOAD_DIR}/bfd"
34 | # Mirror of:
35 | # https://bfd.mmseqs.com/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz.
36 | SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz"
37 | BASENAME=$(basename "${SOURCE_URL}")
38 |
39 | mkdir --parents "${ROOT_DIR}"
40 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
41 | tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \
42 | --directory="${ROOT_DIR}"
43 | rm "${ROOT_DIR}/${BASENAME}"
44 |
--------------------------------------------------------------------------------
/scripts/download_mgnify.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Copyright 2021 DeepMind Technologies Limited
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 | # Downloads and unzips the MGnify database for AlphaFold.
18 | #
19 | # Usage: bash download_mgnify.sh /path/to/download/directory
20 | set -e
21 |
22 | if [[ $# -eq 0 ]]; then
23 | echo "Error: download directory must be provided as an input argument."
24 | exit 1
25 | fi
26 |
27 | if ! command -v aria2c &> /dev/null ; then
28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
29 | exit 1
30 | fi
31 |
32 | DOWNLOAD_DIR="$1"
33 | ROOT_DIR="${DOWNLOAD_DIR}/mgnify"
34 | # Mirror of:
35 | # ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2018_12/mgy_clusters.fa.gz
36 | SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/mgy_clusters_2018_12.fa.gz"
37 | BASENAME=$(basename "${SOURCE_URL}")
38 |
39 | mkdir --parents "${ROOT_DIR}"
40 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
41 | pushd "${ROOT_DIR}"
42 | gunzip "${ROOT_DIR}/${BASENAME}"
43 | popd
44 |
--------------------------------------------------------------------------------
/scripts/download_pdb70.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Copyright 2021 DeepMind Technologies Limited
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 | # Downloads and unzips the PDB70 database for AlphaFold.
18 | #
19 | # Usage: bash download_pdb70.sh /path/to/download/directory
20 | set -e
21 |
22 | if [[ $# -eq 0 ]]; then
23 | echo "Error: download directory must be provided as an input argument."
24 | exit 1
25 | fi
26 |
27 | if ! command -v aria2c &> /dev/null ; then
28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
29 | exit 1
30 | fi
31 |
32 | DOWNLOAD_DIR="$1"
33 | ROOT_DIR="${DOWNLOAD_DIR}/pdb70"
34 | SOURCE_URL="http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/old-releases/pdb70_from_mmcif_200401.tar.gz"
35 | BASENAME=$(basename "${SOURCE_URL}")
36 |
37 | mkdir --parents "${ROOT_DIR}"
38 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
39 | tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \
40 | --directory="${ROOT_DIR}"
41 | rm "${ROOT_DIR}/${BASENAME}"
42 |
--------------------------------------------------------------------------------
/scripts/download_pdb_mmcif.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Copyright 2021 DeepMind Technologies Limited
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 | # Downloads, unzips and flattens the PDB database for AlphaFold.
18 | #
19 | # Usage: bash download_pdb_mmcif.sh /path/to/download/directory
20 | set -e
21 |
22 | if [[ $# -eq 0 ]]; then
23 | echo "Error: download directory must be provided as an input argument."
24 | exit 1
25 | fi
26 |
27 | if ! command -v aria2c &> /dev/null ; then
28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
29 | exit 1
30 | fi
31 |
32 | if ! command -v rsync &> /dev/null ; then
33 | echo "Error: rsync could not be found. Please install rsync."
34 | exit 1
35 | fi
36 |
37 | DOWNLOAD_DIR="$1"
38 | ROOT_DIR="${DOWNLOAD_DIR}/pdb_mmcif"
39 | RAW_DIR="${ROOT_DIR}/raw"
40 | MMCIF_DIR="${ROOT_DIR}/mmcif_files"
41 |
42 | echo "Running rsync to fetch all mmCIF files (note that the rsync progress estimate might be inaccurate)..."
43 | mkdir --parents "${RAW_DIR}"
44 | rsync --recursive --links --perms --times --compress --info=progress2 --delete --port=33444 \
45 | rsync.rcsb.org::ftp_data/structures/divided/mmCIF/ \
46 | "${RAW_DIR}"
47 |
48 | echo "Unzipping all mmCIF files..."
49 | find "${RAW_DIR}/" -type f -iname "*.gz" -exec gunzip {} +
50 |
51 | echo "Flattening all mmCIF files..."
52 | mkdir --parents "${MMCIF_DIR}"
53 | find "${RAW_DIR}" -type d -empty -delete # Delete empty directories.
54 | for subdir in "${RAW_DIR}"/*; do
55 | mv "${subdir}/"*.cif "${MMCIF_DIR}"
56 | done
57 |
58 | # Delete empty download directory structure.
59 | find "${RAW_DIR}" -type d -empty -delete
60 |
61 | aria2c "ftp://ftp.wwpdb.org/pub/pdb/data/status/obsolete.dat" --dir="${ROOT_DIR}"
62 |
--------------------------------------------------------------------------------
/scripts/download_small_bfd.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Copyright 2021 DeepMind Technologies Limited
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 | # Downloads and unzips the Small BFD database for AlphaFold.
18 | #
19 | # Usage: bash download_small_bfd.sh /path/to/download/directory
20 | set -e
21 |
22 | if [[ $# -eq 0 ]]; then
23 | echo "Error: download directory must be provided as an input argument."
24 | exit 1
25 | fi
26 |
27 | if ! command -v aria2c &> /dev/null ; then
28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
29 | exit 1
30 | fi
31 |
32 | DOWNLOAD_DIR="$1"
33 | ROOT_DIR="${DOWNLOAD_DIR}/small_bfd"
34 | SOURCE_URL="https://storage.googleapis.com/alphafold-databases/reduced_dbs/bfd-first_non_consensus_sequences.fasta.gz"
35 | BASENAME=$(basename "${SOURCE_URL}")
36 |
37 | mkdir --parents "${ROOT_DIR}"
38 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
39 | pushd "${ROOT_DIR}"
40 | gunzip "${ROOT_DIR}/${BASENAME}"
41 | popd
42 |
--------------------------------------------------------------------------------
/scripts/download_uniclust30.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Copyright 2021 DeepMind Technologies Limited
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 | # Downloads and unzips the Uniclust30 database for AlphaFold.
18 | #
19 | # Usage: bash download_uniclust30.sh /path/to/download/directory
20 | set -e
21 |
22 | if [[ $# -eq 0 ]]; then
23 | echo "Error: download directory must be provided as an input argument."
24 | exit 1
25 | fi
26 |
27 | if ! command -v aria2c &> /dev/null ; then
28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
29 | exit 1
30 | fi
31 |
32 | DOWNLOAD_DIR="$1"
33 | ROOT_DIR="${DOWNLOAD_DIR}/uniclust30"
34 | # Mirror of:
35 | # http://wwwuser.gwdg.de/~compbiol/uniclust/2018_08/uniclust30_2018_08_hhsuite.tar.gz
36 | SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/uniclust30_2018_08_hhsuite.tar.gz"
37 | BASENAME=$(basename "${SOURCE_URL}")
38 |
39 | mkdir --parents "${ROOT_DIR}"
40 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
41 | tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \
42 | --directory="${ROOT_DIR}"
43 | rm "${ROOT_DIR}/${BASENAME}"
44 |
--------------------------------------------------------------------------------
/scripts/download_uniref90.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Copyright 2021 DeepMind Technologies Limited
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 | # Downloads and unzips the UniRef90 database for AlphaFold.
18 | #
19 | # Usage: bash download_uniref90.sh /path/to/download/directory
20 | set -e
21 |
22 | if [[ $# -eq 0 ]]; then
23 | echo "Error: download directory must be provided as an input argument."
24 | exit 1
25 | fi
26 |
27 | if ! command -v aria2c &> /dev/null ; then
28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)."
29 | exit 1
30 | fi
31 |
32 | DOWNLOAD_DIR="$1"
33 | ROOT_DIR="${DOWNLOAD_DIR}/uniref90"
34 | SOURCE_URL="ftp://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref90/uniref90.fasta.gz"
35 | BASENAME=$(basename "${SOURCE_URL}")
36 |
37 | mkdir --parents "${ROOT_DIR}"
38 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}"
39 | pushd "${ROOT_DIR}"
40 | gunzip "${ROOT_DIR}/${BASENAME}"
41 | popd
42 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Install script for setuptools."""
15 |
16 | from setuptools import find_packages
17 | from setuptools import setup
18 |
19 | setup(
20 | name='alphafold',
21 | version='2.0.0',
22 | description='An implementation of the inference pipeline of AlphaFold v2.0.'
23 | 'This is a completely new model that was entered as AlphaFold2 in CASP14 '
24 | 'and published in Nature.',
25 | author='DeepMind',
26 | author_email='alphafold@deepmind.com',
27 | license='Apache License, Version 2.0',
28 | url='https://github.com/deepmind/alphafold',
29 | packages=find_packages(),
30 | install_requires=[
31 | 'absl-py',
32 | 'biopython',
33 | 'chex',
34 | 'dm-haiku',
35 | 'dm-tree',
36 | 'docker',
37 | 'immutabledict',
38 | 'jax',
39 | 'ml-collections',
40 | 'numpy',
41 | 'scipy',
42 | 'tensorflow',
43 | ],
44 | tests_require=['mock'],
45 | classifiers=[
46 | 'Development Status :: 5 - Production/Stable',
47 | 'Intended Audience :: Science/Research',
48 | 'License :: OSI Approved :: Apache Software License',
49 | 'Operating System :: POSIX :: Linux',
50 | 'Programming Language :: Python :: 3.6',
51 | 'Programming Language :: Python :: 3.7',
52 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
53 | ],
54 | )
55 |
--------------------------------------------------------------------------------