├── .gitignore
├── Cfold.ipynb
├── Logo.svg
├── README.md
├── cfold.ipynb
├── data
└── test
│ ├── 4AVA.a3m
│ └── 4AVA.fasta
├── environment.yml
├── install_dependencies.sh
├── pip_pkgs.txt
├── predict.sh
└── src
├── alphafold
├── .DS_Store
├── common
│ ├── __init__.py
│ ├── confidence.py
│ ├── protein.py
│ ├── protein_test.py
│ ├── residue_constants.py
│ ├── residue_constants_test.py
│ └── testdata
│ │ └── 2rbg.pdb
├── data
│ ├── __init__.py
│ ├── mmcif_parsing.py
│ ├── parsers.py
│ ├── pipeline.py
│ ├── templates.py
│ └── tools
│ │ ├── __init__.py
│ │ ├── hhblits.py
│ │ ├── hhsearch.py
│ │ ├── hmmbuild.py
│ │ ├── hmmsearch.py
│ │ ├── jackhmmer.py
│ │ ├── kalign.py
│ │ └── utils.py
├── model
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-312.pyc
│ │ ├── all_atom.cpython-312.pyc
│ │ ├── common_modules.cpython-312.pyc
│ │ ├── config.cpython-312.pyc
│ │ ├── data.cpython-312.pyc
│ │ ├── features.cpython-312.pyc
│ │ ├── folding.cpython-312.pyc
│ │ ├── layer_stack.cpython-312.pyc
│ │ ├── lddt.cpython-312.pyc
│ │ ├── mapping.cpython-312.pyc
│ │ ├── modules.cpython-312.pyc
│ │ ├── prng.cpython-312.pyc
│ │ ├── quat_affine.cpython-312.pyc
│ │ ├── r3.cpython-312.pyc
│ │ └── utils.cpython-312.pyc
│ ├── all_atom.py
│ ├── all_atom_test.py
│ ├── common_modules.py
│ ├── config.py
│ ├── data.py
│ ├── features.py
│ ├── folding.py
│ ├── layer_stack.py
│ ├── layer_stack_test.py
│ ├── lddt.py
│ ├── lddt_test.py
│ ├── mapping.py
│ ├── model.py
│ ├── modules.py
│ ├── prng.py
│ ├── prng_test.py
│ ├── quat_affine.py
│ ├── quat_affine_test.py
│ ├── r3.py
│ ├── tf
│ │ ├── __init__.py
│ │ ├── data_transforms.py
│ │ ├── input_pipeline.py
│ │ ├── protein_features.py
│ │ ├── protein_features_test.py
│ │ ├── proteins_dataset.py
│ │ ├── shape_helpers.py
│ │ ├── shape_helpers_test.py
│ │ ├── shape_placeholders.py
│ │ └── utils.py
│ └── utils.py
└── relax
│ ├── __init__.py
│ ├── amber_minimize.py
│ ├── amber_minimize_test.py
│ ├── cleanup.py
│ ├── cleanup_test.py
│ ├── relax.py
│ ├── relax_test.py
│ ├── testdata
│ ├── model_output.pdb
│ ├── multiple_disulfides_target.pdb
│ ├── with_violations.pdb
│ └── with_violations_casp14.pdb
│ ├── utils.py
│ └── utils_test.py
├── animate_py3dmol.py
├── check_msa_colab.py
├── make_msa_seq_feats.py
├── make_msa_seq_feats_colab.py
├── predict_with_clusters.py
├── predict_with_clusters_colab.py
├── score_closest_ca.py
└── score_closest_ca_colab.py
/.gitignore:
--------------------------------------------------------------------------------
1 | src/alphafold/common/__pycache__
2 | src/alphafold/data/__pycache__
3 | src/alphafold/data/tools/__pycache__
4 | src/alphafold/model/__pycache__
5 | src/alphafold/model/tf/__pycache__
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cfold
2 |
3 | **Structure prediction of alternative protein conformations**
4 |
5 |
6 |
7 |
8 |
9 | \
10 | Cfold is a structure prediction network similar to AlphaFold2 that is trained on a conformational split of the PDB. Cfold is designed for predicting alternative conformations of protein structures. [Read more about it in the paper here](https://www.nature.com/articles/s41467-024-51507-2)
11 | \
12 | \
13 | AlphaFold2 is available under the [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0) and so is Cfold, which is a derivative thereof. The Cfold parameters are made available under the terms of the [CC BY 4.0 license](https://creativecommons.org/licenses/by/4.0/legalcode).
14 | \
15 | \
16 | **You may not use these files except in compliance with the licenses.**
17 |
18 | # Colab (run in the web)
19 |
20 | [Colab Notebook](https://colab.research.google.com/github/patrickbryant1/Cfold/blob/master/Cfold.ipynb)
21 |
22 | # Local installation
23 |
24 | The entire installation takes <1 hour on a standard computer. \
25 | The runtime will depend on the GPU you have available, the size of the protein
26 | you are predicting and the number of samples taken. On an NVIDIA A100 GPU, the
27 | prediction time is a few minutes per sample for a protein of a few hundred amino acids.
28 |
29 | We assume you have CUDA12. For CUDA11, you will have to change the installation of some packages. \
30 |
31 | First install miniconda, see: https://docs.conda.io/projects/miniconda/en/latest/miniconda-other-installer-links.html
32 |
33 | ```
34 | bash install_dependencies.sh
35 | ```
36 | If the conda doesn't work for you - see "pip_pkgs.txt"
37 |
38 | # Run the test case
39 | ## (a few minutes)
40 | ```
41 | bash predict.sh
42 | ```
43 |
44 | # Data availability
45 | https://zenodo.org/records/10837082
46 |
47 | # Citation
48 | Bryant P, Noé F. Structure prediction of alternative protein conformations. Nat Commun 15, 7328 (2024).
49 |
--------------------------------------------------------------------------------
/data/test/4AVA.fasta:
--------------------------------------------------------------------------------
1 | >4AVA
2 | DGIAELTGARVEDLAGMDVFQGCPAEGLVSLAASVQPLRAAAGQVLLRQGEPAVSFLLISSGSAEVSHVGDDGVAIIARALPGMIVGEIALLRDSPRSATVTTIEPLTGWTGGRGAFATMVHIPGVGERLLRTARQRLAAFVSPIPVRLADGTQLMLRPVLPGDRERTVHGHIQFSGETLYRRFMSPALMHYLSEVDYVDHFVWVVTDGSDPVADARFVRDETDPTVAEIAFTVADAYQGRGIGSFLIGALSVAARVDGVERFAARMLSDNVPMRTIMDRYGAVWQREDVGVITTMIDVPGPGELSLGREMVDQINRVARQVIEAVG
3 |
--------------------------------------------------------------------------------
/install_dependencies.sh:
--------------------------------------------------------------------------------
1 |
2 | #Python packages
3 | conda env create -f environment.yml
4 |
5 | wait
6 | conda activate cfold
7 | pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
8 | conda deactivate
9 |
10 | ## Get network parameters for Cfold
11 | wget https://zenodo.org/records/10517910/files/params10000.npy
12 |
13 |
14 | ## Uniclust30
15 | wget http://wwwuser.gwdg.de/~compbiol/uniclust/2018_08/uniclust30_2018_08_hhsuite.tar.gz --no-check-certificate
16 | mkdir data/uniclust30
17 | mv uniclust30_2018_08_hhsuite.tar.gz data/uniclust30
18 | tar -zxvf data/uniclust30/uniclust30_2018_08_hhsuite.tar.gz
19 |
20 |
21 | ## HHblits
22 | git clone https://github.com/soedinglab/hh-suite.git
23 | mkdir -p hh-suite/build && cd hh-suite/build
24 | cmake -DCMAKE_INSTALL_PREFIX=. ..
25 | make -j 4 && make install
26 | cd ..
27 |
--------------------------------------------------------------------------------
/pip_pkgs.txt:
--------------------------------------------------------------------------------
1 | pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
2 | pip install ml-collections==0.1.1
3 | pip install dm-haiku==0.0.11
4 | pip install pandas==1.3.5
5 | pip install biopython==1.81
6 | pip install chex==0.1.5
7 | pip install dm-tree==0.1.8
8 | pip install immutabledict==2.0.0
9 | pip install scipy
10 | pip install tensorflow
11 | pip install numpy==1.21.6
12 |
--------------------------------------------------------------------------------
/predict.sh:
--------------------------------------------------------------------------------
1 | ## Search Uniclust30 with HHblits to generate an MSA
2 |
3 | ID=4AVA
4 | FASTA_DIR=./data/test/
5 | UNICLUST=./data/uniclust30_2018_08/uniclust30_2018_08
6 | OUTDIR=./data/test/
7 | ./hh-suite/build/bin/hhblits -i $FASTA_DIR/$ID.fasta -d $UNICLUST -E 0.001 -all -oa3m $OUTDIR/$ID'.a3m'
8 |
9 |
10 | ## MSA feats
11 | conda activate cfold
12 | MSA_DIR=./data/test/
13 | OUTDIR=./data/test/
14 |
15 | python3 ./src/make_msa_seq_feats.py --input_fasta_path $FASTA_DIR/$ID'.fasta' \
16 | --input_msas $MSA_DIR/$ID'.a3m' --outdir $OUTDIR
17 |
18 |
19 | ## Predict
20 | FEATURE_DIR=./data/test/
21 | PARAMS=./params10000.npy
22 | OUTDIR=./data/test/
23 | NUM_REC=3 #Increase for hard targets
24 | NUM_SAMPLES=13 #Increase for hard targets
25 |
26 | python3 ./src/predict_with_clusters.py --feature_dir $FEATURE_DIR \
27 | --predict_id $ID \
28 | --ckpt_params $PARAMS \
29 | --num_recycles $NUM_REC \
30 | --num_samples_per_cluster $NUM_SAMPLES \
31 | --outdir $OUTDIR/
32 |
--------------------------------------------------------------------------------
/src/alphafold/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/.DS_Store
--------------------------------------------------------------------------------
/src/alphafold/common/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Common data types and constants used within Alphafold."""
15 |
--------------------------------------------------------------------------------
/src/alphafold/common/confidence.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Functions for processing confidence metrics."""
16 |
17 | from typing import Dict, Optional, Tuple
18 | import numpy as np
19 | import scipy.special
20 |
21 |
22 | def compute_plddt(logits: np.ndarray) -> np.ndarray:
23 | """Computes per-residue pLDDT from logits.
24 |
25 | Args:
26 | logits: [num_res, num_bins] output from the PredictedLDDTHead.
27 |
28 | Returns:
29 | plddt: [num_res] per-residue pLDDT.
30 | """
31 | num_bins = logits.shape[-1]
32 | bin_width = 1.0 / num_bins
33 | bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width)
34 | probs = scipy.special.softmax(logits, axis=-1)
35 | predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1)
36 | return predicted_lddt_ca * 100
37 |
38 |
39 | def _calculate_bin_centers(breaks: np.ndarray):
40 | """Gets the bin centers from the bin edges.
41 |
42 | Args:
43 | breaks: [num_bins - 1] the error bin edges.
44 |
45 | Returns:
46 | bin_centers: [num_bins] the error bin centers.
47 | """
48 | step = (breaks[1] - breaks[0])
49 |
50 | # Add half-step to get the center
51 | bin_centers = breaks + step / 2
52 | # Add a catch-all bin at the end.
53 | bin_centers = np.concatenate([bin_centers, [bin_centers[-1] + step]],
54 | axis=0)
55 | return bin_centers
56 |
57 |
58 | def _calculate_expected_aligned_error(
59 | alignment_confidence_breaks: np.ndarray,
60 | aligned_distance_error_probs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
61 | """Calculates expected aligned distance errors for every pair of residues.
62 |
63 | Args:
64 | alignment_confidence_breaks: [num_bins - 1] the error bin edges.
65 | aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted
66 | probs for each error bin, for each pair of residues.
67 |
68 | Returns:
69 | predicted_aligned_error: [num_res, num_res] the expected aligned distance
70 | error for each pair of residues.
71 | max_predicted_aligned_error: The maximum predicted error possible.
72 | """
73 | bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
74 |
75 | # Tuple of expected aligned distance error and max possible error.
76 | return (np.sum(aligned_distance_error_probs * bin_centers, axis=-1),
77 | np.asarray(bin_centers[-1]))
78 |
79 |
80 | def compute_predicted_aligned_error(
81 | logits: np.ndarray,
82 | breaks: np.ndarray) -> Dict[str, np.ndarray]:
83 | """Computes aligned confidence metrics from logits.
84 |
85 | Args:
86 | logits: [num_res, num_res, num_bins] the logits output from
87 | PredictedAlignedErrorHead.
88 | breaks: [num_bins - 1] the error bin edges.
89 |
90 | Returns:
91 | aligned_confidence_probs: [num_res, num_res, num_bins] the predicted
92 | aligned error probabilities over bins for each residue pair.
93 | predicted_aligned_error: [num_res, num_res] the expected aligned distance
94 | error for each pair of residues.
95 | max_predicted_aligned_error: The maximum predicted error possible.
96 | """
97 | aligned_confidence_probs = scipy.special.softmax(
98 | logits,
99 | axis=-1)
100 | predicted_aligned_error, max_predicted_aligned_error = (
101 | _calculate_expected_aligned_error(
102 | alignment_confidence_breaks=breaks,
103 | aligned_distance_error_probs=aligned_confidence_probs))
104 | return {
105 | 'aligned_confidence_probs': aligned_confidence_probs,
106 | 'predicted_aligned_error': predicted_aligned_error,
107 | 'max_predicted_aligned_error': max_predicted_aligned_error,
108 | }
109 |
110 |
111 | def predicted_tm_score(
112 | logits: np.ndarray,
113 | breaks: np.ndarray,
114 | residue_weights: Optional[np.ndarray] = None) -> np.ndarray:
115 | """Computes predicted TM alignment score.
116 |
117 | Args:
118 | logits: [num_res, num_res, num_bins] the logits output from
119 | PredictedAlignedErrorHead.
120 | breaks: [num_bins] the error bins.
121 | residue_weights: [num_res] the per residue weights to use for the
122 | expectation.
123 |
124 | Returns:
125 | ptm_score: the predicted TM alignment score.
126 | """
127 |
128 | # residue_weights has to be in [0, 1], but can be floating-point, i.e. the
129 | # exp. resolved head's probability.
130 | if residue_weights is None:
131 | residue_weights = np.ones(logits.shape[0])
132 |
133 | bin_centers = _calculate_bin_centers(breaks)
134 |
135 | num_res = np.sum(residue_weights)
136 | # Clip num_res to avoid negative/undefined d0.
137 | clipped_num_res = max(num_res, 19)
138 |
139 | # Compute d_0(num_res) as defined by TM-score, eqn. (5) in
140 | # http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
141 | # Yang & Skolnick "Scoring function for automated
142 | # assessment of protein structure template quality" 2004
143 | d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8
144 |
145 | # Convert logits to probs
146 | probs = scipy.special.softmax(logits, axis=-1)
147 |
148 | # TM-Score term for every bin
149 | tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0))
150 | # E_distances tm(distance)
151 | predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1)
152 |
153 | normed_residue_mask = residue_weights / (1e-8 + residue_weights.sum())
154 | per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1)
155 | return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()])
156 |
--------------------------------------------------------------------------------
/src/alphafold/common/protein.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Protein data type."""
16 | import dataclasses
17 | import io
18 | from typing import Any, Mapping, Optional
19 | from alphafold.common import residue_constants
20 | from Bio.PDB import PDBParser
21 | import numpy as np
22 |
23 | FeatureDict = Mapping[str, np.ndarray]
24 | ModelOutput = Mapping[str, Any] # Is a nested dict.
25 |
26 |
27 | @dataclasses.dataclass(frozen=True)
28 | class Protein:
29 | """Protein structure representation."""
30 |
31 | # Cartesian coordinates of atoms in angstroms. The atom types correspond to
32 | # residue_constants.atom_types, i.e. the first three are N, CA, CB.
33 | atom_positions: np.ndarray # [num_res, num_atom_type, 3]
34 |
35 | # Amino-acid type for each residue represented as an integer between 0 and
36 | # 20, where 20 is 'X'.
37 | aatype: np.ndarray # [num_res]
38 |
39 | # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
40 | # is present and 0.0 if not. This should be used for loss masking.
41 | atom_mask: np.ndarray # [num_res, num_atom_type]
42 |
43 | # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
44 | residue_index: np.ndarray # [num_res]
45 |
46 | # B-factors, or temperature factors, of each residue (in sq. angstroms units),
47 | # representing the displacement of the residue from its ground truth mean
48 | # value.
49 | b_factors: np.ndarray # [num_res, num_atom_type]
50 |
51 |
52 | def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
53 | """Takes a PDB string and constructs a Protein object.
54 |
55 | WARNING: All non-standard residue types will be converted into UNK. All
56 | non-standard atoms will be ignored.
57 |
58 | Args:
59 | pdb_str: The contents of the pdb file
60 | chain_id: If None, then the pdb file must contain a single chain (which
61 | will be parsed). If chain_id is specified (e.g. A), then only that chain
62 | is parsed.
63 |
64 | Returns:
65 | A new `Protein` parsed from the pdb contents.
66 | """
67 | pdb_fh = io.StringIO(pdb_str)
68 | parser = PDBParser(QUIET=True)
69 | structure = parser.get_structure('none', pdb_fh)
70 | models = list(structure.get_models())
71 | if len(models) != 1:
72 | raise ValueError(
73 | f'Only single model PDBs are supported. Found {len(models)} models.')
74 | model = models[0]
75 |
76 | if chain_id is not None:
77 | chain = model[chain_id]
78 | else:
79 | chains = list(model.get_chains())
80 | if len(chains) != 1:
81 | raise ValueError(
82 | 'Only single chain PDBs are supported when chain_id not specified. '
83 | f'Found {len(chains)} chains.')
84 | else:
85 | chain = chains[0]
86 |
87 | atom_positions = []
88 | aatype = []
89 | atom_mask = []
90 | residue_index = []
91 | b_factors = []
92 |
93 | for res in chain:
94 | if res.id[2] != ' ':
95 | raise ValueError(
96 | f'PDB contains an insertion code at chain {chain.id} and residue '
97 | f'index {res.id[1]}. These are not supported.')
98 | res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
99 | restype_idx = residue_constants.restype_order.get(
100 | res_shortname, residue_constants.restype_num)
101 | pos = np.zeros((residue_constants.atom_type_num, 3))
102 | mask = np.zeros((residue_constants.atom_type_num,))
103 | res_b_factors = np.zeros((residue_constants.atom_type_num,))
104 | for atom in res:
105 | if atom.name not in residue_constants.atom_types:
106 | continue
107 | pos[residue_constants.atom_order[atom.name]] = atom.coord
108 | mask[residue_constants.atom_order[atom.name]] = 1.
109 | res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
110 | if np.sum(mask) < 0.5:
111 | # If no known atom positions are reported for the residue then skip it.
112 | continue
113 | aatype.append(restype_idx)
114 | atom_positions.append(pos)
115 | atom_mask.append(mask)
116 | residue_index.append(res.id[1])
117 | b_factors.append(res_b_factors)
118 |
119 | return Protein(
120 | atom_positions=np.array(atom_positions),
121 | atom_mask=np.array(atom_mask),
122 | aatype=np.array(aatype),
123 | residue_index=np.array(residue_index),
124 | b_factors=np.array(b_factors))
125 |
126 |
127 | def to_pdb(prot: Protein) -> str:
128 | """Converts a `Protein` instance to a PDB string.
129 |
130 | Args:
131 | prot: The protein to convert to PDB.
132 |
133 | Returns:
134 | PDB string.
135 | """
136 | restypes = residue_constants.restypes + ['X']
137 | res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
138 | atom_types = residue_constants.atom_types
139 |
140 | pdb_lines = []
141 |
142 | atom_mask = prot.atom_mask
143 | aatype = prot.aatype
144 | atom_positions = prot.atom_positions
145 | residue_index = prot.residue_index.astype(np.int32)
146 | b_factors = prot.b_factors
147 |
148 | if np.any(aatype > residue_constants.restype_num):
149 | raise ValueError('Invalid aatypes.')
150 |
151 | pdb_lines.append('MODEL 1')
152 | atom_index = 1
153 | chain_id = 'A'
154 | # Add all atom sites.
155 | for i in range(aatype.shape[0]):
156 | res_name_3 = res_1to3(aatype[i])
157 | for atom_name, pos, mask, b_factor in zip(
158 | atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
159 | if mask < 0.5:
160 | continue
161 |
162 | record_type = 'ATOM'
163 | name = atom_name if len(atom_name) == 4 else f' {atom_name}'
164 | alt_loc = ''
165 | insertion_code = ''
166 | occupancy = 1.00
167 | element = atom_name[0] # Protein supports only C, N, O, S, this works.
168 | charge = ''
169 | # PDB is a columnar format, every space matters here!
170 | atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
171 | f'{res_name_3:>3} {chain_id:>1}'
172 | f'{residue_index[i]:>4}{insertion_code:>1} '
173 | f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
174 | f'{occupancy:>6.2f}{b_factor:>6.2f} '
175 | f'{element:>2}{charge:>2}')
176 | pdb_lines.append(atom_line)
177 | atom_index += 1
178 |
179 | # Close the chain.
180 | chain_end = 'TER'
181 | chain_termination_line = (
182 | f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} '
183 | f'{chain_id:>1}{residue_index[-1]:>4}')
184 | pdb_lines.append(chain_termination_line)
185 | pdb_lines.append('ENDMDL')
186 |
187 | pdb_lines.append('END')
188 | pdb_lines.append('')
189 | return '\n'.join(pdb_lines)
190 |
191 |
192 | def ideal_atom_mask(prot: Protein) -> np.ndarray:
193 | """Computes an ideal atom mask.
194 |
195 | `Protein.atom_mask` typically is defined according to the atoms that are
196 | reported in the PDB. This function computes a mask according to heavy atoms
197 | that should be present in the given sequence of amino acids.
198 |
199 | Args:
200 | prot: `Protein` whose fields are `numpy.ndarray` objects.
201 |
202 | Returns:
203 | An ideal atom mask.
204 | """
205 | return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
206 |
207 |
208 | def from_prediction(features: FeatureDict, result: ModelOutput,
209 | b_factors: Optional[np.ndarray] = None) -> Protein:
210 | """Assembles a protein from a prediction.
211 |
212 | Args:
213 | features: Dictionary holding model inputs.
214 | result: Dictionary holding model outputs.
215 | b_factors: (Optional) B-factors to use for the protein.
216 |
217 | Returns:
218 | A protein instance.
219 | """
220 | fold_output = result['structure_module']
221 | if b_factors is None:
222 | b_factors = np.zeros_like(fold_output['final_atom_mask'])
223 |
224 | return Protein(
225 | aatype=features['aatype'][0],
226 | atom_positions=fold_output['final_atom_positions'],
227 | atom_mask=fold_output['final_atom_mask'],
228 | residue_index=features['residue_index'][0] + 1,
229 | b_factors=b_factors)
230 |
--------------------------------------------------------------------------------
/src/alphafold/common/protein_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for protein."""
16 |
17 | import os
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | from alphafold.common import protein
22 | from alphafold.common import residue_constants
23 | import numpy as np
24 | # Internal import (7716).
25 |
26 | TEST_DATA_DIR = 'alphafold/common/testdata/'
27 |
28 |
29 | class ProteinTest(parameterized.TestCase):
30 |
31 | def _check_shapes(self, prot, num_res):
32 | """Check that the processed shapes are correct."""
33 | num_atoms = residue_constants.atom_type_num
34 | self.assertEqual((num_res, num_atoms, 3), prot.atom_positions.shape)
35 | self.assertEqual((num_res,), prot.aatype.shape)
36 | self.assertEqual((num_res, num_atoms), prot.atom_mask.shape)
37 | self.assertEqual((num_res,), prot.residue_index.shape)
38 | self.assertEqual((num_res, num_atoms), prot.b_factors.shape)
39 |
40 | @parameterized.parameters(('2rbg.pdb', 'A', 282),
41 | ('2rbg.pdb', 'B', 282))
42 | def test_from_pdb_str(self, pdb_file, chain_id, num_res):
43 | pdb_file = os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
44 | pdb_file)
45 | with open(pdb_file) as f:
46 | pdb_string = f.read()
47 | prot = protein.from_pdb_string(pdb_string, chain_id)
48 | self._check_shapes(prot, num_res)
49 | self.assertGreaterEqual(prot.aatype.min(), 0)
50 | # Allow equal since unknown restypes have index equal to restype_num.
51 | self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num)
52 |
53 | def test_to_pdb(self):
54 | with open(
55 | os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
56 | '2rbg.pdb')) as f:
57 | pdb_string = f.read()
58 | prot = protein.from_pdb_string(pdb_string, chain_id='A')
59 | pdb_string_reconstr = protein.to_pdb(prot)
60 | prot_reconstr = protein.from_pdb_string(pdb_string_reconstr)
61 |
62 | np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype)
63 | np.testing.assert_array_almost_equal(
64 | prot_reconstr.atom_positions, prot.atom_positions)
65 | np.testing.assert_array_almost_equal(
66 | prot_reconstr.atom_mask, prot.atom_mask)
67 | np.testing.assert_array_equal(
68 | prot_reconstr.residue_index, prot.residue_index)
69 | np.testing.assert_array_almost_equal(
70 | prot_reconstr.b_factors, prot.b_factors)
71 |
72 | def test_ideal_atom_mask(self):
73 | with open(
74 | os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
75 | '2rbg.pdb')) as f:
76 | pdb_string = f.read()
77 | prot = protein.from_pdb_string(pdb_string, chain_id='A')
78 | ideal_mask = protein.ideal_atom_mask(prot)
79 | non_ideal_residues = set([102] + list(range(127, 285)))
80 | for i, (res, atom_mask) in enumerate(
81 | zip(prot.residue_index, prot.atom_mask)):
82 | if res in non_ideal_residues:
83 | self.assertFalse(np.all(atom_mask == ideal_mask[i]), msg=f'{res}')
84 | else:
85 | self.assertTrue(np.all(atom_mask == ideal_mask[i]), msg=f'{res}')
86 |
87 |
88 | if __name__ == '__main__':
89 | absltest.main()
90 |
--------------------------------------------------------------------------------
/src/alphafold/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Data pipeline for model features."""
15 |
--------------------------------------------------------------------------------
/src/alphafold/data/pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Functions for building the input features for the AlphaFold model."""
16 |
17 | import os
18 | from typing import Mapping, Optional, Sequence
19 | from absl import logging
20 | from alphafold.common import residue_constants
21 | from alphafold.data import parsers
22 | from alphafold.data import templates
23 | from alphafold.data.tools import hhblits
24 | from alphafold.data.tools import hhsearch
25 | from alphafold.data.tools import jackhmmer
26 | import numpy as np
27 |
28 | # Internal import (7716).
29 |
30 | FeatureDict = Mapping[str, np.ndarray]
31 |
32 |
33 | def make_sequence_features(
34 | sequence: str, description: str, num_res: int) -> FeatureDict:
35 | """Constructs a feature dict of sequence features."""
36 | features = {}
37 | features['aatype'] = residue_constants.sequence_to_onehot(
38 | sequence=sequence,
39 | mapping=residue_constants.restype_order_with_x,
40 | map_unknown_to_x=True)
41 | features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32)
42 | features['domain_name'] = np.array([description.encode('utf-8')],
43 | dtype=np.object_)
44 | features['residue_index'] = np.array(range(num_res), dtype=np.int32)
45 | features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
46 | features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_)
47 | return features
48 |
49 |
50 | def make_msa_features(
51 | msas: Sequence[Sequence[str]],
52 | deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict:
53 | """Constructs a feature dict of MSA features."""
54 | if not msas:
55 | raise ValueError('At least one MSA must be provided.')
56 |
57 | int_msa = []
58 | deletion_matrix = []
59 | seen_sequences = set()
60 | for msa_index, msa in enumerate(msas):
61 | if not msa:
62 | raise ValueError(f'MSA {msa_index} must contain at least one sequence.')
63 | for sequence_index, sequence in enumerate(msa):
64 | if sequence in seen_sequences:
65 | continue
66 | seen_sequences.add(sequence)
67 | int_msa.append(
68 | [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
69 | deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
70 |
71 | num_res = len(msas[0][0])
72 | num_alignments = len(int_msa)
73 | features = {}
74 | features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
75 | features['msa'] = np.array(int_msa, dtype=np.int32)
76 | features['num_alignments'] = np.array(
77 | [num_alignments] * num_res, dtype=np.int32)
78 | return features
79 |
80 |
81 | class DataPipeline:
82 | """Runs the alignment tools and assembles the input features."""
83 |
84 | def __init__(self,
85 | jackhmmer_binary_path: str,
86 | hhblits_binary_path: str,
87 | hhsearch_binary_path: str,
88 | uniref90_database_path: str,
89 | mgnify_database_path: str,
90 | bfd_database_path: Optional[str],
91 | uniclust30_database_path: Optional[str],
92 | small_bfd_database_path: Optional[str],
93 | pdb70_database_path: str,
94 | template_featurizer: templates.TemplateHitFeaturizer,
95 | use_small_bfd: bool,
96 | mgnify_max_hits: int = 501,
97 | uniref_max_hits: int = 10000):
98 | """Constructs a feature dict for a given FASTA file."""
99 | self._use_small_bfd = use_small_bfd
100 | self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
101 | binary_path=jackhmmer_binary_path,
102 | database_path=uniref90_database_path)
103 | if use_small_bfd:
104 | self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
105 | binary_path=jackhmmer_binary_path,
106 | database_path=small_bfd_database_path)
107 | else:
108 | self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
109 | binary_path=hhblits_binary_path,
110 | databases=[bfd_database_path, uniclust30_database_path])
111 | self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
112 | binary_path=jackhmmer_binary_path,
113 | database_path=mgnify_database_path)
114 | self.hhsearch_pdb70_runner = hhsearch.HHSearch(
115 | binary_path=hhsearch_binary_path,
116 | databases=[pdb70_database_path])
117 | self.template_featurizer = template_featurizer
118 | self.mgnify_max_hits = mgnify_max_hits
119 | self.uniref_max_hits = uniref_max_hits
120 |
121 | def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
122 | """Runs alignment tools on the input sequence and creates features."""
123 | with open(input_fasta_path) as f:
124 | input_fasta_str = f.read()
125 | input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
126 | if len(input_seqs) != 1:
127 | raise ValueError(
128 | f'More than one input sequence found in {input_fasta_path}.')
129 | input_sequence = input_seqs[0]
130 | input_description = input_descs[0]
131 | num_res = len(input_sequence)
132 |
133 | jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
134 | input_fasta_path)[0]
135 | jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
136 | input_fasta_path)[0]
137 |
138 | uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
139 | jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits)
140 | hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
141 |
142 | uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
143 | with open(uniref90_out_path, 'w') as f:
144 | f.write(jackhmmer_uniref90_result['sto'])
145 |
146 | mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
147 | with open(mgnify_out_path, 'w') as f:
148 | f.write(jackhmmer_mgnify_result['sto'])
149 |
150 | pdb70_out_path = os.path.join(msa_output_dir, 'pdb70_hits.hhr')
151 | with open(pdb70_out_path, 'w') as f:
152 | f.write(hhsearch_result)
153 |
154 | uniref90_msa, uniref90_deletion_matrix, _ = parsers.parse_stockholm(
155 | jackhmmer_uniref90_result['sto'])
156 | mgnify_msa, mgnify_deletion_matrix, _ = parsers.parse_stockholm(
157 | jackhmmer_mgnify_result['sto'])
158 | hhsearch_hits = parsers.parse_hhr(hhsearch_result)
159 | mgnify_msa = mgnify_msa[:self.mgnify_max_hits]
160 | mgnify_deletion_matrix = mgnify_deletion_matrix[:self.mgnify_max_hits]
161 |
162 | if self._use_small_bfd:
163 | jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
164 | input_fasta_path)[0]
165 |
166 | bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.a3m')
167 | with open(bfd_out_path, 'w') as f:
168 | f.write(jackhmmer_small_bfd_result['sto'])
169 |
170 | bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
171 | jackhmmer_small_bfd_result['sto'])
172 | else:
173 | hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(
174 | input_fasta_path)
175 |
176 | bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
177 | with open(bfd_out_path, 'w') as f:
178 | f.write(hhblits_bfd_uniclust_result['a3m'])
179 |
180 | bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(
181 | hhblits_bfd_uniclust_result['a3m'])
182 |
183 | templates_result = self.template_featurizer.get_templates(
184 | query_sequence=input_sequence,
185 | query_pdb_code=None,
186 | query_release_date=None,
187 | hits=hhsearch_hits)
188 |
189 | sequence_features = make_sequence_features(
190 | sequence=input_sequence,
191 | description=input_description,
192 | num_res=num_res)
193 |
194 | msa_features = make_msa_features(
195 | msas=(uniref90_msa, bfd_msa, mgnify_msa),
196 | deletion_matrices=(uniref90_deletion_matrix,
197 | bfd_deletion_matrix,
198 | mgnify_deletion_matrix))
199 |
200 | logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa))
201 | logging.info('BFD MSA size: %d sequences.', len(bfd_msa))
202 | logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa))
203 | logging.info('Final (deduplicated) MSA size: %d sequences.',
204 | msa_features['num_alignments'][0])
205 | logging.info('Total number of templates (NB: this can include bad '
206 | 'templates and is later filtered to top 4): %d.',
207 | templates_result.features['template_domain_names'].shape[0])
208 |
209 | return {**sequence_features, **msa_features, **templates_result.features}
210 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Python wrappers for third party tools."""
15 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/hhblits.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library to run HHblits from Python."""
16 |
17 | import glob
18 | import os
19 | import subprocess
20 | from typing import Any, Mapping, Optional, Sequence
21 |
22 | from absl import logging
23 | from alphafold.data.tools import utils
24 | # Internal import (7716).
25 |
26 |
27 | _HHBLITS_DEFAULT_P = 20
28 | _HHBLITS_DEFAULT_Z = 500
29 |
30 |
31 | class HHBlits:
32 | """Python wrapper of the HHblits binary."""
33 |
34 | def __init__(self,
35 | *,
36 | binary_path: str,
37 | databases: Sequence[str],
38 | n_cpu: int = 4,
39 | n_iter: int = 3,
40 | e_value: float = 0.001,
41 | maxseq: int = 1_000_000,
42 | realign_max: int = 100_000,
43 | maxfilt: int = 100_000,
44 | min_prefilter_hits: int = 1000,
45 | all_seqs: bool = False,
46 | alt: Optional[int] = None,
47 | p: int = _HHBLITS_DEFAULT_P,
48 | z: int = _HHBLITS_DEFAULT_Z):
49 | """Initializes the Python HHblits wrapper.
50 |
51 | Args:
52 | binary_path: The path to the HHblits executable.
53 | databases: A sequence of HHblits database paths. This should be the
54 | common prefix for the database files (i.e. up to but not including
55 | _hhm.ffindex etc.)
56 | n_cpu: The number of CPUs to give HHblits.
57 | n_iter: The number of HHblits iterations.
58 | e_value: The E-value, see HHblits docs for more details.
59 | maxseq: The maximum number of rows in an input alignment. Note that this
60 | parameter is only supported in HHBlits version 3.1 and higher.
61 | realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
62 | maxfilt: Max number of hits allowed to pass the 2nd prefilter.
63 | HHblits default: 20000.
64 | min_prefilter_hits: Min number of hits to pass prefilter.
65 | HHblits default: 100.
66 | all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
67 | HHblits default: False.
68 | alt: Show up to this many alternative alignments.
69 | p: Minimum Prob for a hit to be included in the output hhr file.
70 | HHblits default: 20.
71 | z: Hard cap on number of hits reported in the hhr file.
72 | HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
73 |
74 | Raises:
75 | RuntimeError: If HHblits binary not found within the path.
76 | """
77 | self.binary_path = binary_path
78 | self.databases = databases
79 |
80 | for database_path in self.databases:
81 | if not glob.glob(database_path + '_*'):
82 | logging.error('Could not find HHBlits database %s', database_path)
83 | raise ValueError(f'Could not find HHBlits database {database_path}')
84 |
85 | self.n_cpu = n_cpu
86 | self.n_iter = n_iter
87 | self.e_value = e_value
88 | self.maxseq = maxseq
89 | self.realign_max = realign_max
90 | self.maxfilt = maxfilt
91 | self.min_prefilter_hits = min_prefilter_hits
92 | self.all_seqs = all_seqs
93 | self.alt = alt
94 | self.p = p
95 | self.z = z
96 |
97 | def query(self, input_fasta_path: str) -> Mapping[str, Any]:
98 | """Queries the database using HHblits."""
99 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
100 | a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
101 |
102 | db_cmd = []
103 | for db_path in self.databases:
104 | db_cmd.append('-d')
105 | db_cmd.append(db_path)
106 | cmd = [
107 | self.binary_path,
108 | '-i', input_fasta_path,
109 | '-cpu', str(self.n_cpu),
110 | '-oa3m', a3m_path,
111 | '-o', '/dev/null',
112 | '-n', str(self.n_iter),
113 | '-e', str(self.e_value),
114 | '-maxseq', str(self.maxseq),
115 | '-realign_max', str(self.realign_max),
116 | '-maxfilt', str(self.maxfilt),
117 | '-min_prefilter_hits', str(self.min_prefilter_hits)]
118 | if self.all_seqs:
119 | cmd += ['-all']
120 | if self.alt:
121 | cmd += ['-alt', str(self.alt)]
122 | if self.p != _HHBLITS_DEFAULT_P:
123 | cmd += ['-p', str(self.p)]
124 | if self.z != _HHBLITS_DEFAULT_Z:
125 | cmd += ['-Z', str(self.z)]
126 | cmd += db_cmd
127 |
128 | logging.info('Launching subprocess "%s"', ' '.join(cmd))
129 | process = subprocess.Popen(
130 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
131 |
132 | with utils.timing('HHblits query'):
133 | stdout, stderr = process.communicate()
134 | retcode = process.wait()
135 |
136 | if retcode:
137 | # Logs have a 15k character limit, so log HHblits error line by line.
138 | logging.error('HHblits failed. HHblits stderr begin:')
139 | for error_line in stderr.decode('utf-8').splitlines():
140 | if error_line.strip():
141 | logging.error(error_line.strip())
142 | logging.error('HHblits stderr end')
143 | raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % (
144 | stdout.decode('utf-8'), stderr[:500_000].decode('utf-8')))
145 |
146 | with open(a3m_path) as f:
147 | a3m = f.read()
148 |
149 | raw_output = dict(
150 | a3m=a3m,
151 | output=stdout,
152 | stderr=stderr,
153 | n_iter=self.n_iter,
154 | e_value=self.e_value)
155 | return raw_output
156 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/hhsearch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library to run HHsearch from Python."""
16 |
17 | import glob
18 | import os
19 | import subprocess
20 | from typing import Sequence
21 |
22 | from absl import logging
23 |
24 | from alphafold.data.tools import utils
25 | # Internal import (7716).
26 |
27 |
28 | class HHSearch:
29 | """Python wrapper of the HHsearch binary."""
30 |
31 | def __init__(self,
32 | *,
33 | binary_path: str,
34 | databases: Sequence[str],
35 | maxseq: int = 1_000_000):
36 | """Initializes the Python HHsearch wrapper.
37 |
38 | Args:
39 | binary_path: The path to the HHsearch executable.
40 | databases: A sequence of HHsearch database paths. This should be the
41 | common prefix for the database files (i.e. up to but not including
42 | _hhm.ffindex etc.)
43 | maxseq: The maximum number of rows in an input alignment. Note that this
44 | parameter is only supported in HHBlits version 3.1 and higher.
45 |
46 | Raises:
47 | RuntimeError: If HHsearch binary not found within the path.
48 | """
49 | self.binary_path = binary_path
50 | self.databases = databases
51 | self.maxseq = maxseq
52 |
53 | for database_path in self.databases:
54 | if not glob.glob(database_path + '_*'):
55 | logging.error('Could not find HHsearch database %s', database_path)
56 | raise ValueError(f'Could not find HHsearch database {database_path}')
57 |
58 | def query(self, a3m: str) -> str:
59 | """Queries the database using HHsearch using a given a3m."""
60 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
61 | input_path = os.path.join(query_tmp_dir, 'query.a3m')
62 | hhr_path = os.path.join(query_tmp_dir, 'output.hhr')
63 | with open(input_path, 'w') as f:
64 | f.write(a3m)
65 |
66 | db_cmd = []
67 | for db_path in self.databases:
68 | db_cmd.append('-d')
69 | db_cmd.append(db_path)
70 | cmd = [self.binary_path,
71 | '-i', input_path,
72 | '-o', hhr_path,
73 | '-maxseq', str(self.maxseq)
74 | ] + db_cmd
75 |
76 | logging.info('Launching subprocess "%s"', ' '.join(cmd))
77 | process = subprocess.Popen(
78 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
79 | with utils.timing('HHsearch query'):
80 | stdout, stderr = process.communicate()
81 | retcode = process.wait()
82 |
83 | if retcode:
84 | # Stderr is truncated to prevent proto size errors in Beam.
85 | raise RuntimeError(
86 | 'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
87 | stdout.decode('utf-8'), stderr[:100_000].decode('utf-8')))
88 |
89 | with open(hhr_path) as f:
90 | hhr = f.read()
91 | return hhr
92 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/hmmbuild.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
16 |
17 | import os
18 | import re
19 | import subprocess
20 |
21 | from absl import logging
22 | from alphafold.data.tools import utils
23 | # Internal import (7716).
24 |
25 |
26 | class Hmmbuild(object):
27 | """Python wrapper of the hmmbuild binary."""
28 |
29 | def __init__(self,
30 | *,
31 | binary_path: str,
32 | singlemx: bool = False):
33 | """Initializes the Python hmmbuild wrapper.
34 |
35 | Args:
36 | binary_path: The path to the hmmbuild executable.
37 | singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
38 | just use a common substitution score matrix.
39 |
40 | Raises:
41 | RuntimeError: If hmmbuild binary not found within the path.
42 | """
43 | self.binary_path = binary_path
44 | self.singlemx = singlemx
45 |
46 | def build_profile_from_sto(self, sto: str, model_construction='fast') -> str:
47 | """Builds a HHM for the aligned sequences given as an A3M string.
48 |
49 | Args:
50 | sto: A string with the aligned sequences in the Stockholm format.
51 | model_construction: Whether to use reference annotation in the msa to
52 | determine consensus columns ('hand') or default ('fast').
53 |
54 | Returns:
55 | A string with the profile in the HMM format.
56 |
57 | Raises:
58 | RuntimeError: If hmmbuild fails.
59 | """
60 | return self._build_profile(sto, model_construction=model_construction)
61 |
62 | def build_profile_from_a3m(self, a3m: str) -> str:
63 | """Builds a HHM for the aligned sequences given as an A3M string.
64 |
65 | Args:
66 | a3m: A string with the aligned sequences in the A3M format.
67 |
68 | Returns:
69 | A string with the profile in the HMM format.
70 |
71 | Raises:
72 | RuntimeError: If hmmbuild fails.
73 | """
74 | lines = []
75 | for line in a3m.splitlines():
76 | if not line.startswith('>'):
77 | line = re.sub('[a-z]+', '', line) # Remove inserted residues.
78 | lines.append(line + '\n')
79 | msa = ''.join(lines)
80 | return self._build_profile(msa, model_construction='fast')
81 |
82 | def _build_profile(self, msa: str, model_construction: str = 'fast') -> str:
83 | """Builds a HMM for the aligned sequences given as an MSA string.
84 |
85 | Args:
86 | msa: A string with the aligned sequences, in A3M or STO format.
87 | model_construction: Whether to use reference annotation in the msa to
88 | determine consensus columns ('hand') or default ('fast').
89 |
90 | Returns:
91 | A string with the profile in the HMM format.
92 |
93 | Raises:
94 | RuntimeError: If hmmbuild fails.
95 | ValueError: If unspecified arguments are provided.
96 | """
97 | if model_construction not in {'hand', 'fast'}:
98 | raise ValueError(f'Invalid model_construction {model_construction} - only'
99 | 'hand and fast supported.')
100 |
101 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
102 | input_query = os.path.join(query_tmp_dir, 'query.msa')
103 | output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm')
104 |
105 | with open(input_query, 'w') as f:
106 | f.write(msa)
107 |
108 | cmd = [self.binary_path]
109 | # If adding flags, we have to do so before the output and input:
110 |
111 | if model_construction == 'hand':
112 | cmd.append(f'--{model_construction}')
113 | if self.singlemx:
114 | cmd.append('--singlemx')
115 | cmd.extend([
116 | '--amino',
117 | output_hmm_path,
118 | input_query,
119 | ])
120 |
121 | logging.info('Launching subprocess %s', cmd)
122 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
123 | stderr=subprocess.PIPE)
124 |
125 | with utils.timing('hmmbuild query'):
126 | stdout, stderr = process.communicate()
127 | retcode = process.wait()
128 | logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n',
129 | stdout.decode('utf-8'), stderr.decode('utf-8'))
130 |
131 | if retcode:
132 | raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n'
133 | % (stdout.decode('utf-8'), stderr.decode('utf-8')))
134 |
135 | with open(output_hmm_path, encoding='utf-8') as f:
136 | hmm = f.read()
137 |
138 | return hmm
139 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/hmmsearch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A Python wrapper for hmmsearch - search profile against a sequence db."""
16 |
17 | import os
18 | import subprocess
19 | from typing import Optional, Sequence
20 |
21 | from absl import logging
22 | from alphafold.data.tools import utils
23 | # Internal import (7716).
24 |
25 |
26 | class Hmmsearch(object):
27 | """Python wrapper of the hmmsearch binary."""
28 |
29 | def __init__(self,
30 | *,
31 | binary_path: str,
32 | database_path: str,
33 | flags: Optional[Sequence[str]] = None):
34 | """Initializes the Python hmmsearch wrapper.
35 |
36 | Args:
37 | binary_path: The path to the hmmsearch executable.
38 | database_path: The path to the hmmsearch database (FASTA format).
39 | flags: List of flags to be used by hmmsearch.
40 |
41 | Raises:
42 | RuntimeError: If hmmsearch binary not found within the path.
43 | """
44 | self.binary_path = binary_path
45 | self.database_path = database_path
46 | self.flags = flags
47 |
48 | if not os.path.exists(self.database_path):
49 | logging.error('Could not find hmmsearch database %s', database_path)
50 | raise ValueError(f'Could not find hmmsearch database {database_path}')
51 |
52 | def query(self, hmm: str) -> str:
53 | """Queries the database using hmmsearch using a given hmm."""
54 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
55 | hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
56 | a3m_out_path = os.path.join(query_tmp_dir, 'output.a3m')
57 | with open(hmm_input_path, 'w') as f:
58 | f.write(hmm)
59 |
60 | cmd = [
61 | self.binary_path,
62 | '--noali', # Don't include the alignment in stdout.
63 | '--cpu', '8'
64 | ]
65 | # If adding flags, we have to do so before the output and input:
66 | if self.flags:
67 | cmd.extend(self.flags)
68 | cmd.extend([
69 | '-A', a3m_out_path,
70 | hmm_input_path,
71 | self.database_path,
72 | ])
73 |
74 | logging.info('Launching sub-process %s', cmd)
75 | process = subprocess.Popen(
76 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
77 | with utils.timing(
78 | f'hmmsearch ({os.path.basename(self.database_path)}) query'):
79 | stdout, stderr = process.communicate()
80 | retcode = process.wait()
81 |
82 | if retcode:
83 | raise RuntimeError(
84 | 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
85 | stdout.decode('utf-8'), stderr.decode('utf-8')))
86 |
87 | with open(a3m_out_path) as f:
88 | a3m_out = f.read()
89 |
90 | return a3m_out
91 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/jackhmmer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library to run Jackhmmer from Python."""
16 |
17 | from concurrent import futures
18 | import glob
19 | import os
20 | import subprocess
21 | from typing import Any, Callable, Mapping, Optional, Sequence
22 | from urllib import request
23 |
24 | from absl import logging
25 |
26 | from alphafold.data.tools import utils
27 | # Internal import (7716).
28 |
29 |
30 | class Jackhmmer:
31 | """Python wrapper of the Jackhmmer binary."""
32 |
33 | def __init__(self,
34 | *,
35 | binary_path: str,
36 | database_path: str,
37 | n_cpu: int = 8,
38 | n_iter: int = 1,
39 | e_value: float = 0.0001,
40 | z_value: Optional[int] = None,
41 | get_tblout: bool = False,
42 | filter_f1: float = 0.0005,
43 | filter_f2: float = 0.00005,
44 | filter_f3: float = 0.0000005,
45 | incdom_e: Optional[float] = None,
46 | dom_e: Optional[float] = None,
47 | num_streamed_chunks: Optional[int] = None,
48 | streaming_callback: Optional[Callable[[int], None]] = None):
49 | """Initializes the Python Jackhmmer wrapper.
50 |
51 | Args:
52 | binary_path: The path to the jackhmmer executable.
53 | database_path: The path to the jackhmmer database (FASTA format).
54 | n_cpu: The number of CPUs to give Jackhmmer.
55 | n_iter: The number of Jackhmmer iterations.
56 | e_value: The E-value, see Jackhmmer docs for more details.
57 | z_value: The Z-value, see Jackhmmer docs for more details.
58 | get_tblout: Whether to save tblout string.
59 | filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
60 | filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
61 | filter_f3: Forward pre-filter, set to >1.0 to turn off.
62 | incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
63 | round.
64 | dom_e: Domain e-value criteria for inclusion in tblout.
65 | num_streamed_chunks: Number of database chunks to stream over.
66 | streaming_callback: Callback function run after each chunk iteration with
67 | the iteration number as argument.
68 | """
69 | self.binary_path = binary_path
70 | self.database_path = database_path
71 | self.num_streamed_chunks = num_streamed_chunks
72 |
73 | if not os.path.exists(self.database_path) and num_streamed_chunks is None:
74 | logging.error('Could not find Jackhmmer database %s', database_path)
75 | raise ValueError(f'Could not find Jackhmmer database {database_path}')
76 |
77 | self.n_cpu = n_cpu
78 | self.n_iter = n_iter
79 | self.e_value = e_value
80 | self.z_value = z_value
81 | self.filter_f1 = filter_f1
82 | self.filter_f2 = filter_f2
83 | self.filter_f3 = filter_f3
84 | self.incdom_e = incdom_e
85 | self.dom_e = dom_e
86 | self.get_tblout = get_tblout
87 | self.streaming_callback = streaming_callback
88 |
89 | def _query_chunk(self, input_fasta_path: str, database_path: str
90 | ) -> Mapping[str, Any]:
91 | """Queries the database chunk using Jackhmmer."""
92 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
93 | sto_path = os.path.join(query_tmp_dir, 'output.sto')
94 |
95 | # The F1/F2/F3 are the expected proportion to pass each of the filtering
96 | # stages (which get progressively more expensive), reducing these
97 | # speeds up the pipeline at the expensive of sensitivity. They are
98 | # currently set very low to make querying Mgnify run in a reasonable
99 | # amount of time.
100 | cmd_flags = [
101 | # Don't pollute stdout with Jackhmmer output.
102 | '-o', '/dev/null',
103 | '-A', sto_path,
104 | '--noali',
105 | '--F1', str(self.filter_f1),
106 | '--F2', str(self.filter_f2),
107 | '--F3', str(self.filter_f3),
108 | '--incE', str(self.e_value),
109 | # Report only sequences with E-values <= x in per-sequence output.
110 | '-E', str(self.e_value),
111 | '--cpu', str(self.n_cpu),
112 | '-N', str(self.n_iter)
113 | ]
114 | if self.get_tblout:
115 | tblout_path = os.path.join(query_tmp_dir, 'tblout.txt')
116 | cmd_flags.extend(['--tblout', tblout_path])
117 |
118 | if self.z_value:
119 | cmd_flags.extend(['-Z', str(self.z_value)])
120 |
121 | if self.dom_e is not None:
122 | cmd_flags.extend(['--domE', str(self.dom_e)])
123 |
124 | if self.incdom_e is not None:
125 | cmd_flags.extend(['--incdomE', str(self.incdom_e)])
126 |
127 | cmd = [self.binary_path] + cmd_flags + [input_fasta_path,
128 | database_path]
129 |
130 | logging.info('Launching subprocess "%s"', ' '.join(cmd))
131 | process = subprocess.Popen(
132 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
133 | with utils.timing(
134 | f'Jackhmmer ({os.path.basename(database_path)}) query'):
135 | _, stderr = process.communicate()
136 | retcode = process.wait()
137 |
138 | if retcode:
139 | raise RuntimeError(
140 | 'Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8'))
141 |
142 | # Get e-values for each target name
143 | tbl = ''
144 | if self.get_tblout:
145 | with open(tblout_path) as f:
146 | tbl = f.read()
147 |
148 | with open(sto_path) as f:
149 | sto = f.read()
150 |
151 | raw_output = dict(
152 | sto=sto,
153 | tbl=tbl,
154 | stderr=stderr,
155 | n_iter=self.n_iter,
156 | e_value=self.e_value)
157 |
158 | return raw_output
159 |
160 | def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
161 | """Queries the database using Jackhmmer."""
162 | if self.num_streamed_chunks is None:
163 | return [self._query_chunk(input_fasta_path, self.database_path)]
164 |
165 | db_basename = os.path.basename(self.database_path)
166 | db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}'
167 | db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}'
168 |
169 | # Remove existing files to prevent OOM
170 | for f in glob.glob(db_local_chunk('[0-9]*')):
171 | try:
172 | os.remove(f)
173 | except OSError:
174 | print(f'OSError while deleting {f}')
175 |
176 | # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
177 | with futures.ThreadPoolExecutor(max_workers=2) as executor:
178 | chunked_output = []
179 | for i in range(1, self.num_streamed_chunks + 1):
180 | # Copy the chunk locally
181 | if i == 1:
182 | future = executor.submit(
183 | request.urlretrieve, db_remote_chunk(i), db_local_chunk(i))
184 | if i < self.num_streamed_chunks:
185 | next_future = executor.submit(
186 | request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1))
187 |
188 | # Run Jackhmmer with the chunk
189 | future.result()
190 | chunked_output.append(
191 | self._query_chunk(input_fasta_path, db_local_chunk(i)))
192 |
193 | # Remove the local copy of the chunk
194 | os.remove(db_local_chunk(i))
195 | future = next_future
196 | if self.streaming_callback:
197 | self.streaming_callback(i)
198 | return chunked_output
199 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/kalign.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A Python wrapper for Kalign."""
16 | import os
17 | import subprocess
18 | from typing import Sequence
19 |
20 | from absl import logging
21 |
22 | from alphafold.data.tools import utils
23 | # Internal import (7716).
24 |
25 |
26 | def _to_a3m(sequences: Sequence[str]) -> str:
27 | """Converts sequences to an a3m file."""
28 | names = ['sequence %d' % i for i in range(1, len(sequences) + 1)]
29 | a3m = []
30 | for sequence, name in zip(sequences, names):
31 | a3m.append(u'>' + name + u'\n')
32 | a3m.append(sequence + u'\n')
33 | return ''.join(a3m)
34 |
35 |
36 | class Kalign:
37 | """Python wrapper of the Kalign binary."""
38 |
39 | def __init__(self, *, binary_path: str):
40 | """Initializes the Python Kalign wrapper.
41 |
42 | Args:
43 | binary_path: The path to the Kalign binary.
44 |
45 | Raises:
46 | RuntimeError: If Kalign binary not found within the path.
47 | """
48 | self.binary_path = binary_path
49 |
50 | def align(self, sequences: Sequence[str]) -> str:
51 | """Aligns the sequences and returns the alignment in A3M string.
52 |
53 | Args:
54 | sequences: A list of query sequence strings. The sequences have to be at
55 | least 6 residues long (Kalign requires this). Note that the order in
56 | which you give the sequences might alter the output slightly as
57 | different alignment tree might get constructed.
58 |
59 | Returns:
60 | A string with the alignment in a3m format.
61 |
62 | Raises:
63 | RuntimeError: If Kalign fails.
64 | ValueError: If any of the sequences is less than 6 residues long.
65 | """
66 | logging.info('Aligning %d sequences', len(sequences))
67 |
68 | for s in sequences:
69 | if len(s) < 6:
70 | raise ValueError('Kalign requires all sequences to be at least 6 '
71 | 'residues long. Got %s (%d residues).' % (s, len(s)))
72 |
73 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
74 | input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta')
75 | output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
76 |
77 | with open(input_fasta_path, 'w') as f:
78 | f.write(_to_a3m(sequences))
79 |
80 | cmd = [
81 | self.binary_path,
82 | '-i', input_fasta_path,
83 | '-o', output_a3m_path,
84 | '-format', 'fasta',
85 | ]
86 |
87 | logging.info('Launching subprocess "%s"', ' '.join(cmd))
88 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
89 | stderr=subprocess.PIPE)
90 |
91 | with utils.timing('Kalign query'):
92 | stdout, stderr = process.communicate()
93 | retcode = process.wait()
94 | logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n',
95 | stdout.decode('utf-8'), stderr.decode('utf-8'))
96 |
97 | if retcode:
98 | raise RuntimeError('Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n'
99 | % (stdout.decode('utf-8'), stderr.decode('utf-8')))
100 |
101 | with open(output_a3m_path) as f:
102 | a3m = f.read()
103 |
104 | return a3m
105 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Common utilities for data pipeline tools."""
15 | import contextlib
16 | import shutil
17 | import tempfile
18 | import time
19 | from typing import Optional
20 |
21 | from absl import logging
22 |
23 |
24 | @contextlib.contextmanager
25 | def tmpdir_manager(base_dir: Optional[str] = None):
26 | """Context manager that deletes a temporary directory on exit."""
27 | tmpdir = tempfile.mkdtemp(dir=base_dir)
28 | try:
29 | yield tmpdir
30 | finally:
31 | shutil.rmtree(tmpdir, ignore_errors=True)
32 |
33 |
34 | @contextlib.contextmanager
35 | def timing(msg: str):
36 | logging.info('Started %s', msg)
37 | tic = time.time()
38 | yield
39 | toc = time.time()
40 | logging.info('Finished %s in %.3f seconds', msg, toc - tic)
41 |
--------------------------------------------------------------------------------
/src/alphafold/model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Alphafold model."""
15 |
--------------------------------------------------------------------------------
/src/alphafold/model/__pycache__/__init__.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/__init__.cpython-312.pyc
--------------------------------------------------------------------------------
/src/alphafold/model/__pycache__/all_atom.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/all_atom.cpython-312.pyc
--------------------------------------------------------------------------------
/src/alphafold/model/__pycache__/common_modules.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/common_modules.cpython-312.pyc
--------------------------------------------------------------------------------
/src/alphafold/model/__pycache__/config.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/config.cpython-312.pyc
--------------------------------------------------------------------------------
/src/alphafold/model/__pycache__/data.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/data.cpython-312.pyc
--------------------------------------------------------------------------------
/src/alphafold/model/__pycache__/features.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/features.cpython-312.pyc
--------------------------------------------------------------------------------
/src/alphafold/model/__pycache__/folding.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/folding.cpython-312.pyc
--------------------------------------------------------------------------------
/src/alphafold/model/__pycache__/layer_stack.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/layer_stack.cpython-312.pyc
--------------------------------------------------------------------------------
/src/alphafold/model/__pycache__/lddt.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/lddt.cpython-312.pyc
--------------------------------------------------------------------------------
/src/alphafold/model/__pycache__/mapping.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/mapping.cpython-312.pyc
--------------------------------------------------------------------------------
/src/alphafold/model/__pycache__/modules.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/modules.cpython-312.pyc
--------------------------------------------------------------------------------
/src/alphafold/model/__pycache__/prng.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/prng.cpython-312.pyc
--------------------------------------------------------------------------------
/src/alphafold/model/__pycache__/quat_affine.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/quat_affine.cpython-312.pyc
--------------------------------------------------------------------------------
/src/alphafold/model/__pycache__/r3.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/r3.cpython-312.pyc
--------------------------------------------------------------------------------
/src/alphafold/model/__pycache__/utils.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/utils.cpython-312.pyc
--------------------------------------------------------------------------------
/src/alphafold/model/all_atom_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for all_atom."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from alphafold.model import all_atom
20 | from alphafold.model import r3
21 | import numpy as np
22 |
23 | L1_CLAMP_DISTANCE = 10
24 |
25 |
26 | def get_identity_rigid(shape):
27 | """Returns identity rigid transform."""
28 |
29 | ones = np.ones(shape)
30 | zeros = np.zeros(shape)
31 | rot = r3.Rots(ones, zeros, zeros,
32 | zeros, ones, zeros,
33 | zeros, zeros, ones)
34 | trans = r3.Vecs(zeros, zeros, zeros)
35 | return r3.Rigids(rot, trans)
36 |
37 |
38 | def get_global_rigid_transform(rot_angle, translation, bcast_dims):
39 | """Returns rigid transform that globally rotates/translates by same amount."""
40 |
41 | rot_angle = np.asarray(rot_angle)
42 | translation = np.asarray(translation)
43 | if bcast_dims:
44 | for _ in range(bcast_dims):
45 | rot_angle = np.expand_dims(rot_angle, 0)
46 | translation = np.expand_dims(translation, 0)
47 | sin_angle = np.sin(np.deg2rad(rot_angle))
48 | cos_angle = np.cos(np.deg2rad(rot_angle))
49 | ones = np.ones_like(sin_angle)
50 | zeros = np.zeros_like(sin_angle)
51 | rot = r3.Rots(ones, zeros, zeros,
52 | zeros, cos_angle, -sin_angle,
53 | zeros, sin_angle, cos_angle)
54 | trans = r3.Vecs(translation[..., 0], translation[..., 1], translation[..., 2])
55 | return r3.Rigids(rot, trans)
56 |
57 |
58 | class AllAtomTest(parameterized.TestCase, absltest.TestCase):
59 |
60 | @parameterized.named_parameters(
61 | ('identity', 0, [0, 0, 0]),
62 | ('rot_90', 90, [0, 0, 0]),
63 | ('trans_10', 0, [0, 0, 10]),
64 | ('rot_174_trans_1', 174, [1, 1, 1]))
65 | def test_frame_aligned_point_error_perfect_on_global_transform(
66 | self, rot_angle, translation):
67 | """Tests global transform between target and preds gives perfect score."""
68 |
69 | # pylint: disable=bad-whitespace
70 | target_positions = np.array(
71 | [[ 21.182, 23.095, 19.731],
72 | [ 22.055, 20.919, 17.294],
73 | [ 24.599, 20.005, 15.041],
74 | [ 25.567, 18.214, 12.166],
75 | [ 28.063, 17.082, 10.043],
76 | [ 28.779, 15.569, 6.985],
77 | [ 30.581, 13.815, 4.612],
78 | [ 29.258, 12.193, 2.296]])
79 | # pylint: enable=bad-whitespace
80 | global_rigid_transform = get_global_rigid_transform(
81 | rot_angle, translation, 1)
82 |
83 | target_positions = r3.vecs_from_tensor(target_positions)
84 | pred_positions = r3.rigids_mul_vecs(
85 | global_rigid_transform, target_positions)
86 | positions_mask = np.ones(target_positions.x.shape[0])
87 |
88 | target_frames = get_identity_rigid(10)
89 | pred_frames = r3.rigids_mul_rigids(global_rigid_transform, target_frames)
90 | frames_mask = np.ones(10)
91 |
92 | fape = all_atom.frame_aligned_point_error(
93 | pred_frames, target_frames, frames_mask, pred_positions,
94 | target_positions, positions_mask, L1_CLAMP_DISTANCE,
95 | L1_CLAMP_DISTANCE, epsilon=0)
96 | self.assertAlmostEqual(fape, 0.)
97 |
98 | @parameterized.named_parameters(
99 | ('identity',
100 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
101 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
102 | 0.),
103 | ('shift_2.5',
104 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
105 | [[2.5, 0, 0], [7.5, 0, 0], [7.5, 0, 0]],
106 | 0.25),
107 | ('shift_5',
108 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
109 | [[5, 0, 0], [10, 0, 0], [15, 0, 0]],
110 | 0.5),
111 | ('shift_10',
112 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
113 | [[10, 0, 0], [15, 0, 0], [0, 0, 0]],
114 | 1.))
115 | def test_frame_aligned_point_error_matches_expected(
116 | self, target_positions, pred_positions, expected_alddt):
117 | """Tests score matches expected."""
118 |
119 | target_frames = get_identity_rigid(2)
120 | pred_frames = target_frames
121 | frames_mask = np.ones(2)
122 |
123 | target_positions = r3.vecs_from_tensor(np.array(target_positions))
124 | pred_positions = r3.vecs_from_tensor(np.array(pred_positions))
125 | positions_mask = np.ones(target_positions.x.shape[0])
126 |
127 | alddt = all_atom.frame_aligned_point_error(
128 | pred_frames, target_frames, frames_mask, pred_positions,
129 | target_positions, positions_mask, L1_CLAMP_DISTANCE,
130 | L1_CLAMP_DISTANCE, epsilon=0)
131 | self.assertAlmostEqual(alddt, expected_alddt)
132 |
133 |
134 | if __name__ == '__main__':
135 | absltest.main()
136 |
--------------------------------------------------------------------------------
/src/alphafold/model/common_modules.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A collection of common Haiku modules for use in protein folding."""
16 | import haiku as hk
17 | import jax.numpy as jnp
18 |
19 |
20 | class Linear(hk.Module):
21 | """Protein folding specific Linear Module.
22 |
23 | This differs from the standard Haiku Linear in a few ways:
24 | * It supports inputs of arbitrary rank
25 | * Initializers are specified by strings
26 | """
27 |
28 | def __init__(self,
29 | num_output: int,
30 | initializer: str = 'linear',
31 | use_bias: bool = True,
32 | bias_init: float = 0.,
33 | name: str = 'linear'):
34 | """Constructs Linear Module.
35 |
36 | Args:
37 | num_output: number of output channels.
38 | initializer: What initializer to use, should be one of {'linear', 'relu',
39 | 'zeros'}
40 | use_bias: Whether to include trainable bias
41 | bias_init: Value used to initialize bias.
42 | name: name of module, used for name scopes.
43 | """
44 |
45 | super().__init__(name=name)
46 | self.num_output = num_output
47 | self.initializer = initializer
48 | self.use_bias = use_bias
49 | self.bias_init = bias_init
50 |
51 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
52 | """Connects Module.
53 |
54 | Args:
55 | inputs: Tensor of shape [..., num_channel]
56 |
57 | Returns:
58 | output of shape [..., num_output]
59 | """
60 | n_channels = int(inputs.shape[-1])
61 |
62 | weight_shape = [n_channels, self.num_output]
63 | if self.initializer == 'linear':
64 | weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=1.)
65 | elif self.initializer == 'relu':
66 | weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=2.)
67 | elif self.initializer == 'zeros':
68 | weight_init = hk.initializers.Constant(0.0)
69 |
70 | weights = hk.get_parameter('weights', weight_shape, inputs.dtype,
71 | weight_init)
72 |
73 | # this is equivalent to einsum('...c,cd->...d', inputs, weights)
74 | # but turns out to be slightly faster
75 | inputs = jnp.swapaxes(inputs, -1, -2)
76 | output = jnp.einsum('...cb,cd->...db', inputs, weights)
77 | output = jnp.swapaxes(output, -1, -2)
78 |
79 | if self.use_bias:
80 | bias = hk.get_parameter('bias', [self.num_output], inputs.dtype,
81 | hk.initializers.Constant(self.bias_init))
82 | output += bias
83 |
84 | return output
85 |
--------------------------------------------------------------------------------
/src/alphafold/model/data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Convenience functions for reading data."""
16 |
17 | import io
18 | import os
19 | from typing import List
20 | from alphafold.model import utils
21 | import haiku as hk
22 | import numpy as np
23 | # Internal import (7716).
24 |
25 |
26 | def casp_model_names(data_dir: str) -> List[str]:
27 | params = os.listdir(os.path.join(data_dir, 'params'))
28 | return [os.path.splitext(filename)[0] for filename in params]
29 |
30 |
31 | def get_model_haiku_params(model_name: str, data_dir: str) -> hk.Params:
32 | """Get the Haiku parameters from a model name."""
33 |
34 | path = os.path.join(data_dir, 'params', f'params_{model_name}.npz')
35 |
36 | with open(path, 'rb') as f:
37 | params = np.load(io.BytesIO(f.read()), allow_pickle=False)
38 |
39 | return utils.flat_params_to_haiku(params)
40 |
--------------------------------------------------------------------------------
/src/alphafold/model/features.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Code to generate processed features."""
16 | import copy
17 | from typing import List, Mapping, Tuple
18 | from alphafold.model.tf import input_pipeline
19 | from alphafold.model.tf import proteins_dataset
20 | import ml_collections
21 | import numpy as np
22 | import os
23 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
24 | import tensorflow.compat.v1 as tf
25 | tf.config.set_visible_devices([], 'GPU')
26 | import pdb
27 |
28 | FeatureDict = Mapping[str, np.ndarray]
29 |
30 |
31 | def make_data_config(
32 | config: ml_collections.ConfigDict,
33 | num_res: int,
34 | ) -> Tuple[ml_collections.ConfigDict, List[str]]:
35 | """Makes a data config for the input pipeline."""
36 | cfg = copy.deepcopy(config.data)
37 |
38 | feature_names = cfg.common.unsupervised_features
39 | if cfg.common.use_templates:
40 | feature_names += cfg.common.template_features
41 |
42 | with cfg.unlocked():
43 | cfg.eval.crop_size = num_res
44 |
45 | return cfg, feature_names
46 |
47 |
48 | def tf_example_to_features(tf_example: tf.train.Example,
49 | config: ml_collections.ConfigDict,
50 | random_seed: int = 0) -> FeatureDict:
51 | """Converts tf_example to numpy feature dictionary."""
52 | num_res = int(tf_example.features.feature['seq_length'].int64_list.value[0])
53 | cfg, feature_names = make_data_config(config, num_res=num_res)
54 |
55 | if 'deletion_matrix_int' in set(tf_example.features.feature):
56 | deletion_matrix_int = (
57 | tf_example.features.feature['deletion_matrix_int'].int64_list.value)
58 | feat = tf.train.Feature(float_list=tf.train.FloatList(
59 | value=map(float, deletion_matrix_int)))
60 | tf_example.features.feature['deletion_matrix'].CopyFrom(feat)
61 | del tf_example.features.feature['deletion_matrix_int']
62 |
63 | tf_graph = tf.Graph()
64 | with tf_graph.as_default(): #, tf.device('/device:CPU:0')
65 | tf.compat.v1.set_random_seed(random_seed)
66 | tensor_dict = proteins_dataset.create_tensor_dict(
67 | raw_data=tf_example.SerializeToString(),
68 | features=feature_names)
69 | processed_batch = input_pipeline.process_tensors_from_config(
70 | tensor_dict, cfg)
71 |
72 | tf_graph.finalize()
73 |
74 | with tf.Session(graph=tf_graph) as sess:
75 | features = sess.run(processed_batch)
76 |
77 | return {k: v for k, v in features.items() if v.dtype != 'O'}
78 |
79 |
80 | def np_example_to_features(np_example: FeatureDict,
81 | config: ml_collections.ConfigDict,
82 | random_seed: int = 0) -> FeatureDict:
83 | """Preprocesses NumPy feature dict using TF pipeline."""
84 | np_example = dict(np_example)
85 | num_res = int(np_example['seq_length'][0])
86 | cfg, feature_names = make_data_config(config, num_res=num_res)
87 |
88 | if 'deletion_matrix_int' in np_example:
89 | np_example['deletion_matrix'] = (
90 | np_example.pop('deletion_matrix_int').astype(np.float32))
91 |
92 | tf_graph = tf.Graph()
93 | with tf_graph.as_default():
94 | tf.compat.v1.set_random_seed(random_seed)
95 | tensor_dict = proteins_dataset.np_to_tensor_dict(
96 | np_example=np_example, features=feature_names)
97 |
98 | processed_batch = input_pipeline.process_tensors_from_config(
99 | tensor_dict, cfg)
100 |
101 | tf_graph.finalize()
102 |
103 | with tf.Session(graph=tf_graph) as sess:
104 | features = sess.run(processed_batch)
105 | return {k: v[0] for k, v in features.items() if v.dtype != 'O'}
106 |
--------------------------------------------------------------------------------
/src/alphafold/model/lddt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """lDDT protein distance score."""
16 | import jax.numpy as jnp
17 |
18 |
19 | def lddt(predicted_points,
20 | true_points,
21 | true_points_mask,
22 | cutoff=15.,
23 | per_residue=False):
24 | """Measure (approximate) lDDT for a batch of coordinates.
25 |
26 | lDDT reference:
27 | Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local
28 | superposition-free score for comparing protein structures and models using
29 | distance difference tests. Bioinformatics 29, 2722–2728 (2013).
30 |
31 | lDDT is a measure of the difference between the true distance matrix and the
32 | distance matrix of the predicted points. The difference is computed only on
33 | points closer than cutoff *in the true structure*.
34 |
35 | This function does not compute the exact lDDT value that the original paper
36 | describes because it does not include terms for physical feasibility
37 | (e.g. bond length violations). Therefore this is only an approximate
38 | lDDT score.
39 |
40 | Args:
41 | predicted_points: (batch, length, 3) array of predicted 3D points
42 | true_points: (batch, length, 3) array of true 3D points
43 | true_points_mask: (batch, length, 1) binary-valued float array. This mask
44 | should be 1 for points that exist in the true points.
45 | cutoff: Maximum distance for a pair of points to be included
46 | per_residue: If true, return score for each residue. Note that the overall
47 | lDDT is not exactly the mean of the per_residue lDDT's because some
48 | residues have more contacts than others.
49 |
50 | Returns:
51 | An (approximate, see above) lDDT score in the range 0-1.
52 | """
53 |
54 | assert len(predicted_points.shape) == 3
55 | assert predicted_points.shape[-1] == 3
56 | assert true_points_mask.shape[-1] == 1
57 | assert len(true_points_mask.shape) == 3
58 |
59 | # Compute true and predicted distance matrices.
60 | dmat_true = jnp.sqrt(1e-10 + jnp.sum(
61 | (true_points[:, :, None] - true_points[:, None, :])**2, axis=-1))
62 |
63 | dmat_predicted = jnp.sqrt(1e-10 + jnp.sum(
64 | (predicted_points[:, :, None] -
65 | predicted_points[:, None, :])**2, axis=-1))
66 |
67 | dists_to_score = (
68 | (dmat_true < cutoff).astype(jnp.float32) * true_points_mask *
69 | jnp.transpose(true_points_mask, [0, 2, 1]) *
70 | (1. - jnp.eye(dmat_true.shape[1])) # Exclude self-interaction.
71 | )
72 |
73 | # Shift unscored distances to be far away.
74 | dist_l1 = jnp.abs(dmat_true - dmat_predicted)
75 |
76 | # True lDDT uses a number of fixed bins.
77 | # We ignore the physical plausibility correction to lDDT, though.
78 | score = 0.25 * ((dist_l1 < 0.5).astype(jnp.float32) +
79 | (dist_l1 < 1.0).astype(jnp.float32) +
80 | (dist_l1 < 2.0).astype(jnp.float32) +
81 | (dist_l1 < 4.0).astype(jnp.float32))
82 |
83 | # Normalize over the appropriate axes.
84 | reduce_axes = (-1,) if per_residue else (-2, -1)
85 | norm = 1. / (1e-10 + jnp.sum(dists_to_score, axis=reduce_axes))
86 | score = norm * (1e-10 + jnp.sum(dists_to_score * score, axis=reduce_axes))
87 |
88 | return score
89 |
--------------------------------------------------------------------------------
/src/alphafold/model/lddt_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for lddt."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from alphafold.model import lddt
20 | import numpy as np
21 |
22 |
23 | class LddtTest(parameterized.TestCase, absltest.TestCase):
24 |
25 | @parameterized.named_parameters(
26 | ('same',
27 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
28 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
29 | [1, 1, 1]),
30 | ('all_shifted',
31 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
32 | [[-1, 0, 0], [4, 0, 0], [9, 0, 0]],
33 | [1, 1, 1]),
34 | ('all_rotated',
35 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
36 | [[0, 0, 0], [0, 5, 0], [0, 10, 0]],
37 | [1, 1, 1]),
38 | ('half_a_dist',
39 | [[0, 0, 0], [5, 0, 0]],
40 | [[0, 0, 0], [5.5-1e-5, 0, 0]],
41 | [1, 1]),
42 | ('one_a_dist',
43 | [[0, 0, 0], [5, 0, 0]],
44 | [[0, 0, 0], [6-1e-5, 0, 0]],
45 | [0.75, 0.75]),
46 | ('two_a_dist',
47 | [[0, 0, 0], [5, 0, 0]],
48 | [[0, 0, 0], [7-1e-5, 0, 0]],
49 | [0.5, 0.5]),
50 | ('four_a_dist',
51 | [[0, 0, 0], [5, 0, 0]],
52 | [[0, 0, 0], [9-1e-5, 0, 0]],
53 | [0.25, 0.25],),
54 | ('five_a_dist',
55 | [[0, 0, 0], [16-1e-5, 0, 0]],
56 | [[0, 0, 0], [11, 0, 0]],
57 | [0, 0]),
58 | ('no_pairs',
59 | [[0, 0, 0], [20, 0, 0]],
60 | [[0, 0, 0], [25-1e-5, 0, 0]],
61 | [1, 1]),
62 | )
63 | def test_lddt(
64 | self, predicted_pos, true_pos, exp_lddt):
65 | predicted_pos = np.array([predicted_pos], dtype=np.float32)
66 | true_points_mask = np.array([[[1]] * len(true_pos)], dtype=np.float32)
67 | true_pos = np.array([true_pos], dtype=np.float32)
68 | cutoff = 15.0
69 | per_residue = True
70 |
71 | result = lddt.lddt(
72 | predicted_pos, true_pos, true_points_mask, cutoff,
73 | per_residue)
74 |
75 | np.testing.assert_almost_equal(result, [exp_lddt], decimal=4)
76 |
77 |
78 | if __name__ == '__main__':
79 | absltest.main()
80 |
--------------------------------------------------------------------------------
/src/alphafold/model/mapping.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Specialized mapping functions."""
16 |
17 | import functools
18 |
19 | from typing import Any, Callable, Optional, Sequence, Union
20 |
21 | import haiku as hk
22 | import jax
23 | import jax.numpy as jnp
24 |
25 |
26 | PYTREE = Any
27 | PYTREE_JAX_ARRAY = Any
28 |
29 | partial = functools.partial
30 | PROXY = object()
31 |
32 |
33 | def _maybe_slice(array, i, slice_size, axis):
34 | if axis is PROXY:
35 | return array
36 | else:
37 | return jax.lax.dynamic_slice_in_dim(
38 | array, i, slice_size=slice_size, axis=axis)
39 |
40 |
41 | def _maybe_get_size(array, axis):
42 | if axis == PROXY:
43 | return -1
44 | else:
45 | return array.shape[axis]
46 |
47 |
48 | def _expand_axes(axes, values, name='sharded_apply'):
49 | values_tree_def = jax.tree_flatten(values)[1]
50 | flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes)
51 | # Replace None's with PROXY
52 | flat_axes = [PROXY if x is None else x for x in flat_axes]
53 | return jax.tree_unflatten(values_tree_def, flat_axes)
54 |
55 |
56 | def sharded_map(
57 | fun: Callable[..., PYTREE_JAX_ARRAY],
58 | shard_size: Union[int, None] = 1,
59 | in_axes: Union[int, PYTREE] = 0,
60 | out_axes: Union[int, PYTREE] = 0) -> Callable[..., PYTREE_JAX_ARRAY]:
61 | """Sharded vmap.
62 |
63 | Maps `fun` over axes, in a way similar to vmap, but does so in shards of
64 | `shard_size`. This allows a smooth trade-off between memory usage
65 | (as in a plain map) vs higher throughput (as in a vmap).
66 |
67 | Args:
68 | fun: Function to apply smap transform to.
69 | shard_size: Integer denoting shard size.
70 | in_axes: Either integer or pytree describing which axis to map over for each
71 | input to `fun`, None denotes broadcasting.
72 | out_axes: integer or pytree denoting to what axis in the output the mapped
73 | over axis maps.
74 |
75 | Returns:
76 | function with smap applied.
77 | """
78 | vmapped_fun = hk.vmap(fun, in_axes, out_axes)
79 | return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes)
80 |
81 |
82 | def sharded_apply(
83 | fun: Callable[..., PYTREE_JAX_ARRAY], # pylint: disable=g-bare-generic
84 | shard_size: Union[int, None] = 1,
85 | in_axes: Union[int, PYTREE] = 0,
86 | out_axes: Union[int, PYTREE] = 0,
87 | new_out_axes: bool = False) -> Callable[..., PYTREE_JAX_ARRAY]:
88 | """Sharded apply.
89 |
90 | Applies `fun` over shards to axes, in a way similar to vmap,
91 | but does so in shards of `shard_size`. Shards are stacked after.
92 | This allows a smooth trade-off between
93 | memory usage (as in a plain map) vs higher throughput (as in a vmap).
94 |
95 | Args:
96 | fun: Function to apply smap transform to.
97 | shard_size: Integer denoting shard size.
98 | in_axes: Either integer or pytree describing which axis to map over for each
99 | input to `fun`, None denotes broadcasting.
100 | out_axes: integer or pytree denoting to what axis in the output the mapped
101 | over axis maps.
102 | new_out_axes: whether to stack outputs on new axes. This assumes that the
103 | output sizes for each shard (including the possible remainder shard) are
104 | the same.
105 |
106 | Returns:
107 | function with smap applied.
108 | """
109 | docstr = ('Mapped version of {fun}. Takes similar arguments to {fun} '
110 | 'but with additional array axes over which {fun} is mapped.')
111 | if new_out_axes:
112 | raise NotImplementedError('New output axes not yet implemented.')
113 |
114 | # shard size None denotes no sharding
115 | if shard_size is None:
116 | return fun
117 |
118 | @jax.util.wraps(fun, docstr=docstr)
119 | def mapped_fn(*args):
120 | # Expand in axes and Determine Loop range
121 | in_axes_ = _expand_axes(in_axes, args)
122 |
123 | in_sizes = jax.tree_map(_maybe_get_size, args, in_axes_)
124 | flat_sizes = jax.tree_flatten(in_sizes)[0]
125 | in_size = max(flat_sizes)
126 | assert all(i in {in_size, -1} for i in flat_sizes)
127 |
128 | num_extra_shards = (in_size - 1) // shard_size
129 |
130 | # Fix Up if necessary
131 | last_shard_size = in_size % shard_size
132 | last_shard_size = shard_size if last_shard_size == 0 else last_shard_size
133 |
134 | def apply_fun_to_slice(slice_start, slice_size):
135 | input_slice = jax.tree_map(
136 | lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis
137 | ), args, in_axes_)
138 | return fun(*input_slice)
139 |
140 | remainder_shape_dtype = hk.eval_shape(
141 | partial(apply_fun_to_slice, 0, last_shard_size))
142 | out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype)
143 | out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype)
144 | out_axes_ = _expand_axes(out_axes, remainder_shape_dtype)
145 |
146 | if num_extra_shards > 0:
147 | regular_shard_shape_dtype = hk.eval_shape(
148 | partial(apply_fun_to_slice, 0, shard_size))
149 | shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype)
150 |
151 | def make_output_shape(axis, shard_shape, remainder_shape):
152 | return shard_shape[:axis] + (
153 | shard_shape[axis] * num_extra_shards +
154 | remainder_shape[axis],) + shard_shape[axis + 1:]
155 |
156 | out_shapes = jax.tree_map(make_output_shape, out_axes_, shard_shapes,
157 | out_shapes)
158 |
159 | # Calls dynamic Update slice with different argument order
160 | # This is here since tree_multimap only works with positional arguments
161 | def dynamic_update_slice_in_dim(full_array, update, axis, i):
162 | return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis)
163 |
164 | def compute_shard(outputs, slice_start, slice_size):
165 | slice_out = apply_fun_to_slice(slice_start, slice_size)
166 | update_slice = partial(
167 | dynamic_update_slice_in_dim, i=slice_start)
168 | return jax.tree_map(update_slice, outputs, slice_out, out_axes_)
169 |
170 | def scan_iteration(outputs, i):
171 | new_outputs = compute_shard(outputs, i, shard_size)
172 | return new_outputs, ()
173 |
174 | slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size)
175 |
176 | def allocate_buffer(dtype, shape):
177 | return jnp.zeros(shape, dtype=dtype)
178 |
179 | outputs = jax.tree_map(allocate_buffer, out_dtypes, out_shapes)
180 |
181 | if slice_starts.shape[0] > 0:
182 | outputs, _ = hk.scan(scan_iteration, outputs, slice_starts)
183 |
184 | if last_shard_size != shard_size:
185 | remainder_start = in_size - last_shard_size
186 | outputs = compute_shard(outputs, remainder_start, last_shard_size)
187 |
188 | return outputs
189 |
190 | return mapped_fn
191 |
192 |
193 | def inference_subbatch(
194 | module: Callable[..., PYTREE_JAX_ARRAY],
195 | subbatch_size: int,
196 | batched_args: Sequence[PYTREE_JAX_ARRAY],
197 | nonbatched_args: Sequence[PYTREE_JAX_ARRAY],
198 | low_memory: bool = True,
199 | input_subbatch_dim: int = 0,
200 | output_subbatch_dim: Optional[int] = None) -> PYTREE_JAX_ARRAY:
201 | """Run through subbatches (like batch apply but with split and concat)."""
202 | assert len(batched_args) > 0 # pylint: disable=g-explicit-length-test
203 |
204 | if not low_memory:
205 | args = list(batched_args) + list(nonbatched_args)
206 | return module(*args)
207 |
208 | if output_subbatch_dim is None:
209 | output_subbatch_dim = input_subbatch_dim
210 |
211 | def run_module(*batched_args):
212 | args = list(batched_args) + list(nonbatched_args)
213 | return module(*args)
214 | sharded_module = sharded_apply(run_module,
215 | shard_size=subbatch_size,
216 | in_axes=input_subbatch_dim,
217 | out_axes=output_subbatch_dim)
218 | return sharded_module(*batched_args)
219 |
--------------------------------------------------------------------------------
/src/alphafold/model/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Code for constructing the model."""
16 | from typing import Any, Mapping, Optional, Union
17 |
18 | from absl import logging
19 | from alphafold.common import confidence
20 | from alphafold.model import features
21 | from alphafold.model import modules
22 | import haiku as hk
23 | import jax
24 | import ml_collections
25 | import numpy as np
26 | import tensorflow.compat.v1 as tf
27 | import tree
28 |
29 |
30 | def get_confidence_metrics(
31 | prediction_result: Mapping[str, Any]) -> Mapping[str, Any]:
32 | """Post processes prediction_result to get confidence metrics."""
33 |
34 | confidence_metrics = {}
35 | confidence_metrics['plddt'] = confidence.compute_plddt(
36 | prediction_result['predicted_lddt']['logits'])
37 | if 'predicted_aligned_error' in prediction_result:
38 | confidence_metrics.update(confidence.compute_predicted_aligned_error(
39 | prediction_result['predicted_aligned_error']['logits'],
40 | prediction_result['predicted_aligned_error']['breaks']))
41 | confidence_metrics['ptm'] = confidence.predicted_tm_score(
42 | prediction_result['predicted_aligned_error']['logits'],
43 | prediction_result['predicted_aligned_error']['breaks'])
44 |
45 | return confidence_metrics
46 |
47 |
48 | class RunModel:
49 | """Container for JAX model."""
50 |
51 | def __init__(self,
52 | config: ml_collections.ConfigDict,
53 | params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None):
54 | self.config = config
55 | self.params = params
56 |
57 | def _forward_fn(batch):
58 | model = modules.AlphaFold(self.config.model)
59 | return model(
60 | batch,
61 | is_training=False,
62 | compute_loss=False,
63 | ensemble_representations=True)
64 |
65 | self.apply = jax.jit(hk.transform(_forward_fn).apply)
66 | self.init = jax.jit(hk.transform(_forward_fn).init)
67 |
68 | def init_params(self, feat: features.FeatureDict, random_seed: int = 0):
69 | """Initializes the model parameters.
70 |
71 | If none were provided when this class was instantiated then the parameters
72 | are randomly initialized.
73 |
74 | Args:
75 | feat: A dictionary of NumPy feature arrays as output by
76 | RunModel.process_features.
77 | random_seed: A random seed to use to initialize the parameters if none
78 | were set when this class was initialized.
79 | """
80 | if not self.params:
81 | # Init params randomly.
82 | rng = jax.random.PRNGKey(random_seed)
83 | self.params = hk.data_structures.to_mutable_dict(
84 | self.init(rng, feat))
85 | logging.warning('Initialized parameters randomly')
86 |
87 | def process_features(
88 | self,
89 | raw_features: Union[tf.train.Example, features.FeatureDict],
90 | random_seed: int) -> features.FeatureDict:
91 | """Processes features to prepare for feeding them into the model.
92 |
93 | Args:
94 | raw_features: The output of the data pipeline either as a dict of NumPy
95 | arrays or as a tf.train.Example.
96 | random_seed: The random seed to use when processing the features.
97 |
98 | Returns:
99 | A dict of NumPy feature arrays suitable for feeding into the model.
100 | """
101 | if isinstance(raw_features, dict):
102 | return features.np_example_to_features(
103 | np_example=raw_features,
104 | config=self.config,
105 | random_seed=random_seed)
106 | else:
107 | return features.tf_example_to_features(
108 | tf_example=raw_features,
109 | config=self.config,
110 | random_seed=random_seed)
111 |
112 | def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct:
113 | self.init_params(feat)
114 | logging.info('Running eval_shape with shape(feat) = %s',
115 | tree.map_structure(lambda x: x.shape, feat))
116 | shape = jax.eval_shape(self.apply, self.params, jax.random.PRNGKey(0), feat)
117 | logging.info('Output shape was %s', shape)
118 | return shape
119 |
120 | def predict(self, feat: features.FeatureDict) -> Mapping[str, Any]:
121 | """Makes a prediction by inferencing the model on the provided features.
122 |
123 | Args:
124 | feat: A dictionary of NumPy feature arrays as output by
125 | RunModel.process_features.
126 |
127 | Returns:
128 | A dictionary of model outputs.
129 | """
130 | self.init_params(feat)
131 | logging.info('Running predict with shape(feat) = %s',
132 | tree.map_structure(lambda x: x.shape, feat))
133 | result = self.apply(self.params, jax.random.PRNGKey(0), feat)
134 | # This block is to ensure benchmark timings are accurate. Some blocking is
135 | # already happening when computing get_confidence_metrics, and this ensures
136 | # all outputs are blocked on.
137 | jax.tree_map(lambda x: x.block_until_ready(), result)
138 | result.update(get_confidence_metrics(result))
139 | logging.info('Output shape was %s',
140 | tree.map_structure(lambda x: x.shape, result))
141 | return result
142 |
--------------------------------------------------------------------------------
/src/alphafold/model/prng.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A collection of utilities surrounding PRNG usage in protein folding."""
16 |
17 | import haiku as hk
18 | import jax
19 |
20 |
21 | def safe_dropout(*, tensor, safe_key, rate, is_deterministic, is_training):
22 | if is_training and rate != 0.0 and not is_deterministic:
23 | return hk.dropout(safe_key.get(), rate, tensor)
24 | else:
25 | return tensor
26 |
27 |
28 | class SafeKey:
29 | """Safety wrapper for PRNG keys."""
30 |
31 | def __init__(self, key):
32 | self._key = key
33 | self._used = False
34 |
35 | def _assert_not_used(self):
36 | if self._used:
37 | raise RuntimeError('Random key has been used previously.')
38 |
39 | def get(self):
40 | self._assert_not_used()
41 | self._used = True
42 | return self._key
43 |
44 | def split(self, num_keys=2):
45 | self._assert_not_used()
46 | self._used = True
47 | new_keys = jax.random.split(self._key, num_keys)
48 | return jax.tree_map(SafeKey, tuple(new_keys))
49 |
50 | def duplicate(self, num_keys=2):
51 | self._assert_not_used()
52 | self._used = True
53 | return tuple(SafeKey(self._key) for _ in range(num_keys))
54 |
55 |
56 | def _safe_key_flatten(safe_key):
57 | # Flatten transfers "ownership" to the tree
58 | return (safe_key._key,), safe_key._used # pylint: disable=protected-access
59 |
60 |
61 | def _safe_key_unflatten(aux_data, children):
62 | ret = SafeKey(children[0])
63 | ret._used = aux_data # pylint: disable=protected-access
64 | return ret
65 |
66 |
67 | jax.tree_util.register_pytree_node(
68 | SafeKey, _safe_key_flatten, _safe_key_unflatten)
69 |
70 |
--------------------------------------------------------------------------------
/src/alphafold/model/prng_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for prng."""
16 |
17 | from absl.testing import absltest
18 | from alphafold.model import prng
19 | import jax
20 |
21 |
22 | class PrngTest(absltest.TestCase):
23 |
24 | def test_key_reuse(self):
25 |
26 | init_key = jax.random.PRNGKey(42)
27 | safe_key = prng.SafeKey(init_key)
28 | _, safe_key = safe_key.split()
29 |
30 | raw_key = safe_key.get()
31 |
32 | self.assertNotEqual(raw_key[0], init_key[0])
33 | self.assertNotEqual(raw_key[1], init_key[1])
34 |
35 | with self.assertRaises(RuntimeError):
36 | safe_key.get()
37 |
38 | with self.assertRaises(RuntimeError):
39 | safe_key.split()
40 |
41 | with self.assertRaises(RuntimeError):
42 | safe_key.duplicate()
43 |
44 |
45 | if __name__ == '__main__':
46 | absltest.main()
47 |
--------------------------------------------------------------------------------
/src/alphafold/model/quat_affine_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for quat_affine."""
16 |
17 | from absl import logging
18 | from absl.testing import absltest
19 | from alphafold.model import quat_affine
20 | import jax
21 | import jax.numpy as jnp
22 | import numpy as np
23 |
24 | VERBOSE = False
25 | np.set_printoptions(precision=3, suppress=True)
26 |
27 | r2t = quat_affine.rot_list_to_tensor
28 | v2t = quat_affine.vec_list_to_tensor
29 |
30 | q2r = lambda q: r2t(quat_affine.quat_to_rot(q))
31 |
32 |
33 | class QuatAffineTest(absltest.TestCase):
34 |
35 | def _assert_check(self, to_check, tol=1e-5):
36 | for k, (correct, generated) in to_check.items():
37 | if VERBOSE:
38 | logging.info(k)
39 | logging.info('Correct %s', correct)
40 | logging.info('Predicted %s', generated)
41 | self.assertLess(np.max(np.abs(correct - generated)), tol)
42 |
43 | def test_conversion(self):
44 | quat = jnp.array([-2., 5., -1., 4.])
45 |
46 | rotation = jnp.array([
47 | [0.26087, 0.130435, 0.956522],
48 | [-0.565217, -0.782609, 0.26087],
49 | [0.782609, -0.608696, -0.130435]])
50 |
51 | translation = jnp.array([1., -3., 4.])
52 | point = jnp.array([0.7, 3.2, -2.9])
53 |
54 | a = quat_affine.QuatAffine(quat, translation, unstack_inputs=True)
55 | true_new_point = jnp.matmul(rotation, point[:, None])[:, 0] + translation
56 |
57 | self._assert_check({
58 | 'rot': (rotation, r2t(a.rotation)),
59 | 'trans': (translation, v2t(a.translation)),
60 | 'point': (true_new_point,
61 | v2t(a.apply_to_point(jnp.moveaxis(point, -1, 0)))),
62 | # Because of the double cover, we must be careful and compare rotations
63 | 'quat': (q2r(a.quaternion),
64 | q2r(quat_affine.rot_to_quat(a.rotation))),
65 |
66 | })
67 |
68 | def test_double_cover(self):
69 | """Test that -q is the same rotation as q."""
70 | rng = jax.random.PRNGKey(42)
71 | keys = jax.random.split(rng)
72 | q = jax.random.normal(keys[0], (2, 4))
73 | trans = jax.random.normal(keys[1], (2, 3))
74 | a1 = quat_affine.QuatAffine(q, trans, unstack_inputs=True)
75 | a2 = quat_affine.QuatAffine(-q, trans, unstack_inputs=True)
76 |
77 | self._assert_check({
78 | 'rot': (r2t(a1.rotation),
79 | r2t(a2.rotation)),
80 | 'trans': (v2t(a1.translation),
81 | v2t(a2.translation)),
82 | })
83 |
84 | def test_homomorphism(self):
85 | rng = jax.random.PRNGKey(42)
86 | keys = jax.random.split(rng, 4)
87 | vec_q1 = jax.random.normal(keys[0], (2, 3))
88 |
89 | q1 = jnp.concatenate([
90 | jnp.ones_like(vec_q1)[:, :1],
91 | vec_q1], axis=-1)
92 |
93 | q2 = jax.random.normal(keys[1], (2, 4))
94 | t1 = jax.random.normal(keys[2], (2, 3))
95 | t2 = jax.random.normal(keys[3], (2, 3))
96 |
97 | a1 = quat_affine.QuatAffine(q1, t1, unstack_inputs=True)
98 | a2 = quat_affine.QuatAffine(q2, t2, unstack_inputs=True)
99 | a21 = a2.pre_compose(jnp.concatenate([vec_q1, t1], axis=-1))
100 |
101 | rng, key = jax.random.split(rng)
102 | x = jax.random.normal(key, (2, 3))
103 | new_x = a21.apply_to_point(jnp.moveaxis(x, -1, 0))
104 | new_x_apply2 = a2.apply_to_point(a1.apply_to_point(jnp.moveaxis(x, -1, 0)))
105 |
106 | self._assert_check({
107 | 'quat': (q2r(quat_affine.quat_multiply(a2.quaternion, a1.quaternion)),
108 | q2r(a21.quaternion)),
109 | 'rot': (jnp.matmul(r2t(a2.rotation), r2t(a1.rotation)),
110 | r2t(a21.rotation)),
111 | 'point': (v2t(new_x_apply2),
112 | v2t(new_x)),
113 | 'inverse': (x, v2t(a21.invert_point(new_x))),
114 | })
115 |
116 | def test_batching(self):
117 | """Test that affine applies batchwise."""
118 | rng = jax.random.PRNGKey(42)
119 | keys = jax.random.split(rng, 3)
120 | q = jax.random.uniform(keys[0], (5, 2, 4))
121 | t = jax.random.uniform(keys[1], (2, 3))
122 | x = jax.random.uniform(keys[2], (5, 1, 3))
123 |
124 | a = quat_affine.QuatAffine(q, t, unstack_inputs=True)
125 | y = v2t(a.apply_to_point(jnp.moveaxis(x, -1, 0)))
126 |
127 | y_list = []
128 | for i in range(5):
129 | for j in range(2):
130 | a_local = quat_affine.QuatAffine(q[i, j], t[j],
131 | unstack_inputs=True)
132 | y_local = v2t(a_local.apply_to_point(jnp.moveaxis(x[i, 0], -1, 0)))
133 | y_list.append(y_local)
134 | y_combine = jnp.reshape(jnp.stack(y_list, axis=0), (5, 2, 3))
135 |
136 | self._assert_check({
137 | 'batch': (y_combine, y),
138 | 'quat': (q2r(a.quaternion),
139 | q2r(quat_affine.rot_to_quat(a.rotation))),
140 | })
141 |
142 | def assertAllClose(self, a, b, rtol=1e-06, atol=1e-06):
143 | self.assertTrue(np.allclose(a, b, rtol=rtol, atol=atol))
144 |
145 | def assertAllEqual(self, a, b):
146 | self.assertTrue(np.all(np.array(a) == np.array(b)))
147 |
148 |
149 | if __name__ == '__main__':
150 | absltest.main()
151 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Alphafold model TensorFlow code."""
15 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/input_pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Feature pre-processing input pipeline for AlphaFold."""
16 |
17 | from alphafold.model.tf import data_transforms
18 | from alphafold.model.tf import shape_placeholders
19 | import tensorflow.compat.v1 as tf
20 | import pdb
21 | import tree
22 |
23 | # Pylint gets confused by the curry1 decorator because it changes the number
24 | # of arguments to the function.
25 | # pylint:disable=no-value-for-parameter
26 |
27 |
28 | NUM_RES = shape_placeholders.NUM_RES
29 | NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ
30 | NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ
31 | NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
32 |
33 |
34 | def nonensembled_map_fns(data_config):
35 | """Input pipeline functions which are not ensembled."""
36 | common_cfg = data_config.common
37 |
38 | map_fns = [
39 | data_transforms.correct_msa_restypes,
40 | data_transforms.add_distillation_flag(False),
41 | data_transforms.cast_64bit_ints,
42 | data_transforms.squeeze_features,
43 | # Keep to not disrupt RNG.
44 | data_transforms.randomly_replace_msa_with_unknown(0.0),
45 | data_transforms.make_seq_mask,
46 | data_transforms.make_msa_mask,
47 | # Compute the HHblits profile if it's not set. This has to be run before
48 | # sampling the MSA.
49 | data_transforms.make_hhblits_profile,
50 | data_transforms.make_random_crop_to_size_seed,
51 | ]
52 | if common_cfg.use_templates:
53 | map_fns.extend([
54 | data_transforms.fix_templates_aatype,
55 | data_transforms.make_template_mask,
56 | data_transforms.make_pseudo_beta('template_')
57 | ])
58 | map_fns.extend([
59 | data_transforms.make_atom14_masks,
60 | ])
61 |
62 | return map_fns
63 |
64 |
65 | def ensembled_map_fns(data_config):
66 | """Input pipeline functions that can be ensembled and averaged."""
67 | common_cfg = data_config.common
68 | eval_cfg = data_config.eval
69 |
70 | map_fns = []
71 |
72 | if common_cfg.reduce_msa_clusters_by_max_templates:
73 | pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates
74 | else:
75 | pad_msa_clusters = eval_cfg.max_msa_clusters
76 |
77 | max_msa_clusters = pad_msa_clusters
78 | max_extra_msa = common_cfg.max_extra_msa
79 |
80 | map_fns.append(
81 | data_transforms.sample_msa(
82 | max_msa_clusters,
83 | keep_extra=True))
84 |
85 | if 'masked_msa' in common_cfg:
86 | # Masked MSA should come *before* MSA clustering so that
87 | # the clustering and full MSA profile do not leak information about
88 | # the masked locations and secret corrupted locations.
89 | map_fns.append(
90 | data_transforms.make_masked_msa(common_cfg.masked_msa,
91 | eval_cfg.masked_msa_replace_fraction))
92 |
93 | if common_cfg.msa_cluster_features:
94 | map_fns.append(data_transforms.nearest_neighbor_clusters())
95 | map_fns.append(data_transforms.summarize_clusters())
96 |
97 | # Crop after creating the cluster profiles.
98 | if max_extra_msa:
99 | map_fns.append(data_transforms.crop_extra_msa(max_extra_msa))
100 | else:
101 | map_fns.append(data_transforms.delete_extra_msa)
102 |
103 | map_fns.append(data_transforms.make_msa_feat())
104 |
105 | crop_feats = dict(eval_cfg.feat)
106 |
107 | if eval_cfg.fixed_size:
108 | map_fns.append(data_transforms.select_feat(list(crop_feats)))
109 | map_fns.append(data_transforms.random_crop_to_size(
110 | eval_cfg.crop_size,
111 | eval_cfg.max_templates,
112 | crop_feats,
113 | eval_cfg.subsample_templates))
114 | map_fns.append(data_transforms.make_fixed_size(
115 | crop_feats,
116 | pad_msa_clusters,
117 | common_cfg.max_extra_msa,
118 | eval_cfg.crop_size,
119 | eval_cfg.max_templates))
120 | else:
121 | map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates))
122 |
123 | return map_fns
124 |
125 |
126 | def process_tensors_from_config(tensors, data_config):
127 | """Apply filters and maps to an existing dataset, based on the config."""
128 |
129 | def wrap_ensemble_fn(data, i):
130 | """Function to be mapped over the ensemble dimension."""
131 | d = data.copy()
132 | fns = ensembled_map_fns(data_config)
133 | fn = compose(fns)
134 | d['ensemble_index'] = i
135 | return fn(d)
136 |
137 | eval_cfg = data_config.eval
138 | tensors = compose(
139 | nonensembled_map_fns(
140 | data_config))(
141 | tensors)
142 |
143 | tensors_0 = wrap_ensemble_fn(tensors, tf.constant(0))
144 | num_ensemble = eval_cfg.num_ensemble
145 | if data_config.common.resample_msa_in_recycling:
146 | # Separate batch per ensembling & recycling step.
147 | num_ensemble *= data_config.common.num_recycle + 1
148 |
149 | if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1:
150 | fn_output_signature = tree.map_structure(
151 | tf.TensorSpec.from_tensor, tensors_0)
152 | tensors = tf.map_fn(
153 | lambda x: wrap_ensemble_fn(tensors, x),
154 | tf.range(num_ensemble),
155 | parallel_iterations=1,
156 | fn_output_signature=fn_output_signature)
157 | else:
158 | tensors = tree.map_structure(lambda x: x[None],
159 | tensors_0)
160 | return tensors
161 |
162 |
163 | @data_transforms.curry1
164 | def compose(x, fs):
165 | for f in fs:
166 | x = f(x)
167 | return x
168 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/protein_features.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Contains descriptions of various protein features."""
16 | import enum
17 | from typing import Dict, Optional, Sequence, Tuple, Union
18 | from alphafold.common import residue_constants
19 | import tensorflow.compat.v1 as tf
20 | import pdb
21 |
22 | # Type aliases.
23 | FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]]
24 |
25 |
26 | class FeatureType(enum.Enum):
27 | ZERO_DIM = 0 # Shape [x]
28 | ONE_DIM = 1 # Shape [num_res, x]
29 | TWO_DIM = 2 # Shape [num_res, num_res, x]
30 | MSA = 3 # Shape [msa_length, num_res, x]
31 |
32 |
33 | # Placeholder values that will be replaced with their true value at runtime.
34 | NUM_RES = "num residues placeholder"
35 | NUM_SEQ = "length msa placeholder"
36 | NUM_TEMPLATES = "num templates placeholder"
37 | # Sizes of the protein features, NUM_RES and NUM_SEQ are allowed as placeholders
38 | # to be replaced with the number of residues and the number of sequences in the
39 | # multiple sequence alignment, respectively.
40 |
41 |
42 | FEATURES = {
43 | #### Static features of a protein sequence ####
44 | "aatype": (tf.float32, [NUM_RES, 21]),
45 | "between_segment_residues": (tf.int64, [NUM_RES, 1]),
46 | "deletion_matrix": (tf.float32, [NUM_SEQ, NUM_RES, 1]),
47 | "domain_name": (tf.string, [1]),
48 | "msa": (tf.int64, [NUM_SEQ, NUM_RES, 1]),
49 | "num_alignments": (tf.int64, [NUM_RES, 1]),
50 | "residue_index": (tf.int64, [NUM_RES, 1]),
51 | "seq_length": (tf.int64, [NUM_RES, 1]),
52 | "sequence": (tf.string, [1]),
53 | "all_atom_positions": (tf.float32,
54 | [NUM_RES, residue_constants.atom_type_num, 3]),
55 | "all_atom_mask": (tf.int64, [NUM_RES, residue_constants.atom_type_num]),
56 | "resolution": (tf.float32, [1]),
57 | "template_domain_names": (tf.string, [NUM_TEMPLATES]),
58 | "template_sum_probs": (tf.float32, [NUM_TEMPLATES, 1]),
59 | "template_aatype": (tf.float32, [NUM_TEMPLATES, NUM_RES, 22]),
60 | "template_all_atom_positions": (tf.float32, [
61 | NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3
62 | ]),
63 | "template_all_atom_masks": (tf.float32, [
64 | NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1
65 | ]),
66 | }
67 |
68 | FEATURE_TYPES = {k: v[0] for k, v in FEATURES.items()}
69 | FEATURE_SIZES = {k: v[1] for k, v in FEATURES.items()}
70 |
71 |
72 | def register_feature(name: str,
73 | type_: tf.dtypes.DType,
74 | shape_: Tuple[Union[str, int]]):
75 | """Register extra features used in custom datasets."""
76 | FEATURES[name] = (type_, shape_)
77 | FEATURE_TYPES[name] = type_
78 | FEATURE_SIZES[name] = shape_
79 |
80 |
81 | def shape(feature_name: str,
82 | num_residues: int,
83 | msa_length: int,
84 | num_templates: Optional[int] = None,
85 | features: Optional[FeaturesMetadata] = None):
86 | """Get the shape for the given feature name.
87 |
88 | This is near identical to _get_tf_shape_no_placeholders() but with 2
89 | differences:
90 | * This method does not calculate a single placeholder from the total number of
91 | elements (eg given and size := 12, this won't deduce NUM_RES
92 | must be 4)
93 | * This method will work with tensors
94 |
95 | Args:
96 | feature_name: String identifier for the feature. If the feature name ends
97 | with "_unnormalized", this suffix is stripped off.
98 | num_residues: The number of residues in the current domain - some elements
99 | of the shape can be dynamic and will be replaced by this value.
100 | msa_length: The number of sequences in the multiple sequence alignment, some
101 | elements of the shape can be dynamic and will be replaced by this value.
102 | If the number of alignments is unknown / not read, please pass None for
103 | msa_length.
104 | num_templates (optional): The number of templates in this tfexample.
105 | features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES.
106 |
107 | Returns:
108 | List of ints representation the tensor size.
109 |
110 | Raises:
111 | ValueError: If a feature is requested but no concrete placeholder value is
112 | given.
113 | """
114 | features = features or FEATURES
115 | if feature_name.endswith("_unnormalized"):
116 | feature_name = feature_name[:-13]
117 |
118 | unused_dtype, raw_sizes = features[feature_name]
119 | replacements = {NUM_RES: num_residues,
120 | NUM_SEQ: msa_length}
121 |
122 | if num_templates is not None:
123 | replacements[NUM_TEMPLATES] = num_templates
124 |
125 | sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes]
126 | for dimension in sizes:
127 | if isinstance(dimension, str):
128 | raise ValueError("Could not parse %s (shape: %s) with values: %s" % (
129 | feature_name, raw_sizes, replacements))
130 | return sizes
131 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/protein_features_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for protein_features."""
16 | import uuid
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from alphafold.model.tf import protein_features
21 | import tensorflow.compat.v1 as tf
22 |
23 |
24 | def _random_bytes():
25 | return str(uuid.uuid4()).encode('utf-8')
26 |
27 |
28 | class FeaturesTest(parameterized.TestCase, tf.test.TestCase):
29 |
30 | def testFeatureNames(self):
31 | self.assertEqual(len(protein_features.FEATURE_SIZES),
32 | len(protein_features.FEATURE_TYPES))
33 | sorted_size_names = sorted(protein_features.FEATURE_SIZES.keys())
34 | sorted_type_names = sorted(protein_features.FEATURE_TYPES.keys())
35 | for i, size_name in enumerate(sorted_size_names):
36 | self.assertEqual(size_name, sorted_type_names[i])
37 |
38 | def testReplacement(self):
39 | for name in protein_features.FEATURE_SIZES.keys():
40 | sizes = protein_features.shape(name,
41 | num_residues=12,
42 | msa_length=24,
43 | num_templates=3)
44 | for x in sizes:
45 | self.assertEqual(type(x), int)
46 | self.assertGreater(x, 0)
47 |
48 |
49 | if __name__ == '__main__':
50 | tf.disable_v2_behavior()
51 | absltest.main()
52 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/proteins_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Datasets consisting of proteins."""
16 | from typing import Dict, Mapping, Optional, Sequence
17 | from alphafold.model.tf import protein_features
18 | import numpy as np
19 | import tensorflow.compat.v1 as tf
20 |
21 | TensorDict = Dict[str, tf.Tensor]
22 |
23 |
24 | def parse_tfexample(
25 | raw_data: bytes,
26 | features: protein_features.FeaturesMetadata,
27 | key: Optional[str] = None) -> Dict[str, tf.train.Feature]:
28 | """Read a single TF Example proto and return a subset of its features.
29 |
30 | Args:
31 | raw_data: A serialized tf.Example proto.
32 | features: A dictionary of features, mapping string feature names to a tuple
33 | (dtype, shape). This dictionary should be a subset of
34 | protein_features.FEATURES (or the dictionary itself for all features).
35 | key: Optional string with the SSTable key of that tf.Example. This will be
36 | added into features as a 'key' but only if requested in features.
37 |
38 | Returns:
39 | A dictionary of features mapping feature names to features. Only the given
40 | features are returned, all other ones are filtered out.
41 | """
42 | feature_map = {
43 | k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True)
44 | for k, v in features.items()
45 | }
46 | parsed_features = tf.io.parse_single_example(raw_data, feature_map)
47 | reshaped_features = parse_reshape_logic(parsed_features, features, key=key)
48 |
49 | return reshaped_features
50 |
51 |
52 | def _first(tensor: tf.Tensor) -> tf.Tensor:
53 | """Returns the 1st element - the input can be a tensor or a scalar."""
54 | return tf.reshape(tensor, shape=(-1,))[0]
55 |
56 |
57 | def parse_reshape_logic(
58 | parsed_features: TensorDict,
59 | features: protein_features.FeaturesMetadata,
60 | key: Optional[str] = None) -> TensorDict:
61 | """Transforms parsed serial features to the correct shape."""
62 | # Find out what is the number of sequences and the number of alignments.
63 | num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32)
64 |
65 | if "num_alignments" in parsed_features:
66 | num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32)
67 | else:
68 | num_msa = 0
69 |
70 | num_templates = 0
71 | # if "template_domain_names" in parsed_features:
72 | # num_templates = tf.cast(
73 | # tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32)
74 | # else:
75 | # num_templates = 0
76 |
77 | if key is not None and "key" in features:
78 | parsed_features["key"] = [key] # Expand dims from () to (1,).
79 |
80 | # Reshape the tensors according to the sequence length and num alignments.
81 | for k, v in parsed_features.items():
82 | new_shape = protein_features.shape(
83 | feature_name=k,
84 | num_residues=num_residues,
85 | msa_length=num_msa,
86 | num_templates=num_templates,
87 | features=features)
88 | new_shape_size = tf.constant(1, dtype=tf.int32)
89 | for dim in new_shape:
90 | new_shape_size *= tf.cast(dim, tf.int32)
91 |
92 | assert_equal = tf.assert_equal(
93 | tf.size(v), new_shape_size,
94 | name="assert_%s_shape_correct" % k,
95 | message="The size of feature %s (%s) could not be reshaped "
96 | "into %s" % (k, tf.size(v), new_shape))
97 | if "template" not in k:
98 | # Make sure the feature we are reshaping is not empty.
99 | assert_non_empty = tf.assert_greater(
100 | tf.size(v), 0, name="assert_%s_non_empty" % k,
101 | message="The feature %s is not set in the tf.Example. Either do not "
102 | "request the feature or use a tf.Example that has the "
103 | "feature set." % k)
104 | with tf.control_dependencies([assert_non_empty, assert_equal]):
105 | parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)
106 | else:
107 | with tf.control_dependencies([assert_equal]):
108 | parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)
109 |
110 | return parsed_features
111 |
112 |
113 | def _make_features_metadata(
114 | feature_names: Sequence[str]) -> protein_features.FeaturesMetadata:
115 | """Makes a feature name to type and shape mapping from a list of names."""
116 | # Make sure these features are always read.
117 | required_features = ["aatype", "sequence", "seq_length"]
118 | feature_names = list(set(feature_names) | set(required_features))
119 |
120 | features_metadata = {name: protein_features.FEATURES[name]
121 | for name in feature_names}
122 | return features_metadata
123 |
124 |
125 | def create_tensor_dict(
126 | raw_data: bytes,
127 | features: Sequence[str],
128 | key: Optional[str] = None,
129 | ) -> TensorDict:
130 | """Creates a dictionary of tensor features.
131 |
132 | Args:
133 | raw_data: A serialized tf.Example proto.
134 | features: A list of strings of feature names to be returned in the dataset.
135 | key: Optional string with the SSTable key of that tf.Example. This will be
136 | added into features as a 'key' but only if requested in features.
137 |
138 | Returns:
139 | A dictionary of features mapping feature names to features. Only the given
140 | features are returned, all other ones are filtered out.
141 | """
142 | features_metadata = _make_features_metadata(features)
143 | return parse_tfexample(raw_data, features_metadata, key)
144 |
145 |
146 | def np_to_tensor_dict(
147 | np_example: Mapping[str, np.ndarray],
148 | features: Sequence[str],
149 | ) -> TensorDict:
150 | """Creates dict of tensors from a dict of NumPy arrays.
151 |
152 | Args:
153 | np_example: A dict of NumPy feature arrays.
154 | features: A list of strings of feature names to be returned in the dataset.
155 |
156 | Returns:
157 | A dictionary of features mapping feature names to features. Only the given
158 | features are returned, all other ones are filtered out.
159 | """
160 | features_metadata = _make_features_metadata(features)
161 | tensor_dict = {k: tf.constant(v) for k, v in np_example.items()
162 | if k in features_metadata}
163 |
164 | # Ensures shapes are as expected. Needed for setting size of empty features
165 | # e.g. when no template hits were found.
166 | tensor_dict = parse_reshape_logic(tensor_dict, features_metadata)
167 | return tensor_dict
168 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/shape_helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for dealing with shapes of TensorFlow tensors."""
16 | import tensorflow.compat.v1 as tf
17 |
18 |
19 | def shape_list(x):
20 | """Return list of dimensions of a tensor, statically where possible.
21 |
22 | Like `x.shape.as_list()` but with tensors instead of `None`s.
23 |
24 | Args:
25 | x: A tensor.
26 | Returns:
27 | A list with length equal to the rank of the tensor. The n-th element of the
28 | list is an integer when that dimension is statically known otherwise it is
29 | the n-th element of `tf.shape(x)`.
30 | """
31 | x = tf.convert_to_tensor(x)
32 |
33 | # If unknown rank, return dynamic shape
34 | if x.get_shape().dims is None:
35 | return tf.shape(x)
36 |
37 | static = x.get_shape().as_list()
38 | shape = tf.shape(x)
39 |
40 | ret = []
41 | for i in range(len(static)):
42 | dim = static[i]
43 | if dim is None:
44 | dim = shape[i]
45 | ret.append(dim)
46 | return ret
47 |
48 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/shape_helpers_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for shape_helpers."""
16 |
17 | from alphafold.model.tf import shape_helpers
18 | import numpy as np
19 | import tensorflow.compat.v1 as tf
20 |
21 |
22 | class ShapeTest(tf.test.TestCase):
23 |
24 | def test_shape_list(self):
25 | """Test that shape_list can allow for reshaping to dynamic shapes."""
26 | a = tf.zeros([10, 4, 4, 2])
27 | p = tf.placeholder(tf.float32, shape=[None, None, 1, 4, 4])
28 | shape_dyn = shape_helpers.shape_list(p)[:2] + [4, 4]
29 |
30 | b = tf.reshape(a, shape_dyn)
31 | with self.session() as sess:
32 | out = sess.run(b, feed_dict={p: np.ones((20, 1, 1, 4, 4))})
33 |
34 | self.assertAllEqual(out.shape, (20, 1, 4, 4))
35 |
36 |
37 | if __name__ == '__main__':
38 | tf.disable_v2_behavior()
39 | tf.test.main()
40 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/shape_placeholders.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Placeholder values for run-time varying dimension sizes."""
16 |
17 | NUM_RES = 'num residues placeholder'
18 | NUM_MSA_SEQ = 'msa placeholder'
19 | NUM_EXTRA_SEQ = 'extra msa placeholder'
20 | NUM_TEMPLATES = 'num templates placeholder'
21 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Shared utilities for various components."""
16 | import tensorflow.compat.v1 as tf
17 |
18 |
19 | def tf_combine_mask(*masks):
20 | """Take the intersection of float-valued masks."""
21 | ret = 1
22 | for m in masks:
23 | ret *= m
24 | return ret
25 |
26 |
27 | class SeedMaker(object):
28 | """Return unique seeds."""
29 |
30 | def __init__(self, initial_seed=0):
31 | self.next_seed = initial_seed
32 |
33 | def __call__(self):
34 | i = self.next_seed
35 | self.next_seed += 1
36 | return i
37 |
38 | seed_maker = SeedMaker()
39 |
40 |
41 | def make_random_seed():
42 | return tf.random.uniform([2],
43 | tf.int32.min,
44 | tf.int32.max,
45 | tf.int32,
46 | seed=seed_maker())
47 |
48 |
--------------------------------------------------------------------------------
/src/alphafold/model/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A collection of JAX utility functions for use in protein folding."""
16 |
17 | import collections.abc as collections
18 | import numbers
19 | from typing import Mapping
20 |
21 | import haiku as hk
22 | import jax
23 | import jax.numpy as jnp
24 | import numpy as np
25 |
26 |
27 | def final_init(config):
28 | if config.zero_init:
29 | return 'zeros'
30 | else:
31 | return 'linear'
32 |
33 |
34 | def batched_gather(params, indices, axis=0, batch_dims=0):
35 | """Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`."""
36 | take_fn = lambda p, i: jnp.take(p, i, axis=axis)
37 | for _ in range(batch_dims):
38 | take_fn = jax.vmap(take_fn)
39 | return take_fn(params, indices)
40 |
41 |
42 | def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10):
43 | """Masked mean."""
44 | if drop_mask_channel:
45 | mask = mask[..., 0]
46 |
47 | mask_shape = mask.shape
48 | value_shape = value.shape
49 |
50 | assert len(mask_shape) == len(value_shape)
51 |
52 | if isinstance(axis, numbers.Integral):
53 | axis = [axis]
54 | elif axis is None:
55 | axis = list(range(len(mask_shape)))
56 | assert isinstance(axis, collections.Iterable), (
57 | 'axis needs to be either an iterable, integer or "None"')
58 |
59 | broadcast_factor = 1.
60 | for axis_ in axis:
61 | value_size = value_shape[axis_]
62 | mask_size = mask_shape[axis_]
63 | if mask_size == 1:
64 | broadcast_factor *= value_size
65 | else:
66 | assert mask_size == value_size
67 |
68 | return (jnp.sum(mask * value, axis=axis) /
69 | (jnp.sum(mask, axis=axis) * broadcast_factor + eps))
70 |
71 |
72 | def flat_params_to_haiku(params: Mapping[str, np.ndarray]) -> hk.Params:
73 | """Convert a dictionary of NumPy arrays to Haiku parameters."""
74 | hk_params = {}
75 | for path, array in params.items():
76 | scope, name = path.split('//')
77 | if scope not in hk_params:
78 | hk_params[scope] = {}
79 | hk_params[scope][name] = jnp.array(array)
80 |
81 | return hk_params
82 |
--------------------------------------------------------------------------------
/src/alphafold/relax/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Amber relaxation."""
15 |
--------------------------------------------------------------------------------
/src/alphafold/relax/amber_minimize_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for amber_minimize."""
16 | import os
17 |
18 | from absl.testing import absltest
19 | from alphafold.common import protein
20 | from alphafold.relax import amber_minimize
21 | import numpy as np
22 | # Internal import (7716).
23 |
24 |
25 | def _load_test_protein(data_path):
26 | pdb_path = os.path.join(absltest.get_default_test_srcdir(), data_path)
27 | with open(pdb_path, 'r') as f:
28 | return protein.from_pdb_string(f.read())
29 |
30 |
31 | class AmberMinimizeTest(absltest.TestCase):
32 |
33 | def test_multiple_disulfides_target(self):
34 | prot = _load_test_protein(
35 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb'
36 | )
37 | ret = amber_minimize.run_pipeline(prot, max_iterations=10, max_attempts=1,
38 | stiffness=10.)
39 | self.assertIn('opt_time', ret)
40 | self.assertIn('min_attempts', ret)
41 |
42 | def test_raises_invalid_protein_assertion(self):
43 | prot = _load_test_protein(
44 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb'
45 | )
46 | prot.atom_mask[4, :] = 0
47 | with self.assertRaisesRegex(
48 | ValueError,
49 | 'Amber minimization can only be performed on proteins with well-defined'
50 | ' residues. This protein contains at least one residue with no atoms.'):
51 | amber_minimize.run_pipeline(prot, max_iterations=10,
52 | stiffness=1.,
53 | max_attempts=1)
54 |
55 | def test_iterative_relax(self):
56 | prot = _load_test_protein(
57 | 'alphafold/relax/testdata/with_violations.pdb'
58 | )
59 | violations = amber_minimize.get_violation_metrics(prot)
60 | self.assertGreater(violations['num_residue_violations'], 0)
61 | out = amber_minimize.run_pipeline(
62 | prot=prot, max_outer_iterations=10, stiffness=10.)
63 | self.assertLess(out['efinal'], out['einit'])
64 | self.assertEqual(0, out['num_residue_violations'])
65 |
66 | def test_find_violations(self):
67 | prot = _load_test_protein(
68 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb'
69 | )
70 | viols, _ = amber_minimize.find_violations(prot)
71 |
72 | expected_between_residues_connection_mask = np.zeros((191,), np.float32)
73 | for residue in (42, 43, 59, 60, 135, 136):
74 | expected_between_residues_connection_mask[residue] = 1.0
75 |
76 | expected_clash_indices = np.array([
77 | [8, 4],
78 | [8, 5],
79 | [13, 3],
80 | [14, 1],
81 | [14, 4],
82 | [26, 4],
83 | [26, 5],
84 | [31, 8],
85 | [31, 10],
86 | [39, 0],
87 | [39, 1],
88 | [39, 2],
89 | [39, 3],
90 | [39, 4],
91 | [42, 5],
92 | [42, 6],
93 | [42, 7],
94 | [42, 8],
95 | [47, 7],
96 | [47, 8],
97 | [47, 9],
98 | [47, 10],
99 | [64, 4],
100 | [85, 5],
101 | [102, 4],
102 | [102, 5],
103 | [109, 13],
104 | [111, 5],
105 | [118, 6],
106 | [118, 7],
107 | [118, 8],
108 | [124, 4],
109 | [124, 5],
110 | [131, 5],
111 | [139, 7],
112 | [147, 4],
113 | [152, 7]], dtype=np.int32)
114 | expected_between_residues_clash_mask = np.zeros([191, 14])
115 | expected_between_residues_clash_mask[expected_clash_indices[:, 0],
116 | expected_clash_indices[:, 1]] += 1
117 | expected_per_atom_violations = np.zeros([191, 14])
118 | np.testing.assert_array_equal(
119 | viols['between_residues']['connections_per_residue_violation_mask'],
120 | expected_between_residues_connection_mask)
121 | np.testing.assert_array_equal(
122 | viols['between_residues']['clashes_per_atom_clash_mask'],
123 | expected_between_residues_clash_mask)
124 | np.testing.assert_array_equal(
125 | viols['within_residues']['per_atom_violations'],
126 | expected_per_atom_violations)
127 |
128 |
129 | if __name__ == '__main__':
130 | absltest.main()
131 |
--------------------------------------------------------------------------------
/src/alphafold/relax/cleanup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations.
16 |
17 | fix_pdb uses a third-party tool. We also support fixing some additional edge
18 | cases like removing chains of length one (see clean_structure).
19 | """
20 | import io
21 |
22 | import pdbfixer
23 | from simtk.openmm import app
24 | from simtk.openmm.app import element
25 |
26 |
27 | def fix_pdb(pdbfile, alterations_info):
28 | """Apply pdbfixer to the contents of a PDB file; return a PDB string result.
29 |
30 | 1) Replaces nonstandard residues.
31 | 2) Removes heterogens (non protein residues) including water.
32 | 3) Adds missing residues and missing atoms within existing residues.
33 | 4) Adds hydrogens assuming pH=7.0.
34 | 5) KeepIds is currently true, so the fixer must keep the existing chain and
35 | residue identifiers. This will fail for some files in wider PDB that have
36 | invalid IDs.
37 |
38 | Args:
39 | pdbfile: Input PDB file handle.
40 | alterations_info: A dict that will store details of changes made.
41 |
42 | Returns:
43 | A PDB string representing the fixed structure.
44 | """
45 | fixer = pdbfixer.PDBFixer(pdbfile=pdbfile)
46 | fixer.findNonstandardResidues()
47 | alterations_info['nonstandard_residues'] = fixer.nonstandardResidues
48 | fixer.replaceNonstandardResidues()
49 | _remove_heterogens(fixer, alterations_info, keep_water=False)
50 | fixer.findMissingResidues()
51 | alterations_info['missing_residues'] = fixer.missingResidues
52 | fixer.findMissingAtoms()
53 | alterations_info['missing_heavy_atoms'] = fixer.missingAtoms
54 | alterations_info['missing_terminals'] = fixer.missingTerminals
55 | fixer.addMissingAtoms(seed=0)
56 | fixer.addMissingHydrogens()
57 | out_handle = io.StringIO()
58 | app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle,
59 | keepIds=True)
60 | return out_handle.getvalue()
61 |
62 |
63 | def clean_structure(pdb_structure, alterations_info):
64 | """Applies additional fixes to an OpenMM structure, to handle edge cases.
65 |
66 | Args:
67 | pdb_structure: An OpenMM structure to modify and fix.
68 | alterations_info: A dict that will store details of changes made.
69 | """
70 | _replace_met_se(pdb_structure, alterations_info)
71 | _remove_chains_of_length_one(pdb_structure, alterations_info)
72 |
73 |
74 | def _remove_heterogens(fixer, alterations_info, keep_water):
75 | """Removes the residues that Pdbfixer considers to be heterogens.
76 |
77 | Args:
78 | fixer: A Pdbfixer instance.
79 | alterations_info: A dict that will store details of changes made.
80 | keep_water: If True, water (HOH) is not considered to be a heterogen.
81 | """
82 | initial_resnames = set()
83 | for chain in fixer.topology.chains():
84 | for residue in chain.residues():
85 | initial_resnames.add(residue.name)
86 | fixer.removeHeterogens(keepWater=keep_water)
87 | final_resnames = set()
88 | for chain in fixer.topology.chains():
89 | for residue in chain.residues():
90 | final_resnames.add(residue.name)
91 | alterations_info['removed_heterogens'] = (
92 | initial_resnames.difference(final_resnames))
93 |
94 |
95 | def _replace_met_se(pdb_structure, alterations_info):
96 | """Replace the Se in any MET residues that were not marked as modified."""
97 | modified_met_residues = []
98 | for res in pdb_structure.iter_residues():
99 | name = res.get_name_with_spaces().strip()
100 | if name == 'MET':
101 | s_atom = res.get_atom('SD')
102 | if s_atom.element_symbol == 'Se':
103 | s_atom.element_symbol = 'S'
104 | s_atom.element = element.get_by_symbol('S')
105 | modified_met_residues.append(s_atom.residue_number)
106 | alterations_info['Se_in_MET'] = modified_met_residues
107 |
108 |
109 | def _remove_chains_of_length_one(pdb_structure, alterations_info):
110 | """Removes chains that correspond to a single amino acid.
111 |
112 | A single amino acid in a chain is both N and C terminus. There is no force
113 | template for this case.
114 |
115 | Args:
116 | pdb_structure: An OpenMM pdb_structure to modify and fix.
117 | alterations_info: A dict that will store details of changes made.
118 | """
119 | removed_chains = {}
120 | for model in pdb_structure.iter_models():
121 | valid_chains = [c for c in model.iter_chains() if len(c) > 1]
122 | invalid_chain_ids = [c.chain_id for c in model.iter_chains() if len(c) <= 1]
123 | model.chains = valid_chains
124 | for chain_id in invalid_chain_ids:
125 | model.chains_by_id.pop(chain_id)
126 | removed_chains[model.number] = invalid_chain_ids
127 | alterations_info['removed_chains'] = removed_chains
128 |
--------------------------------------------------------------------------------
/src/alphafold/relax/cleanup_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for relax.cleanup."""
16 | import io
17 |
18 | from absl.testing import absltest
19 | from alphafold.relax import cleanup
20 | from simtk.openmm.app.internal import pdbstructure
21 |
22 |
23 | def _pdb_to_structure(pdb_str):
24 | handle = io.StringIO(pdb_str)
25 | return pdbstructure.PdbStructure(handle)
26 |
27 |
28 | def _lines_to_structure(pdb_lines):
29 | return _pdb_to_structure('\n'.join(pdb_lines))
30 |
31 |
32 | class CleanupTest(absltest.TestCase):
33 |
34 | def test_missing_residues(self):
35 | pdb_lines = ['SEQRES 1 C 3 CYS GLY LEU',
36 | 'ATOM 1 N CYS C 1 -12.262 20.115 60.959 1.00 '
37 | '19.08 N',
38 | 'ATOM 2 CA CYS C 1 -11.065 20.934 60.773 1.00 '
39 | '17.23 C',
40 | 'ATOM 3 C CYS C 1 -10.002 20.742 61.844 1.00 '
41 | '15.38 C',
42 | 'ATOM 4 O CYS C 1 -10.284 20.225 62.929 1.00 '
43 | '16.04 O',
44 | 'ATOM 5 N LEU C 3 -7.688 18.700 62.045 1.00 '
45 | '14.75 N',
46 | 'ATOM 6 CA LEU C 3 -7.256 17.320 62.234 1.00 '
47 | '16.81 C',
48 | 'ATOM 7 C LEU C 3 -6.380 16.864 61.070 1.00 '
49 | '16.95 C',
50 | 'ATOM 8 O LEU C 3 -6.551 17.332 59.947 1.00 '
51 | '16.97 O']
52 | input_handle = io.StringIO('\n'.join(pdb_lines))
53 | alterations = {}
54 | result = cleanup.fix_pdb(input_handle, alterations)
55 | structure = _pdb_to_structure(result)
56 | residue_names = [r.get_name() for r in structure.iter_residues()]
57 | self.assertCountEqual(residue_names, ['CYS', 'GLY', 'LEU'])
58 | self.assertCountEqual(alterations['missing_residues'].values(), [['GLY']])
59 |
60 | def test_missing_atoms(self):
61 | pdb_lines = ['SEQRES 1 A 1 PRO',
62 | 'ATOM 1 CA PRO A 1 1.000 1.000 1.000 1.00 '
63 | ' 0.00 C']
64 | input_handle = io.StringIO('\n'.join(pdb_lines))
65 | alterations = {}
66 | result = cleanup.fix_pdb(input_handle, alterations)
67 | structure = _pdb_to_structure(result)
68 | atom_names = [a.get_name() for a in structure.iter_atoms()]
69 | self.assertCountEqual(atom_names, ['N', 'CD', 'HD2', 'HD3', 'CG', 'HG2',
70 | 'HG3', 'CB', 'HB2', 'HB3', 'CA', 'HA',
71 | 'C', 'O', 'H2', 'H3', 'OXT'])
72 | missing_atoms_by_residue = list(alterations['missing_heavy_atoms'].values())
73 | self.assertLen(missing_atoms_by_residue, 1)
74 | atoms_added = [a.name for a in missing_atoms_by_residue[0]]
75 | self.assertCountEqual(atoms_added, ['N', 'CD', 'CG', 'CB', 'C', 'O'])
76 | missing_terminals_by_residue = alterations['missing_terminals']
77 | self.assertLen(missing_terminals_by_residue, 1)
78 | has_missing_terminal = [r.name for r in missing_terminals_by_residue.keys()]
79 | self.assertCountEqual(has_missing_terminal, ['PRO'])
80 | self.assertCountEqual([t for t in missing_terminals_by_residue.values()],
81 | [['OXT']])
82 |
83 | def test_remove_heterogens(self):
84 | pdb_lines = ['SEQRES 1 A 1 GLY',
85 | 'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 '
86 | ' 0.00 C',
87 | 'ATOM 2 O HOH A 2 0.000 0.000 0.000 1.00 '
88 | ' 0.00 O']
89 | input_handle = io.StringIO('\n'.join(pdb_lines))
90 | alterations = {}
91 | result = cleanup.fix_pdb(input_handle, alterations)
92 | structure = _pdb_to_structure(result)
93 | self.assertCountEqual([res.get_name() for res in structure.iter_residues()],
94 | ['GLY'])
95 | self.assertEqual(alterations['removed_heterogens'], set(['HOH']))
96 |
97 | def test_fix_nonstandard_residues(self):
98 | pdb_lines = ['SEQRES 1 A 1 DAL',
99 | 'ATOM 1 CA DAL A 1 0.000 0.000 0.000 1.00 '
100 | ' 0.00 C']
101 | input_handle = io.StringIO('\n'.join(pdb_lines))
102 | alterations = {}
103 | result = cleanup.fix_pdb(input_handle, alterations)
104 | structure = _pdb_to_structure(result)
105 | residue_names = [res.get_name() for res in structure.iter_residues()]
106 | self.assertCountEqual(residue_names, ['ALA'])
107 | self.assertLen(alterations['nonstandard_residues'], 1)
108 | original_res, new_name = alterations['nonstandard_residues'][0]
109 | self.assertEqual(original_res.id, '1')
110 | self.assertEqual(new_name, 'ALA')
111 |
112 | def test_replace_met_se(self):
113 | pdb_lines = ['SEQRES 1 A 1 MET',
114 | 'ATOM 1 SD MET A 1 0.000 0.000 0.000 1.00 '
115 | ' 0.00 Se']
116 | structure = _lines_to_structure(pdb_lines)
117 | alterations = {}
118 | cleanup._replace_met_se(structure, alterations)
119 | sd = [a for a in structure.iter_atoms() if a.get_name() == 'SD']
120 | self.assertLen(sd, 1)
121 | self.assertEqual(sd[0].element_symbol, 'S')
122 | self.assertCountEqual(alterations['Se_in_MET'], [sd[0].residue_number])
123 |
124 | def test_remove_chains_of_length_one(self):
125 | pdb_lines = ['SEQRES 1 A 1 GLY',
126 | 'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 '
127 | ' 0.00 C']
128 | structure = _lines_to_structure(pdb_lines)
129 | alterations = {}
130 | cleanup._remove_chains_of_length_one(structure, alterations)
131 | chains = list(structure.iter_chains())
132 | self.assertEmpty(chains)
133 | self.assertCountEqual(alterations['removed_chains'].values(), [['A']])
134 |
135 |
136 | if __name__ == '__main__':
137 | absltest.main()
138 |
--------------------------------------------------------------------------------
/src/alphafold/relax/relax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Amber relaxation."""
16 | from typing import Any, Dict, Sequence, Tuple
17 | from alphafold.common import protein
18 | from alphafold.relax import amber_minimize
19 | from alphafold.relax import utils
20 | import numpy as np
21 |
22 |
23 | class AmberRelaxation(object):
24 | """Amber relaxation."""
25 |
26 | def __init__(self,
27 | *,
28 | max_iterations: int,
29 | tolerance: float,
30 | stiffness: float,
31 | exclude_residues: Sequence[int],
32 | max_outer_iterations: int):
33 | """Initialize Amber Relaxer.
34 |
35 | Args:
36 | max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
37 | tolerance: kcal/mol, the energy tolerance of L-BFGS.
38 | stiffness: kcal/mol A**2, spring constant of heavy atom restraining
39 | potential.
40 | exclude_residues: Residues to exclude from per-atom restraining.
41 | Zero-indexed.
42 | max_outer_iterations: Maximum number of violation-informed relax
43 | iterations. A value of 1 will run the non-iterative procedure used in
44 | CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
45 | as soon as there are no violations, hence in most cases this causes no
46 | slowdown. In the worst case we do 20 outer iterations.
47 | """
48 |
49 | self._max_iterations = max_iterations
50 | self._tolerance = tolerance
51 | self._stiffness = stiffness
52 | self._exclude_residues = exclude_residues
53 | self._max_outer_iterations = max_outer_iterations
54 |
55 | def process(self, *,
56 | prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]:
57 | """Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
58 | out = amber_minimize.run_pipeline(
59 | prot=prot, max_iterations=self._max_iterations,
60 | tolerance=self._tolerance, stiffness=self._stiffness,
61 | exclude_residues=self._exclude_residues,
62 | max_outer_iterations=self._max_outer_iterations)
63 | min_pos = out['pos']
64 | start_pos = out['posinit']
65 | rmsd = np.sqrt(np.sum((start_pos - min_pos)**2) / start_pos.shape[0])
66 | debug_data = {
67 | 'initial_energy': out['einit'],
68 | 'final_energy': out['efinal'],
69 | 'attempts': out['min_attempts'],
70 | 'rmsd': rmsd
71 | }
72 | pdb_str = amber_minimize.clean_protein(prot)
73 | min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos)
74 | min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
75 | utils.assert_equal_nonterminal_atom_types(
76 | protein.from_pdb_string(min_pdb).atom_mask,
77 | prot.atom_mask)
78 | violations = out['structural_violations'][
79 | 'total_per_residue_violations_mask']
80 | return min_pdb, debug_data, violations
81 |
--------------------------------------------------------------------------------
/src/alphafold/relax/relax_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for relax."""
16 | import os
17 |
18 | from absl.testing import absltest
19 | from alphafold.common import protein
20 | from alphafold.relax import relax
21 | import numpy as np
22 | # Internal import (7716).
23 |
24 |
25 | class RunAmberRelaxTest(absltest.TestCase):
26 |
27 | def setUp(self):
28 | super().setUp()
29 | self.test_dir = os.path.join(
30 | absltest.get_default_test_srcdir(),
31 | 'alphafold/relax/testdata/')
32 | self.test_config = {
33 | 'max_iterations': 1,
34 | 'tolerance': 2.39,
35 | 'stiffness': 10.0,
36 | 'exclude_residues': [],
37 | 'max_outer_iterations': 1}
38 |
39 | def test_process(self):
40 | amber_relax = relax.AmberRelaxation(**self.test_config)
41 |
42 | with open(os.path.join(self.test_dir, 'model_output.pdb')) as f:
43 | test_prot = protein.from_pdb_string(f.read())
44 | pdb_min, debug_info, num_violations = amber_relax.process(prot=test_prot)
45 |
46 | self.assertCountEqual(debug_info.keys(),
47 | set({'initial_energy', 'final_energy',
48 | 'attempts', 'rmsd'}))
49 | self.assertLess(debug_info['final_energy'], debug_info['initial_energy'])
50 | self.assertGreater(debug_info['rmsd'], 0)
51 |
52 | prot_min = protein.from_pdb_string(pdb_min)
53 | # Most protein properties should be unchanged.
54 | np.testing.assert_almost_equal(test_prot.aatype, prot_min.aatype)
55 | np.testing.assert_almost_equal(test_prot.residue_index,
56 | prot_min.residue_index)
57 | # Atom mask and bfactors identical except for terminal OXT of last residue.
58 | np.testing.assert_almost_equal(test_prot.atom_mask[:-1, :],
59 | prot_min.atom_mask[:-1, :])
60 | np.testing.assert_almost_equal(test_prot.b_factors[:-1, :],
61 | prot_min.b_factors[:-1, :])
62 | np.testing.assert_almost_equal(test_prot.atom_mask[:, :-1],
63 | prot_min.atom_mask[:, :-1])
64 | np.testing.assert_almost_equal(test_prot.b_factors[:, :-1],
65 | prot_min.b_factors[:, :-1])
66 | # There are no residues with violations.
67 | np.testing.assert_equal(num_violations, np.zeros_like(num_violations))
68 |
69 | def test_unresolved_violations(self):
70 | amber_relax = relax.AmberRelaxation(**self.test_config)
71 | with open(os.path.join(self.test_dir,
72 | 'with_violations_casp14.pdb')) as f:
73 | test_prot = protein.from_pdb_string(f.read())
74 | _, _, num_violations = amber_relax.process(prot=test_prot)
75 | exp_num_violations = np.array(
76 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
77 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,
78 | 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
79 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
80 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
81 | 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
82 | 0, 0, 0, 0])
83 | # Check no violations were added. Can't check exactly due to stochasticity.
84 | self.assertTrue(np.all(num_violations <= exp_num_violations))
85 |
86 |
87 | if __name__ == '__main__':
88 | absltest.main()
89 |
--------------------------------------------------------------------------------
/src/alphafold/relax/testdata/model_output.pdb:
--------------------------------------------------------------------------------
1 | ATOM 1 C MET A 1 1.921 -46.152 7.786 1.00 4.39 C
2 | ATOM 2 CA MET A 1 1.631 -46.829 9.131 1.00 4.39 C
3 | ATOM 3 CB MET A 1 2.759 -47.768 9.578 1.00 4.39 C
4 | ATOM 4 CE MET A 1 3.466 -49.770 13.198 1.00 4.39 C
5 | ATOM 5 CG MET A 1 2.581 -48.221 11.034 1.00 4.39 C
6 | ATOM 6 H MET A 1 0.234 -48.249 8.549 1.00 4.39 H
7 | ATOM 7 H2 MET A 1 -0.424 -46.789 8.952 1.00 4.39 H
8 | ATOM 8 H3 MET A 1 0.111 -47.796 10.118 1.00 4.39 H
9 | ATOM 9 HA MET A 1 1.628 -46.009 9.849 1.00 4.39 H
10 | ATOM 10 HB2 MET A 1 3.701 -47.225 9.500 1.00 4.39 H
11 | ATOM 11 HB3 MET A 1 2.807 -48.640 8.926 1.00 4.39 H
12 | ATOM 12 HE1 MET A 1 2.747 -50.537 12.910 1.00 4.39 H
13 | ATOM 13 HE2 MET A 1 4.296 -50.241 13.725 1.00 4.39 H
14 | ATOM 14 HE3 MET A 1 2.988 -49.052 13.864 1.00 4.39 H
15 | ATOM 15 HG2 MET A 1 1.791 -48.971 11.083 1.00 4.39 H
16 | ATOM 16 HG3 MET A 1 2.295 -47.368 11.650 1.00 4.39 H
17 | ATOM 17 N MET A 1 0.291 -47.464 9.182 1.00 4.39 N
18 | ATOM 18 O MET A 1 2.091 -44.945 7.799 1.00 4.39 O
19 | ATOM 19 SD MET A 1 4.096 -48.921 11.725 1.00 4.39 S
20 | ATOM 20 C LYS A 2 1.366 -45.033 4.898 1.00 2.92 C
21 | ATOM 21 CA LYS A 2 2.235 -46.242 5.308 1.00 2.92 C
22 | ATOM 22 CB LYS A 2 2.206 -47.314 4.196 1.00 2.92 C
23 | ATOM 23 CD LYS A 2 3.331 -49.342 3.134 1.00 2.92 C
24 | ATOM 24 CE LYS A 2 4.434 -50.403 3.293 1.00 2.92 C
25 | ATOM 25 CG LYS A 2 3.294 -48.395 4.349 1.00 2.92 C
26 | ATOM 26 H LYS A 2 1.832 -47.853 6.656 1.00 2.92 H
27 | ATOM 27 HA LYS A 2 3.248 -45.841 5.355 1.00 2.92 H
28 | ATOM 28 HB2 LYS A 2 1.223 -47.785 4.167 1.00 2.92 H
29 | ATOM 29 HB3 LYS A 2 2.363 -46.812 3.241 1.00 2.92 H
30 | ATOM 30 HD2 LYS A 2 3.524 -48.754 2.237 1.00 2.92 H
31 | ATOM 31 HD3 LYS A 2 2.364 -49.833 3.031 1.00 2.92 H
32 | ATOM 32 HE2 LYS A 2 5.383 -49.891 3.455 1.00 2.92 H
33 | ATOM 33 HE3 LYS A 2 4.225 -51.000 4.180 1.00 2.92 H
34 | ATOM 34 HG2 LYS A 2 3.102 -48.977 5.250 1.00 2.92 H
35 | ATOM 35 HG3 LYS A 2 4.264 -47.909 4.446 1.00 2.92 H
36 | ATOM 36 HZ1 LYS A 2 4.763 -50.747 1.274 1.00 2.92 H
37 | ATOM 37 HZ2 LYS A 2 3.681 -51.785 1.931 1.00 2.92 H
38 | ATOM 38 HZ3 LYS A 2 5.280 -51.965 2.224 1.00 2.92 H
39 | ATOM 39 N LYS A 2 1.907 -46.846 6.629 1.00 2.92 N
40 | ATOM 40 NZ LYS A 2 4.542 -51.286 2.100 1.00 2.92 N
41 | ATOM 41 O LYS A 2 1.882 -44.093 4.312 1.00 2.92 O
42 | ATOM 42 C PHE A 3 -0.511 -42.597 5.624 1.00 4.39 C
43 | ATOM 43 CA PHE A 3 -0.853 -43.933 4.929 1.00 4.39 C
44 | ATOM 44 CB PHE A 3 -2.271 -44.408 5.285 1.00 4.39 C
45 | ATOM 45 CD1 PHE A 3 -3.760 -43.542 3.432 1.00 4.39 C
46 | ATOM 46 CD2 PHE A 3 -4.050 -42.638 5.675 1.00 4.39 C
47 | ATOM 47 CE1 PHE A 3 -4.797 -42.715 2.965 1.00 4.39 C
48 | ATOM 48 CE2 PHE A 3 -5.091 -41.818 5.207 1.00 4.39 C
49 | ATOM 49 CG PHE A 3 -3.382 -43.505 4.788 1.00 4.39 C
50 | ATOM 50 CZ PHE A 3 -5.463 -41.853 3.853 1.00 4.39 C
51 | ATOM 51 H PHE A 3 -0.311 -45.868 5.655 1.00 4.39 H
52 | ATOM 52 HA PHE A 3 -0.817 -43.746 3.856 1.00 4.39 H
53 | ATOM 53 HB2 PHE A 3 -2.353 -44.512 6.367 1.00 4.39 H
54 | ATOM 54 HB3 PHE A 3 -2.432 -45.393 4.848 1.00 4.39 H
55 | ATOM 55 HD1 PHE A 3 -3.255 -44.198 2.739 1.00 4.39 H
56 | ATOM 56 HD2 PHE A 3 -3.768 -42.590 6.716 1.00 4.39 H
57 | ATOM 57 HE1 PHE A 3 -5.083 -42.735 1.923 1.00 4.39 H
58 | ATOM 58 HE2 PHE A 3 -5.604 -41.151 5.885 1.00 4.39 H
59 | ATOM 59 HZ PHE A 3 -6.257 -41.215 3.493 1.00 4.39 H
60 | ATOM 60 N PHE A 3 0.079 -45.027 5.253 1.00 4.39 N
61 | ATOM 61 O PHE A 3 -0.633 -41.541 5.014 1.00 4.39 O
62 | ATOM 62 C LEU A 4 1.598 -40.732 7.042 1.00 4.39 C
63 | ATOM 63 CA LEU A 4 0.367 -41.437 7.633 1.00 4.39 C
64 | ATOM 64 CB LEU A 4 0.628 -41.823 9.104 1.00 4.39 C
65 | ATOM 65 CD1 LEU A 4 -0.319 -42.778 11.228 1.00 4.39 C
66 | ATOM 66 CD2 LEU A 4 -1.300 -40.694 10.309 1.00 4.39 C
67 | ATOM 67 CG LEU A 4 -0.650 -42.027 9.937 1.00 4.39 C
68 | ATOM 68 H LEU A 4 0.163 -43.538 7.292 1.00 4.39 H
69 | ATOM 69 HA LEU A 4 -0.445 -40.712 7.588 1.00 4.39 H
70 | ATOM 70 HB2 LEU A 4 1.213 -41.034 9.576 1.00 4.39 H
71 | ATOM 71 HB3 LEU A 4 1.235 -42.728 9.127 1.00 4.39 H
72 | ATOM 72 HD11 LEU A 4 0.380 -42.191 11.824 1.00 4.39 H
73 | ATOM 73 HD12 LEU A 4 0.127 -43.747 11.002 1.00 4.39 H
74 | ATOM 74 HD13 LEU A 4 -1.230 -42.927 11.808 1.00 4.39 H
75 | ATOM 75 HD21 LEU A 4 -0.606 -40.080 10.883 1.00 4.39 H
76 | ATOM 76 HD22 LEU A 4 -2.193 -40.869 10.909 1.00 4.39 H
77 | ATOM 77 HD23 LEU A 4 -1.593 -40.147 9.413 1.00 4.39 H
78 | ATOM 78 HG LEU A 4 -1.359 -42.630 9.370 1.00 4.39 H
79 | ATOM 79 N LEU A 4 -0.012 -42.638 6.869 1.00 4.39 N
80 | ATOM 80 O LEU A 4 1.655 -39.508 7.028 1.00 4.39 O
81 | ATOM 81 C VAL A 5 3.372 -40.190 4.573 1.00 4.39 C
82 | ATOM 82 CA VAL A 5 3.752 -40.956 5.845 1.00 4.39 C
83 | ATOM 83 CB VAL A 5 4.757 -42.083 5.528 1.00 4.39 C
84 | ATOM 84 CG1 VAL A 5 6.019 -41.568 4.827 1.00 4.39 C
85 | ATOM 85 CG2 VAL A 5 5.199 -42.807 6.810 1.00 4.39 C
86 | ATOM 86 H VAL A 5 2.440 -42.503 6.548 1.00 4.39 H
87 | ATOM 87 HA VAL A 5 4.234 -40.242 6.512 1.00 4.39 H
88 | ATOM 88 HB VAL A 5 4.279 -42.813 4.875 1.00 4.39 H
89 | ATOM 89 HG11 VAL A 5 6.494 -40.795 5.431 1.00 4.39 H
90 | ATOM 90 HG12 VAL A 5 5.770 -41.145 3.853 1.00 4.39 H
91 | ATOM 91 HG13 VAL A 5 6.725 -42.383 4.670 1.00 4.39 H
92 | ATOM 92 HG21 VAL A 5 4.347 -43.283 7.297 1.00 4.39 H
93 | ATOM 93 HG22 VAL A 5 5.933 -43.575 6.568 1.00 4.39 H
94 | ATOM 94 HG23 VAL A 5 5.651 -42.093 7.498 1.00 4.39 H
95 | ATOM 95 N VAL A 5 2.554 -41.501 6.509 1.00 4.39 N
96 | ATOM 96 O VAL A 5 3.937 -39.138 4.297 1.00 4.39 O
97 | TER 96 VAL A 5
98 | END
99 |
--------------------------------------------------------------------------------
/src/alphafold/relax/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utils for minimization."""
16 | import io
17 | from alphafold.common import residue_constants
18 | from Bio import PDB
19 | import numpy as np
20 | from simtk.openmm import app as openmm_app
21 | from simtk.openmm.app.internal.pdbstructure import PdbStructure
22 |
23 |
24 | def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
25 | pdb_file = io.StringIO(pdb_str)
26 | structure = PdbStructure(pdb_file)
27 | topology = openmm_app.PDBFile(structure).getTopology()
28 | with io.StringIO() as f:
29 | openmm_app.PDBFile.writeFile(topology, pos, f)
30 | return f.getvalue()
31 |
32 |
33 | def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
34 | """Overwrites the B-factors in pdb_str with contents of bfactors array.
35 |
36 | Args:
37 | pdb_str: An input PDB string.
38 | bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
39 | B-factors are per residue; i.e. that the nonzero entries are identical in
40 | [0, i, :].
41 |
42 | Returns:
43 | A new PDB string with the B-factors replaced.
44 | """
45 | if bfactors.shape[-1] != residue_constants.atom_type_num:
46 | raise ValueError(
47 | f'Invalid final dimension size for bfactors: {bfactors.shape[-1]}.')
48 |
49 | parser = PDB.PDBParser(QUIET=True)
50 | handle = io.StringIO(pdb_str)
51 | structure = parser.get_structure('', handle)
52 |
53 | curr_resid = ('', '', '')
54 | idx = -1
55 | for atom in structure.get_atoms():
56 | atom_resid = atom.parent.get_id()
57 | if atom_resid != curr_resid:
58 | idx += 1
59 | if idx >= bfactors.shape[0]:
60 | raise ValueError('Index into bfactors exceeds number of residues. '
61 | 'B-factors shape: {shape}, idx: {idx}.')
62 | curr_resid = atom_resid
63 | atom.bfactor = bfactors[idx, residue_constants.atom_order['CA']]
64 |
65 | new_pdb = io.StringIO()
66 | pdb_io = PDB.PDBIO()
67 | pdb_io.set_structure(structure)
68 | pdb_io.save(new_pdb)
69 | return new_pdb.getvalue()
70 |
71 |
72 | def assert_equal_nonterminal_atom_types(
73 | atom_mask: np.ndarray, ref_atom_mask: np.ndarray):
74 | """Checks that pre- and post-minimized proteins have same atom set."""
75 | # Ignore any terminal OXT atoms which may have been added by minimization.
76 | oxt = residue_constants.atom_order['OXT']
77 | no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
78 | no_oxt_mask[..., oxt] = False
79 | np.testing.assert_almost_equal(ref_atom_mask[no_oxt_mask],
80 | atom_mask[no_oxt_mask])
81 |
--------------------------------------------------------------------------------
/src/alphafold/relax/utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for utils."""
16 |
17 | import os
18 |
19 | from absl.testing import absltest
20 | from alphafold.common import protein
21 | from alphafold.relax import utils
22 | import numpy as np
23 | # Internal import (7716).
24 |
25 |
26 | class UtilsTest(absltest.TestCase):
27 |
28 | def test_overwrite_b_factors(self):
29 | testdir = os.path.join(
30 | absltest.get_default_test_srcdir(),
31 | 'alphafold/relax/testdata/'
32 | 'multiple_disulfides_target.pdb')
33 | with open(testdir) as f:
34 | test_pdb = f.read()
35 | n_residues = 191
36 | bfactors = np.stack([np.arange(0, n_residues)] * 37, axis=-1)
37 |
38 | output_pdb = utils.overwrite_b_factors(test_pdb, bfactors)
39 |
40 | # Check that the atom lines are unchanged apart from the B-factors.
41 | atom_lines_original = [l for l in test_pdb.split('\n') if l[:4] == ('ATOM')]
42 | atom_lines_new = [l for l in output_pdb.split('\n') if l[:4] == ('ATOM')]
43 | for line_original, line_new in zip(atom_lines_original, atom_lines_new):
44 | self.assertEqual(line_original[:60].strip(), line_new[:60].strip())
45 | self.assertEqual(line_original[66:].strip(), line_new[66:].strip())
46 |
47 | # Check B-factors are correctly set for all atoms present.
48 | as_protein = protein.from_pdb_string(output_pdb)
49 | np.testing.assert_almost_equal(
50 | np.where(as_protein.atom_mask > 0, as_protein.b_factors, 0),
51 | np.where(as_protein.atom_mask > 0, bfactors, 0))
52 |
53 |
54 | if __name__ == '__main__':
55 | absltest.main()
56 |
--------------------------------------------------------------------------------
/src/animate_py3dmol.py:
--------------------------------------------------------------------------------
1 | import py3Dmol
2 |
3 | import sys
4 | import os
5 | import pdb
6 |
7 | import pandas as pd
8 | import numpy as np
9 | import argparse
10 |
11 |
12 | parser = argparse.ArgumentParser(description = '''Visualise sampled conformations''')
13 | parser.add_argument('--pdbdir', nargs=1, type= str, required=True, help = "Path to directory with PDB files")
14 | parser.add_argument('--order_df', nargs=1, type= str, required=True, help = "Path to csv with order of PDB files")
15 | parser.add_argument('--outdir', nargs=1, type= str, help = 'Outdir.')
16 |
17 |
18 | #################FUNCTIONS#################
19 |
20 | class Atom(dict):
21 | def __init__(self, line):
22 | self["type"] = line[0:6].strip()
23 | self["idx"] = line[6:11].strip()
24 | self["name"] = line[12:16].strip()
25 | self["resname"] = line[17:20].strip()
26 | self["resid"] = int(int(line[22:26]))
27 | self["x"] = float(line[30:38])
28 | self["y"] = float(line[38:46])
29 | self["z"] = float(line[46:54])
30 | self["sym"] = line[76:78].strip()
31 |
32 | def __str__(self):
33 | line = list(" " * 80)
34 |
35 | line[0:6] = self["type"].ljust(6)
36 | line[6:11] = self["idx"].ljust(5)
37 | line[12:16] = self["name"].ljust(4)
38 | line[17:20] = self["resname"].ljust(3)
39 | line[22:26] = str(self["resid"]).ljust(4)
40 | line[30:38] = str(self["x"]).rjust(8)
41 | line[38:46] = str(self["y"]).rjust(8)
42 | line[46:54] = str(self["z"]).rjust(8)
43 | line[76:78] = self["sym"].rjust(2)
44 | return "".join(line) + "\n"
45 |
46 | class Molecule(list):
47 | def __init__(self, file):
48 | for line in file:
49 | if "ATOM" in line or "HETATM" in line:
50 | self.append(Atom(line))
51 |
52 | def __str__(self):
53 | outstr = ""
54 | for at in self:
55 | outstr += str(at)
56 |
57 | return outstr
58 |
59 |
60 |
61 |
62 |
63 | #################MAIN####################
64 |
65 | #Parse args
66 | args = parser.parse_args()
67 | #Get data
68 | pdbdir = args.pdbdir[0]
69 | order_df = pd.read_csv(args.order_df[0]).loc[:10]
70 | outdir = args.outdir[0]
71 |
72 | molecules = []
73 | for ind,row in order_df.iterrows():
74 | with open(pdbdir+row['name']) as ifile:
75 | molecules.append(Molecule(ifile))
76 |
77 |
78 | view = py3Dmol.view(width=400, height=300)
79 |
80 | models = ""
81 | for i, mol in enumerate(molecules):
82 | models += "MODEL " + str(i) + "\n"
83 | models += str(mol)
84 | models += "ENDMDL\n"
85 | view.addModelsAsFrames(models)
86 |
87 | for i, at in enumerate(molecules[0]):
88 | default = {"cartoon": {'color': 'grey'}}
89 | view.setStyle({'model': -1, 'serial': i+1}, at.get("pymol", default))
90 |
91 | view.zoomTo()
92 | view.animate({'loop': 'backAndForth'})
93 | view.show()
94 |
--------------------------------------------------------------------------------
/src/check_msa_colab.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import re
3 |
4 | ##############FUNCTIONS##############
5 |
6 | def read_a3m(a3mfile):
7 | '''Read an a3m file
8 | '''
9 | parsed_a3m = {}
10 | #Regex pattern
11 | pattern = '[A-Z]+'
12 | nhits = 0
13 | with open(a3mfile, 'r') as file:
14 | for line in file:
15 | line = line.rstrip()
16 | if line[0]=='>':
17 | nhits+=1
18 | else:
19 | if line[0]=='#':
20 | continue
21 | else:
22 | parsed_a3m[nhits] = ''.join(re.findall('[A-Z,-]+', line))
23 |
24 | return parsed_a3m
25 |
26 |
27 | def process_a3m(a3mfile, sequence, outname):
28 | '''Process a3m file - remove insertions and get only matches and gaps
29 | Write these to a new file
30 | '''
31 |
32 | parsed_a3m = read_a3m(a3mfile)
33 | if len([*parsed_a3m.values()][0])!=len(sequence):
34 | print('The sequence length does not match with the MSA.')
35 | sys.exit()
36 | with open(outname, 'w') as file:
37 | for key in parsed_a3m:
38 | file.write('>'+str(key)+'\n')
39 | file.write(parsed_a3m[key]+'\n')
40 |
--------------------------------------------------------------------------------
/src/make_msa_seq_feats.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | from typing import Mapping, Optional, Sequence
18 | from alphafold.common import residue_constants
19 | from alphafold.data import parsers
20 | import numpy as np
21 | import argparse
22 | import pickle
23 | import sys
24 | import pdb
25 | # Internal import (7716).
26 |
27 | parser = argparse.ArgumentParser(description = """Builds the input features for training the structure prediction model.""")
28 |
29 | parser.add_argument('--input_fasta_path', nargs=1, type= str, default=sys.stdin, help = 'Path to fasta.')
30 | parser.add_argument('--input_msas', nargs=1, type= str, default=sys.stdin, help = 'Path to MSAs. Separated by comma.')
31 | parser.add_argument('--outdir', nargs=1, type= str, default=sys.stdin, help = 'Path to output directory. Include /in end')
32 |
33 | FeatureDict = Mapping[str, np.ndarray]
34 |
35 | ##############FUNCTIONS##############
36 | def make_sequence_features(
37 | sequence: str, description: str, num_res: int) -> FeatureDict:
38 | """Constructs a feature dict of sequence features."""
39 | features = {}
40 | features['aatype'] = residue_constants.sequence_to_onehot(
41 | sequence=sequence,
42 | mapping=residue_constants.restype_order_with_x,
43 | map_unknown_to_x=True)
44 | features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32)
45 | features['domain_name'] = np.array([description.encode('utf-8')],
46 | dtype=np.object_)
47 | features['residue_index'] = np.array(range(num_res), dtype=np.int32)
48 | features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
49 | features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_)
50 | return features
51 |
52 |
53 | def make_msa_features(
54 | msas: Sequence[Sequence[str]],
55 | deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict:
56 | """Constructs a feature dict of MSA features."""
57 | if not msas:
58 | raise ValueError('At least one MSA must be provided.')
59 |
60 | int_msa = []
61 | deletion_matrix = []
62 | seen_sequences = set()
63 | for msa_index, msa in enumerate(msas):
64 | if not msa:
65 | raise ValueError(f'MSA {msa_index} must contain at least one sequence.')
66 | for sequence_index, sequence in enumerate(msa):
67 | if sequence in seen_sequences:
68 | continue
69 | seen_sequences.add(sequence)
70 | int_msa.append(
71 | [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
72 | deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
73 |
74 | num_res = len(msas[0][0])
75 | num_alignments = len(int_msa)
76 | features = {}
77 | features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
78 | features['msa'] = np.array(int_msa, dtype=np.int32)
79 | features['num_alignments'] = np.array(
80 | [num_alignments] * num_res, dtype=np.int32)
81 | return features
82 |
83 |
84 |
85 | def process(input_fasta_path: str, input_msas: list) -> FeatureDict:
86 | """Runs alignment tools on the input sequence and creates features."""
87 | with open(input_fasta_path) as f:
88 | input_fasta_str = f.read()
89 | input_seqs, input_desc = parsers.parse_fasta(input_fasta_str)
90 | if len(input_seqs) != 1:
91 | raise ValueError(
92 | f'More than one input sequence found in {input_fasta_path}.')
93 | input_sequence = input_seqs[0]
94 | input_description = input_desc[0]
95 | num_res = len(input_sequence)
96 |
97 | parsed_msas = []
98 | parsed_delmat = []
99 | for custom_msa in input_msas:
100 | msa = ''.join([line for line in open(custom_msa)])
101 | if custom_msa[-3:] == 'sto':
102 | parsed_msa, parsed_deletion_matrix, _ = parsers.parse_stockholm(msa)
103 | elif custom_msa[-3:] == 'a3m':
104 | parsed_msa, parsed_deletion_matrix = parsers.parse_a3m(msa)
105 | else: raise TypeError('Unknown format for input MSA, please make sure '
106 | 'the MSA files you provide terminates with (and '
107 | 'are formatted as) .sto or .a3m')
108 | parsed_msas.append(parsed_msa)
109 | parsed_delmat.append(parsed_deletion_matrix)
110 |
111 | sequence_features = make_sequence_features(
112 | sequence=input_sequence,
113 | description=input_description,
114 | num_res=num_res)
115 |
116 | msa_features = make_msa_features(
117 | msas=parsed_msas, deletion_matrices=parsed_delmat)
118 |
119 | return {**sequence_features, **msa_features}
120 |
121 |
122 | ##################MAIN#######################
123 |
124 | #Parse args
125 | args = parser.parse_args()
126 | #Data
127 | input_fasta_path = args.input_fasta_path[0]
128 | input_msas = args.input_msas[0].split(',')
129 | outdir = args.outdir[0]
130 | #Get feats
131 | feature_dict = process(input_fasta_path, input_msas)
132 |
133 | #Write out features as a pickled dictionary.
134 | features_output_path = os.path.join(outdir, 'msa_features.pkl')
135 | with open(features_output_path, 'wb') as f:
136 | pickle.dump(feature_dict, f, protocol=4)
137 | print('Saved features to',features_output_path)
138 |
--------------------------------------------------------------------------------
/src/make_msa_seq_feats_colab.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | from typing import Mapping, Optional, Sequence
18 | from alphafold.common import residue_constants
19 | from alphafold.data import parsers
20 | import numpy as np
21 | import argparse
22 | import pickle
23 | import sys
24 | import pdb
25 | # Internal import (7716).
26 |
27 |
28 | FeatureDict = Mapping[str, np.ndarray]
29 |
30 | ##############FUNCTIONS##############
31 | def make_sequence_features(
32 | sequence: str, description: str, num_res: int) -> FeatureDict:
33 | """Constructs a feature dict of sequence features."""
34 | features = {}
35 | features['aatype'] = residue_constants.sequence_to_onehot(
36 | sequence=sequence,
37 | mapping=residue_constants.restype_order_with_x,
38 | map_unknown_to_x=True)
39 | features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32)
40 | features['domain_name'] = np.array([description.encode('utf-8')],
41 | dtype=np.object_)
42 | features['residue_index'] = np.array(range(num_res), dtype=np.int32)
43 | features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
44 | features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_)
45 | return features
46 |
47 |
48 | def make_msa_features(
49 | msas: Sequence[Sequence[str]],
50 | deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict:
51 | """Constructs a feature dict of MSA features."""
52 | if not msas:
53 | raise ValueError('At least one MSA must be provided.')
54 |
55 | int_msa = []
56 | deletion_matrix = []
57 | seen_sequences = set()
58 | for msa_index, msa in enumerate(msas):
59 | if not msa:
60 | raise ValueError(f'MSA {msa_index} must contain at least one sequence.')
61 | for sequence_index, sequence in enumerate(msa):
62 | if sequence in seen_sequences:
63 | continue
64 | seen_sequences.add(sequence)
65 | int_msa.append(
66 | [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
67 | deletion_matrix.append(deletion_matrices[msa_index][sequence_index])
68 |
69 | num_res = len(msas[0][0])
70 | num_alignments = len(int_msa)
71 | features = {}
72 | features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
73 | features['msa'] = np.array(int_msa, dtype=np.int32)
74 | features['num_alignments'] = np.array(
75 | [num_alignments] * num_res, dtype=np.int32)
76 | return features
77 |
78 |
79 |
80 | def process(input_fasta_path: str, input_msas: list) -> FeatureDict:
81 | """Runs alignment tools on the input sequence and creates features."""
82 | with open(input_fasta_path) as f:
83 | input_fasta_str = f.read()
84 | input_seqs, input_desc = parsers.parse_fasta(input_fasta_str)
85 | if len(input_seqs) != 1:
86 | raise ValueError(
87 | f'More than one input sequence found in {input_fasta_path}.')
88 | input_sequence = input_seqs[0]
89 | input_description = input_desc[0]
90 | num_res = len(input_sequence)
91 |
92 | parsed_msas = []
93 | parsed_delmat = []
94 | for custom_msa in input_msas:
95 | msa = ''.join([line for line in open(custom_msa)])
96 | if custom_msa[-3:] == 'sto':
97 | parsed_msa, parsed_deletion_matrix, _ = parsers.parse_stockholm(msa)
98 | elif custom_msa[-3:] == 'a3m':
99 | parsed_msa, parsed_deletion_matrix = parsers.parse_a3m(msa)
100 | else: raise TypeError('Unknown format for input MSA, please make sure '
101 | 'the MSA files you provide terminates with (and '
102 | 'are formatted as) .sto or .a3m')
103 | parsed_msas.append(parsed_msa)
104 | parsed_delmat.append(parsed_deletion_matrix)
105 |
106 | sequence_features = make_sequence_features(
107 | sequence=input_sequence,
108 | description=input_description,
109 | num_res=num_res)
110 |
111 | msa_features = make_msa_features(
112 | msas=parsed_msas, deletion_matrices=parsed_delmat)
113 |
114 | return {**sequence_features, **msa_features}
115 |
116 |
117 | ##################MAIN#######################
118 |
119 |
120 | #Get feats
121 | #feature_dict = process(input_fasta_path, input_msas)
122 |
123 | #Write out features as a pickled dictionary.
124 | #features_output_path = os.path.join(outdir, 'msa_features.pkl')
125 | #with open(features_output_path, 'wb') as f:
126 | # pickle.dump(feature_dict, f, protocol=4)
127 | #print('Saved features to',features_output_path)
128 |
--------------------------------------------------------------------------------
/src/predict_with_clusters.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import warnings
4 | import pathlib
5 | import pickle
6 | import random
7 | import sys
8 | import time
9 | from typing import Dict, Optional
10 | from typing import NamedTuple
11 | import haiku as hk
12 | import jax
13 | import jax.numpy as jnp
14 |
15 | #Silence tf
16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
17 | import tensorflow.compat.v1 as tf
18 | tf.config.set_visible_devices([], 'GPU')
19 |
20 | import argparse
21 | import pandas as pd
22 | import numpy as np
23 | from collections import Counter
24 | from scipy.special import softmax
25 | import pdb
26 |
27 | #AlphaFold imports
28 | from alphafold.common import protein
29 | from alphafold.common import residue_constants
30 | from alphafold.data import templates
31 | from alphafold.model import data
32 | from alphafold.model import config
33 | from alphafold.model import features
34 | from alphafold.model import modules
35 |
36 | #JAX will preallocate 90% of currently-available GPU memory when the first JAX operation is run.
37 | #This prevents this
38 | #os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
39 |
40 |
41 | parser = argparse.ArgumentParser(description = """Predict using trained weights and a varying amount of sampled MSA clusters for more varied predictions.""")
42 |
43 | parser.add_argument('--feature_dir', nargs=1, type= str, default=sys.stdin, help = 'Path to location of features.')
44 | parser.add_argument('--predict_id', nargs=1, type= str, default=sys.stdin, help = 'Id to predict.')
45 | parser.add_argument('--ckpt_params', nargs=1, type= str, default=sys.stdin, help = 'Params to start from.')
46 | parser.add_argument('--num_recycles', nargs=1, type= int, default=sys.stdin, help = 'Number of recycles to use.')
47 | parser.add_argument('--num_samples_per_cluster', nargs=1, type= int, default=sys.stdin, help = 'Number of samples to use per cluster selection.')
48 | parser.add_argument('--outdir', nargs=1, type= str, default=sys.stdin, help = 'Path to output directory. Include /in end')
49 |
50 | ##############FUNCTIONS##############
51 | ##########INPUT DATA#########
52 |
53 | def process_features(raw_features, config, random_seed):
54 | """Processes features to prepare for feeding them into the model.
55 |
56 | Args:
57 | raw_features: The output of the data pipeline either as a dict of NumPy
58 | arrays or as a tf.train.Example.
59 | random_seed: The random seed to use when processing the features.
60 |
61 | Returns:
62 | A dict of NumPy feature arrays suitable for feeding into the model.
63 | """
64 | return features.np_example_to_features(np_example=raw_features,
65 | config=config,
66 | random_seed=random_seed)
67 |
68 |
69 | def load_input_feats(id, feature_dir, config, num_clusters):
70 | """
71 | Load all input feats.
72 | """
73 |
74 | #Load raw features
75 | msa_feature_dict = np.load(feature_dir+'/msa_features.pkl', allow_pickle=True)
76 | #Process the features on CPU (sample MSA)
77 | #Set the config to determine the number of clusters
78 | config.data.eval.max_msa_clusters = num_clusters
79 | #processed_feature_dict['msa_feat'].shape = num_clusts, L, 49
80 | processed_feature_dict = process_features(msa_feature_dict, config, np.random.choice(sys.maxsize))
81 |
82 | return processed_feature_dict
83 |
84 |
85 |
86 | ##########MODEL#########
87 |
88 | def predict(config,
89 | feature_dir,
90 | predict_ids,
91 | num_recycles,
92 | num_samples_per_cluster,
93 | ckpt_params=None,
94 | outdir=None):
95 | """Predict a structure
96 | """
97 |
98 | #Does the config have to be updated here?
99 | #No - the clusters can be changed.
100 | #Define the forward function
101 | def _forward_fn(batch):
102 | '''Define the forward function - has to be a function for JAX
103 | '''
104 | model = modules.AlphaFold(config.model)
105 |
106 | return model(batch,
107 | is_training=False,
108 | compute_loss=False,
109 | ensemble_representations=False,
110 | return_representations=True)
111 |
112 | #The forward function is here transformed to apply and init functions which
113 | #can be called during training and initialisation (JAX needs functions)
114 | forward = hk.transform(_forward_fn)
115 | apply_fwd = forward.apply
116 | #Get a random key
117 | rng = jax.random.PRNGKey(42)
118 |
119 | for id in predict_ids:
120 | for num_clusts in [16, 32, 64, 128, 256, 512, 1024, 5120]:
121 | for i in range(num_samples_per_cluster):
122 | if os.path.exists(outdir+'/'+id+'_'+str(num_clusts)+'_'+str(i)+'_pred.pdb'):
123 | print('Prediction',num_clusts, i+1, 'exists...')
124 | continue
125 |
126 | #Load input feats
127 | batch = load_input_feats(id, feature_dir, config, num_clusts)
128 | for key in batch:
129 | batch[key] = np.reshape(batch[key], (1, *batch[key].shape))
130 |
131 | batch['num_iter_recycling'] = [num_recycles]
132 | ret = apply_fwd(ckpt_params, rng, batch)
133 | #Save structure
134 | save_feats = {'aatype':batch['aatype'], 'residue_index':batch['residue_index']}
135 | result = {'predicted_lddt':ret['predicted_lddt'],
136 | 'structure_module':{'final_atom_positions':ret['structure_module']['final_atom_positions'],
137 | 'final_atom_mask': ret['structure_module']['final_atom_mask']
138 | }}
139 | save_structure(save_feats, result, id+'_'+str(num_clusts)+'_'+str(i), outdir)
140 |
141 |
142 |
143 | def save_structure(save_feats, result, id, outdir):
144 | """Save prediction
145 |
146 | save_feats = {'aatype':batch['aatype'][0][0], 'residue_index':batch['residue_index'][0][0]}
147 | result = {'predicted_lddt':aux['predicted_lddt'],
148 | 'structure_module':{'final_atom_positions':aux['structure_module']['final_atom_positions'][0],
149 | 'final_atom_mask': aux['structure_module']['final_atom_mask'][0]
150 | }}
151 | save_structure(save_feats, result, step_num, outdir)
152 |
153 | """
154 | #Define the plDDT bins
155 | bin_width = 1.0 / 50
156 | bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width)
157 |
158 | # Add the predicted LDDT in the b-factor column.
159 | plddt_per_pos = jnp.sum(jax.nn.softmax(result['predicted_lddt']['logits']) * bin_centers[None, :], axis=-1)
160 | plddt_b_factors = np.repeat(plddt_per_pos[:, None], residue_constants.atom_type_num, axis=-1)
161 | unrelaxed_protein = protein.from_prediction(features=save_feats, result=result, b_factors=plddt_b_factors)
162 | unrelaxed_pdb = protein.to_pdb(unrelaxed_protein)
163 | unrelaxed_pdb_path = os.path.join(outdir+'/', id+'_pred.pdb')
164 | with open(unrelaxed_pdb_path, 'w') as f:
165 | f.write(unrelaxed_pdb)
166 |
167 |
168 |
169 | ##################MAIN#######################
170 |
171 | #Parse args
172 | args = parser.parse_args()
173 | feature_dir = args.feature_dir[0]
174 | predict_id = args.predict_id[0]
175 | ckpt_params = np.load(args.ckpt_params[0] , allow_pickle=True)
176 | num_recycles = args.num_recycles[0]
177 | num_samples_per_cluster = args.num_samples_per_cluster[0]
178 | outdir = args.outdir[0]
179 | #Predict
180 | predict(config.CONFIG,
181 | feature_dir,
182 | [predict_id],
183 | num_recycles,
184 | num_samples_per_cluster,
185 | ckpt_params=ckpt_params,
186 | outdir=outdir)
187 |
--------------------------------------------------------------------------------
/src/predict_with_clusters_colab.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import warnings
4 | import pathlib
5 | import pickle
6 | import random
7 | import sys
8 | import time
9 | from typing import Dict, Optional
10 | from typing import NamedTuple
11 | import haiku as hk
12 | import jax
13 | import jax.numpy as jnp
14 |
15 | #Silence tf
16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
17 | import tensorflow.compat.v1 as tf
18 | tf.config.set_visible_devices([], 'GPU')
19 |
20 | import argparse
21 | import pandas as pd
22 | import numpy as np
23 | from collections import Counter
24 | from scipy.special import softmax
25 | import pdb
26 |
27 | #AlphaFold imports
28 | from alphafold.common import protein
29 | from alphafold.common import residue_constants
30 | from alphafold.data import templates
31 | from alphafold.model import data
32 | from alphafold.model import config
33 | from alphafold.model import features
34 | from alphafold.model import modules
35 |
36 | #JAX will preallocate 90% of currently-available GPU memory when the first JAX operation is run.
37 | #This prevents this
38 | #os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
39 |
40 |
41 | ##############FUNCTIONS##############
42 | ##########INPUT DATA#########
43 |
44 | def process_features(raw_features, config, random_seed):
45 | """Processes features to prepare for feeding them into the model.
46 |
47 | Args:
48 | raw_features: The output of the data pipeline either as a dict of NumPy
49 | arrays or as a tf.train.Example.
50 | random_seed: The random seed to use when processing the features.
51 |
52 | Returns:
53 | A dict of NumPy feature arrays suitable for feeding into the model.
54 | """
55 | return features.np_example_to_features(np_example=raw_features,
56 | config=config,
57 | random_seed=random_seed)
58 |
59 |
60 | def load_input_feats(id, feature_dir, config, num_clusters):
61 | """
62 | Load all input feats.
63 | """
64 |
65 | #Load raw features
66 | msa_feature_dict = np.load(feature_dir+id+'/msa_features.pkl', allow_pickle=True)
67 | #Process the features on CPU (sample MSA)
68 | #Set the config to determine the number of clusters
69 | config.data.eval.max_msa_clusters = num_clusters
70 | #processed_feature_dict['msa_feat'].shape = num_clusts, L, 49
71 | processed_feature_dict = process_features(msa_feature_dict, config, np.random.choice(sys.maxsize))
72 |
73 | return processed_feature_dict
74 |
75 |
76 |
77 | ##########MODEL#########
78 |
79 | def predict(feature_dir,
80 | predict_id,
81 | num_recycles,
82 | num_samples_per_cluster,
83 | clusters,
84 | ckpt_params=None,
85 | outdir=None):
86 | """Predict a structure
87 | """
88 |
89 |
90 | #Does the config have to be updated here?
91 | #No - the clusters can be changed.
92 | #Define the forward function
93 | def _forward_fn(batch):
94 | '''Define the forward function - has to be a function for JAX
95 | '''
96 | model = modules.AlphaFold(config.CONFIG.model)
97 |
98 | return model(batch,
99 | is_training=False,
100 | compute_loss=False,
101 | ensemble_representations=False,
102 | return_representations=True)
103 |
104 | #The forward function is here transformed to apply and init functions which
105 | #can be called during training and initialisation (JAX needs functions)
106 | forward = hk.transform(_forward_fn)
107 | apply_fwd = forward.apply
108 | #Get a random key
109 | rng = jax.random.PRNGKey(42)
110 |
111 | for num_clusts in clusters:
112 | for i in range(num_samples_per_cluster):
113 | print('Predicting: number of clusters',num_clusts, 'Sample', i+1, '...')
114 | if os.path.exists(outdir+'/'+predict_id+'_'+str(num_clusts)+'_'+str(i)+'_pred.pdb'):
115 | print('Prediction',num_clusts, i+1, 'exists...')
116 | continue
117 |
118 | #Load input feats
119 | batch = load_input_feats(predict_id, feature_dir, config.CONFIG, num_clusts)
120 | for key in batch:
121 | batch[key] = np.reshape(batch[key], (1, *batch[key].shape))
122 |
123 | batch['num_iter_recycling'] = [num_recycles]
124 | ret = apply_fwd(ckpt_params, rng, batch)
125 | #Save structure
126 | save_feats = {'aatype':batch['aatype'], 'residue_index':batch['residue_index']}
127 | result = {'predicted_lddt':ret['predicted_lddt'],
128 | 'structure_module':{'final_atom_positions':ret['structure_module']['final_atom_positions'],
129 | 'final_atom_mask': ret['structure_module']['final_atom_mask']
130 | }}
131 | save_structure(save_feats, result, predict_id+'_'+str(num_clusts)+'_'+str(i), outdir)
132 |
133 |
134 |
135 | def save_structure(save_feats, result, id, outdir):
136 | """Save prediction
137 |
138 | save_feats = {'aatype':batch['aatype'][0][0], 'residue_index':batch['residue_index'][0][0]}
139 | result = {'predicted_lddt':aux['predicted_lddt'],
140 | 'structure_module':{'final_atom_positions':aux['structure_module']['final_atom_positions'][0],
141 | 'final_atom_mask': aux['structure_module']['final_atom_mask'][0]
142 | }}
143 | save_structure(save_feats, result, step_num, outdir)
144 |
145 | """
146 | #Define the plDDT bins
147 | bin_width = 1.0 / 50
148 | bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width)
149 |
150 | # Add the predicted LDDT in the b-factor column.
151 | plddt_per_pos = jnp.sum(jax.nn.softmax(result['predicted_lddt']['logits']) * bin_centers[None, :], axis=-1)
152 | plddt_b_factors = np.repeat(plddt_per_pos[:, None], residue_constants.atom_type_num, axis=-1)
153 | unrelaxed_protein = protein.from_prediction(features=save_feats, result=result, b_factors=plddt_b_factors)
154 | unrelaxed_pdb = protein.to_pdb(unrelaxed_protein)
155 | unrelaxed_pdb_path = os.path.join(outdir+'/', id+'_pred.pdb')
156 | with open(unrelaxed_pdb_path, 'w') as f:
157 | f.write(unrelaxed_pdb)
158 |
159 |
160 |
161 | ##################MAIN#######################
162 |
163 | # #Predict
164 | # predict(feature_dir,
165 | # [predict_id],
166 | # num_recycles,
167 | # num_samples_per_cluster,
168 | # ckpt_params=ckpt_params,
169 | # outdir=outdir)
170 |
--------------------------------------------------------------------------------
/src/score_closest_ca.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import pdb
4 |
5 | import matplotlib.pyplot as plt
6 | import pandas as pd
7 | import numpy as np
8 | from collections import defaultdict
9 | from Bio.SVDSuperimposer import SVDSuperimposer
10 | import shutil
11 | import glob
12 | import argparse
13 |
14 |
15 | parser = argparse.ArgumentParser(description = '''Structurally align all files in a directory to a reference and write new roto-translated pdb files''')
16 | parser.add_argument('--pdbdir', nargs=1, type= str, required=True, help = "Path to directory with PDB files")
17 | parser.add_argument('--outdir', nargs=1, type= str, help = 'Outdir.')
18 |
19 |
20 | #################FUNCTIONS#################
21 |
22 |
23 | def parse_atm_record(line):
24 | '''Get the atm record
25 | '''
26 | record = defaultdict()
27 | record['name'] = line[0:6].strip()
28 | record['atm_no'] = int(line[6:11])
29 | record['atm_name'] = line[12:16].strip()
30 | record['atm_alt'] = line[17]
31 | record['res_name'] = line[17:20].strip()
32 | record['chain'] = line[21]
33 | record['res_no'] = int(line[22:26])
34 | record['insert'] = line[26].strip()
35 | record['resid'] = line[22:29]
36 | record['x'] = float(line[30:38])
37 | record['y'] = float(line[38:46])
38 | record['z'] = float(line[46:54])
39 | record['occ'] = float(line[54:60])
40 | record['B'] = float(line[60:66])
41 |
42 | return record
43 |
44 | def read_pdb(pdbfile):
45 | '''Read a pdb file per chain
46 | '''
47 | pdb_file_info = []
48 | all_coords = []
49 | CA_coords = []
50 |
51 | with open(pdbfile) as file:
52 | for line in file:
53 | if line.startswith('ATOM'):
54 | #Parse line
55 | record = parse_atm_record(line)
56 | #Save line
57 | pdb_file_info.append(line.rstrip())
58 | #Save coords
59 | all_coords.append([record['x'],record['y'],record['z']])
60 | if record['atm_name']=='CA':
61 | CA_coords.append([record['x'],record['y'],record['z']])
62 |
63 |
64 | return pdb_file_info, np.array(all_coords), np.array(CA_coords)
65 |
66 | def score_ca_diff(all_ca_coords):
67 | """Score all CA coords against each other
68 | """
69 |
70 | num_samples = len(all_ca_coords)
71 | score_mat = np.zeros((num_samples, num_samples))
72 | for i in range(num_samples):
73 | dmat_i = np.sqrt(np.sum((all_ca_coords[i][:,None]-all_ca_coords[i][None,:])**2,axis=-1))
74 | score_mat[i,i]=1000
75 | for j in range(i+1, num_samples):
76 | dmat_j = np.sqrt(np.sum((all_ca_coords[j][:,None]-all_ca_coords[j][None,:])**2,axis=-1))
77 | #Diff
78 | diff = np.mean(np.sqrt((dmat_i-dmat_j)**2))
79 | score_mat[i,j] = diff
80 | score_mat[j,i] = diff
81 |
82 | #Establish the order from the score mat
83 | order = []
84 | min_ind = np.unravel_index(score_mat.argmin(), score_mat.shape)
85 | all_inds = np.arange(num_samples)
86 | order.append(min_ind[0])
87 | ci = min_ind[1]
88 | while len(order)