├── .gitignore ├── Cfold.ipynb ├── Logo.svg ├── README.md ├── cfold.ipynb ├── data └── test │ ├── 4AVA.a3m │ └── 4AVA.fasta ├── environment.yml ├── install_dependencies.sh ├── pip_pkgs.txt ├── predict.sh └── src ├── alphafold ├── .DS_Store ├── common │ ├── __init__.py │ ├── confidence.py │ ├── protein.py │ ├── protein_test.py │ ├── residue_constants.py │ ├── residue_constants_test.py │ └── testdata │ │ └── 2rbg.pdb ├── data │ ├── __init__.py │ ├── mmcif_parsing.py │ ├── parsers.py │ ├── pipeline.py │ ├── templates.py │ └── tools │ │ ├── __init__.py │ │ ├── hhblits.py │ │ ├── hhsearch.py │ │ ├── hmmbuild.py │ │ ├── hmmsearch.py │ │ ├── jackhmmer.py │ │ ├── kalign.py │ │ └── utils.py ├── model │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-312.pyc │ │ ├── all_atom.cpython-312.pyc │ │ ├── common_modules.cpython-312.pyc │ │ ├── config.cpython-312.pyc │ │ ├── data.cpython-312.pyc │ │ ├── features.cpython-312.pyc │ │ ├── folding.cpython-312.pyc │ │ ├── layer_stack.cpython-312.pyc │ │ ├── lddt.cpython-312.pyc │ │ ├── mapping.cpython-312.pyc │ │ ├── modules.cpython-312.pyc │ │ ├── prng.cpython-312.pyc │ │ ├── quat_affine.cpython-312.pyc │ │ ├── r3.cpython-312.pyc │ │ └── utils.cpython-312.pyc │ ├── all_atom.py │ ├── all_atom_test.py │ ├── common_modules.py │ ├── config.py │ ├── data.py │ ├── features.py │ ├── folding.py │ ├── layer_stack.py │ ├── layer_stack_test.py │ ├── lddt.py │ ├── lddt_test.py │ ├── mapping.py │ ├── model.py │ ├── modules.py │ ├── prng.py │ ├── prng_test.py │ ├── quat_affine.py │ ├── quat_affine_test.py │ ├── r3.py │ ├── tf │ │ ├── __init__.py │ │ ├── data_transforms.py │ │ ├── input_pipeline.py │ │ ├── protein_features.py │ │ ├── protein_features_test.py │ │ ├── proteins_dataset.py │ │ ├── shape_helpers.py │ │ ├── shape_helpers_test.py │ │ ├── shape_placeholders.py │ │ └── utils.py │ └── utils.py └── relax │ ├── __init__.py │ ├── amber_minimize.py │ ├── amber_minimize_test.py │ ├── cleanup.py │ ├── cleanup_test.py │ ├── relax.py │ ├── relax_test.py │ ├── testdata │ ├── model_output.pdb │ ├── multiple_disulfides_target.pdb │ ├── with_violations.pdb │ └── with_violations_casp14.pdb │ ├── utils.py │ └── utils_test.py ├── animate_py3dmol.py ├── check_msa_colab.py ├── make_msa_seq_feats.py ├── make_msa_seq_feats_colab.py ├── predict_with_clusters.py ├── predict_with_clusters_colab.py ├── score_closest_ca.py └── score_closest_ca_colab.py /.gitignore: -------------------------------------------------------------------------------- 1 | src/alphafold/common/__pycache__ 2 | src/alphafold/data/__pycache__ 3 | src/alphafold/data/tools/__pycache__ 4 | src/alphafold/model/__pycache__ 5 | src/alphafold/model/tf/__pycache__ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cfold 2 | 3 | **Structure prediction of alternative protein conformations** 4 | 5 | 6 | 7 | 8 | 9 | \ 10 | Cfold is a structure prediction network similar to AlphaFold2 that is trained on a conformational split of the PDB. Cfold is designed for predicting alternative conformations of protein structures. [Read more about it in the paper here](https://www.nature.com/articles/s41467-024-51507-2) 11 | \ 12 | \ 13 | AlphaFold2 is available under the [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0) and so is Cfold, which is a derivative thereof. The Cfold parameters are made available under the terms of the [CC BY 4.0 license](https://creativecommons.org/licenses/by/4.0/legalcode). 14 | \ 15 | \ 16 | **You may not use these files except in compliance with the licenses.** 17 | 18 | # Colab (run in the web) 19 | 20 | [Colab Notebook](https://colab.research.google.com/github/patrickbryant1/Cfold/blob/master/Cfold.ipynb) 21 | 22 | # Local installation 23 | 24 | The entire installation takes <1 hour on a standard computer. \ 25 | The runtime will depend on the GPU you have available, the size of the protein 26 | you are predicting and the number of samples taken. On an NVIDIA A100 GPU, the 27 | prediction time is a few minutes per sample for a protein of a few hundred amino acids. 28 | 29 | We assume you have CUDA12. For CUDA11, you will have to change the installation of some packages. \ 30 | 31 | First install miniconda, see: https://docs.conda.io/projects/miniconda/en/latest/miniconda-other-installer-links.html 32 | 33 | ``` 34 | bash install_dependencies.sh 35 | ``` 36 | If the conda doesn't work for you - see "pip_pkgs.txt" 37 | 38 | # Run the test case 39 | ## (a few minutes) 40 | ``` 41 | bash predict.sh 42 | ``` 43 | 44 | # Data availability 45 | https://zenodo.org/records/10837082 46 | 47 | # Citation 48 | Bryant P, Noé F. Structure prediction of alternative protein conformations. Nat Commun 15, 7328 (2024). 49 | -------------------------------------------------------------------------------- /data/test/4AVA.fasta: -------------------------------------------------------------------------------- 1 | >4AVA 2 | DGIAELTGARVEDLAGMDVFQGCPAEGLVSLAASVQPLRAAAGQVLLRQGEPAVSFLLISSGSAEVSHVGDDGVAIIARALPGMIVGEIALLRDSPRSATVTTIEPLTGWTGGRGAFATMVHIPGVGERLLRTARQRLAAFVSPIPVRLADGTQLMLRPVLPGDRERTVHGHIQFSGETLYRRFMSPALMHYLSEVDYVDHFVWVVTDGSDPVADARFVRDETDPTVAEIAFTVADAYQGRGIGSFLIGALSVAARVDGVERFAARMLSDNVPMRTIMDRYGAVWQREDVGVITTMIDVPGPGELSLGREMVDQINRVARQVIEAVG 3 | -------------------------------------------------------------------------------- /install_dependencies.sh: -------------------------------------------------------------------------------- 1 | 2 | #Python packages 3 | conda env create -f environment.yml 4 | 5 | wait 6 | conda activate cfold 7 | pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 8 | conda deactivate 9 | 10 | ## Get network parameters for Cfold 11 | wget https://zenodo.org/records/10517910/files/params10000.npy 12 | 13 | 14 | ## Uniclust30 15 | wget http://wwwuser.gwdg.de/~compbiol/uniclust/2018_08/uniclust30_2018_08_hhsuite.tar.gz --no-check-certificate 16 | mkdir data/uniclust30 17 | mv uniclust30_2018_08_hhsuite.tar.gz data/uniclust30 18 | tar -zxvf data/uniclust30/uniclust30_2018_08_hhsuite.tar.gz 19 | 20 | 21 | ## HHblits 22 | git clone https://github.com/soedinglab/hh-suite.git 23 | mkdir -p hh-suite/build && cd hh-suite/build 24 | cmake -DCMAKE_INSTALL_PREFIX=. .. 25 | make -j 4 && make install 26 | cd .. 27 | -------------------------------------------------------------------------------- /pip_pkgs.txt: -------------------------------------------------------------------------------- 1 | pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 2 | pip install ml-collections==0.1.1 3 | pip install dm-haiku==0.0.11 4 | pip install pandas==1.3.5 5 | pip install biopython==1.81 6 | pip install chex==0.1.5 7 | pip install dm-tree==0.1.8 8 | pip install immutabledict==2.0.0 9 | pip install scipy 10 | pip install tensorflow 11 | pip install numpy==1.21.6 12 | -------------------------------------------------------------------------------- /predict.sh: -------------------------------------------------------------------------------- 1 | ## Search Uniclust30 with HHblits to generate an MSA 2 | 3 | ID=4AVA 4 | FASTA_DIR=./data/test/ 5 | UNICLUST=./data/uniclust30_2018_08/uniclust30_2018_08 6 | OUTDIR=./data/test/ 7 | ./hh-suite/build/bin/hhblits -i $FASTA_DIR/$ID.fasta -d $UNICLUST -E 0.001 -all -oa3m $OUTDIR/$ID'.a3m' 8 | 9 | 10 | ## MSA feats 11 | conda activate cfold 12 | MSA_DIR=./data/test/ 13 | OUTDIR=./data/test/ 14 | 15 | python3 ./src/make_msa_seq_feats.py --input_fasta_path $FASTA_DIR/$ID'.fasta' \ 16 | --input_msas $MSA_DIR/$ID'.a3m' --outdir $OUTDIR 17 | 18 | 19 | ## Predict 20 | FEATURE_DIR=./data/test/ 21 | PARAMS=./params10000.npy 22 | OUTDIR=./data/test/ 23 | NUM_REC=3 #Increase for hard targets 24 | NUM_SAMPLES=13 #Increase for hard targets 25 | 26 | python3 ./src/predict_with_clusters.py --feature_dir $FEATURE_DIR \ 27 | --predict_id $ID \ 28 | --ckpt_params $PARAMS \ 29 | --num_recycles $NUM_REC \ 30 | --num_samples_per_cluster $NUM_SAMPLES \ 31 | --outdir $OUTDIR/ 32 | -------------------------------------------------------------------------------- /src/alphafold/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/.DS_Store -------------------------------------------------------------------------------- /src/alphafold/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common data types and constants used within Alphafold.""" 15 | -------------------------------------------------------------------------------- /src/alphafold/common/confidence.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for processing confidence metrics.""" 16 | 17 | from typing import Dict, Optional, Tuple 18 | import numpy as np 19 | import scipy.special 20 | 21 | 22 | def compute_plddt(logits: np.ndarray) -> np.ndarray: 23 | """Computes per-residue pLDDT from logits. 24 | 25 | Args: 26 | logits: [num_res, num_bins] output from the PredictedLDDTHead. 27 | 28 | Returns: 29 | plddt: [num_res] per-residue pLDDT. 30 | """ 31 | num_bins = logits.shape[-1] 32 | bin_width = 1.0 / num_bins 33 | bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width) 34 | probs = scipy.special.softmax(logits, axis=-1) 35 | predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1) 36 | return predicted_lddt_ca * 100 37 | 38 | 39 | def _calculate_bin_centers(breaks: np.ndarray): 40 | """Gets the bin centers from the bin edges. 41 | 42 | Args: 43 | breaks: [num_bins - 1] the error bin edges. 44 | 45 | Returns: 46 | bin_centers: [num_bins] the error bin centers. 47 | """ 48 | step = (breaks[1] - breaks[0]) 49 | 50 | # Add half-step to get the center 51 | bin_centers = breaks + step / 2 52 | # Add a catch-all bin at the end. 53 | bin_centers = np.concatenate([bin_centers, [bin_centers[-1] + step]], 54 | axis=0) 55 | return bin_centers 56 | 57 | 58 | def _calculate_expected_aligned_error( 59 | alignment_confidence_breaks: np.ndarray, 60 | aligned_distance_error_probs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 61 | """Calculates expected aligned distance errors for every pair of residues. 62 | 63 | Args: 64 | alignment_confidence_breaks: [num_bins - 1] the error bin edges. 65 | aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted 66 | probs for each error bin, for each pair of residues. 67 | 68 | Returns: 69 | predicted_aligned_error: [num_res, num_res] the expected aligned distance 70 | error for each pair of residues. 71 | max_predicted_aligned_error: The maximum predicted error possible. 72 | """ 73 | bin_centers = _calculate_bin_centers(alignment_confidence_breaks) 74 | 75 | # Tuple of expected aligned distance error and max possible error. 76 | return (np.sum(aligned_distance_error_probs * bin_centers, axis=-1), 77 | np.asarray(bin_centers[-1])) 78 | 79 | 80 | def compute_predicted_aligned_error( 81 | logits: np.ndarray, 82 | breaks: np.ndarray) -> Dict[str, np.ndarray]: 83 | """Computes aligned confidence metrics from logits. 84 | 85 | Args: 86 | logits: [num_res, num_res, num_bins] the logits output from 87 | PredictedAlignedErrorHead. 88 | breaks: [num_bins - 1] the error bin edges. 89 | 90 | Returns: 91 | aligned_confidence_probs: [num_res, num_res, num_bins] the predicted 92 | aligned error probabilities over bins for each residue pair. 93 | predicted_aligned_error: [num_res, num_res] the expected aligned distance 94 | error for each pair of residues. 95 | max_predicted_aligned_error: The maximum predicted error possible. 96 | """ 97 | aligned_confidence_probs = scipy.special.softmax( 98 | logits, 99 | axis=-1) 100 | predicted_aligned_error, max_predicted_aligned_error = ( 101 | _calculate_expected_aligned_error( 102 | alignment_confidence_breaks=breaks, 103 | aligned_distance_error_probs=aligned_confidence_probs)) 104 | return { 105 | 'aligned_confidence_probs': aligned_confidence_probs, 106 | 'predicted_aligned_error': predicted_aligned_error, 107 | 'max_predicted_aligned_error': max_predicted_aligned_error, 108 | } 109 | 110 | 111 | def predicted_tm_score( 112 | logits: np.ndarray, 113 | breaks: np.ndarray, 114 | residue_weights: Optional[np.ndarray] = None) -> np.ndarray: 115 | """Computes predicted TM alignment score. 116 | 117 | Args: 118 | logits: [num_res, num_res, num_bins] the logits output from 119 | PredictedAlignedErrorHead. 120 | breaks: [num_bins] the error bins. 121 | residue_weights: [num_res] the per residue weights to use for the 122 | expectation. 123 | 124 | Returns: 125 | ptm_score: the predicted TM alignment score. 126 | """ 127 | 128 | # residue_weights has to be in [0, 1], but can be floating-point, i.e. the 129 | # exp. resolved head's probability. 130 | if residue_weights is None: 131 | residue_weights = np.ones(logits.shape[0]) 132 | 133 | bin_centers = _calculate_bin_centers(breaks) 134 | 135 | num_res = np.sum(residue_weights) 136 | # Clip num_res to avoid negative/undefined d0. 137 | clipped_num_res = max(num_res, 19) 138 | 139 | # Compute d_0(num_res) as defined by TM-score, eqn. (5) in 140 | # http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf 141 | # Yang & Skolnick "Scoring function for automated 142 | # assessment of protein structure template quality" 2004 143 | d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8 144 | 145 | # Convert logits to probs 146 | probs = scipy.special.softmax(logits, axis=-1) 147 | 148 | # TM-Score term for every bin 149 | tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0)) 150 | # E_distances tm(distance) 151 | predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1) 152 | 153 | normed_residue_mask = residue_weights / (1e-8 + residue_weights.sum()) 154 | per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1) 155 | return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()]) 156 | -------------------------------------------------------------------------------- /src/alphafold/common/protein.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Protein data type.""" 16 | import dataclasses 17 | import io 18 | from typing import Any, Mapping, Optional 19 | from alphafold.common import residue_constants 20 | from Bio.PDB import PDBParser 21 | import numpy as np 22 | 23 | FeatureDict = Mapping[str, np.ndarray] 24 | ModelOutput = Mapping[str, Any] # Is a nested dict. 25 | 26 | 27 | @dataclasses.dataclass(frozen=True) 28 | class Protein: 29 | """Protein structure representation.""" 30 | 31 | # Cartesian coordinates of atoms in angstroms. The atom types correspond to 32 | # residue_constants.atom_types, i.e. the first three are N, CA, CB. 33 | atom_positions: np.ndarray # [num_res, num_atom_type, 3] 34 | 35 | # Amino-acid type for each residue represented as an integer between 0 and 36 | # 20, where 20 is 'X'. 37 | aatype: np.ndarray # [num_res] 38 | 39 | # Binary float mask to indicate presence of a particular atom. 1.0 if an atom 40 | # is present and 0.0 if not. This should be used for loss masking. 41 | atom_mask: np.ndarray # [num_res, num_atom_type] 42 | 43 | # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. 44 | residue_index: np.ndarray # [num_res] 45 | 46 | # B-factors, or temperature factors, of each residue (in sq. angstroms units), 47 | # representing the displacement of the residue from its ground truth mean 48 | # value. 49 | b_factors: np.ndarray # [num_res, num_atom_type] 50 | 51 | 52 | def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: 53 | """Takes a PDB string and constructs a Protein object. 54 | 55 | WARNING: All non-standard residue types will be converted into UNK. All 56 | non-standard atoms will be ignored. 57 | 58 | Args: 59 | pdb_str: The contents of the pdb file 60 | chain_id: If None, then the pdb file must contain a single chain (which 61 | will be parsed). If chain_id is specified (e.g. A), then only that chain 62 | is parsed. 63 | 64 | Returns: 65 | A new `Protein` parsed from the pdb contents. 66 | """ 67 | pdb_fh = io.StringIO(pdb_str) 68 | parser = PDBParser(QUIET=True) 69 | structure = parser.get_structure('none', pdb_fh) 70 | models = list(structure.get_models()) 71 | if len(models) != 1: 72 | raise ValueError( 73 | f'Only single model PDBs are supported. Found {len(models)} models.') 74 | model = models[0] 75 | 76 | if chain_id is not None: 77 | chain = model[chain_id] 78 | else: 79 | chains = list(model.get_chains()) 80 | if len(chains) != 1: 81 | raise ValueError( 82 | 'Only single chain PDBs are supported when chain_id not specified. ' 83 | f'Found {len(chains)} chains.') 84 | else: 85 | chain = chains[0] 86 | 87 | atom_positions = [] 88 | aatype = [] 89 | atom_mask = [] 90 | residue_index = [] 91 | b_factors = [] 92 | 93 | for res in chain: 94 | if res.id[2] != ' ': 95 | raise ValueError( 96 | f'PDB contains an insertion code at chain {chain.id} and residue ' 97 | f'index {res.id[1]}. These are not supported.') 98 | res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') 99 | restype_idx = residue_constants.restype_order.get( 100 | res_shortname, residue_constants.restype_num) 101 | pos = np.zeros((residue_constants.atom_type_num, 3)) 102 | mask = np.zeros((residue_constants.atom_type_num,)) 103 | res_b_factors = np.zeros((residue_constants.atom_type_num,)) 104 | for atom in res: 105 | if atom.name not in residue_constants.atom_types: 106 | continue 107 | pos[residue_constants.atom_order[atom.name]] = atom.coord 108 | mask[residue_constants.atom_order[atom.name]] = 1. 109 | res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor 110 | if np.sum(mask) < 0.5: 111 | # If no known atom positions are reported for the residue then skip it. 112 | continue 113 | aatype.append(restype_idx) 114 | atom_positions.append(pos) 115 | atom_mask.append(mask) 116 | residue_index.append(res.id[1]) 117 | b_factors.append(res_b_factors) 118 | 119 | return Protein( 120 | atom_positions=np.array(atom_positions), 121 | atom_mask=np.array(atom_mask), 122 | aatype=np.array(aatype), 123 | residue_index=np.array(residue_index), 124 | b_factors=np.array(b_factors)) 125 | 126 | 127 | def to_pdb(prot: Protein) -> str: 128 | """Converts a `Protein` instance to a PDB string. 129 | 130 | Args: 131 | prot: The protein to convert to PDB. 132 | 133 | Returns: 134 | PDB string. 135 | """ 136 | restypes = residue_constants.restypes + ['X'] 137 | res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') 138 | atom_types = residue_constants.atom_types 139 | 140 | pdb_lines = [] 141 | 142 | atom_mask = prot.atom_mask 143 | aatype = prot.aatype 144 | atom_positions = prot.atom_positions 145 | residue_index = prot.residue_index.astype(np.int32) 146 | b_factors = prot.b_factors 147 | 148 | if np.any(aatype > residue_constants.restype_num): 149 | raise ValueError('Invalid aatypes.') 150 | 151 | pdb_lines.append('MODEL 1') 152 | atom_index = 1 153 | chain_id = 'A' 154 | # Add all atom sites. 155 | for i in range(aatype.shape[0]): 156 | res_name_3 = res_1to3(aatype[i]) 157 | for atom_name, pos, mask, b_factor in zip( 158 | atom_types, atom_positions[i], atom_mask[i], b_factors[i]): 159 | if mask < 0.5: 160 | continue 161 | 162 | record_type = 'ATOM' 163 | name = atom_name if len(atom_name) == 4 else f' {atom_name}' 164 | alt_loc = '' 165 | insertion_code = '' 166 | occupancy = 1.00 167 | element = atom_name[0] # Protein supports only C, N, O, S, this works. 168 | charge = '' 169 | # PDB is a columnar format, every space matters here! 170 | atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' 171 | f'{res_name_3:>3} {chain_id:>1}' 172 | f'{residue_index[i]:>4}{insertion_code:>1} ' 173 | f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' 174 | f'{occupancy:>6.2f}{b_factor:>6.2f} ' 175 | f'{element:>2}{charge:>2}') 176 | pdb_lines.append(atom_line) 177 | atom_index += 1 178 | 179 | # Close the chain. 180 | chain_end = 'TER' 181 | chain_termination_line = ( 182 | f'{chain_end:<6}{atom_index:>5} {res_1to3(aatype[-1]):>3} ' 183 | f'{chain_id:>1}{residue_index[-1]:>4}') 184 | pdb_lines.append(chain_termination_line) 185 | pdb_lines.append('ENDMDL') 186 | 187 | pdb_lines.append('END') 188 | pdb_lines.append('') 189 | return '\n'.join(pdb_lines) 190 | 191 | 192 | def ideal_atom_mask(prot: Protein) -> np.ndarray: 193 | """Computes an ideal atom mask. 194 | 195 | `Protein.atom_mask` typically is defined according to the atoms that are 196 | reported in the PDB. This function computes a mask according to heavy atoms 197 | that should be present in the given sequence of amino acids. 198 | 199 | Args: 200 | prot: `Protein` whose fields are `numpy.ndarray` objects. 201 | 202 | Returns: 203 | An ideal atom mask. 204 | """ 205 | return residue_constants.STANDARD_ATOM_MASK[prot.aatype] 206 | 207 | 208 | def from_prediction(features: FeatureDict, result: ModelOutput, 209 | b_factors: Optional[np.ndarray] = None) -> Protein: 210 | """Assembles a protein from a prediction. 211 | 212 | Args: 213 | features: Dictionary holding model inputs. 214 | result: Dictionary holding model outputs. 215 | b_factors: (Optional) B-factors to use for the protein. 216 | 217 | Returns: 218 | A protein instance. 219 | """ 220 | fold_output = result['structure_module'] 221 | if b_factors is None: 222 | b_factors = np.zeros_like(fold_output['final_atom_mask']) 223 | 224 | return Protein( 225 | aatype=features['aatype'][0], 226 | atom_positions=fold_output['final_atom_positions'], 227 | atom_mask=fold_output['final_atom_mask'], 228 | residue_index=features['residue_index'][0] + 1, 229 | b_factors=b_factors) 230 | -------------------------------------------------------------------------------- /src/alphafold/common/protein_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for protein.""" 16 | 17 | import os 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from alphafold.common import protein 22 | from alphafold.common import residue_constants 23 | import numpy as np 24 | # Internal import (7716). 25 | 26 | TEST_DATA_DIR = 'alphafold/common/testdata/' 27 | 28 | 29 | class ProteinTest(parameterized.TestCase): 30 | 31 | def _check_shapes(self, prot, num_res): 32 | """Check that the processed shapes are correct.""" 33 | num_atoms = residue_constants.atom_type_num 34 | self.assertEqual((num_res, num_atoms, 3), prot.atom_positions.shape) 35 | self.assertEqual((num_res,), prot.aatype.shape) 36 | self.assertEqual((num_res, num_atoms), prot.atom_mask.shape) 37 | self.assertEqual((num_res,), prot.residue_index.shape) 38 | self.assertEqual((num_res, num_atoms), prot.b_factors.shape) 39 | 40 | @parameterized.parameters(('2rbg.pdb', 'A', 282), 41 | ('2rbg.pdb', 'B', 282)) 42 | def test_from_pdb_str(self, pdb_file, chain_id, num_res): 43 | pdb_file = os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, 44 | pdb_file) 45 | with open(pdb_file) as f: 46 | pdb_string = f.read() 47 | prot = protein.from_pdb_string(pdb_string, chain_id) 48 | self._check_shapes(prot, num_res) 49 | self.assertGreaterEqual(prot.aatype.min(), 0) 50 | # Allow equal since unknown restypes have index equal to restype_num. 51 | self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num) 52 | 53 | def test_to_pdb(self): 54 | with open( 55 | os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, 56 | '2rbg.pdb')) as f: 57 | pdb_string = f.read() 58 | prot = protein.from_pdb_string(pdb_string, chain_id='A') 59 | pdb_string_reconstr = protein.to_pdb(prot) 60 | prot_reconstr = protein.from_pdb_string(pdb_string_reconstr) 61 | 62 | np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype) 63 | np.testing.assert_array_almost_equal( 64 | prot_reconstr.atom_positions, prot.atom_positions) 65 | np.testing.assert_array_almost_equal( 66 | prot_reconstr.atom_mask, prot.atom_mask) 67 | np.testing.assert_array_equal( 68 | prot_reconstr.residue_index, prot.residue_index) 69 | np.testing.assert_array_almost_equal( 70 | prot_reconstr.b_factors, prot.b_factors) 71 | 72 | def test_ideal_atom_mask(self): 73 | with open( 74 | os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, 75 | '2rbg.pdb')) as f: 76 | pdb_string = f.read() 77 | prot = protein.from_pdb_string(pdb_string, chain_id='A') 78 | ideal_mask = protein.ideal_atom_mask(prot) 79 | non_ideal_residues = set([102] + list(range(127, 285))) 80 | for i, (res, atom_mask) in enumerate( 81 | zip(prot.residue_index, prot.atom_mask)): 82 | if res in non_ideal_residues: 83 | self.assertFalse(np.all(atom_mask == ideal_mask[i]), msg=f'{res}') 84 | else: 85 | self.assertTrue(np.all(atom_mask == ideal_mask[i]), msg=f'{res}') 86 | 87 | 88 | if __name__ == '__main__': 89 | absltest.main() 90 | -------------------------------------------------------------------------------- /src/alphafold/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Data pipeline for model features.""" 15 | -------------------------------------------------------------------------------- /src/alphafold/data/pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for building the input features for the AlphaFold model.""" 16 | 17 | import os 18 | from typing import Mapping, Optional, Sequence 19 | from absl import logging 20 | from alphafold.common import residue_constants 21 | from alphafold.data import parsers 22 | from alphafold.data import templates 23 | from alphafold.data.tools import hhblits 24 | from alphafold.data.tools import hhsearch 25 | from alphafold.data.tools import jackhmmer 26 | import numpy as np 27 | 28 | # Internal import (7716). 29 | 30 | FeatureDict = Mapping[str, np.ndarray] 31 | 32 | 33 | def make_sequence_features( 34 | sequence: str, description: str, num_res: int) -> FeatureDict: 35 | """Constructs a feature dict of sequence features.""" 36 | features = {} 37 | features['aatype'] = residue_constants.sequence_to_onehot( 38 | sequence=sequence, 39 | mapping=residue_constants.restype_order_with_x, 40 | map_unknown_to_x=True) 41 | features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32) 42 | features['domain_name'] = np.array([description.encode('utf-8')], 43 | dtype=np.object_) 44 | features['residue_index'] = np.array(range(num_res), dtype=np.int32) 45 | features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) 46 | features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_) 47 | return features 48 | 49 | 50 | def make_msa_features( 51 | msas: Sequence[Sequence[str]], 52 | deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict: 53 | """Constructs a feature dict of MSA features.""" 54 | if not msas: 55 | raise ValueError('At least one MSA must be provided.') 56 | 57 | int_msa = [] 58 | deletion_matrix = [] 59 | seen_sequences = set() 60 | for msa_index, msa in enumerate(msas): 61 | if not msa: 62 | raise ValueError(f'MSA {msa_index} must contain at least one sequence.') 63 | for sequence_index, sequence in enumerate(msa): 64 | if sequence in seen_sequences: 65 | continue 66 | seen_sequences.add(sequence) 67 | int_msa.append( 68 | [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) 69 | deletion_matrix.append(deletion_matrices[msa_index][sequence_index]) 70 | 71 | num_res = len(msas[0][0]) 72 | num_alignments = len(int_msa) 73 | features = {} 74 | features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) 75 | features['msa'] = np.array(int_msa, dtype=np.int32) 76 | features['num_alignments'] = np.array( 77 | [num_alignments] * num_res, dtype=np.int32) 78 | return features 79 | 80 | 81 | class DataPipeline: 82 | """Runs the alignment tools and assembles the input features.""" 83 | 84 | def __init__(self, 85 | jackhmmer_binary_path: str, 86 | hhblits_binary_path: str, 87 | hhsearch_binary_path: str, 88 | uniref90_database_path: str, 89 | mgnify_database_path: str, 90 | bfd_database_path: Optional[str], 91 | uniclust30_database_path: Optional[str], 92 | small_bfd_database_path: Optional[str], 93 | pdb70_database_path: str, 94 | template_featurizer: templates.TemplateHitFeaturizer, 95 | use_small_bfd: bool, 96 | mgnify_max_hits: int = 501, 97 | uniref_max_hits: int = 10000): 98 | """Constructs a feature dict for a given FASTA file.""" 99 | self._use_small_bfd = use_small_bfd 100 | self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( 101 | binary_path=jackhmmer_binary_path, 102 | database_path=uniref90_database_path) 103 | if use_small_bfd: 104 | self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer( 105 | binary_path=jackhmmer_binary_path, 106 | database_path=small_bfd_database_path) 107 | else: 108 | self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( 109 | binary_path=hhblits_binary_path, 110 | databases=[bfd_database_path, uniclust30_database_path]) 111 | self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( 112 | binary_path=jackhmmer_binary_path, 113 | database_path=mgnify_database_path) 114 | self.hhsearch_pdb70_runner = hhsearch.HHSearch( 115 | binary_path=hhsearch_binary_path, 116 | databases=[pdb70_database_path]) 117 | self.template_featurizer = template_featurizer 118 | self.mgnify_max_hits = mgnify_max_hits 119 | self.uniref_max_hits = uniref_max_hits 120 | 121 | def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: 122 | """Runs alignment tools on the input sequence and creates features.""" 123 | with open(input_fasta_path) as f: 124 | input_fasta_str = f.read() 125 | input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) 126 | if len(input_seqs) != 1: 127 | raise ValueError( 128 | f'More than one input sequence found in {input_fasta_path}.') 129 | input_sequence = input_seqs[0] 130 | input_description = input_descs[0] 131 | num_res = len(input_sequence) 132 | 133 | jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query( 134 | input_fasta_path)[0] 135 | jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query( 136 | input_fasta_path)[0] 137 | 138 | uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( 139 | jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits) 140 | hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m) 141 | 142 | uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') 143 | with open(uniref90_out_path, 'w') as f: 144 | f.write(jackhmmer_uniref90_result['sto']) 145 | 146 | mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') 147 | with open(mgnify_out_path, 'w') as f: 148 | f.write(jackhmmer_mgnify_result['sto']) 149 | 150 | pdb70_out_path = os.path.join(msa_output_dir, 'pdb70_hits.hhr') 151 | with open(pdb70_out_path, 'w') as f: 152 | f.write(hhsearch_result) 153 | 154 | uniref90_msa, uniref90_deletion_matrix, _ = parsers.parse_stockholm( 155 | jackhmmer_uniref90_result['sto']) 156 | mgnify_msa, mgnify_deletion_matrix, _ = parsers.parse_stockholm( 157 | jackhmmer_mgnify_result['sto']) 158 | hhsearch_hits = parsers.parse_hhr(hhsearch_result) 159 | mgnify_msa = mgnify_msa[:self.mgnify_max_hits] 160 | mgnify_deletion_matrix = mgnify_deletion_matrix[:self.mgnify_max_hits] 161 | 162 | if self._use_small_bfd: 163 | jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query( 164 | input_fasta_path)[0] 165 | 166 | bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.a3m') 167 | with open(bfd_out_path, 'w') as f: 168 | f.write(jackhmmer_small_bfd_result['sto']) 169 | 170 | bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm( 171 | jackhmmer_small_bfd_result['sto']) 172 | else: 173 | hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query( 174 | input_fasta_path) 175 | 176 | bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') 177 | with open(bfd_out_path, 'w') as f: 178 | f.write(hhblits_bfd_uniclust_result['a3m']) 179 | 180 | bfd_msa, bfd_deletion_matrix = parsers.parse_a3m( 181 | hhblits_bfd_uniclust_result['a3m']) 182 | 183 | templates_result = self.template_featurizer.get_templates( 184 | query_sequence=input_sequence, 185 | query_pdb_code=None, 186 | query_release_date=None, 187 | hits=hhsearch_hits) 188 | 189 | sequence_features = make_sequence_features( 190 | sequence=input_sequence, 191 | description=input_description, 192 | num_res=num_res) 193 | 194 | msa_features = make_msa_features( 195 | msas=(uniref90_msa, bfd_msa, mgnify_msa), 196 | deletion_matrices=(uniref90_deletion_matrix, 197 | bfd_deletion_matrix, 198 | mgnify_deletion_matrix)) 199 | 200 | logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa)) 201 | logging.info('BFD MSA size: %d sequences.', len(bfd_msa)) 202 | logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa)) 203 | logging.info('Final (deduplicated) MSA size: %d sequences.', 204 | msa_features['num_alignments'][0]) 205 | logging.info('Total number of templates (NB: this can include bad ' 206 | 'templates and is later filtered to top 4): %d.', 207 | templates_result.features['template_domain_names'].shape[0]) 208 | 209 | return {**sequence_features, **msa_features, **templates_result.features} 210 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Python wrappers for third party tools.""" 15 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/hhblits.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library to run HHblits from Python.""" 16 | 17 | import glob 18 | import os 19 | import subprocess 20 | from typing import Any, Mapping, Optional, Sequence 21 | 22 | from absl import logging 23 | from alphafold.data.tools import utils 24 | # Internal import (7716). 25 | 26 | 27 | _HHBLITS_DEFAULT_P = 20 28 | _HHBLITS_DEFAULT_Z = 500 29 | 30 | 31 | class HHBlits: 32 | """Python wrapper of the HHblits binary.""" 33 | 34 | def __init__(self, 35 | *, 36 | binary_path: str, 37 | databases: Sequence[str], 38 | n_cpu: int = 4, 39 | n_iter: int = 3, 40 | e_value: float = 0.001, 41 | maxseq: int = 1_000_000, 42 | realign_max: int = 100_000, 43 | maxfilt: int = 100_000, 44 | min_prefilter_hits: int = 1000, 45 | all_seqs: bool = False, 46 | alt: Optional[int] = None, 47 | p: int = _HHBLITS_DEFAULT_P, 48 | z: int = _HHBLITS_DEFAULT_Z): 49 | """Initializes the Python HHblits wrapper. 50 | 51 | Args: 52 | binary_path: The path to the HHblits executable. 53 | databases: A sequence of HHblits database paths. This should be the 54 | common prefix for the database files (i.e. up to but not including 55 | _hhm.ffindex etc.) 56 | n_cpu: The number of CPUs to give HHblits. 57 | n_iter: The number of HHblits iterations. 58 | e_value: The E-value, see HHblits docs for more details. 59 | maxseq: The maximum number of rows in an input alignment. Note that this 60 | parameter is only supported in HHBlits version 3.1 and higher. 61 | realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. 62 | maxfilt: Max number of hits allowed to pass the 2nd prefilter. 63 | HHblits default: 20000. 64 | min_prefilter_hits: Min number of hits to pass prefilter. 65 | HHblits default: 100. 66 | all_seqs: Return all sequences in the MSA / Do not filter the result MSA. 67 | HHblits default: False. 68 | alt: Show up to this many alternative alignments. 69 | p: Minimum Prob for a hit to be included in the output hhr file. 70 | HHblits default: 20. 71 | z: Hard cap on number of hits reported in the hhr file. 72 | HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. 73 | 74 | Raises: 75 | RuntimeError: If HHblits binary not found within the path. 76 | """ 77 | self.binary_path = binary_path 78 | self.databases = databases 79 | 80 | for database_path in self.databases: 81 | if not glob.glob(database_path + '_*'): 82 | logging.error('Could not find HHBlits database %s', database_path) 83 | raise ValueError(f'Could not find HHBlits database {database_path}') 84 | 85 | self.n_cpu = n_cpu 86 | self.n_iter = n_iter 87 | self.e_value = e_value 88 | self.maxseq = maxseq 89 | self.realign_max = realign_max 90 | self.maxfilt = maxfilt 91 | self.min_prefilter_hits = min_prefilter_hits 92 | self.all_seqs = all_seqs 93 | self.alt = alt 94 | self.p = p 95 | self.z = z 96 | 97 | def query(self, input_fasta_path: str) -> Mapping[str, Any]: 98 | """Queries the database using HHblits.""" 99 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: 100 | a3m_path = os.path.join(query_tmp_dir, 'output.a3m') 101 | 102 | db_cmd = [] 103 | for db_path in self.databases: 104 | db_cmd.append('-d') 105 | db_cmd.append(db_path) 106 | cmd = [ 107 | self.binary_path, 108 | '-i', input_fasta_path, 109 | '-cpu', str(self.n_cpu), 110 | '-oa3m', a3m_path, 111 | '-o', '/dev/null', 112 | '-n', str(self.n_iter), 113 | '-e', str(self.e_value), 114 | '-maxseq', str(self.maxseq), 115 | '-realign_max', str(self.realign_max), 116 | '-maxfilt', str(self.maxfilt), 117 | '-min_prefilter_hits', str(self.min_prefilter_hits)] 118 | if self.all_seqs: 119 | cmd += ['-all'] 120 | if self.alt: 121 | cmd += ['-alt', str(self.alt)] 122 | if self.p != _HHBLITS_DEFAULT_P: 123 | cmd += ['-p', str(self.p)] 124 | if self.z != _HHBLITS_DEFAULT_Z: 125 | cmd += ['-Z', str(self.z)] 126 | cmd += db_cmd 127 | 128 | logging.info('Launching subprocess "%s"', ' '.join(cmd)) 129 | process = subprocess.Popen( 130 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 131 | 132 | with utils.timing('HHblits query'): 133 | stdout, stderr = process.communicate() 134 | retcode = process.wait() 135 | 136 | if retcode: 137 | # Logs have a 15k character limit, so log HHblits error line by line. 138 | logging.error('HHblits failed. HHblits stderr begin:') 139 | for error_line in stderr.decode('utf-8').splitlines(): 140 | if error_line.strip(): 141 | logging.error(error_line.strip()) 142 | logging.error('HHblits stderr end') 143 | raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % ( 144 | stdout.decode('utf-8'), stderr[:500_000].decode('utf-8'))) 145 | 146 | with open(a3m_path) as f: 147 | a3m = f.read() 148 | 149 | raw_output = dict( 150 | a3m=a3m, 151 | output=stdout, 152 | stderr=stderr, 153 | n_iter=self.n_iter, 154 | e_value=self.e_value) 155 | return raw_output 156 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/hhsearch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library to run HHsearch from Python.""" 16 | 17 | import glob 18 | import os 19 | import subprocess 20 | from typing import Sequence 21 | 22 | from absl import logging 23 | 24 | from alphafold.data.tools import utils 25 | # Internal import (7716). 26 | 27 | 28 | class HHSearch: 29 | """Python wrapper of the HHsearch binary.""" 30 | 31 | def __init__(self, 32 | *, 33 | binary_path: str, 34 | databases: Sequence[str], 35 | maxseq: int = 1_000_000): 36 | """Initializes the Python HHsearch wrapper. 37 | 38 | Args: 39 | binary_path: The path to the HHsearch executable. 40 | databases: A sequence of HHsearch database paths. This should be the 41 | common prefix for the database files (i.e. up to but not including 42 | _hhm.ffindex etc.) 43 | maxseq: The maximum number of rows in an input alignment. Note that this 44 | parameter is only supported in HHBlits version 3.1 and higher. 45 | 46 | Raises: 47 | RuntimeError: If HHsearch binary not found within the path. 48 | """ 49 | self.binary_path = binary_path 50 | self.databases = databases 51 | self.maxseq = maxseq 52 | 53 | for database_path in self.databases: 54 | if not glob.glob(database_path + '_*'): 55 | logging.error('Could not find HHsearch database %s', database_path) 56 | raise ValueError(f'Could not find HHsearch database {database_path}') 57 | 58 | def query(self, a3m: str) -> str: 59 | """Queries the database using HHsearch using a given a3m.""" 60 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: 61 | input_path = os.path.join(query_tmp_dir, 'query.a3m') 62 | hhr_path = os.path.join(query_tmp_dir, 'output.hhr') 63 | with open(input_path, 'w') as f: 64 | f.write(a3m) 65 | 66 | db_cmd = [] 67 | for db_path in self.databases: 68 | db_cmd.append('-d') 69 | db_cmd.append(db_path) 70 | cmd = [self.binary_path, 71 | '-i', input_path, 72 | '-o', hhr_path, 73 | '-maxseq', str(self.maxseq) 74 | ] + db_cmd 75 | 76 | logging.info('Launching subprocess "%s"', ' '.join(cmd)) 77 | process = subprocess.Popen( 78 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 79 | with utils.timing('HHsearch query'): 80 | stdout, stderr = process.communicate() 81 | retcode = process.wait() 82 | 83 | if retcode: 84 | # Stderr is truncated to prevent proto size errors in Beam. 85 | raise RuntimeError( 86 | 'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( 87 | stdout.decode('utf-8'), stderr[:100_000].decode('utf-8'))) 88 | 89 | with open(hhr_path) as f: 90 | hhr = f.read() 91 | return hhr 92 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/hmmbuild.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A Python wrapper for hmmbuild - construct HMM profiles from MSA.""" 16 | 17 | import os 18 | import re 19 | import subprocess 20 | 21 | from absl import logging 22 | from alphafold.data.tools import utils 23 | # Internal import (7716). 24 | 25 | 26 | class Hmmbuild(object): 27 | """Python wrapper of the hmmbuild binary.""" 28 | 29 | def __init__(self, 30 | *, 31 | binary_path: str, 32 | singlemx: bool = False): 33 | """Initializes the Python hmmbuild wrapper. 34 | 35 | Args: 36 | binary_path: The path to the hmmbuild executable. 37 | singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to 38 | just use a common substitution score matrix. 39 | 40 | Raises: 41 | RuntimeError: If hmmbuild binary not found within the path. 42 | """ 43 | self.binary_path = binary_path 44 | self.singlemx = singlemx 45 | 46 | def build_profile_from_sto(self, sto: str, model_construction='fast') -> str: 47 | """Builds a HHM for the aligned sequences given as an A3M string. 48 | 49 | Args: 50 | sto: A string with the aligned sequences in the Stockholm format. 51 | model_construction: Whether to use reference annotation in the msa to 52 | determine consensus columns ('hand') or default ('fast'). 53 | 54 | Returns: 55 | A string with the profile in the HMM format. 56 | 57 | Raises: 58 | RuntimeError: If hmmbuild fails. 59 | """ 60 | return self._build_profile(sto, model_construction=model_construction) 61 | 62 | def build_profile_from_a3m(self, a3m: str) -> str: 63 | """Builds a HHM for the aligned sequences given as an A3M string. 64 | 65 | Args: 66 | a3m: A string with the aligned sequences in the A3M format. 67 | 68 | Returns: 69 | A string with the profile in the HMM format. 70 | 71 | Raises: 72 | RuntimeError: If hmmbuild fails. 73 | """ 74 | lines = [] 75 | for line in a3m.splitlines(): 76 | if not line.startswith('>'): 77 | line = re.sub('[a-z]+', '', line) # Remove inserted residues. 78 | lines.append(line + '\n') 79 | msa = ''.join(lines) 80 | return self._build_profile(msa, model_construction='fast') 81 | 82 | def _build_profile(self, msa: str, model_construction: str = 'fast') -> str: 83 | """Builds a HMM for the aligned sequences given as an MSA string. 84 | 85 | Args: 86 | msa: A string with the aligned sequences, in A3M or STO format. 87 | model_construction: Whether to use reference annotation in the msa to 88 | determine consensus columns ('hand') or default ('fast'). 89 | 90 | Returns: 91 | A string with the profile in the HMM format. 92 | 93 | Raises: 94 | RuntimeError: If hmmbuild fails. 95 | ValueError: If unspecified arguments are provided. 96 | """ 97 | if model_construction not in {'hand', 'fast'}: 98 | raise ValueError(f'Invalid model_construction {model_construction} - only' 99 | 'hand and fast supported.') 100 | 101 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: 102 | input_query = os.path.join(query_tmp_dir, 'query.msa') 103 | output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm') 104 | 105 | with open(input_query, 'w') as f: 106 | f.write(msa) 107 | 108 | cmd = [self.binary_path] 109 | # If adding flags, we have to do so before the output and input: 110 | 111 | if model_construction == 'hand': 112 | cmd.append(f'--{model_construction}') 113 | if self.singlemx: 114 | cmd.append('--singlemx') 115 | cmd.extend([ 116 | '--amino', 117 | output_hmm_path, 118 | input_query, 119 | ]) 120 | 121 | logging.info('Launching subprocess %s', cmd) 122 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, 123 | stderr=subprocess.PIPE) 124 | 125 | with utils.timing('hmmbuild query'): 126 | stdout, stderr = process.communicate() 127 | retcode = process.wait() 128 | logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n', 129 | stdout.decode('utf-8'), stderr.decode('utf-8')) 130 | 131 | if retcode: 132 | raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n' 133 | % (stdout.decode('utf-8'), stderr.decode('utf-8'))) 134 | 135 | with open(output_hmm_path, encoding='utf-8') as f: 136 | hmm = f.read() 137 | 138 | return hmm 139 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/hmmsearch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A Python wrapper for hmmsearch - search profile against a sequence db.""" 16 | 17 | import os 18 | import subprocess 19 | from typing import Optional, Sequence 20 | 21 | from absl import logging 22 | from alphafold.data.tools import utils 23 | # Internal import (7716). 24 | 25 | 26 | class Hmmsearch(object): 27 | """Python wrapper of the hmmsearch binary.""" 28 | 29 | def __init__(self, 30 | *, 31 | binary_path: str, 32 | database_path: str, 33 | flags: Optional[Sequence[str]] = None): 34 | """Initializes the Python hmmsearch wrapper. 35 | 36 | Args: 37 | binary_path: The path to the hmmsearch executable. 38 | database_path: The path to the hmmsearch database (FASTA format). 39 | flags: List of flags to be used by hmmsearch. 40 | 41 | Raises: 42 | RuntimeError: If hmmsearch binary not found within the path. 43 | """ 44 | self.binary_path = binary_path 45 | self.database_path = database_path 46 | self.flags = flags 47 | 48 | if not os.path.exists(self.database_path): 49 | logging.error('Could not find hmmsearch database %s', database_path) 50 | raise ValueError(f'Could not find hmmsearch database {database_path}') 51 | 52 | def query(self, hmm: str) -> str: 53 | """Queries the database using hmmsearch using a given hmm.""" 54 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: 55 | hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') 56 | a3m_out_path = os.path.join(query_tmp_dir, 'output.a3m') 57 | with open(hmm_input_path, 'w') as f: 58 | f.write(hmm) 59 | 60 | cmd = [ 61 | self.binary_path, 62 | '--noali', # Don't include the alignment in stdout. 63 | '--cpu', '8' 64 | ] 65 | # If adding flags, we have to do so before the output and input: 66 | if self.flags: 67 | cmd.extend(self.flags) 68 | cmd.extend([ 69 | '-A', a3m_out_path, 70 | hmm_input_path, 71 | self.database_path, 72 | ]) 73 | 74 | logging.info('Launching sub-process %s', cmd) 75 | process = subprocess.Popen( 76 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 77 | with utils.timing( 78 | f'hmmsearch ({os.path.basename(self.database_path)}) query'): 79 | stdout, stderr = process.communicate() 80 | retcode = process.wait() 81 | 82 | if retcode: 83 | raise RuntimeError( 84 | 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( 85 | stdout.decode('utf-8'), stderr.decode('utf-8'))) 86 | 87 | with open(a3m_out_path) as f: 88 | a3m_out = f.read() 89 | 90 | return a3m_out 91 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/jackhmmer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library to run Jackhmmer from Python.""" 16 | 17 | from concurrent import futures 18 | import glob 19 | import os 20 | import subprocess 21 | from typing import Any, Callable, Mapping, Optional, Sequence 22 | from urllib import request 23 | 24 | from absl import logging 25 | 26 | from alphafold.data.tools import utils 27 | # Internal import (7716). 28 | 29 | 30 | class Jackhmmer: 31 | """Python wrapper of the Jackhmmer binary.""" 32 | 33 | def __init__(self, 34 | *, 35 | binary_path: str, 36 | database_path: str, 37 | n_cpu: int = 8, 38 | n_iter: int = 1, 39 | e_value: float = 0.0001, 40 | z_value: Optional[int] = None, 41 | get_tblout: bool = False, 42 | filter_f1: float = 0.0005, 43 | filter_f2: float = 0.00005, 44 | filter_f3: float = 0.0000005, 45 | incdom_e: Optional[float] = None, 46 | dom_e: Optional[float] = None, 47 | num_streamed_chunks: Optional[int] = None, 48 | streaming_callback: Optional[Callable[[int], None]] = None): 49 | """Initializes the Python Jackhmmer wrapper. 50 | 51 | Args: 52 | binary_path: The path to the jackhmmer executable. 53 | database_path: The path to the jackhmmer database (FASTA format). 54 | n_cpu: The number of CPUs to give Jackhmmer. 55 | n_iter: The number of Jackhmmer iterations. 56 | e_value: The E-value, see Jackhmmer docs for more details. 57 | z_value: The Z-value, see Jackhmmer docs for more details. 58 | get_tblout: Whether to save tblout string. 59 | filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. 60 | filter_f2: Viterbi pre-filter, set to >1.0 to turn off. 61 | filter_f3: Forward pre-filter, set to >1.0 to turn off. 62 | incdom_e: Domain e-value criteria for inclusion of domains in MSA/next 63 | round. 64 | dom_e: Domain e-value criteria for inclusion in tblout. 65 | num_streamed_chunks: Number of database chunks to stream over. 66 | streaming_callback: Callback function run after each chunk iteration with 67 | the iteration number as argument. 68 | """ 69 | self.binary_path = binary_path 70 | self.database_path = database_path 71 | self.num_streamed_chunks = num_streamed_chunks 72 | 73 | if not os.path.exists(self.database_path) and num_streamed_chunks is None: 74 | logging.error('Could not find Jackhmmer database %s', database_path) 75 | raise ValueError(f'Could not find Jackhmmer database {database_path}') 76 | 77 | self.n_cpu = n_cpu 78 | self.n_iter = n_iter 79 | self.e_value = e_value 80 | self.z_value = z_value 81 | self.filter_f1 = filter_f1 82 | self.filter_f2 = filter_f2 83 | self.filter_f3 = filter_f3 84 | self.incdom_e = incdom_e 85 | self.dom_e = dom_e 86 | self.get_tblout = get_tblout 87 | self.streaming_callback = streaming_callback 88 | 89 | def _query_chunk(self, input_fasta_path: str, database_path: str 90 | ) -> Mapping[str, Any]: 91 | """Queries the database chunk using Jackhmmer.""" 92 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: 93 | sto_path = os.path.join(query_tmp_dir, 'output.sto') 94 | 95 | # The F1/F2/F3 are the expected proportion to pass each of the filtering 96 | # stages (which get progressively more expensive), reducing these 97 | # speeds up the pipeline at the expensive of sensitivity. They are 98 | # currently set very low to make querying Mgnify run in a reasonable 99 | # amount of time. 100 | cmd_flags = [ 101 | # Don't pollute stdout with Jackhmmer output. 102 | '-o', '/dev/null', 103 | '-A', sto_path, 104 | '--noali', 105 | '--F1', str(self.filter_f1), 106 | '--F2', str(self.filter_f2), 107 | '--F3', str(self.filter_f3), 108 | '--incE', str(self.e_value), 109 | # Report only sequences with E-values <= x in per-sequence output. 110 | '-E', str(self.e_value), 111 | '--cpu', str(self.n_cpu), 112 | '-N', str(self.n_iter) 113 | ] 114 | if self.get_tblout: 115 | tblout_path = os.path.join(query_tmp_dir, 'tblout.txt') 116 | cmd_flags.extend(['--tblout', tblout_path]) 117 | 118 | if self.z_value: 119 | cmd_flags.extend(['-Z', str(self.z_value)]) 120 | 121 | if self.dom_e is not None: 122 | cmd_flags.extend(['--domE', str(self.dom_e)]) 123 | 124 | if self.incdom_e is not None: 125 | cmd_flags.extend(['--incdomE', str(self.incdom_e)]) 126 | 127 | cmd = [self.binary_path] + cmd_flags + [input_fasta_path, 128 | database_path] 129 | 130 | logging.info('Launching subprocess "%s"', ' '.join(cmd)) 131 | process = subprocess.Popen( 132 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 133 | with utils.timing( 134 | f'Jackhmmer ({os.path.basename(database_path)}) query'): 135 | _, stderr = process.communicate() 136 | retcode = process.wait() 137 | 138 | if retcode: 139 | raise RuntimeError( 140 | 'Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8')) 141 | 142 | # Get e-values for each target name 143 | tbl = '' 144 | if self.get_tblout: 145 | with open(tblout_path) as f: 146 | tbl = f.read() 147 | 148 | with open(sto_path) as f: 149 | sto = f.read() 150 | 151 | raw_output = dict( 152 | sto=sto, 153 | tbl=tbl, 154 | stderr=stderr, 155 | n_iter=self.n_iter, 156 | e_value=self.e_value) 157 | 158 | return raw_output 159 | 160 | def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: 161 | """Queries the database using Jackhmmer.""" 162 | if self.num_streamed_chunks is None: 163 | return [self._query_chunk(input_fasta_path, self.database_path)] 164 | 165 | db_basename = os.path.basename(self.database_path) 166 | db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' 167 | db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}' 168 | 169 | # Remove existing files to prevent OOM 170 | for f in glob.glob(db_local_chunk('[0-9]*')): 171 | try: 172 | os.remove(f) 173 | except OSError: 174 | print(f'OSError while deleting {f}') 175 | 176 | # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk 177 | with futures.ThreadPoolExecutor(max_workers=2) as executor: 178 | chunked_output = [] 179 | for i in range(1, self.num_streamed_chunks + 1): 180 | # Copy the chunk locally 181 | if i == 1: 182 | future = executor.submit( 183 | request.urlretrieve, db_remote_chunk(i), db_local_chunk(i)) 184 | if i < self.num_streamed_chunks: 185 | next_future = executor.submit( 186 | request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1)) 187 | 188 | # Run Jackhmmer with the chunk 189 | future.result() 190 | chunked_output.append( 191 | self._query_chunk(input_fasta_path, db_local_chunk(i))) 192 | 193 | # Remove the local copy of the chunk 194 | os.remove(db_local_chunk(i)) 195 | future = next_future 196 | if self.streaming_callback: 197 | self.streaming_callback(i) 198 | return chunked_output 199 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/kalign.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A Python wrapper for Kalign.""" 16 | import os 17 | import subprocess 18 | from typing import Sequence 19 | 20 | from absl import logging 21 | 22 | from alphafold.data.tools import utils 23 | # Internal import (7716). 24 | 25 | 26 | def _to_a3m(sequences: Sequence[str]) -> str: 27 | """Converts sequences to an a3m file.""" 28 | names = ['sequence %d' % i for i in range(1, len(sequences) + 1)] 29 | a3m = [] 30 | for sequence, name in zip(sequences, names): 31 | a3m.append(u'>' + name + u'\n') 32 | a3m.append(sequence + u'\n') 33 | return ''.join(a3m) 34 | 35 | 36 | class Kalign: 37 | """Python wrapper of the Kalign binary.""" 38 | 39 | def __init__(self, *, binary_path: str): 40 | """Initializes the Python Kalign wrapper. 41 | 42 | Args: 43 | binary_path: The path to the Kalign binary. 44 | 45 | Raises: 46 | RuntimeError: If Kalign binary not found within the path. 47 | """ 48 | self.binary_path = binary_path 49 | 50 | def align(self, sequences: Sequence[str]) -> str: 51 | """Aligns the sequences and returns the alignment in A3M string. 52 | 53 | Args: 54 | sequences: A list of query sequence strings. The sequences have to be at 55 | least 6 residues long (Kalign requires this). Note that the order in 56 | which you give the sequences might alter the output slightly as 57 | different alignment tree might get constructed. 58 | 59 | Returns: 60 | A string with the alignment in a3m format. 61 | 62 | Raises: 63 | RuntimeError: If Kalign fails. 64 | ValueError: If any of the sequences is less than 6 residues long. 65 | """ 66 | logging.info('Aligning %d sequences', len(sequences)) 67 | 68 | for s in sequences: 69 | if len(s) < 6: 70 | raise ValueError('Kalign requires all sequences to be at least 6 ' 71 | 'residues long. Got %s (%d residues).' % (s, len(s))) 72 | 73 | with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: 74 | input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta') 75 | output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') 76 | 77 | with open(input_fasta_path, 'w') as f: 78 | f.write(_to_a3m(sequences)) 79 | 80 | cmd = [ 81 | self.binary_path, 82 | '-i', input_fasta_path, 83 | '-o', output_a3m_path, 84 | '-format', 'fasta', 85 | ] 86 | 87 | logging.info('Launching subprocess "%s"', ' '.join(cmd)) 88 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, 89 | stderr=subprocess.PIPE) 90 | 91 | with utils.timing('Kalign query'): 92 | stdout, stderr = process.communicate() 93 | retcode = process.wait() 94 | logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n', 95 | stdout.decode('utf-8'), stderr.decode('utf-8')) 96 | 97 | if retcode: 98 | raise RuntimeError('Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' 99 | % (stdout.decode('utf-8'), stderr.decode('utf-8'))) 100 | 101 | with open(output_a3m_path) as f: 102 | a3m = f.read() 103 | 104 | return a3m 105 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common utilities for data pipeline tools.""" 15 | import contextlib 16 | import shutil 17 | import tempfile 18 | import time 19 | from typing import Optional 20 | 21 | from absl import logging 22 | 23 | 24 | @contextlib.contextmanager 25 | def tmpdir_manager(base_dir: Optional[str] = None): 26 | """Context manager that deletes a temporary directory on exit.""" 27 | tmpdir = tempfile.mkdtemp(dir=base_dir) 28 | try: 29 | yield tmpdir 30 | finally: 31 | shutil.rmtree(tmpdir, ignore_errors=True) 32 | 33 | 34 | @contextlib.contextmanager 35 | def timing(msg: str): 36 | logging.info('Started %s', msg) 37 | tic = time.time() 38 | yield 39 | toc = time.time() 40 | logging.info('Finished %s in %.3f seconds', msg, toc - tic) 41 | -------------------------------------------------------------------------------- /src/alphafold/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Alphafold model.""" 15 | -------------------------------------------------------------------------------- /src/alphafold/model/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /src/alphafold/model/__pycache__/all_atom.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/all_atom.cpython-312.pyc -------------------------------------------------------------------------------- /src/alphafold/model/__pycache__/common_modules.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/common_modules.cpython-312.pyc -------------------------------------------------------------------------------- /src/alphafold/model/__pycache__/config.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/config.cpython-312.pyc -------------------------------------------------------------------------------- /src/alphafold/model/__pycache__/data.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/data.cpython-312.pyc -------------------------------------------------------------------------------- /src/alphafold/model/__pycache__/features.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/features.cpython-312.pyc -------------------------------------------------------------------------------- /src/alphafold/model/__pycache__/folding.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/folding.cpython-312.pyc -------------------------------------------------------------------------------- /src/alphafold/model/__pycache__/layer_stack.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/layer_stack.cpython-312.pyc -------------------------------------------------------------------------------- /src/alphafold/model/__pycache__/lddt.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/lddt.cpython-312.pyc -------------------------------------------------------------------------------- /src/alphafold/model/__pycache__/mapping.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/mapping.cpython-312.pyc -------------------------------------------------------------------------------- /src/alphafold/model/__pycache__/modules.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/modules.cpython-312.pyc -------------------------------------------------------------------------------- /src/alphafold/model/__pycache__/prng.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/prng.cpython-312.pyc -------------------------------------------------------------------------------- /src/alphafold/model/__pycache__/quat_affine.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/quat_affine.cpython-312.pyc -------------------------------------------------------------------------------- /src/alphafold/model/__pycache__/r3.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/r3.cpython-312.pyc -------------------------------------------------------------------------------- /src/alphafold/model/__pycache__/utils.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/Cfold/a66406e981ce434b985120f8c40712d17290408c/src/alphafold/model/__pycache__/utils.cpython-312.pyc -------------------------------------------------------------------------------- /src/alphafold/model/all_atom_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for all_atom.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from alphafold.model import all_atom 20 | from alphafold.model import r3 21 | import numpy as np 22 | 23 | L1_CLAMP_DISTANCE = 10 24 | 25 | 26 | def get_identity_rigid(shape): 27 | """Returns identity rigid transform.""" 28 | 29 | ones = np.ones(shape) 30 | zeros = np.zeros(shape) 31 | rot = r3.Rots(ones, zeros, zeros, 32 | zeros, ones, zeros, 33 | zeros, zeros, ones) 34 | trans = r3.Vecs(zeros, zeros, zeros) 35 | return r3.Rigids(rot, trans) 36 | 37 | 38 | def get_global_rigid_transform(rot_angle, translation, bcast_dims): 39 | """Returns rigid transform that globally rotates/translates by same amount.""" 40 | 41 | rot_angle = np.asarray(rot_angle) 42 | translation = np.asarray(translation) 43 | if bcast_dims: 44 | for _ in range(bcast_dims): 45 | rot_angle = np.expand_dims(rot_angle, 0) 46 | translation = np.expand_dims(translation, 0) 47 | sin_angle = np.sin(np.deg2rad(rot_angle)) 48 | cos_angle = np.cos(np.deg2rad(rot_angle)) 49 | ones = np.ones_like(sin_angle) 50 | zeros = np.zeros_like(sin_angle) 51 | rot = r3.Rots(ones, zeros, zeros, 52 | zeros, cos_angle, -sin_angle, 53 | zeros, sin_angle, cos_angle) 54 | trans = r3.Vecs(translation[..., 0], translation[..., 1], translation[..., 2]) 55 | return r3.Rigids(rot, trans) 56 | 57 | 58 | class AllAtomTest(parameterized.TestCase, absltest.TestCase): 59 | 60 | @parameterized.named_parameters( 61 | ('identity', 0, [0, 0, 0]), 62 | ('rot_90', 90, [0, 0, 0]), 63 | ('trans_10', 0, [0, 0, 10]), 64 | ('rot_174_trans_1', 174, [1, 1, 1])) 65 | def test_frame_aligned_point_error_perfect_on_global_transform( 66 | self, rot_angle, translation): 67 | """Tests global transform between target and preds gives perfect score.""" 68 | 69 | # pylint: disable=bad-whitespace 70 | target_positions = np.array( 71 | [[ 21.182, 23.095, 19.731], 72 | [ 22.055, 20.919, 17.294], 73 | [ 24.599, 20.005, 15.041], 74 | [ 25.567, 18.214, 12.166], 75 | [ 28.063, 17.082, 10.043], 76 | [ 28.779, 15.569, 6.985], 77 | [ 30.581, 13.815, 4.612], 78 | [ 29.258, 12.193, 2.296]]) 79 | # pylint: enable=bad-whitespace 80 | global_rigid_transform = get_global_rigid_transform( 81 | rot_angle, translation, 1) 82 | 83 | target_positions = r3.vecs_from_tensor(target_positions) 84 | pred_positions = r3.rigids_mul_vecs( 85 | global_rigid_transform, target_positions) 86 | positions_mask = np.ones(target_positions.x.shape[0]) 87 | 88 | target_frames = get_identity_rigid(10) 89 | pred_frames = r3.rigids_mul_rigids(global_rigid_transform, target_frames) 90 | frames_mask = np.ones(10) 91 | 92 | fape = all_atom.frame_aligned_point_error( 93 | pred_frames, target_frames, frames_mask, pred_positions, 94 | target_positions, positions_mask, L1_CLAMP_DISTANCE, 95 | L1_CLAMP_DISTANCE, epsilon=0) 96 | self.assertAlmostEqual(fape, 0.) 97 | 98 | @parameterized.named_parameters( 99 | ('identity', 100 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 101 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 102 | 0.), 103 | ('shift_2.5', 104 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 105 | [[2.5, 0, 0], [7.5, 0, 0], [7.5, 0, 0]], 106 | 0.25), 107 | ('shift_5', 108 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 109 | [[5, 0, 0], [10, 0, 0], [15, 0, 0]], 110 | 0.5), 111 | ('shift_10', 112 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 113 | [[10, 0, 0], [15, 0, 0], [0, 0, 0]], 114 | 1.)) 115 | def test_frame_aligned_point_error_matches_expected( 116 | self, target_positions, pred_positions, expected_alddt): 117 | """Tests score matches expected.""" 118 | 119 | target_frames = get_identity_rigid(2) 120 | pred_frames = target_frames 121 | frames_mask = np.ones(2) 122 | 123 | target_positions = r3.vecs_from_tensor(np.array(target_positions)) 124 | pred_positions = r3.vecs_from_tensor(np.array(pred_positions)) 125 | positions_mask = np.ones(target_positions.x.shape[0]) 126 | 127 | alddt = all_atom.frame_aligned_point_error( 128 | pred_frames, target_frames, frames_mask, pred_positions, 129 | target_positions, positions_mask, L1_CLAMP_DISTANCE, 130 | L1_CLAMP_DISTANCE, epsilon=0) 131 | self.assertAlmostEqual(alddt, expected_alddt) 132 | 133 | 134 | if __name__ == '__main__': 135 | absltest.main() 136 | -------------------------------------------------------------------------------- /src/alphafold/model/common_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A collection of common Haiku modules for use in protein folding.""" 16 | import haiku as hk 17 | import jax.numpy as jnp 18 | 19 | 20 | class Linear(hk.Module): 21 | """Protein folding specific Linear Module. 22 | 23 | This differs from the standard Haiku Linear in a few ways: 24 | * It supports inputs of arbitrary rank 25 | * Initializers are specified by strings 26 | """ 27 | 28 | def __init__(self, 29 | num_output: int, 30 | initializer: str = 'linear', 31 | use_bias: bool = True, 32 | bias_init: float = 0., 33 | name: str = 'linear'): 34 | """Constructs Linear Module. 35 | 36 | Args: 37 | num_output: number of output channels. 38 | initializer: What initializer to use, should be one of {'linear', 'relu', 39 | 'zeros'} 40 | use_bias: Whether to include trainable bias 41 | bias_init: Value used to initialize bias. 42 | name: name of module, used for name scopes. 43 | """ 44 | 45 | super().__init__(name=name) 46 | self.num_output = num_output 47 | self.initializer = initializer 48 | self.use_bias = use_bias 49 | self.bias_init = bias_init 50 | 51 | def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: 52 | """Connects Module. 53 | 54 | Args: 55 | inputs: Tensor of shape [..., num_channel] 56 | 57 | Returns: 58 | output of shape [..., num_output] 59 | """ 60 | n_channels = int(inputs.shape[-1]) 61 | 62 | weight_shape = [n_channels, self.num_output] 63 | if self.initializer == 'linear': 64 | weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=1.) 65 | elif self.initializer == 'relu': 66 | weight_init = hk.initializers.VarianceScaling(mode='fan_in', scale=2.) 67 | elif self.initializer == 'zeros': 68 | weight_init = hk.initializers.Constant(0.0) 69 | 70 | weights = hk.get_parameter('weights', weight_shape, inputs.dtype, 71 | weight_init) 72 | 73 | # this is equivalent to einsum('...c,cd->...d', inputs, weights) 74 | # but turns out to be slightly faster 75 | inputs = jnp.swapaxes(inputs, -1, -2) 76 | output = jnp.einsum('...cb,cd->...db', inputs, weights) 77 | output = jnp.swapaxes(output, -1, -2) 78 | 79 | if self.use_bias: 80 | bias = hk.get_parameter('bias', [self.num_output], inputs.dtype, 81 | hk.initializers.Constant(self.bias_init)) 82 | output += bias 83 | 84 | return output 85 | -------------------------------------------------------------------------------- /src/alphafold/model/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Convenience functions for reading data.""" 16 | 17 | import io 18 | import os 19 | from typing import List 20 | from alphafold.model import utils 21 | import haiku as hk 22 | import numpy as np 23 | # Internal import (7716). 24 | 25 | 26 | def casp_model_names(data_dir: str) -> List[str]: 27 | params = os.listdir(os.path.join(data_dir, 'params')) 28 | return [os.path.splitext(filename)[0] for filename in params] 29 | 30 | 31 | def get_model_haiku_params(model_name: str, data_dir: str) -> hk.Params: 32 | """Get the Haiku parameters from a model name.""" 33 | 34 | path = os.path.join(data_dir, 'params', f'params_{model_name}.npz') 35 | 36 | with open(path, 'rb') as f: 37 | params = np.load(io.BytesIO(f.read()), allow_pickle=False) 38 | 39 | return utils.flat_params_to_haiku(params) 40 | -------------------------------------------------------------------------------- /src/alphafold/model/features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Code to generate processed features.""" 16 | import copy 17 | from typing import List, Mapping, Tuple 18 | from alphafold.model.tf import input_pipeline 19 | from alphafold.model.tf import proteins_dataset 20 | import ml_collections 21 | import numpy as np 22 | import os 23 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 24 | import tensorflow.compat.v1 as tf 25 | tf.config.set_visible_devices([], 'GPU') 26 | import pdb 27 | 28 | FeatureDict = Mapping[str, np.ndarray] 29 | 30 | 31 | def make_data_config( 32 | config: ml_collections.ConfigDict, 33 | num_res: int, 34 | ) -> Tuple[ml_collections.ConfigDict, List[str]]: 35 | """Makes a data config for the input pipeline.""" 36 | cfg = copy.deepcopy(config.data) 37 | 38 | feature_names = cfg.common.unsupervised_features 39 | if cfg.common.use_templates: 40 | feature_names += cfg.common.template_features 41 | 42 | with cfg.unlocked(): 43 | cfg.eval.crop_size = num_res 44 | 45 | return cfg, feature_names 46 | 47 | 48 | def tf_example_to_features(tf_example: tf.train.Example, 49 | config: ml_collections.ConfigDict, 50 | random_seed: int = 0) -> FeatureDict: 51 | """Converts tf_example to numpy feature dictionary.""" 52 | num_res = int(tf_example.features.feature['seq_length'].int64_list.value[0]) 53 | cfg, feature_names = make_data_config(config, num_res=num_res) 54 | 55 | if 'deletion_matrix_int' in set(tf_example.features.feature): 56 | deletion_matrix_int = ( 57 | tf_example.features.feature['deletion_matrix_int'].int64_list.value) 58 | feat = tf.train.Feature(float_list=tf.train.FloatList( 59 | value=map(float, deletion_matrix_int))) 60 | tf_example.features.feature['deletion_matrix'].CopyFrom(feat) 61 | del tf_example.features.feature['deletion_matrix_int'] 62 | 63 | tf_graph = tf.Graph() 64 | with tf_graph.as_default(): #, tf.device('/device:CPU:0') 65 | tf.compat.v1.set_random_seed(random_seed) 66 | tensor_dict = proteins_dataset.create_tensor_dict( 67 | raw_data=tf_example.SerializeToString(), 68 | features=feature_names) 69 | processed_batch = input_pipeline.process_tensors_from_config( 70 | tensor_dict, cfg) 71 | 72 | tf_graph.finalize() 73 | 74 | with tf.Session(graph=tf_graph) as sess: 75 | features = sess.run(processed_batch) 76 | 77 | return {k: v for k, v in features.items() if v.dtype != 'O'} 78 | 79 | 80 | def np_example_to_features(np_example: FeatureDict, 81 | config: ml_collections.ConfigDict, 82 | random_seed: int = 0) -> FeatureDict: 83 | """Preprocesses NumPy feature dict using TF pipeline.""" 84 | np_example = dict(np_example) 85 | num_res = int(np_example['seq_length'][0]) 86 | cfg, feature_names = make_data_config(config, num_res=num_res) 87 | 88 | if 'deletion_matrix_int' in np_example: 89 | np_example['deletion_matrix'] = ( 90 | np_example.pop('deletion_matrix_int').astype(np.float32)) 91 | 92 | tf_graph = tf.Graph() 93 | with tf_graph.as_default(): 94 | tf.compat.v1.set_random_seed(random_seed) 95 | tensor_dict = proteins_dataset.np_to_tensor_dict( 96 | np_example=np_example, features=feature_names) 97 | 98 | processed_batch = input_pipeline.process_tensors_from_config( 99 | tensor_dict, cfg) 100 | 101 | tf_graph.finalize() 102 | 103 | with tf.Session(graph=tf_graph) as sess: 104 | features = sess.run(processed_batch) 105 | return {k: v[0] for k, v in features.items() if v.dtype != 'O'} 106 | -------------------------------------------------------------------------------- /src/alphafold/model/lddt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """lDDT protein distance score.""" 16 | import jax.numpy as jnp 17 | 18 | 19 | def lddt(predicted_points, 20 | true_points, 21 | true_points_mask, 22 | cutoff=15., 23 | per_residue=False): 24 | """Measure (approximate) lDDT for a batch of coordinates. 25 | 26 | lDDT reference: 27 | Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local 28 | superposition-free score for comparing protein structures and models using 29 | distance difference tests. Bioinformatics 29, 2722–2728 (2013). 30 | 31 | lDDT is a measure of the difference between the true distance matrix and the 32 | distance matrix of the predicted points. The difference is computed only on 33 | points closer than cutoff *in the true structure*. 34 | 35 | This function does not compute the exact lDDT value that the original paper 36 | describes because it does not include terms for physical feasibility 37 | (e.g. bond length violations). Therefore this is only an approximate 38 | lDDT score. 39 | 40 | Args: 41 | predicted_points: (batch, length, 3) array of predicted 3D points 42 | true_points: (batch, length, 3) array of true 3D points 43 | true_points_mask: (batch, length, 1) binary-valued float array. This mask 44 | should be 1 for points that exist in the true points. 45 | cutoff: Maximum distance for a pair of points to be included 46 | per_residue: If true, return score for each residue. Note that the overall 47 | lDDT is not exactly the mean of the per_residue lDDT's because some 48 | residues have more contacts than others. 49 | 50 | Returns: 51 | An (approximate, see above) lDDT score in the range 0-1. 52 | """ 53 | 54 | assert len(predicted_points.shape) == 3 55 | assert predicted_points.shape[-1] == 3 56 | assert true_points_mask.shape[-1] == 1 57 | assert len(true_points_mask.shape) == 3 58 | 59 | # Compute true and predicted distance matrices. 60 | dmat_true = jnp.sqrt(1e-10 + jnp.sum( 61 | (true_points[:, :, None] - true_points[:, None, :])**2, axis=-1)) 62 | 63 | dmat_predicted = jnp.sqrt(1e-10 + jnp.sum( 64 | (predicted_points[:, :, None] - 65 | predicted_points[:, None, :])**2, axis=-1)) 66 | 67 | dists_to_score = ( 68 | (dmat_true < cutoff).astype(jnp.float32) * true_points_mask * 69 | jnp.transpose(true_points_mask, [0, 2, 1]) * 70 | (1. - jnp.eye(dmat_true.shape[1])) # Exclude self-interaction. 71 | ) 72 | 73 | # Shift unscored distances to be far away. 74 | dist_l1 = jnp.abs(dmat_true - dmat_predicted) 75 | 76 | # True lDDT uses a number of fixed bins. 77 | # We ignore the physical plausibility correction to lDDT, though. 78 | score = 0.25 * ((dist_l1 < 0.5).astype(jnp.float32) + 79 | (dist_l1 < 1.0).astype(jnp.float32) + 80 | (dist_l1 < 2.0).astype(jnp.float32) + 81 | (dist_l1 < 4.0).astype(jnp.float32)) 82 | 83 | # Normalize over the appropriate axes. 84 | reduce_axes = (-1,) if per_residue else (-2, -1) 85 | norm = 1. / (1e-10 + jnp.sum(dists_to_score, axis=reduce_axes)) 86 | score = norm * (1e-10 + jnp.sum(dists_to_score * score, axis=reduce_axes)) 87 | 88 | return score 89 | -------------------------------------------------------------------------------- /src/alphafold/model/lddt_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for lddt.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from alphafold.model import lddt 20 | import numpy as np 21 | 22 | 23 | class LddtTest(parameterized.TestCase, absltest.TestCase): 24 | 25 | @parameterized.named_parameters( 26 | ('same', 27 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 28 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 29 | [1, 1, 1]), 30 | ('all_shifted', 31 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 32 | [[-1, 0, 0], [4, 0, 0], [9, 0, 0]], 33 | [1, 1, 1]), 34 | ('all_rotated', 35 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 36 | [[0, 0, 0], [0, 5, 0], [0, 10, 0]], 37 | [1, 1, 1]), 38 | ('half_a_dist', 39 | [[0, 0, 0], [5, 0, 0]], 40 | [[0, 0, 0], [5.5-1e-5, 0, 0]], 41 | [1, 1]), 42 | ('one_a_dist', 43 | [[0, 0, 0], [5, 0, 0]], 44 | [[0, 0, 0], [6-1e-5, 0, 0]], 45 | [0.75, 0.75]), 46 | ('two_a_dist', 47 | [[0, 0, 0], [5, 0, 0]], 48 | [[0, 0, 0], [7-1e-5, 0, 0]], 49 | [0.5, 0.5]), 50 | ('four_a_dist', 51 | [[0, 0, 0], [5, 0, 0]], 52 | [[0, 0, 0], [9-1e-5, 0, 0]], 53 | [0.25, 0.25],), 54 | ('five_a_dist', 55 | [[0, 0, 0], [16-1e-5, 0, 0]], 56 | [[0, 0, 0], [11, 0, 0]], 57 | [0, 0]), 58 | ('no_pairs', 59 | [[0, 0, 0], [20, 0, 0]], 60 | [[0, 0, 0], [25-1e-5, 0, 0]], 61 | [1, 1]), 62 | ) 63 | def test_lddt( 64 | self, predicted_pos, true_pos, exp_lddt): 65 | predicted_pos = np.array([predicted_pos], dtype=np.float32) 66 | true_points_mask = np.array([[[1]] * len(true_pos)], dtype=np.float32) 67 | true_pos = np.array([true_pos], dtype=np.float32) 68 | cutoff = 15.0 69 | per_residue = True 70 | 71 | result = lddt.lddt( 72 | predicted_pos, true_pos, true_points_mask, cutoff, 73 | per_residue) 74 | 75 | np.testing.assert_almost_equal(result, [exp_lddt], decimal=4) 76 | 77 | 78 | if __name__ == '__main__': 79 | absltest.main() 80 | -------------------------------------------------------------------------------- /src/alphafold/model/mapping.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Specialized mapping functions.""" 16 | 17 | import functools 18 | 19 | from typing import Any, Callable, Optional, Sequence, Union 20 | 21 | import haiku as hk 22 | import jax 23 | import jax.numpy as jnp 24 | 25 | 26 | PYTREE = Any 27 | PYTREE_JAX_ARRAY = Any 28 | 29 | partial = functools.partial 30 | PROXY = object() 31 | 32 | 33 | def _maybe_slice(array, i, slice_size, axis): 34 | if axis is PROXY: 35 | return array 36 | else: 37 | return jax.lax.dynamic_slice_in_dim( 38 | array, i, slice_size=slice_size, axis=axis) 39 | 40 | 41 | def _maybe_get_size(array, axis): 42 | if axis == PROXY: 43 | return -1 44 | else: 45 | return array.shape[axis] 46 | 47 | 48 | def _expand_axes(axes, values, name='sharded_apply'): 49 | values_tree_def = jax.tree_flatten(values)[1] 50 | flat_axes = jax.api_util.flatten_axes(name, values_tree_def, axes) 51 | # Replace None's with PROXY 52 | flat_axes = [PROXY if x is None else x for x in flat_axes] 53 | return jax.tree_unflatten(values_tree_def, flat_axes) 54 | 55 | 56 | def sharded_map( 57 | fun: Callable[..., PYTREE_JAX_ARRAY], 58 | shard_size: Union[int, None] = 1, 59 | in_axes: Union[int, PYTREE] = 0, 60 | out_axes: Union[int, PYTREE] = 0) -> Callable[..., PYTREE_JAX_ARRAY]: 61 | """Sharded vmap. 62 | 63 | Maps `fun` over axes, in a way similar to vmap, but does so in shards of 64 | `shard_size`. This allows a smooth trade-off between memory usage 65 | (as in a plain map) vs higher throughput (as in a vmap). 66 | 67 | Args: 68 | fun: Function to apply smap transform to. 69 | shard_size: Integer denoting shard size. 70 | in_axes: Either integer or pytree describing which axis to map over for each 71 | input to `fun`, None denotes broadcasting. 72 | out_axes: integer or pytree denoting to what axis in the output the mapped 73 | over axis maps. 74 | 75 | Returns: 76 | function with smap applied. 77 | """ 78 | vmapped_fun = hk.vmap(fun, in_axes, out_axes) 79 | return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes) 80 | 81 | 82 | def sharded_apply( 83 | fun: Callable[..., PYTREE_JAX_ARRAY], # pylint: disable=g-bare-generic 84 | shard_size: Union[int, None] = 1, 85 | in_axes: Union[int, PYTREE] = 0, 86 | out_axes: Union[int, PYTREE] = 0, 87 | new_out_axes: bool = False) -> Callable[..., PYTREE_JAX_ARRAY]: 88 | """Sharded apply. 89 | 90 | Applies `fun` over shards to axes, in a way similar to vmap, 91 | but does so in shards of `shard_size`. Shards are stacked after. 92 | This allows a smooth trade-off between 93 | memory usage (as in a plain map) vs higher throughput (as in a vmap). 94 | 95 | Args: 96 | fun: Function to apply smap transform to. 97 | shard_size: Integer denoting shard size. 98 | in_axes: Either integer or pytree describing which axis to map over for each 99 | input to `fun`, None denotes broadcasting. 100 | out_axes: integer or pytree denoting to what axis in the output the mapped 101 | over axis maps. 102 | new_out_axes: whether to stack outputs on new axes. This assumes that the 103 | output sizes for each shard (including the possible remainder shard) are 104 | the same. 105 | 106 | Returns: 107 | function with smap applied. 108 | """ 109 | docstr = ('Mapped version of {fun}. Takes similar arguments to {fun} ' 110 | 'but with additional array axes over which {fun} is mapped.') 111 | if new_out_axes: 112 | raise NotImplementedError('New output axes not yet implemented.') 113 | 114 | # shard size None denotes no sharding 115 | if shard_size is None: 116 | return fun 117 | 118 | @jax.util.wraps(fun, docstr=docstr) 119 | def mapped_fn(*args): 120 | # Expand in axes and Determine Loop range 121 | in_axes_ = _expand_axes(in_axes, args) 122 | 123 | in_sizes = jax.tree_map(_maybe_get_size, args, in_axes_) 124 | flat_sizes = jax.tree_flatten(in_sizes)[0] 125 | in_size = max(flat_sizes) 126 | assert all(i in {in_size, -1} for i in flat_sizes) 127 | 128 | num_extra_shards = (in_size - 1) // shard_size 129 | 130 | # Fix Up if necessary 131 | last_shard_size = in_size % shard_size 132 | last_shard_size = shard_size if last_shard_size == 0 else last_shard_size 133 | 134 | def apply_fun_to_slice(slice_start, slice_size): 135 | input_slice = jax.tree_map( 136 | lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis 137 | ), args, in_axes_) 138 | return fun(*input_slice) 139 | 140 | remainder_shape_dtype = hk.eval_shape( 141 | partial(apply_fun_to_slice, 0, last_shard_size)) 142 | out_dtypes = jax.tree_map(lambda x: x.dtype, remainder_shape_dtype) 143 | out_shapes = jax.tree_map(lambda x: x.shape, remainder_shape_dtype) 144 | out_axes_ = _expand_axes(out_axes, remainder_shape_dtype) 145 | 146 | if num_extra_shards > 0: 147 | regular_shard_shape_dtype = hk.eval_shape( 148 | partial(apply_fun_to_slice, 0, shard_size)) 149 | shard_shapes = jax.tree_map(lambda x: x.shape, regular_shard_shape_dtype) 150 | 151 | def make_output_shape(axis, shard_shape, remainder_shape): 152 | return shard_shape[:axis] + ( 153 | shard_shape[axis] * num_extra_shards + 154 | remainder_shape[axis],) + shard_shape[axis + 1:] 155 | 156 | out_shapes = jax.tree_map(make_output_shape, out_axes_, shard_shapes, 157 | out_shapes) 158 | 159 | # Calls dynamic Update slice with different argument order 160 | # This is here since tree_multimap only works with positional arguments 161 | def dynamic_update_slice_in_dim(full_array, update, axis, i): 162 | return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis) 163 | 164 | def compute_shard(outputs, slice_start, slice_size): 165 | slice_out = apply_fun_to_slice(slice_start, slice_size) 166 | update_slice = partial( 167 | dynamic_update_slice_in_dim, i=slice_start) 168 | return jax.tree_map(update_slice, outputs, slice_out, out_axes_) 169 | 170 | def scan_iteration(outputs, i): 171 | new_outputs = compute_shard(outputs, i, shard_size) 172 | return new_outputs, () 173 | 174 | slice_starts = jnp.arange(0, in_size - shard_size + 1, shard_size) 175 | 176 | def allocate_buffer(dtype, shape): 177 | return jnp.zeros(shape, dtype=dtype) 178 | 179 | outputs = jax.tree_map(allocate_buffer, out_dtypes, out_shapes) 180 | 181 | if slice_starts.shape[0] > 0: 182 | outputs, _ = hk.scan(scan_iteration, outputs, slice_starts) 183 | 184 | if last_shard_size != shard_size: 185 | remainder_start = in_size - last_shard_size 186 | outputs = compute_shard(outputs, remainder_start, last_shard_size) 187 | 188 | return outputs 189 | 190 | return mapped_fn 191 | 192 | 193 | def inference_subbatch( 194 | module: Callable[..., PYTREE_JAX_ARRAY], 195 | subbatch_size: int, 196 | batched_args: Sequence[PYTREE_JAX_ARRAY], 197 | nonbatched_args: Sequence[PYTREE_JAX_ARRAY], 198 | low_memory: bool = True, 199 | input_subbatch_dim: int = 0, 200 | output_subbatch_dim: Optional[int] = None) -> PYTREE_JAX_ARRAY: 201 | """Run through subbatches (like batch apply but with split and concat).""" 202 | assert len(batched_args) > 0 # pylint: disable=g-explicit-length-test 203 | 204 | if not low_memory: 205 | args = list(batched_args) + list(nonbatched_args) 206 | return module(*args) 207 | 208 | if output_subbatch_dim is None: 209 | output_subbatch_dim = input_subbatch_dim 210 | 211 | def run_module(*batched_args): 212 | args = list(batched_args) + list(nonbatched_args) 213 | return module(*args) 214 | sharded_module = sharded_apply(run_module, 215 | shard_size=subbatch_size, 216 | in_axes=input_subbatch_dim, 217 | out_axes=output_subbatch_dim) 218 | return sharded_module(*batched_args) 219 | -------------------------------------------------------------------------------- /src/alphafold/model/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Code for constructing the model.""" 16 | from typing import Any, Mapping, Optional, Union 17 | 18 | from absl import logging 19 | from alphafold.common import confidence 20 | from alphafold.model import features 21 | from alphafold.model import modules 22 | import haiku as hk 23 | import jax 24 | import ml_collections 25 | import numpy as np 26 | import tensorflow.compat.v1 as tf 27 | import tree 28 | 29 | 30 | def get_confidence_metrics( 31 | prediction_result: Mapping[str, Any]) -> Mapping[str, Any]: 32 | """Post processes prediction_result to get confidence metrics.""" 33 | 34 | confidence_metrics = {} 35 | confidence_metrics['plddt'] = confidence.compute_plddt( 36 | prediction_result['predicted_lddt']['logits']) 37 | if 'predicted_aligned_error' in prediction_result: 38 | confidence_metrics.update(confidence.compute_predicted_aligned_error( 39 | prediction_result['predicted_aligned_error']['logits'], 40 | prediction_result['predicted_aligned_error']['breaks'])) 41 | confidence_metrics['ptm'] = confidence.predicted_tm_score( 42 | prediction_result['predicted_aligned_error']['logits'], 43 | prediction_result['predicted_aligned_error']['breaks']) 44 | 45 | return confidence_metrics 46 | 47 | 48 | class RunModel: 49 | """Container for JAX model.""" 50 | 51 | def __init__(self, 52 | config: ml_collections.ConfigDict, 53 | params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None): 54 | self.config = config 55 | self.params = params 56 | 57 | def _forward_fn(batch): 58 | model = modules.AlphaFold(self.config.model) 59 | return model( 60 | batch, 61 | is_training=False, 62 | compute_loss=False, 63 | ensemble_representations=True) 64 | 65 | self.apply = jax.jit(hk.transform(_forward_fn).apply) 66 | self.init = jax.jit(hk.transform(_forward_fn).init) 67 | 68 | def init_params(self, feat: features.FeatureDict, random_seed: int = 0): 69 | """Initializes the model parameters. 70 | 71 | If none were provided when this class was instantiated then the parameters 72 | are randomly initialized. 73 | 74 | Args: 75 | feat: A dictionary of NumPy feature arrays as output by 76 | RunModel.process_features. 77 | random_seed: A random seed to use to initialize the parameters if none 78 | were set when this class was initialized. 79 | """ 80 | if not self.params: 81 | # Init params randomly. 82 | rng = jax.random.PRNGKey(random_seed) 83 | self.params = hk.data_structures.to_mutable_dict( 84 | self.init(rng, feat)) 85 | logging.warning('Initialized parameters randomly') 86 | 87 | def process_features( 88 | self, 89 | raw_features: Union[tf.train.Example, features.FeatureDict], 90 | random_seed: int) -> features.FeatureDict: 91 | """Processes features to prepare for feeding them into the model. 92 | 93 | Args: 94 | raw_features: The output of the data pipeline either as a dict of NumPy 95 | arrays or as a tf.train.Example. 96 | random_seed: The random seed to use when processing the features. 97 | 98 | Returns: 99 | A dict of NumPy feature arrays suitable for feeding into the model. 100 | """ 101 | if isinstance(raw_features, dict): 102 | return features.np_example_to_features( 103 | np_example=raw_features, 104 | config=self.config, 105 | random_seed=random_seed) 106 | else: 107 | return features.tf_example_to_features( 108 | tf_example=raw_features, 109 | config=self.config, 110 | random_seed=random_seed) 111 | 112 | def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct: 113 | self.init_params(feat) 114 | logging.info('Running eval_shape with shape(feat) = %s', 115 | tree.map_structure(lambda x: x.shape, feat)) 116 | shape = jax.eval_shape(self.apply, self.params, jax.random.PRNGKey(0), feat) 117 | logging.info('Output shape was %s', shape) 118 | return shape 119 | 120 | def predict(self, feat: features.FeatureDict) -> Mapping[str, Any]: 121 | """Makes a prediction by inferencing the model on the provided features. 122 | 123 | Args: 124 | feat: A dictionary of NumPy feature arrays as output by 125 | RunModel.process_features. 126 | 127 | Returns: 128 | A dictionary of model outputs. 129 | """ 130 | self.init_params(feat) 131 | logging.info('Running predict with shape(feat) = %s', 132 | tree.map_structure(lambda x: x.shape, feat)) 133 | result = self.apply(self.params, jax.random.PRNGKey(0), feat) 134 | # This block is to ensure benchmark timings are accurate. Some blocking is 135 | # already happening when computing get_confidence_metrics, and this ensures 136 | # all outputs are blocked on. 137 | jax.tree_map(lambda x: x.block_until_ready(), result) 138 | result.update(get_confidence_metrics(result)) 139 | logging.info('Output shape was %s', 140 | tree.map_structure(lambda x: x.shape, result)) 141 | return result 142 | -------------------------------------------------------------------------------- /src/alphafold/model/prng.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A collection of utilities surrounding PRNG usage in protein folding.""" 16 | 17 | import haiku as hk 18 | import jax 19 | 20 | 21 | def safe_dropout(*, tensor, safe_key, rate, is_deterministic, is_training): 22 | if is_training and rate != 0.0 and not is_deterministic: 23 | return hk.dropout(safe_key.get(), rate, tensor) 24 | else: 25 | return tensor 26 | 27 | 28 | class SafeKey: 29 | """Safety wrapper for PRNG keys.""" 30 | 31 | def __init__(self, key): 32 | self._key = key 33 | self._used = False 34 | 35 | def _assert_not_used(self): 36 | if self._used: 37 | raise RuntimeError('Random key has been used previously.') 38 | 39 | def get(self): 40 | self._assert_not_used() 41 | self._used = True 42 | return self._key 43 | 44 | def split(self, num_keys=2): 45 | self._assert_not_used() 46 | self._used = True 47 | new_keys = jax.random.split(self._key, num_keys) 48 | return jax.tree_map(SafeKey, tuple(new_keys)) 49 | 50 | def duplicate(self, num_keys=2): 51 | self._assert_not_used() 52 | self._used = True 53 | return tuple(SafeKey(self._key) for _ in range(num_keys)) 54 | 55 | 56 | def _safe_key_flatten(safe_key): 57 | # Flatten transfers "ownership" to the tree 58 | return (safe_key._key,), safe_key._used # pylint: disable=protected-access 59 | 60 | 61 | def _safe_key_unflatten(aux_data, children): 62 | ret = SafeKey(children[0]) 63 | ret._used = aux_data # pylint: disable=protected-access 64 | return ret 65 | 66 | 67 | jax.tree_util.register_pytree_node( 68 | SafeKey, _safe_key_flatten, _safe_key_unflatten) 69 | 70 | -------------------------------------------------------------------------------- /src/alphafold/model/prng_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for prng.""" 16 | 17 | from absl.testing import absltest 18 | from alphafold.model import prng 19 | import jax 20 | 21 | 22 | class PrngTest(absltest.TestCase): 23 | 24 | def test_key_reuse(self): 25 | 26 | init_key = jax.random.PRNGKey(42) 27 | safe_key = prng.SafeKey(init_key) 28 | _, safe_key = safe_key.split() 29 | 30 | raw_key = safe_key.get() 31 | 32 | self.assertNotEqual(raw_key[0], init_key[0]) 33 | self.assertNotEqual(raw_key[1], init_key[1]) 34 | 35 | with self.assertRaises(RuntimeError): 36 | safe_key.get() 37 | 38 | with self.assertRaises(RuntimeError): 39 | safe_key.split() 40 | 41 | with self.assertRaises(RuntimeError): 42 | safe_key.duplicate() 43 | 44 | 45 | if __name__ == '__main__': 46 | absltest.main() 47 | -------------------------------------------------------------------------------- /src/alphafold/model/quat_affine_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for quat_affine.""" 16 | 17 | from absl import logging 18 | from absl.testing import absltest 19 | from alphafold.model import quat_affine 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | VERBOSE = False 25 | np.set_printoptions(precision=3, suppress=True) 26 | 27 | r2t = quat_affine.rot_list_to_tensor 28 | v2t = quat_affine.vec_list_to_tensor 29 | 30 | q2r = lambda q: r2t(quat_affine.quat_to_rot(q)) 31 | 32 | 33 | class QuatAffineTest(absltest.TestCase): 34 | 35 | def _assert_check(self, to_check, tol=1e-5): 36 | for k, (correct, generated) in to_check.items(): 37 | if VERBOSE: 38 | logging.info(k) 39 | logging.info('Correct %s', correct) 40 | logging.info('Predicted %s', generated) 41 | self.assertLess(np.max(np.abs(correct - generated)), tol) 42 | 43 | def test_conversion(self): 44 | quat = jnp.array([-2., 5., -1., 4.]) 45 | 46 | rotation = jnp.array([ 47 | [0.26087, 0.130435, 0.956522], 48 | [-0.565217, -0.782609, 0.26087], 49 | [0.782609, -0.608696, -0.130435]]) 50 | 51 | translation = jnp.array([1., -3., 4.]) 52 | point = jnp.array([0.7, 3.2, -2.9]) 53 | 54 | a = quat_affine.QuatAffine(quat, translation, unstack_inputs=True) 55 | true_new_point = jnp.matmul(rotation, point[:, None])[:, 0] + translation 56 | 57 | self._assert_check({ 58 | 'rot': (rotation, r2t(a.rotation)), 59 | 'trans': (translation, v2t(a.translation)), 60 | 'point': (true_new_point, 61 | v2t(a.apply_to_point(jnp.moveaxis(point, -1, 0)))), 62 | # Because of the double cover, we must be careful and compare rotations 63 | 'quat': (q2r(a.quaternion), 64 | q2r(quat_affine.rot_to_quat(a.rotation))), 65 | 66 | }) 67 | 68 | def test_double_cover(self): 69 | """Test that -q is the same rotation as q.""" 70 | rng = jax.random.PRNGKey(42) 71 | keys = jax.random.split(rng) 72 | q = jax.random.normal(keys[0], (2, 4)) 73 | trans = jax.random.normal(keys[1], (2, 3)) 74 | a1 = quat_affine.QuatAffine(q, trans, unstack_inputs=True) 75 | a2 = quat_affine.QuatAffine(-q, trans, unstack_inputs=True) 76 | 77 | self._assert_check({ 78 | 'rot': (r2t(a1.rotation), 79 | r2t(a2.rotation)), 80 | 'trans': (v2t(a1.translation), 81 | v2t(a2.translation)), 82 | }) 83 | 84 | def test_homomorphism(self): 85 | rng = jax.random.PRNGKey(42) 86 | keys = jax.random.split(rng, 4) 87 | vec_q1 = jax.random.normal(keys[0], (2, 3)) 88 | 89 | q1 = jnp.concatenate([ 90 | jnp.ones_like(vec_q1)[:, :1], 91 | vec_q1], axis=-1) 92 | 93 | q2 = jax.random.normal(keys[1], (2, 4)) 94 | t1 = jax.random.normal(keys[2], (2, 3)) 95 | t2 = jax.random.normal(keys[3], (2, 3)) 96 | 97 | a1 = quat_affine.QuatAffine(q1, t1, unstack_inputs=True) 98 | a2 = quat_affine.QuatAffine(q2, t2, unstack_inputs=True) 99 | a21 = a2.pre_compose(jnp.concatenate([vec_q1, t1], axis=-1)) 100 | 101 | rng, key = jax.random.split(rng) 102 | x = jax.random.normal(key, (2, 3)) 103 | new_x = a21.apply_to_point(jnp.moveaxis(x, -1, 0)) 104 | new_x_apply2 = a2.apply_to_point(a1.apply_to_point(jnp.moveaxis(x, -1, 0))) 105 | 106 | self._assert_check({ 107 | 'quat': (q2r(quat_affine.quat_multiply(a2.quaternion, a1.quaternion)), 108 | q2r(a21.quaternion)), 109 | 'rot': (jnp.matmul(r2t(a2.rotation), r2t(a1.rotation)), 110 | r2t(a21.rotation)), 111 | 'point': (v2t(new_x_apply2), 112 | v2t(new_x)), 113 | 'inverse': (x, v2t(a21.invert_point(new_x))), 114 | }) 115 | 116 | def test_batching(self): 117 | """Test that affine applies batchwise.""" 118 | rng = jax.random.PRNGKey(42) 119 | keys = jax.random.split(rng, 3) 120 | q = jax.random.uniform(keys[0], (5, 2, 4)) 121 | t = jax.random.uniform(keys[1], (2, 3)) 122 | x = jax.random.uniform(keys[2], (5, 1, 3)) 123 | 124 | a = quat_affine.QuatAffine(q, t, unstack_inputs=True) 125 | y = v2t(a.apply_to_point(jnp.moveaxis(x, -1, 0))) 126 | 127 | y_list = [] 128 | for i in range(5): 129 | for j in range(2): 130 | a_local = quat_affine.QuatAffine(q[i, j], t[j], 131 | unstack_inputs=True) 132 | y_local = v2t(a_local.apply_to_point(jnp.moveaxis(x[i, 0], -1, 0))) 133 | y_list.append(y_local) 134 | y_combine = jnp.reshape(jnp.stack(y_list, axis=0), (5, 2, 3)) 135 | 136 | self._assert_check({ 137 | 'batch': (y_combine, y), 138 | 'quat': (q2r(a.quaternion), 139 | q2r(quat_affine.rot_to_quat(a.rotation))), 140 | }) 141 | 142 | def assertAllClose(self, a, b, rtol=1e-06, atol=1e-06): 143 | self.assertTrue(np.allclose(a, b, rtol=rtol, atol=atol)) 144 | 145 | def assertAllEqual(self, a, b): 146 | self.assertTrue(np.all(np.array(a) == np.array(b))) 147 | 148 | 149 | if __name__ == '__main__': 150 | absltest.main() 151 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Alphafold model TensorFlow code.""" 15 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/input_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Feature pre-processing input pipeline for AlphaFold.""" 16 | 17 | from alphafold.model.tf import data_transforms 18 | from alphafold.model.tf import shape_placeholders 19 | import tensorflow.compat.v1 as tf 20 | import pdb 21 | import tree 22 | 23 | # Pylint gets confused by the curry1 decorator because it changes the number 24 | # of arguments to the function. 25 | # pylint:disable=no-value-for-parameter 26 | 27 | 28 | NUM_RES = shape_placeholders.NUM_RES 29 | NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ 30 | NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ 31 | NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES 32 | 33 | 34 | def nonensembled_map_fns(data_config): 35 | """Input pipeline functions which are not ensembled.""" 36 | common_cfg = data_config.common 37 | 38 | map_fns = [ 39 | data_transforms.correct_msa_restypes, 40 | data_transforms.add_distillation_flag(False), 41 | data_transforms.cast_64bit_ints, 42 | data_transforms.squeeze_features, 43 | # Keep to not disrupt RNG. 44 | data_transforms.randomly_replace_msa_with_unknown(0.0), 45 | data_transforms.make_seq_mask, 46 | data_transforms.make_msa_mask, 47 | # Compute the HHblits profile if it's not set. This has to be run before 48 | # sampling the MSA. 49 | data_transforms.make_hhblits_profile, 50 | data_transforms.make_random_crop_to_size_seed, 51 | ] 52 | if common_cfg.use_templates: 53 | map_fns.extend([ 54 | data_transforms.fix_templates_aatype, 55 | data_transforms.make_template_mask, 56 | data_transforms.make_pseudo_beta('template_') 57 | ]) 58 | map_fns.extend([ 59 | data_transforms.make_atom14_masks, 60 | ]) 61 | 62 | return map_fns 63 | 64 | 65 | def ensembled_map_fns(data_config): 66 | """Input pipeline functions that can be ensembled and averaged.""" 67 | common_cfg = data_config.common 68 | eval_cfg = data_config.eval 69 | 70 | map_fns = [] 71 | 72 | if common_cfg.reduce_msa_clusters_by_max_templates: 73 | pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates 74 | else: 75 | pad_msa_clusters = eval_cfg.max_msa_clusters 76 | 77 | max_msa_clusters = pad_msa_clusters 78 | max_extra_msa = common_cfg.max_extra_msa 79 | 80 | map_fns.append( 81 | data_transforms.sample_msa( 82 | max_msa_clusters, 83 | keep_extra=True)) 84 | 85 | if 'masked_msa' in common_cfg: 86 | # Masked MSA should come *before* MSA clustering so that 87 | # the clustering and full MSA profile do not leak information about 88 | # the masked locations and secret corrupted locations. 89 | map_fns.append( 90 | data_transforms.make_masked_msa(common_cfg.masked_msa, 91 | eval_cfg.masked_msa_replace_fraction)) 92 | 93 | if common_cfg.msa_cluster_features: 94 | map_fns.append(data_transforms.nearest_neighbor_clusters()) 95 | map_fns.append(data_transforms.summarize_clusters()) 96 | 97 | # Crop after creating the cluster profiles. 98 | if max_extra_msa: 99 | map_fns.append(data_transforms.crop_extra_msa(max_extra_msa)) 100 | else: 101 | map_fns.append(data_transforms.delete_extra_msa) 102 | 103 | map_fns.append(data_transforms.make_msa_feat()) 104 | 105 | crop_feats = dict(eval_cfg.feat) 106 | 107 | if eval_cfg.fixed_size: 108 | map_fns.append(data_transforms.select_feat(list(crop_feats))) 109 | map_fns.append(data_transforms.random_crop_to_size( 110 | eval_cfg.crop_size, 111 | eval_cfg.max_templates, 112 | crop_feats, 113 | eval_cfg.subsample_templates)) 114 | map_fns.append(data_transforms.make_fixed_size( 115 | crop_feats, 116 | pad_msa_clusters, 117 | common_cfg.max_extra_msa, 118 | eval_cfg.crop_size, 119 | eval_cfg.max_templates)) 120 | else: 121 | map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates)) 122 | 123 | return map_fns 124 | 125 | 126 | def process_tensors_from_config(tensors, data_config): 127 | """Apply filters and maps to an existing dataset, based on the config.""" 128 | 129 | def wrap_ensemble_fn(data, i): 130 | """Function to be mapped over the ensemble dimension.""" 131 | d = data.copy() 132 | fns = ensembled_map_fns(data_config) 133 | fn = compose(fns) 134 | d['ensemble_index'] = i 135 | return fn(d) 136 | 137 | eval_cfg = data_config.eval 138 | tensors = compose( 139 | nonensembled_map_fns( 140 | data_config))( 141 | tensors) 142 | 143 | tensors_0 = wrap_ensemble_fn(tensors, tf.constant(0)) 144 | num_ensemble = eval_cfg.num_ensemble 145 | if data_config.common.resample_msa_in_recycling: 146 | # Separate batch per ensembling & recycling step. 147 | num_ensemble *= data_config.common.num_recycle + 1 148 | 149 | if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1: 150 | fn_output_signature = tree.map_structure( 151 | tf.TensorSpec.from_tensor, tensors_0) 152 | tensors = tf.map_fn( 153 | lambda x: wrap_ensemble_fn(tensors, x), 154 | tf.range(num_ensemble), 155 | parallel_iterations=1, 156 | fn_output_signature=fn_output_signature) 157 | else: 158 | tensors = tree.map_structure(lambda x: x[None], 159 | tensors_0) 160 | return tensors 161 | 162 | 163 | @data_transforms.curry1 164 | def compose(x, fs): 165 | for f in fs: 166 | x = f(x) 167 | return x 168 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/protein_features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Contains descriptions of various protein features.""" 16 | import enum 17 | from typing import Dict, Optional, Sequence, Tuple, Union 18 | from alphafold.common import residue_constants 19 | import tensorflow.compat.v1 as tf 20 | import pdb 21 | 22 | # Type aliases. 23 | FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]] 24 | 25 | 26 | class FeatureType(enum.Enum): 27 | ZERO_DIM = 0 # Shape [x] 28 | ONE_DIM = 1 # Shape [num_res, x] 29 | TWO_DIM = 2 # Shape [num_res, num_res, x] 30 | MSA = 3 # Shape [msa_length, num_res, x] 31 | 32 | 33 | # Placeholder values that will be replaced with their true value at runtime. 34 | NUM_RES = "num residues placeholder" 35 | NUM_SEQ = "length msa placeholder" 36 | NUM_TEMPLATES = "num templates placeholder" 37 | # Sizes of the protein features, NUM_RES and NUM_SEQ are allowed as placeholders 38 | # to be replaced with the number of residues and the number of sequences in the 39 | # multiple sequence alignment, respectively. 40 | 41 | 42 | FEATURES = { 43 | #### Static features of a protein sequence #### 44 | "aatype": (tf.float32, [NUM_RES, 21]), 45 | "between_segment_residues": (tf.int64, [NUM_RES, 1]), 46 | "deletion_matrix": (tf.float32, [NUM_SEQ, NUM_RES, 1]), 47 | "domain_name": (tf.string, [1]), 48 | "msa": (tf.int64, [NUM_SEQ, NUM_RES, 1]), 49 | "num_alignments": (tf.int64, [NUM_RES, 1]), 50 | "residue_index": (tf.int64, [NUM_RES, 1]), 51 | "seq_length": (tf.int64, [NUM_RES, 1]), 52 | "sequence": (tf.string, [1]), 53 | "all_atom_positions": (tf.float32, 54 | [NUM_RES, residue_constants.atom_type_num, 3]), 55 | "all_atom_mask": (tf.int64, [NUM_RES, residue_constants.atom_type_num]), 56 | "resolution": (tf.float32, [1]), 57 | "template_domain_names": (tf.string, [NUM_TEMPLATES]), 58 | "template_sum_probs": (tf.float32, [NUM_TEMPLATES, 1]), 59 | "template_aatype": (tf.float32, [NUM_TEMPLATES, NUM_RES, 22]), 60 | "template_all_atom_positions": (tf.float32, [ 61 | NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3 62 | ]), 63 | "template_all_atom_masks": (tf.float32, [ 64 | NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1 65 | ]), 66 | } 67 | 68 | FEATURE_TYPES = {k: v[0] for k, v in FEATURES.items()} 69 | FEATURE_SIZES = {k: v[1] for k, v in FEATURES.items()} 70 | 71 | 72 | def register_feature(name: str, 73 | type_: tf.dtypes.DType, 74 | shape_: Tuple[Union[str, int]]): 75 | """Register extra features used in custom datasets.""" 76 | FEATURES[name] = (type_, shape_) 77 | FEATURE_TYPES[name] = type_ 78 | FEATURE_SIZES[name] = shape_ 79 | 80 | 81 | def shape(feature_name: str, 82 | num_residues: int, 83 | msa_length: int, 84 | num_templates: Optional[int] = None, 85 | features: Optional[FeaturesMetadata] = None): 86 | """Get the shape for the given feature name. 87 | 88 | This is near identical to _get_tf_shape_no_placeholders() but with 2 89 | differences: 90 | * This method does not calculate a single placeholder from the total number of 91 | elements (eg given and size := 12, this won't deduce NUM_RES 92 | must be 4) 93 | * This method will work with tensors 94 | 95 | Args: 96 | feature_name: String identifier for the feature. If the feature name ends 97 | with "_unnormalized", this suffix is stripped off. 98 | num_residues: The number of residues in the current domain - some elements 99 | of the shape can be dynamic and will be replaced by this value. 100 | msa_length: The number of sequences in the multiple sequence alignment, some 101 | elements of the shape can be dynamic and will be replaced by this value. 102 | If the number of alignments is unknown / not read, please pass None for 103 | msa_length. 104 | num_templates (optional): The number of templates in this tfexample. 105 | features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES. 106 | 107 | Returns: 108 | List of ints representation the tensor size. 109 | 110 | Raises: 111 | ValueError: If a feature is requested but no concrete placeholder value is 112 | given. 113 | """ 114 | features = features or FEATURES 115 | if feature_name.endswith("_unnormalized"): 116 | feature_name = feature_name[:-13] 117 | 118 | unused_dtype, raw_sizes = features[feature_name] 119 | replacements = {NUM_RES: num_residues, 120 | NUM_SEQ: msa_length} 121 | 122 | if num_templates is not None: 123 | replacements[NUM_TEMPLATES] = num_templates 124 | 125 | sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes] 126 | for dimension in sizes: 127 | if isinstance(dimension, str): 128 | raise ValueError("Could not parse %s (shape: %s) with values: %s" % ( 129 | feature_name, raw_sizes, replacements)) 130 | return sizes 131 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/protein_features_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for protein_features.""" 16 | import uuid 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from alphafold.model.tf import protein_features 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | def _random_bytes(): 25 | return str(uuid.uuid4()).encode('utf-8') 26 | 27 | 28 | class FeaturesTest(parameterized.TestCase, tf.test.TestCase): 29 | 30 | def testFeatureNames(self): 31 | self.assertEqual(len(protein_features.FEATURE_SIZES), 32 | len(protein_features.FEATURE_TYPES)) 33 | sorted_size_names = sorted(protein_features.FEATURE_SIZES.keys()) 34 | sorted_type_names = sorted(protein_features.FEATURE_TYPES.keys()) 35 | for i, size_name in enumerate(sorted_size_names): 36 | self.assertEqual(size_name, sorted_type_names[i]) 37 | 38 | def testReplacement(self): 39 | for name in protein_features.FEATURE_SIZES.keys(): 40 | sizes = protein_features.shape(name, 41 | num_residues=12, 42 | msa_length=24, 43 | num_templates=3) 44 | for x in sizes: 45 | self.assertEqual(type(x), int) 46 | self.assertGreater(x, 0) 47 | 48 | 49 | if __name__ == '__main__': 50 | tf.disable_v2_behavior() 51 | absltest.main() 52 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/proteins_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Datasets consisting of proteins.""" 16 | from typing import Dict, Mapping, Optional, Sequence 17 | from alphafold.model.tf import protein_features 18 | import numpy as np 19 | import tensorflow.compat.v1 as tf 20 | 21 | TensorDict = Dict[str, tf.Tensor] 22 | 23 | 24 | def parse_tfexample( 25 | raw_data: bytes, 26 | features: protein_features.FeaturesMetadata, 27 | key: Optional[str] = None) -> Dict[str, tf.train.Feature]: 28 | """Read a single TF Example proto and return a subset of its features. 29 | 30 | Args: 31 | raw_data: A serialized tf.Example proto. 32 | features: A dictionary of features, mapping string feature names to a tuple 33 | (dtype, shape). This dictionary should be a subset of 34 | protein_features.FEATURES (or the dictionary itself for all features). 35 | key: Optional string with the SSTable key of that tf.Example. This will be 36 | added into features as a 'key' but only if requested in features. 37 | 38 | Returns: 39 | A dictionary of features mapping feature names to features. Only the given 40 | features are returned, all other ones are filtered out. 41 | """ 42 | feature_map = { 43 | k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True) 44 | for k, v in features.items() 45 | } 46 | parsed_features = tf.io.parse_single_example(raw_data, feature_map) 47 | reshaped_features = parse_reshape_logic(parsed_features, features, key=key) 48 | 49 | return reshaped_features 50 | 51 | 52 | def _first(tensor: tf.Tensor) -> tf.Tensor: 53 | """Returns the 1st element - the input can be a tensor or a scalar.""" 54 | return tf.reshape(tensor, shape=(-1,))[0] 55 | 56 | 57 | def parse_reshape_logic( 58 | parsed_features: TensorDict, 59 | features: protein_features.FeaturesMetadata, 60 | key: Optional[str] = None) -> TensorDict: 61 | """Transforms parsed serial features to the correct shape.""" 62 | # Find out what is the number of sequences and the number of alignments. 63 | num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32) 64 | 65 | if "num_alignments" in parsed_features: 66 | num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32) 67 | else: 68 | num_msa = 0 69 | 70 | num_templates = 0 71 | # if "template_domain_names" in parsed_features: 72 | # num_templates = tf.cast( 73 | # tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32) 74 | # else: 75 | # num_templates = 0 76 | 77 | if key is not None and "key" in features: 78 | parsed_features["key"] = [key] # Expand dims from () to (1,). 79 | 80 | # Reshape the tensors according to the sequence length and num alignments. 81 | for k, v in parsed_features.items(): 82 | new_shape = protein_features.shape( 83 | feature_name=k, 84 | num_residues=num_residues, 85 | msa_length=num_msa, 86 | num_templates=num_templates, 87 | features=features) 88 | new_shape_size = tf.constant(1, dtype=tf.int32) 89 | for dim in new_shape: 90 | new_shape_size *= tf.cast(dim, tf.int32) 91 | 92 | assert_equal = tf.assert_equal( 93 | tf.size(v), new_shape_size, 94 | name="assert_%s_shape_correct" % k, 95 | message="The size of feature %s (%s) could not be reshaped " 96 | "into %s" % (k, tf.size(v), new_shape)) 97 | if "template" not in k: 98 | # Make sure the feature we are reshaping is not empty. 99 | assert_non_empty = tf.assert_greater( 100 | tf.size(v), 0, name="assert_%s_non_empty" % k, 101 | message="The feature %s is not set in the tf.Example. Either do not " 102 | "request the feature or use a tf.Example that has the " 103 | "feature set." % k) 104 | with tf.control_dependencies([assert_non_empty, assert_equal]): 105 | parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) 106 | else: 107 | with tf.control_dependencies([assert_equal]): 108 | parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) 109 | 110 | return parsed_features 111 | 112 | 113 | def _make_features_metadata( 114 | feature_names: Sequence[str]) -> protein_features.FeaturesMetadata: 115 | """Makes a feature name to type and shape mapping from a list of names.""" 116 | # Make sure these features are always read. 117 | required_features = ["aatype", "sequence", "seq_length"] 118 | feature_names = list(set(feature_names) | set(required_features)) 119 | 120 | features_metadata = {name: protein_features.FEATURES[name] 121 | for name in feature_names} 122 | return features_metadata 123 | 124 | 125 | def create_tensor_dict( 126 | raw_data: bytes, 127 | features: Sequence[str], 128 | key: Optional[str] = None, 129 | ) -> TensorDict: 130 | """Creates a dictionary of tensor features. 131 | 132 | Args: 133 | raw_data: A serialized tf.Example proto. 134 | features: A list of strings of feature names to be returned in the dataset. 135 | key: Optional string with the SSTable key of that tf.Example. This will be 136 | added into features as a 'key' but only if requested in features. 137 | 138 | Returns: 139 | A dictionary of features mapping feature names to features. Only the given 140 | features are returned, all other ones are filtered out. 141 | """ 142 | features_metadata = _make_features_metadata(features) 143 | return parse_tfexample(raw_data, features_metadata, key) 144 | 145 | 146 | def np_to_tensor_dict( 147 | np_example: Mapping[str, np.ndarray], 148 | features: Sequence[str], 149 | ) -> TensorDict: 150 | """Creates dict of tensors from a dict of NumPy arrays. 151 | 152 | Args: 153 | np_example: A dict of NumPy feature arrays. 154 | features: A list of strings of feature names to be returned in the dataset. 155 | 156 | Returns: 157 | A dictionary of features mapping feature names to features. Only the given 158 | features are returned, all other ones are filtered out. 159 | """ 160 | features_metadata = _make_features_metadata(features) 161 | tensor_dict = {k: tf.constant(v) for k, v in np_example.items() 162 | if k in features_metadata} 163 | 164 | # Ensures shapes are as expected. Needed for setting size of empty features 165 | # e.g. when no template hits were found. 166 | tensor_dict = parse_reshape_logic(tensor_dict, features_metadata) 167 | return tensor_dict 168 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/shape_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for dealing with shapes of TensorFlow tensors.""" 16 | import tensorflow.compat.v1 as tf 17 | 18 | 19 | def shape_list(x): 20 | """Return list of dimensions of a tensor, statically where possible. 21 | 22 | Like `x.shape.as_list()` but with tensors instead of `None`s. 23 | 24 | Args: 25 | x: A tensor. 26 | Returns: 27 | A list with length equal to the rank of the tensor. The n-th element of the 28 | list is an integer when that dimension is statically known otherwise it is 29 | the n-th element of `tf.shape(x)`. 30 | """ 31 | x = tf.convert_to_tensor(x) 32 | 33 | # If unknown rank, return dynamic shape 34 | if x.get_shape().dims is None: 35 | return tf.shape(x) 36 | 37 | static = x.get_shape().as_list() 38 | shape = tf.shape(x) 39 | 40 | ret = [] 41 | for i in range(len(static)): 42 | dim = static[i] 43 | if dim is None: 44 | dim = shape[i] 45 | ret.append(dim) 46 | return ret 47 | 48 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/shape_helpers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for shape_helpers.""" 16 | 17 | from alphafold.model.tf import shape_helpers 18 | import numpy as np 19 | import tensorflow.compat.v1 as tf 20 | 21 | 22 | class ShapeTest(tf.test.TestCase): 23 | 24 | def test_shape_list(self): 25 | """Test that shape_list can allow for reshaping to dynamic shapes.""" 26 | a = tf.zeros([10, 4, 4, 2]) 27 | p = tf.placeholder(tf.float32, shape=[None, None, 1, 4, 4]) 28 | shape_dyn = shape_helpers.shape_list(p)[:2] + [4, 4] 29 | 30 | b = tf.reshape(a, shape_dyn) 31 | with self.session() as sess: 32 | out = sess.run(b, feed_dict={p: np.ones((20, 1, 1, 4, 4))}) 33 | 34 | self.assertAllEqual(out.shape, (20, 1, 4, 4)) 35 | 36 | 37 | if __name__ == '__main__': 38 | tf.disable_v2_behavior() 39 | tf.test.main() 40 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/shape_placeholders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Placeholder values for run-time varying dimension sizes.""" 16 | 17 | NUM_RES = 'num residues placeholder' 18 | NUM_MSA_SEQ = 'msa placeholder' 19 | NUM_EXTRA_SEQ = 'extra msa placeholder' 20 | NUM_TEMPLATES = 'num templates placeholder' 21 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Shared utilities for various components.""" 16 | import tensorflow.compat.v1 as tf 17 | 18 | 19 | def tf_combine_mask(*masks): 20 | """Take the intersection of float-valued masks.""" 21 | ret = 1 22 | for m in masks: 23 | ret *= m 24 | return ret 25 | 26 | 27 | class SeedMaker(object): 28 | """Return unique seeds.""" 29 | 30 | def __init__(self, initial_seed=0): 31 | self.next_seed = initial_seed 32 | 33 | def __call__(self): 34 | i = self.next_seed 35 | self.next_seed += 1 36 | return i 37 | 38 | seed_maker = SeedMaker() 39 | 40 | 41 | def make_random_seed(): 42 | return tf.random.uniform([2], 43 | tf.int32.min, 44 | tf.int32.max, 45 | tf.int32, 46 | seed=seed_maker()) 47 | 48 | -------------------------------------------------------------------------------- /src/alphafold/model/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A collection of JAX utility functions for use in protein folding.""" 16 | 17 | import collections.abc as collections 18 | import numbers 19 | from typing import Mapping 20 | 21 | import haiku as hk 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | 27 | def final_init(config): 28 | if config.zero_init: 29 | return 'zeros' 30 | else: 31 | return 'linear' 32 | 33 | 34 | def batched_gather(params, indices, axis=0, batch_dims=0): 35 | """Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`.""" 36 | take_fn = lambda p, i: jnp.take(p, i, axis=axis) 37 | for _ in range(batch_dims): 38 | take_fn = jax.vmap(take_fn) 39 | return take_fn(params, indices) 40 | 41 | 42 | def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10): 43 | """Masked mean.""" 44 | if drop_mask_channel: 45 | mask = mask[..., 0] 46 | 47 | mask_shape = mask.shape 48 | value_shape = value.shape 49 | 50 | assert len(mask_shape) == len(value_shape) 51 | 52 | if isinstance(axis, numbers.Integral): 53 | axis = [axis] 54 | elif axis is None: 55 | axis = list(range(len(mask_shape))) 56 | assert isinstance(axis, collections.Iterable), ( 57 | 'axis needs to be either an iterable, integer or "None"') 58 | 59 | broadcast_factor = 1. 60 | for axis_ in axis: 61 | value_size = value_shape[axis_] 62 | mask_size = mask_shape[axis_] 63 | if mask_size == 1: 64 | broadcast_factor *= value_size 65 | else: 66 | assert mask_size == value_size 67 | 68 | return (jnp.sum(mask * value, axis=axis) / 69 | (jnp.sum(mask, axis=axis) * broadcast_factor + eps)) 70 | 71 | 72 | def flat_params_to_haiku(params: Mapping[str, np.ndarray]) -> hk.Params: 73 | """Convert a dictionary of NumPy arrays to Haiku parameters.""" 74 | hk_params = {} 75 | for path, array in params.items(): 76 | scope, name = path.split('//') 77 | if scope not in hk_params: 78 | hk_params[scope] = {} 79 | hk_params[scope][name] = jnp.array(array) 80 | 81 | return hk_params 82 | -------------------------------------------------------------------------------- /src/alphafold/relax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Amber relaxation.""" 15 | -------------------------------------------------------------------------------- /src/alphafold/relax/amber_minimize_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for amber_minimize.""" 16 | import os 17 | 18 | from absl.testing import absltest 19 | from alphafold.common import protein 20 | from alphafold.relax import amber_minimize 21 | import numpy as np 22 | # Internal import (7716). 23 | 24 | 25 | def _load_test_protein(data_path): 26 | pdb_path = os.path.join(absltest.get_default_test_srcdir(), data_path) 27 | with open(pdb_path, 'r') as f: 28 | return protein.from_pdb_string(f.read()) 29 | 30 | 31 | class AmberMinimizeTest(absltest.TestCase): 32 | 33 | def test_multiple_disulfides_target(self): 34 | prot = _load_test_protein( 35 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb' 36 | ) 37 | ret = amber_minimize.run_pipeline(prot, max_iterations=10, max_attempts=1, 38 | stiffness=10.) 39 | self.assertIn('opt_time', ret) 40 | self.assertIn('min_attempts', ret) 41 | 42 | def test_raises_invalid_protein_assertion(self): 43 | prot = _load_test_protein( 44 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb' 45 | ) 46 | prot.atom_mask[4, :] = 0 47 | with self.assertRaisesRegex( 48 | ValueError, 49 | 'Amber minimization can only be performed on proteins with well-defined' 50 | ' residues. This protein contains at least one residue with no atoms.'): 51 | amber_minimize.run_pipeline(prot, max_iterations=10, 52 | stiffness=1., 53 | max_attempts=1) 54 | 55 | def test_iterative_relax(self): 56 | prot = _load_test_protein( 57 | 'alphafold/relax/testdata/with_violations.pdb' 58 | ) 59 | violations = amber_minimize.get_violation_metrics(prot) 60 | self.assertGreater(violations['num_residue_violations'], 0) 61 | out = amber_minimize.run_pipeline( 62 | prot=prot, max_outer_iterations=10, stiffness=10.) 63 | self.assertLess(out['efinal'], out['einit']) 64 | self.assertEqual(0, out['num_residue_violations']) 65 | 66 | def test_find_violations(self): 67 | prot = _load_test_protein( 68 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb' 69 | ) 70 | viols, _ = amber_minimize.find_violations(prot) 71 | 72 | expected_between_residues_connection_mask = np.zeros((191,), np.float32) 73 | for residue in (42, 43, 59, 60, 135, 136): 74 | expected_between_residues_connection_mask[residue] = 1.0 75 | 76 | expected_clash_indices = np.array([ 77 | [8, 4], 78 | [8, 5], 79 | [13, 3], 80 | [14, 1], 81 | [14, 4], 82 | [26, 4], 83 | [26, 5], 84 | [31, 8], 85 | [31, 10], 86 | [39, 0], 87 | [39, 1], 88 | [39, 2], 89 | [39, 3], 90 | [39, 4], 91 | [42, 5], 92 | [42, 6], 93 | [42, 7], 94 | [42, 8], 95 | [47, 7], 96 | [47, 8], 97 | [47, 9], 98 | [47, 10], 99 | [64, 4], 100 | [85, 5], 101 | [102, 4], 102 | [102, 5], 103 | [109, 13], 104 | [111, 5], 105 | [118, 6], 106 | [118, 7], 107 | [118, 8], 108 | [124, 4], 109 | [124, 5], 110 | [131, 5], 111 | [139, 7], 112 | [147, 4], 113 | [152, 7]], dtype=np.int32) 114 | expected_between_residues_clash_mask = np.zeros([191, 14]) 115 | expected_between_residues_clash_mask[expected_clash_indices[:, 0], 116 | expected_clash_indices[:, 1]] += 1 117 | expected_per_atom_violations = np.zeros([191, 14]) 118 | np.testing.assert_array_equal( 119 | viols['between_residues']['connections_per_residue_violation_mask'], 120 | expected_between_residues_connection_mask) 121 | np.testing.assert_array_equal( 122 | viols['between_residues']['clashes_per_atom_clash_mask'], 123 | expected_between_residues_clash_mask) 124 | np.testing.assert_array_equal( 125 | viols['within_residues']['per_atom_violations'], 126 | expected_per_atom_violations) 127 | 128 | 129 | if __name__ == '__main__': 130 | absltest.main() 131 | -------------------------------------------------------------------------------- /src/alphafold/relax/cleanup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations. 16 | 17 | fix_pdb uses a third-party tool. We also support fixing some additional edge 18 | cases like removing chains of length one (see clean_structure). 19 | """ 20 | import io 21 | 22 | import pdbfixer 23 | from simtk.openmm import app 24 | from simtk.openmm.app import element 25 | 26 | 27 | def fix_pdb(pdbfile, alterations_info): 28 | """Apply pdbfixer to the contents of a PDB file; return a PDB string result. 29 | 30 | 1) Replaces nonstandard residues. 31 | 2) Removes heterogens (non protein residues) including water. 32 | 3) Adds missing residues and missing atoms within existing residues. 33 | 4) Adds hydrogens assuming pH=7.0. 34 | 5) KeepIds is currently true, so the fixer must keep the existing chain and 35 | residue identifiers. This will fail for some files in wider PDB that have 36 | invalid IDs. 37 | 38 | Args: 39 | pdbfile: Input PDB file handle. 40 | alterations_info: A dict that will store details of changes made. 41 | 42 | Returns: 43 | A PDB string representing the fixed structure. 44 | """ 45 | fixer = pdbfixer.PDBFixer(pdbfile=pdbfile) 46 | fixer.findNonstandardResidues() 47 | alterations_info['nonstandard_residues'] = fixer.nonstandardResidues 48 | fixer.replaceNonstandardResidues() 49 | _remove_heterogens(fixer, alterations_info, keep_water=False) 50 | fixer.findMissingResidues() 51 | alterations_info['missing_residues'] = fixer.missingResidues 52 | fixer.findMissingAtoms() 53 | alterations_info['missing_heavy_atoms'] = fixer.missingAtoms 54 | alterations_info['missing_terminals'] = fixer.missingTerminals 55 | fixer.addMissingAtoms(seed=0) 56 | fixer.addMissingHydrogens() 57 | out_handle = io.StringIO() 58 | app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle, 59 | keepIds=True) 60 | return out_handle.getvalue() 61 | 62 | 63 | def clean_structure(pdb_structure, alterations_info): 64 | """Applies additional fixes to an OpenMM structure, to handle edge cases. 65 | 66 | Args: 67 | pdb_structure: An OpenMM structure to modify and fix. 68 | alterations_info: A dict that will store details of changes made. 69 | """ 70 | _replace_met_se(pdb_structure, alterations_info) 71 | _remove_chains_of_length_one(pdb_structure, alterations_info) 72 | 73 | 74 | def _remove_heterogens(fixer, alterations_info, keep_water): 75 | """Removes the residues that Pdbfixer considers to be heterogens. 76 | 77 | Args: 78 | fixer: A Pdbfixer instance. 79 | alterations_info: A dict that will store details of changes made. 80 | keep_water: If True, water (HOH) is not considered to be a heterogen. 81 | """ 82 | initial_resnames = set() 83 | for chain in fixer.topology.chains(): 84 | for residue in chain.residues(): 85 | initial_resnames.add(residue.name) 86 | fixer.removeHeterogens(keepWater=keep_water) 87 | final_resnames = set() 88 | for chain in fixer.topology.chains(): 89 | for residue in chain.residues(): 90 | final_resnames.add(residue.name) 91 | alterations_info['removed_heterogens'] = ( 92 | initial_resnames.difference(final_resnames)) 93 | 94 | 95 | def _replace_met_se(pdb_structure, alterations_info): 96 | """Replace the Se in any MET residues that were not marked as modified.""" 97 | modified_met_residues = [] 98 | for res in pdb_structure.iter_residues(): 99 | name = res.get_name_with_spaces().strip() 100 | if name == 'MET': 101 | s_atom = res.get_atom('SD') 102 | if s_atom.element_symbol == 'Se': 103 | s_atom.element_symbol = 'S' 104 | s_atom.element = element.get_by_symbol('S') 105 | modified_met_residues.append(s_atom.residue_number) 106 | alterations_info['Se_in_MET'] = modified_met_residues 107 | 108 | 109 | def _remove_chains_of_length_one(pdb_structure, alterations_info): 110 | """Removes chains that correspond to a single amino acid. 111 | 112 | A single amino acid in a chain is both N and C terminus. There is no force 113 | template for this case. 114 | 115 | Args: 116 | pdb_structure: An OpenMM pdb_structure to modify and fix. 117 | alterations_info: A dict that will store details of changes made. 118 | """ 119 | removed_chains = {} 120 | for model in pdb_structure.iter_models(): 121 | valid_chains = [c for c in model.iter_chains() if len(c) > 1] 122 | invalid_chain_ids = [c.chain_id for c in model.iter_chains() if len(c) <= 1] 123 | model.chains = valid_chains 124 | for chain_id in invalid_chain_ids: 125 | model.chains_by_id.pop(chain_id) 126 | removed_chains[model.number] = invalid_chain_ids 127 | alterations_info['removed_chains'] = removed_chains 128 | -------------------------------------------------------------------------------- /src/alphafold/relax/cleanup_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for relax.cleanup.""" 16 | import io 17 | 18 | from absl.testing import absltest 19 | from alphafold.relax import cleanup 20 | from simtk.openmm.app.internal import pdbstructure 21 | 22 | 23 | def _pdb_to_structure(pdb_str): 24 | handle = io.StringIO(pdb_str) 25 | return pdbstructure.PdbStructure(handle) 26 | 27 | 28 | def _lines_to_structure(pdb_lines): 29 | return _pdb_to_structure('\n'.join(pdb_lines)) 30 | 31 | 32 | class CleanupTest(absltest.TestCase): 33 | 34 | def test_missing_residues(self): 35 | pdb_lines = ['SEQRES 1 C 3 CYS GLY LEU', 36 | 'ATOM 1 N CYS C 1 -12.262 20.115 60.959 1.00 ' 37 | '19.08 N', 38 | 'ATOM 2 CA CYS C 1 -11.065 20.934 60.773 1.00 ' 39 | '17.23 C', 40 | 'ATOM 3 C CYS C 1 -10.002 20.742 61.844 1.00 ' 41 | '15.38 C', 42 | 'ATOM 4 O CYS C 1 -10.284 20.225 62.929 1.00 ' 43 | '16.04 O', 44 | 'ATOM 5 N LEU C 3 -7.688 18.700 62.045 1.00 ' 45 | '14.75 N', 46 | 'ATOM 6 CA LEU C 3 -7.256 17.320 62.234 1.00 ' 47 | '16.81 C', 48 | 'ATOM 7 C LEU C 3 -6.380 16.864 61.070 1.00 ' 49 | '16.95 C', 50 | 'ATOM 8 O LEU C 3 -6.551 17.332 59.947 1.00 ' 51 | '16.97 O'] 52 | input_handle = io.StringIO('\n'.join(pdb_lines)) 53 | alterations = {} 54 | result = cleanup.fix_pdb(input_handle, alterations) 55 | structure = _pdb_to_structure(result) 56 | residue_names = [r.get_name() for r in structure.iter_residues()] 57 | self.assertCountEqual(residue_names, ['CYS', 'GLY', 'LEU']) 58 | self.assertCountEqual(alterations['missing_residues'].values(), [['GLY']]) 59 | 60 | def test_missing_atoms(self): 61 | pdb_lines = ['SEQRES 1 A 1 PRO', 62 | 'ATOM 1 CA PRO A 1 1.000 1.000 1.000 1.00 ' 63 | ' 0.00 C'] 64 | input_handle = io.StringIO('\n'.join(pdb_lines)) 65 | alterations = {} 66 | result = cleanup.fix_pdb(input_handle, alterations) 67 | structure = _pdb_to_structure(result) 68 | atom_names = [a.get_name() for a in structure.iter_atoms()] 69 | self.assertCountEqual(atom_names, ['N', 'CD', 'HD2', 'HD3', 'CG', 'HG2', 70 | 'HG3', 'CB', 'HB2', 'HB3', 'CA', 'HA', 71 | 'C', 'O', 'H2', 'H3', 'OXT']) 72 | missing_atoms_by_residue = list(alterations['missing_heavy_atoms'].values()) 73 | self.assertLen(missing_atoms_by_residue, 1) 74 | atoms_added = [a.name for a in missing_atoms_by_residue[0]] 75 | self.assertCountEqual(atoms_added, ['N', 'CD', 'CG', 'CB', 'C', 'O']) 76 | missing_terminals_by_residue = alterations['missing_terminals'] 77 | self.assertLen(missing_terminals_by_residue, 1) 78 | has_missing_terminal = [r.name for r in missing_terminals_by_residue.keys()] 79 | self.assertCountEqual(has_missing_terminal, ['PRO']) 80 | self.assertCountEqual([t for t in missing_terminals_by_residue.values()], 81 | [['OXT']]) 82 | 83 | def test_remove_heterogens(self): 84 | pdb_lines = ['SEQRES 1 A 1 GLY', 85 | 'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 ' 86 | ' 0.00 C', 87 | 'ATOM 2 O HOH A 2 0.000 0.000 0.000 1.00 ' 88 | ' 0.00 O'] 89 | input_handle = io.StringIO('\n'.join(pdb_lines)) 90 | alterations = {} 91 | result = cleanup.fix_pdb(input_handle, alterations) 92 | structure = _pdb_to_structure(result) 93 | self.assertCountEqual([res.get_name() for res in structure.iter_residues()], 94 | ['GLY']) 95 | self.assertEqual(alterations['removed_heterogens'], set(['HOH'])) 96 | 97 | def test_fix_nonstandard_residues(self): 98 | pdb_lines = ['SEQRES 1 A 1 DAL', 99 | 'ATOM 1 CA DAL A 1 0.000 0.000 0.000 1.00 ' 100 | ' 0.00 C'] 101 | input_handle = io.StringIO('\n'.join(pdb_lines)) 102 | alterations = {} 103 | result = cleanup.fix_pdb(input_handle, alterations) 104 | structure = _pdb_to_structure(result) 105 | residue_names = [res.get_name() for res in structure.iter_residues()] 106 | self.assertCountEqual(residue_names, ['ALA']) 107 | self.assertLen(alterations['nonstandard_residues'], 1) 108 | original_res, new_name = alterations['nonstandard_residues'][0] 109 | self.assertEqual(original_res.id, '1') 110 | self.assertEqual(new_name, 'ALA') 111 | 112 | def test_replace_met_se(self): 113 | pdb_lines = ['SEQRES 1 A 1 MET', 114 | 'ATOM 1 SD MET A 1 0.000 0.000 0.000 1.00 ' 115 | ' 0.00 Se'] 116 | structure = _lines_to_structure(pdb_lines) 117 | alterations = {} 118 | cleanup._replace_met_se(structure, alterations) 119 | sd = [a for a in structure.iter_atoms() if a.get_name() == 'SD'] 120 | self.assertLen(sd, 1) 121 | self.assertEqual(sd[0].element_symbol, 'S') 122 | self.assertCountEqual(alterations['Se_in_MET'], [sd[0].residue_number]) 123 | 124 | def test_remove_chains_of_length_one(self): 125 | pdb_lines = ['SEQRES 1 A 1 GLY', 126 | 'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 ' 127 | ' 0.00 C'] 128 | structure = _lines_to_structure(pdb_lines) 129 | alterations = {} 130 | cleanup._remove_chains_of_length_one(structure, alterations) 131 | chains = list(structure.iter_chains()) 132 | self.assertEmpty(chains) 133 | self.assertCountEqual(alterations['removed_chains'].values(), [['A']]) 134 | 135 | 136 | if __name__ == '__main__': 137 | absltest.main() 138 | -------------------------------------------------------------------------------- /src/alphafold/relax/relax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Amber relaxation.""" 16 | from typing import Any, Dict, Sequence, Tuple 17 | from alphafold.common import protein 18 | from alphafold.relax import amber_minimize 19 | from alphafold.relax import utils 20 | import numpy as np 21 | 22 | 23 | class AmberRelaxation(object): 24 | """Amber relaxation.""" 25 | 26 | def __init__(self, 27 | *, 28 | max_iterations: int, 29 | tolerance: float, 30 | stiffness: float, 31 | exclude_residues: Sequence[int], 32 | max_outer_iterations: int): 33 | """Initialize Amber Relaxer. 34 | 35 | Args: 36 | max_iterations: Maximum number of L-BFGS iterations. 0 means no max. 37 | tolerance: kcal/mol, the energy tolerance of L-BFGS. 38 | stiffness: kcal/mol A**2, spring constant of heavy atom restraining 39 | potential. 40 | exclude_residues: Residues to exclude from per-atom restraining. 41 | Zero-indexed. 42 | max_outer_iterations: Maximum number of violation-informed relax 43 | iterations. A value of 1 will run the non-iterative procedure used in 44 | CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes 45 | as soon as there are no violations, hence in most cases this causes no 46 | slowdown. In the worst case we do 20 outer iterations. 47 | """ 48 | 49 | self._max_iterations = max_iterations 50 | self._tolerance = tolerance 51 | self._stiffness = stiffness 52 | self._exclude_residues = exclude_residues 53 | self._max_outer_iterations = max_outer_iterations 54 | 55 | def process(self, *, 56 | prot: protein.Protein) -> Tuple[str, Dict[str, Any], np.ndarray]: 57 | """Runs Amber relax on a prediction, adds hydrogens, returns PDB string.""" 58 | out = amber_minimize.run_pipeline( 59 | prot=prot, max_iterations=self._max_iterations, 60 | tolerance=self._tolerance, stiffness=self._stiffness, 61 | exclude_residues=self._exclude_residues, 62 | max_outer_iterations=self._max_outer_iterations) 63 | min_pos = out['pos'] 64 | start_pos = out['posinit'] 65 | rmsd = np.sqrt(np.sum((start_pos - min_pos)**2) / start_pos.shape[0]) 66 | debug_data = { 67 | 'initial_energy': out['einit'], 68 | 'final_energy': out['efinal'], 69 | 'attempts': out['min_attempts'], 70 | 'rmsd': rmsd 71 | } 72 | pdb_str = amber_minimize.clean_protein(prot) 73 | min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos) 74 | min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors) 75 | utils.assert_equal_nonterminal_atom_types( 76 | protein.from_pdb_string(min_pdb).atom_mask, 77 | prot.atom_mask) 78 | violations = out['structural_violations'][ 79 | 'total_per_residue_violations_mask'] 80 | return min_pdb, debug_data, violations 81 | -------------------------------------------------------------------------------- /src/alphafold/relax/relax_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for relax.""" 16 | import os 17 | 18 | from absl.testing import absltest 19 | from alphafold.common import protein 20 | from alphafold.relax import relax 21 | import numpy as np 22 | # Internal import (7716). 23 | 24 | 25 | class RunAmberRelaxTest(absltest.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | self.test_dir = os.path.join( 30 | absltest.get_default_test_srcdir(), 31 | 'alphafold/relax/testdata/') 32 | self.test_config = { 33 | 'max_iterations': 1, 34 | 'tolerance': 2.39, 35 | 'stiffness': 10.0, 36 | 'exclude_residues': [], 37 | 'max_outer_iterations': 1} 38 | 39 | def test_process(self): 40 | amber_relax = relax.AmberRelaxation(**self.test_config) 41 | 42 | with open(os.path.join(self.test_dir, 'model_output.pdb')) as f: 43 | test_prot = protein.from_pdb_string(f.read()) 44 | pdb_min, debug_info, num_violations = amber_relax.process(prot=test_prot) 45 | 46 | self.assertCountEqual(debug_info.keys(), 47 | set({'initial_energy', 'final_energy', 48 | 'attempts', 'rmsd'})) 49 | self.assertLess(debug_info['final_energy'], debug_info['initial_energy']) 50 | self.assertGreater(debug_info['rmsd'], 0) 51 | 52 | prot_min = protein.from_pdb_string(pdb_min) 53 | # Most protein properties should be unchanged. 54 | np.testing.assert_almost_equal(test_prot.aatype, prot_min.aatype) 55 | np.testing.assert_almost_equal(test_prot.residue_index, 56 | prot_min.residue_index) 57 | # Atom mask and bfactors identical except for terminal OXT of last residue. 58 | np.testing.assert_almost_equal(test_prot.atom_mask[:-1, :], 59 | prot_min.atom_mask[:-1, :]) 60 | np.testing.assert_almost_equal(test_prot.b_factors[:-1, :], 61 | prot_min.b_factors[:-1, :]) 62 | np.testing.assert_almost_equal(test_prot.atom_mask[:, :-1], 63 | prot_min.atom_mask[:, :-1]) 64 | np.testing.assert_almost_equal(test_prot.b_factors[:, :-1], 65 | prot_min.b_factors[:, :-1]) 66 | # There are no residues with violations. 67 | np.testing.assert_equal(num_violations, np.zeros_like(num_violations)) 68 | 69 | def test_unresolved_violations(self): 70 | amber_relax = relax.AmberRelaxation(**self.test_config) 71 | with open(os.path.join(self.test_dir, 72 | 'with_violations_casp14.pdb')) as f: 73 | test_prot = protein.from_pdb_string(f.read()) 74 | _, _, num_violations = amber_relax.process(prot=test_prot) 75 | exp_num_violations = np.array( 76 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 77 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 78 | 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 79 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 80 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81 | 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 82 | 0, 0, 0, 0]) 83 | # Check no violations were added. Can't check exactly due to stochasticity. 84 | self.assertTrue(np.all(num_violations <= exp_num_violations)) 85 | 86 | 87 | if __name__ == '__main__': 88 | absltest.main() 89 | -------------------------------------------------------------------------------- /src/alphafold/relax/testdata/model_output.pdb: -------------------------------------------------------------------------------- 1 | ATOM 1 C MET A 1 1.921 -46.152 7.786 1.00 4.39 C 2 | ATOM 2 CA MET A 1 1.631 -46.829 9.131 1.00 4.39 C 3 | ATOM 3 CB MET A 1 2.759 -47.768 9.578 1.00 4.39 C 4 | ATOM 4 CE MET A 1 3.466 -49.770 13.198 1.00 4.39 C 5 | ATOM 5 CG MET A 1 2.581 -48.221 11.034 1.00 4.39 C 6 | ATOM 6 H MET A 1 0.234 -48.249 8.549 1.00 4.39 H 7 | ATOM 7 H2 MET A 1 -0.424 -46.789 8.952 1.00 4.39 H 8 | ATOM 8 H3 MET A 1 0.111 -47.796 10.118 1.00 4.39 H 9 | ATOM 9 HA MET A 1 1.628 -46.009 9.849 1.00 4.39 H 10 | ATOM 10 HB2 MET A 1 3.701 -47.225 9.500 1.00 4.39 H 11 | ATOM 11 HB3 MET A 1 2.807 -48.640 8.926 1.00 4.39 H 12 | ATOM 12 HE1 MET A 1 2.747 -50.537 12.910 1.00 4.39 H 13 | ATOM 13 HE2 MET A 1 4.296 -50.241 13.725 1.00 4.39 H 14 | ATOM 14 HE3 MET A 1 2.988 -49.052 13.864 1.00 4.39 H 15 | ATOM 15 HG2 MET A 1 1.791 -48.971 11.083 1.00 4.39 H 16 | ATOM 16 HG3 MET A 1 2.295 -47.368 11.650 1.00 4.39 H 17 | ATOM 17 N MET A 1 0.291 -47.464 9.182 1.00 4.39 N 18 | ATOM 18 O MET A 1 2.091 -44.945 7.799 1.00 4.39 O 19 | ATOM 19 SD MET A 1 4.096 -48.921 11.725 1.00 4.39 S 20 | ATOM 20 C LYS A 2 1.366 -45.033 4.898 1.00 2.92 C 21 | ATOM 21 CA LYS A 2 2.235 -46.242 5.308 1.00 2.92 C 22 | ATOM 22 CB LYS A 2 2.206 -47.314 4.196 1.00 2.92 C 23 | ATOM 23 CD LYS A 2 3.331 -49.342 3.134 1.00 2.92 C 24 | ATOM 24 CE LYS A 2 4.434 -50.403 3.293 1.00 2.92 C 25 | ATOM 25 CG LYS A 2 3.294 -48.395 4.349 1.00 2.92 C 26 | ATOM 26 H LYS A 2 1.832 -47.853 6.656 1.00 2.92 H 27 | ATOM 27 HA LYS A 2 3.248 -45.841 5.355 1.00 2.92 H 28 | ATOM 28 HB2 LYS A 2 1.223 -47.785 4.167 1.00 2.92 H 29 | ATOM 29 HB3 LYS A 2 2.363 -46.812 3.241 1.00 2.92 H 30 | ATOM 30 HD2 LYS A 2 3.524 -48.754 2.237 1.00 2.92 H 31 | ATOM 31 HD3 LYS A 2 2.364 -49.833 3.031 1.00 2.92 H 32 | ATOM 32 HE2 LYS A 2 5.383 -49.891 3.455 1.00 2.92 H 33 | ATOM 33 HE3 LYS A 2 4.225 -51.000 4.180 1.00 2.92 H 34 | ATOM 34 HG2 LYS A 2 3.102 -48.977 5.250 1.00 2.92 H 35 | ATOM 35 HG3 LYS A 2 4.264 -47.909 4.446 1.00 2.92 H 36 | ATOM 36 HZ1 LYS A 2 4.763 -50.747 1.274 1.00 2.92 H 37 | ATOM 37 HZ2 LYS A 2 3.681 -51.785 1.931 1.00 2.92 H 38 | ATOM 38 HZ3 LYS A 2 5.280 -51.965 2.224 1.00 2.92 H 39 | ATOM 39 N LYS A 2 1.907 -46.846 6.629 1.00 2.92 N 40 | ATOM 40 NZ LYS A 2 4.542 -51.286 2.100 1.00 2.92 N 41 | ATOM 41 O LYS A 2 1.882 -44.093 4.312 1.00 2.92 O 42 | ATOM 42 C PHE A 3 -0.511 -42.597 5.624 1.00 4.39 C 43 | ATOM 43 CA PHE A 3 -0.853 -43.933 4.929 1.00 4.39 C 44 | ATOM 44 CB PHE A 3 -2.271 -44.408 5.285 1.00 4.39 C 45 | ATOM 45 CD1 PHE A 3 -3.760 -43.542 3.432 1.00 4.39 C 46 | ATOM 46 CD2 PHE A 3 -4.050 -42.638 5.675 1.00 4.39 C 47 | ATOM 47 CE1 PHE A 3 -4.797 -42.715 2.965 1.00 4.39 C 48 | ATOM 48 CE2 PHE A 3 -5.091 -41.818 5.207 1.00 4.39 C 49 | ATOM 49 CG PHE A 3 -3.382 -43.505 4.788 1.00 4.39 C 50 | ATOM 50 CZ PHE A 3 -5.463 -41.853 3.853 1.00 4.39 C 51 | ATOM 51 H PHE A 3 -0.311 -45.868 5.655 1.00 4.39 H 52 | ATOM 52 HA PHE A 3 -0.817 -43.746 3.856 1.00 4.39 H 53 | ATOM 53 HB2 PHE A 3 -2.353 -44.512 6.367 1.00 4.39 H 54 | ATOM 54 HB3 PHE A 3 -2.432 -45.393 4.848 1.00 4.39 H 55 | ATOM 55 HD1 PHE A 3 -3.255 -44.198 2.739 1.00 4.39 H 56 | ATOM 56 HD2 PHE A 3 -3.768 -42.590 6.716 1.00 4.39 H 57 | ATOM 57 HE1 PHE A 3 -5.083 -42.735 1.923 1.00 4.39 H 58 | ATOM 58 HE2 PHE A 3 -5.604 -41.151 5.885 1.00 4.39 H 59 | ATOM 59 HZ PHE A 3 -6.257 -41.215 3.493 1.00 4.39 H 60 | ATOM 60 N PHE A 3 0.079 -45.027 5.253 1.00 4.39 N 61 | ATOM 61 O PHE A 3 -0.633 -41.541 5.014 1.00 4.39 O 62 | ATOM 62 C LEU A 4 1.598 -40.732 7.042 1.00 4.39 C 63 | ATOM 63 CA LEU A 4 0.367 -41.437 7.633 1.00 4.39 C 64 | ATOM 64 CB LEU A 4 0.628 -41.823 9.104 1.00 4.39 C 65 | ATOM 65 CD1 LEU A 4 -0.319 -42.778 11.228 1.00 4.39 C 66 | ATOM 66 CD2 LEU A 4 -1.300 -40.694 10.309 1.00 4.39 C 67 | ATOM 67 CG LEU A 4 -0.650 -42.027 9.937 1.00 4.39 C 68 | ATOM 68 H LEU A 4 0.163 -43.538 7.292 1.00 4.39 H 69 | ATOM 69 HA LEU A 4 -0.445 -40.712 7.588 1.00 4.39 H 70 | ATOM 70 HB2 LEU A 4 1.213 -41.034 9.576 1.00 4.39 H 71 | ATOM 71 HB3 LEU A 4 1.235 -42.728 9.127 1.00 4.39 H 72 | ATOM 72 HD11 LEU A 4 0.380 -42.191 11.824 1.00 4.39 H 73 | ATOM 73 HD12 LEU A 4 0.127 -43.747 11.002 1.00 4.39 H 74 | ATOM 74 HD13 LEU A 4 -1.230 -42.927 11.808 1.00 4.39 H 75 | ATOM 75 HD21 LEU A 4 -0.606 -40.080 10.883 1.00 4.39 H 76 | ATOM 76 HD22 LEU A 4 -2.193 -40.869 10.909 1.00 4.39 H 77 | ATOM 77 HD23 LEU A 4 -1.593 -40.147 9.413 1.00 4.39 H 78 | ATOM 78 HG LEU A 4 -1.359 -42.630 9.370 1.00 4.39 H 79 | ATOM 79 N LEU A 4 -0.012 -42.638 6.869 1.00 4.39 N 80 | ATOM 80 O LEU A 4 1.655 -39.508 7.028 1.00 4.39 O 81 | ATOM 81 C VAL A 5 3.372 -40.190 4.573 1.00 4.39 C 82 | ATOM 82 CA VAL A 5 3.752 -40.956 5.845 1.00 4.39 C 83 | ATOM 83 CB VAL A 5 4.757 -42.083 5.528 1.00 4.39 C 84 | ATOM 84 CG1 VAL A 5 6.019 -41.568 4.827 1.00 4.39 C 85 | ATOM 85 CG2 VAL A 5 5.199 -42.807 6.810 1.00 4.39 C 86 | ATOM 86 H VAL A 5 2.440 -42.503 6.548 1.00 4.39 H 87 | ATOM 87 HA VAL A 5 4.234 -40.242 6.512 1.00 4.39 H 88 | ATOM 88 HB VAL A 5 4.279 -42.813 4.875 1.00 4.39 H 89 | ATOM 89 HG11 VAL A 5 6.494 -40.795 5.431 1.00 4.39 H 90 | ATOM 90 HG12 VAL A 5 5.770 -41.145 3.853 1.00 4.39 H 91 | ATOM 91 HG13 VAL A 5 6.725 -42.383 4.670 1.00 4.39 H 92 | ATOM 92 HG21 VAL A 5 4.347 -43.283 7.297 1.00 4.39 H 93 | ATOM 93 HG22 VAL A 5 5.933 -43.575 6.568 1.00 4.39 H 94 | ATOM 94 HG23 VAL A 5 5.651 -42.093 7.498 1.00 4.39 H 95 | ATOM 95 N VAL A 5 2.554 -41.501 6.509 1.00 4.39 N 96 | ATOM 96 O VAL A 5 3.937 -39.138 4.297 1.00 4.39 O 97 | TER 96 VAL A 5 98 | END 99 | -------------------------------------------------------------------------------- /src/alphafold/relax/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utils for minimization.""" 16 | import io 17 | from alphafold.common import residue_constants 18 | from Bio import PDB 19 | import numpy as np 20 | from simtk.openmm import app as openmm_app 21 | from simtk.openmm.app.internal.pdbstructure import PdbStructure 22 | 23 | 24 | def overwrite_pdb_coordinates(pdb_str: str, pos) -> str: 25 | pdb_file = io.StringIO(pdb_str) 26 | structure = PdbStructure(pdb_file) 27 | topology = openmm_app.PDBFile(structure).getTopology() 28 | with io.StringIO() as f: 29 | openmm_app.PDBFile.writeFile(topology, pos, f) 30 | return f.getvalue() 31 | 32 | 33 | def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str: 34 | """Overwrites the B-factors in pdb_str with contents of bfactors array. 35 | 36 | Args: 37 | pdb_str: An input PDB string. 38 | bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the 39 | B-factors are per residue; i.e. that the nonzero entries are identical in 40 | [0, i, :]. 41 | 42 | Returns: 43 | A new PDB string with the B-factors replaced. 44 | """ 45 | if bfactors.shape[-1] != residue_constants.atom_type_num: 46 | raise ValueError( 47 | f'Invalid final dimension size for bfactors: {bfactors.shape[-1]}.') 48 | 49 | parser = PDB.PDBParser(QUIET=True) 50 | handle = io.StringIO(pdb_str) 51 | structure = parser.get_structure('', handle) 52 | 53 | curr_resid = ('', '', '') 54 | idx = -1 55 | for atom in structure.get_atoms(): 56 | atom_resid = atom.parent.get_id() 57 | if atom_resid != curr_resid: 58 | idx += 1 59 | if idx >= bfactors.shape[0]: 60 | raise ValueError('Index into bfactors exceeds number of residues. ' 61 | 'B-factors shape: {shape}, idx: {idx}.') 62 | curr_resid = atom_resid 63 | atom.bfactor = bfactors[idx, residue_constants.atom_order['CA']] 64 | 65 | new_pdb = io.StringIO() 66 | pdb_io = PDB.PDBIO() 67 | pdb_io.set_structure(structure) 68 | pdb_io.save(new_pdb) 69 | return new_pdb.getvalue() 70 | 71 | 72 | def assert_equal_nonterminal_atom_types( 73 | atom_mask: np.ndarray, ref_atom_mask: np.ndarray): 74 | """Checks that pre- and post-minimized proteins have same atom set.""" 75 | # Ignore any terminal OXT atoms which may have been added by minimization. 76 | oxt = residue_constants.atom_order['OXT'] 77 | no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool) 78 | no_oxt_mask[..., oxt] = False 79 | np.testing.assert_almost_equal(ref_atom_mask[no_oxt_mask], 80 | atom_mask[no_oxt_mask]) 81 | -------------------------------------------------------------------------------- /src/alphafold/relax/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for utils.""" 16 | 17 | import os 18 | 19 | from absl.testing import absltest 20 | from alphafold.common import protein 21 | from alphafold.relax import utils 22 | import numpy as np 23 | # Internal import (7716). 24 | 25 | 26 | class UtilsTest(absltest.TestCase): 27 | 28 | def test_overwrite_b_factors(self): 29 | testdir = os.path.join( 30 | absltest.get_default_test_srcdir(), 31 | 'alphafold/relax/testdata/' 32 | 'multiple_disulfides_target.pdb') 33 | with open(testdir) as f: 34 | test_pdb = f.read() 35 | n_residues = 191 36 | bfactors = np.stack([np.arange(0, n_residues)] * 37, axis=-1) 37 | 38 | output_pdb = utils.overwrite_b_factors(test_pdb, bfactors) 39 | 40 | # Check that the atom lines are unchanged apart from the B-factors. 41 | atom_lines_original = [l for l in test_pdb.split('\n') if l[:4] == ('ATOM')] 42 | atom_lines_new = [l for l in output_pdb.split('\n') if l[:4] == ('ATOM')] 43 | for line_original, line_new in zip(atom_lines_original, atom_lines_new): 44 | self.assertEqual(line_original[:60].strip(), line_new[:60].strip()) 45 | self.assertEqual(line_original[66:].strip(), line_new[66:].strip()) 46 | 47 | # Check B-factors are correctly set for all atoms present. 48 | as_protein = protein.from_pdb_string(output_pdb) 49 | np.testing.assert_almost_equal( 50 | np.where(as_protein.atom_mask > 0, as_protein.b_factors, 0), 51 | np.where(as_protein.atom_mask > 0, bfactors, 0)) 52 | 53 | 54 | if __name__ == '__main__': 55 | absltest.main() 56 | -------------------------------------------------------------------------------- /src/animate_py3dmol.py: -------------------------------------------------------------------------------- 1 | import py3Dmol 2 | 3 | import sys 4 | import os 5 | import pdb 6 | 7 | import pandas as pd 8 | import numpy as np 9 | import argparse 10 | 11 | 12 | parser = argparse.ArgumentParser(description = '''Visualise sampled conformations''') 13 | parser.add_argument('--pdbdir', nargs=1, type= str, required=True, help = "Path to directory with PDB files") 14 | parser.add_argument('--order_df', nargs=1, type= str, required=True, help = "Path to csv with order of PDB files") 15 | parser.add_argument('--outdir', nargs=1, type= str, help = 'Outdir.') 16 | 17 | 18 | #################FUNCTIONS################# 19 | 20 | class Atom(dict): 21 | def __init__(self, line): 22 | self["type"] = line[0:6].strip() 23 | self["idx"] = line[6:11].strip() 24 | self["name"] = line[12:16].strip() 25 | self["resname"] = line[17:20].strip() 26 | self["resid"] = int(int(line[22:26])) 27 | self["x"] = float(line[30:38]) 28 | self["y"] = float(line[38:46]) 29 | self["z"] = float(line[46:54]) 30 | self["sym"] = line[76:78].strip() 31 | 32 | def __str__(self): 33 | line = list(" " * 80) 34 | 35 | line[0:6] = self["type"].ljust(6) 36 | line[6:11] = self["idx"].ljust(5) 37 | line[12:16] = self["name"].ljust(4) 38 | line[17:20] = self["resname"].ljust(3) 39 | line[22:26] = str(self["resid"]).ljust(4) 40 | line[30:38] = str(self["x"]).rjust(8) 41 | line[38:46] = str(self["y"]).rjust(8) 42 | line[46:54] = str(self["z"]).rjust(8) 43 | line[76:78] = self["sym"].rjust(2) 44 | return "".join(line) + "\n" 45 | 46 | class Molecule(list): 47 | def __init__(self, file): 48 | for line in file: 49 | if "ATOM" in line or "HETATM" in line: 50 | self.append(Atom(line)) 51 | 52 | def __str__(self): 53 | outstr = "" 54 | for at in self: 55 | outstr += str(at) 56 | 57 | return outstr 58 | 59 | 60 | 61 | 62 | 63 | #################MAIN#################### 64 | 65 | #Parse args 66 | args = parser.parse_args() 67 | #Get data 68 | pdbdir = args.pdbdir[0] 69 | order_df = pd.read_csv(args.order_df[0]).loc[:10] 70 | outdir = args.outdir[0] 71 | 72 | molecules = [] 73 | for ind,row in order_df.iterrows(): 74 | with open(pdbdir+row['name']) as ifile: 75 | molecules.append(Molecule(ifile)) 76 | 77 | 78 | view = py3Dmol.view(width=400, height=300) 79 | 80 | models = "" 81 | for i, mol in enumerate(molecules): 82 | models += "MODEL " + str(i) + "\n" 83 | models += str(mol) 84 | models += "ENDMDL\n" 85 | view.addModelsAsFrames(models) 86 | 87 | for i, at in enumerate(molecules[0]): 88 | default = {"cartoon": {'color': 'grey'}} 89 | view.setStyle({'model': -1, 'serial': i+1}, at.get("pymol", default)) 90 | 91 | view.zoomTo() 92 | view.animate({'loop': 'backAndForth'}) 93 | view.show() 94 | -------------------------------------------------------------------------------- /src/check_msa_colab.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | 4 | ##############FUNCTIONS############## 5 | 6 | def read_a3m(a3mfile): 7 | '''Read an a3m file 8 | ''' 9 | parsed_a3m = {} 10 | #Regex pattern 11 | pattern = '[A-Z]+' 12 | nhits = 0 13 | with open(a3mfile, 'r') as file: 14 | for line in file: 15 | line = line.rstrip() 16 | if line[0]=='>': 17 | nhits+=1 18 | else: 19 | if line[0]=='#': 20 | continue 21 | else: 22 | parsed_a3m[nhits] = ''.join(re.findall('[A-Z,-]+', line)) 23 | 24 | return parsed_a3m 25 | 26 | 27 | def process_a3m(a3mfile, sequence, outname): 28 | '''Process a3m file - remove insertions and get only matches and gaps 29 | Write these to a new file 30 | ''' 31 | 32 | parsed_a3m = read_a3m(a3mfile) 33 | if len([*parsed_a3m.values()][0])!=len(sequence): 34 | print('The sequence length does not match with the MSA.') 35 | sys.exit() 36 | with open(outname, 'w') as file: 37 | for key in parsed_a3m: 38 | file.write('>'+str(key)+'\n') 39 | file.write(parsed_a3m[key]+'\n') 40 | -------------------------------------------------------------------------------- /src/make_msa_seq_feats.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | from typing import Mapping, Optional, Sequence 18 | from alphafold.common import residue_constants 19 | from alphafold.data import parsers 20 | import numpy as np 21 | import argparse 22 | import pickle 23 | import sys 24 | import pdb 25 | # Internal import (7716). 26 | 27 | parser = argparse.ArgumentParser(description = """Builds the input features for training the structure prediction model.""") 28 | 29 | parser.add_argument('--input_fasta_path', nargs=1, type= str, default=sys.stdin, help = 'Path to fasta.') 30 | parser.add_argument('--input_msas', nargs=1, type= str, default=sys.stdin, help = 'Path to MSAs. Separated by comma.') 31 | parser.add_argument('--outdir', nargs=1, type= str, default=sys.stdin, help = 'Path to output directory. Include /in end') 32 | 33 | FeatureDict = Mapping[str, np.ndarray] 34 | 35 | ##############FUNCTIONS############## 36 | def make_sequence_features( 37 | sequence: str, description: str, num_res: int) -> FeatureDict: 38 | """Constructs a feature dict of sequence features.""" 39 | features = {} 40 | features['aatype'] = residue_constants.sequence_to_onehot( 41 | sequence=sequence, 42 | mapping=residue_constants.restype_order_with_x, 43 | map_unknown_to_x=True) 44 | features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32) 45 | features['domain_name'] = np.array([description.encode('utf-8')], 46 | dtype=np.object_) 47 | features['residue_index'] = np.array(range(num_res), dtype=np.int32) 48 | features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) 49 | features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_) 50 | return features 51 | 52 | 53 | def make_msa_features( 54 | msas: Sequence[Sequence[str]], 55 | deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict: 56 | """Constructs a feature dict of MSA features.""" 57 | if not msas: 58 | raise ValueError('At least one MSA must be provided.') 59 | 60 | int_msa = [] 61 | deletion_matrix = [] 62 | seen_sequences = set() 63 | for msa_index, msa in enumerate(msas): 64 | if not msa: 65 | raise ValueError(f'MSA {msa_index} must contain at least one sequence.') 66 | for sequence_index, sequence in enumerate(msa): 67 | if sequence in seen_sequences: 68 | continue 69 | seen_sequences.add(sequence) 70 | int_msa.append( 71 | [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) 72 | deletion_matrix.append(deletion_matrices[msa_index][sequence_index]) 73 | 74 | num_res = len(msas[0][0]) 75 | num_alignments = len(int_msa) 76 | features = {} 77 | features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) 78 | features['msa'] = np.array(int_msa, dtype=np.int32) 79 | features['num_alignments'] = np.array( 80 | [num_alignments] * num_res, dtype=np.int32) 81 | return features 82 | 83 | 84 | 85 | def process(input_fasta_path: str, input_msas: list) -> FeatureDict: 86 | """Runs alignment tools on the input sequence and creates features.""" 87 | with open(input_fasta_path) as f: 88 | input_fasta_str = f.read() 89 | input_seqs, input_desc = parsers.parse_fasta(input_fasta_str) 90 | if len(input_seqs) != 1: 91 | raise ValueError( 92 | f'More than one input sequence found in {input_fasta_path}.') 93 | input_sequence = input_seqs[0] 94 | input_description = input_desc[0] 95 | num_res = len(input_sequence) 96 | 97 | parsed_msas = [] 98 | parsed_delmat = [] 99 | for custom_msa in input_msas: 100 | msa = ''.join([line for line in open(custom_msa)]) 101 | if custom_msa[-3:] == 'sto': 102 | parsed_msa, parsed_deletion_matrix, _ = parsers.parse_stockholm(msa) 103 | elif custom_msa[-3:] == 'a3m': 104 | parsed_msa, parsed_deletion_matrix = parsers.parse_a3m(msa) 105 | else: raise TypeError('Unknown format for input MSA, please make sure ' 106 | 'the MSA files you provide terminates with (and ' 107 | 'are formatted as) .sto or .a3m') 108 | parsed_msas.append(parsed_msa) 109 | parsed_delmat.append(parsed_deletion_matrix) 110 | 111 | sequence_features = make_sequence_features( 112 | sequence=input_sequence, 113 | description=input_description, 114 | num_res=num_res) 115 | 116 | msa_features = make_msa_features( 117 | msas=parsed_msas, deletion_matrices=parsed_delmat) 118 | 119 | return {**sequence_features, **msa_features} 120 | 121 | 122 | ##################MAIN####################### 123 | 124 | #Parse args 125 | args = parser.parse_args() 126 | #Data 127 | input_fasta_path = args.input_fasta_path[0] 128 | input_msas = args.input_msas[0].split(',') 129 | outdir = args.outdir[0] 130 | #Get feats 131 | feature_dict = process(input_fasta_path, input_msas) 132 | 133 | #Write out features as a pickled dictionary. 134 | features_output_path = os.path.join(outdir, 'msa_features.pkl') 135 | with open(features_output_path, 'wb') as f: 136 | pickle.dump(feature_dict, f, protocol=4) 137 | print('Saved features to',features_output_path) 138 | -------------------------------------------------------------------------------- /src/make_msa_seq_feats_colab.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | from typing import Mapping, Optional, Sequence 18 | from alphafold.common import residue_constants 19 | from alphafold.data import parsers 20 | import numpy as np 21 | import argparse 22 | import pickle 23 | import sys 24 | import pdb 25 | # Internal import (7716). 26 | 27 | 28 | FeatureDict = Mapping[str, np.ndarray] 29 | 30 | ##############FUNCTIONS############## 31 | def make_sequence_features( 32 | sequence: str, description: str, num_res: int) -> FeatureDict: 33 | """Constructs a feature dict of sequence features.""" 34 | features = {} 35 | features['aatype'] = residue_constants.sequence_to_onehot( 36 | sequence=sequence, 37 | mapping=residue_constants.restype_order_with_x, 38 | map_unknown_to_x=True) 39 | features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32) 40 | features['domain_name'] = np.array([description.encode('utf-8')], 41 | dtype=np.object_) 42 | features['residue_index'] = np.array(range(num_res), dtype=np.int32) 43 | features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) 44 | features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_) 45 | return features 46 | 47 | 48 | def make_msa_features( 49 | msas: Sequence[Sequence[str]], 50 | deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict: 51 | """Constructs a feature dict of MSA features.""" 52 | if not msas: 53 | raise ValueError('At least one MSA must be provided.') 54 | 55 | int_msa = [] 56 | deletion_matrix = [] 57 | seen_sequences = set() 58 | for msa_index, msa in enumerate(msas): 59 | if not msa: 60 | raise ValueError(f'MSA {msa_index} must contain at least one sequence.') 61 | for sequence_index, sequence in enumerate(msa): 62 | if sequence in seen_sequences: 63 | continue 64 | seen_sequences.add(sequence) 65 | int_msa.append( 66 | [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) 67 | deletion_matrix.append(deletion_matrices[msa_index][sequence_index]) 68 | 69 | num_res = len(msas[0][0]) 70 | num_alignments = len(int_msa) 71 | features = {} 72 | features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) 73 | features['msa'] = np.array(int_msa, dtype=np.int32) 74 | features['num_alignments'] = np.array( 75 | [num_alignments] * num_res, dtype=np.int32) 76 | return features 77 | 78 | 79 | 80 | def process(input_fasta_path: str, input_msas: list) -> FeatureDict: 81 | """Runs alignment tools on the input sequence and creates features.""" 82 | with open(input_fasta_path) as f: 83 | input_fasta_str = f.read() 84 | input_seqs, input_desc = parsers.parse_fasta(input_fasta_str) 85 | if len(input_seqs) != 1: 86 | raise ValueError( 87 | f'More than one input sequence found in {input_fasta_path}.') 88 | input_sequence = input_seqs[0] 89 | input_description = input_desc[0] 90 | num_res = len(input_sequence) 91 | 92 | parsed_msas = [] 93 | parsed_delmat = [] 94 | for custom_msa in input_msas: 95 | msa = ''.join([line for line in open(custom_msa)]) 96 | if custom_msa[-3:] == 'sto': 97 | parsed_msa, parsed_deletion_matrix, _ = parsers.parse_stockholm(msa) 98 | elif custom_msa[-3:] == 'a3m': 99 | parsed_msa, parsed_deletion_matrix = parsers.parse_a3m(msa) 100 | else: raise TypeError('Unknown format for input MSA, please make sure ' 101 | 'the MSA files you provide terminates with (and ' 102 | 'are formatted as) .sto or .a3m') 103 | parsed_msas.append(parsed_msa) 104 | parsed_delmat.append(parsed_deletion_matrix) 105 | 106 | sequence_features = make_sequence_features( 107 | sequence=input_sequence, 108 | description=input_description, 109 | num_res=num_res) 110 | 111 | msa_features = make_msa_features( 112 | msas=parsed_msas, deletion_matrices=parsed_delmat) 113 | 114 | return {**sequence_features, **msa_features} 115 | 116 | 117 | ##################MAIN####################### 118 | 119 | 120 | #Get feats 121 | #feature_dict = process(input_fasta_path, input_msas) 122 | 123 | #Write out features as a pickled dictionary. 124 | #features_output_path = os.path.join(outdir, 'msa_features.pkl') 125 | #with open(features_output_path, 'wb') as f: 126 | # pickle.dump(feature_dict, f, protocol=4) 127 | #print('Saved features to',features_output_path) 128 | -------------------------------------------------------------------------------- /src/predict_with_clusters.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import warnings 4 | import pathlib 5 | import pickle 6 | import random 7 | import sys 8 | import time 9 | from typing import Dict, Optional 10 | from typing import NamedTuple 11 | import haiku as hk 12 | import jax 13 | import jax.numpy as jnp 14 | 15 | #Silence tf 16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 17 | import tensorflow.compat.v1 as tf 18 | tf.config.set_visible_devices([], 'GPU') 19 | 20 | import argparse 21 | import pandas as pd 22 | import numpy as np 23 | from collections import Counter 24 | from scipy.special import softmax 25 | import pdb 26 | 27 | #AlphaFold imports 28 | from alphafold.common import protein 29 | from alphafold.common import residue_constants 30 | from alphafold.data import templates 31 | from alphafold.model import data 32 | from alphafold.model import config 33 | from alphafold.model import features 34 | from alphafold.model import modules 35 | 36 | #JAX will preallocate 90% of currently-available GPU memory when the first JAX operation is run. 37 | #This prevents this 38 | #os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 39 | 40 | 41 | parser = argparse.ArgumentParser(description = """Predict using trained weights and a varying amount of sampled MSA clusters for more varied predictions.""") 42 | 43 | parser.add_argument('--feature_dir', nargs=1, type= str, default=sys.stdin, help = 'Path to location of features.') 44 | parser.add_argument('--predict_id', nargs=1, type= str, default=sys.stdin, help = 'Id to predict.') 45 | parser.add_argument('--ckpt_params', nargs=1, type= str, default=sys.stdin, help = 'Params to start from.') 46 | parser.add_argument('--num_recycles', nargs=1, type= int, default=sys.stdin, help = 'Number of recycles to use.') 47 | parser.add_argument('--num_samples_per_cluster', nargs=1, type= int, default=sys.stdin, help = 'Number of samples to use per cluster selection.') 48 | parser.add_argument('--outdir', nargs=1, type= str, default=sys.stdin, help = 'Path to output directory. Include /in end') 49 | 50 | ##############FUNCTIONS############## 51 | ##########INPUT DATA######### 52 | 53 | def process_features(raw_features, config, random_seed): 54 | """Processes features to prepare for feeding them into the model. 55 | 56 | Args: 57 | raw_features: The output of the data pipeline either as a dict of NumPy 58 | arrays or as a tf.train.Example. 59 | random_seed: The random seed to use when processing the features. 60 | 61 | Returns: 62 | A dict of NumPy feature arrays suitable for feeding into the model. 63 | """ 64 | return features.np_example_to_features(np_example=raw_features, 65 | config=config, 66 | random_seed=random_seed) 67 | 68 | 69 | def load_input_feats(id, feature_dir, config, num_clusters): 70 | """ 71 | Load all input feats. 72 | """ 73 | 74 | #Load raw features 75 | msa_feature_dict = np.load(feature_dir+'/msa_features.pkl', allow_pickle=True) 76 | #Process the features on CPU (sample MSA) 77 | #Set the config to determine the number of clusters 78 | config.data.eval.max_msa_clusters = num_clusters 79 | #processed_feature_dict['msa_feat'].shape = num_clusts, L, 49 80 | processed_feature_dict = process_features(msa_feature_dict, config, np.random.choice(sys.maxsize)) 81 | 82 | return processed_feature_dict 83 | 84 | 85 | 86 | ##########MODEL######### 87 | 88 | def predict(config, 89 | feature_dir, 90 | predict_ids, 91 | num_recycles, 92 | num_samples_per_cluster, 93 | ckpt_params=None, 94 | outdir=None): 95 | """Predict a structure 96 | """ 97 | 98 | #Does the config have to be updated here? 99 | #No - the clusters can be changed. 100 | #Define the forward function 101 | def _forward_fn(batch): 102 | '''Define the forward function - has to be a function for JAX 103 | ''' 104 | model = modules.AlphaFold(config.model) 105 | 106 | return model(batch, 107 | is_training=False, 108 | compute_loss=False, 109 | ensemble_representations=False, 110 | return_representations=True) 111 | 112 | #The forward function is here transformed to apply and init functions which 113 | #can be called during training and initialisation (JAX needs functions) 114 | forward = hk.transform(_forward_fn) 115 | apply_fwd = forward.apply 116 | #Get a random key 117 | rng = jax.random.PRNGKey(42) 118 | 119 | for id in predict_ids: 120 | for num_clusts in [16, 32, 64, 128, 256, 512, 1024, 5120]: 121 | for i in range(num_samples_per_cluster): 122 | if os.path.exists(outdir+'/'+id+'_'+str(num_clusts)+'_'+str(i)+'_pred.pdb'): 123 | print('Prediction',num_clusts, i+1, 'exists...') 124 | continue 125 | 126 | #Load input feats 127 | batch = load_input_feats(id, feature_dir, config, num_clusts) 128 | for key in batch: 129 | batch[key] = np.reshape(batch[key], (1, *batch[key].shape)) 130 | 131 | batch['num_iter_recycling'] = [num_recycles] 132 | ret = apply_fwd(ckpt_params, rng, batch) 133 | #Save structure 134 | save_feats = {'aatype':batch['aatype'], 'residue_index':batch['residue_index']} 135 | result = {'predicted_lddt':ret['predicted_lddt'], 136 | 'structure_module':{'final_atom_positions':ret['structure_module']['final_atom_positions'], 137 | 'final_atom_mask': ret['structure_module']['final_atom_mask'] 138 | }} 139 | save_structure(save_feats, result, id+'_'+str(num_clusts)+'_'+str(i), outdir) 140 | 141 | 142 | 143 | def save_structure(save_feats, result, id, outdir): 144 | """Save prediction 145 | 146 | save_feats = {'aatype':batch['aatype'][0][0], 'residue_index':batch['residue_index'][0][0]} 147 | result = {'predicted_lddt':aux['predicted_lddt'], 148 | 'structure_module':{'final_atom_positions':aux['structure_module']['final_atom_positions'][0], 149 | 'final_atom_mask': aux['structure_module']['final_atom_mask'][0] 150 | }} 151 | save_structure(save_feats, result, step_num, outdir) 152 | 153 | """ 154 | #Define the plDDT bins 155 | bin_width = 1.0 / 50 156 | bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width) 157 | 158 | # Add the predicted LDDT in the b-factor column. 159 | plddt_per_pos = jnp.sum(jax.nn.softmax(result['predicted_lddt']['logits']) * bin_centers[None, :], axis=-1) 160 | plddt_b_factors = np.repeat(plddt_per_pos[:, None], residue_constants.atom_type_num, axis=-1) 161 | unrelaxed_protein = protein.from_prediction(features=save_feats, result=result, b_factors=plddt_b_factors) 162 | unrelaxed_pdb = protein.to_pdb(unrelaxed_protein) 163 | unrelaxed_pdb_path = os.path.join(outdir+'/', id+'_pred.pdb') 164 | with open(unrelaxed_pdb_path, 'w') as f: 165 | f.write(unrelaxed_pdb) 166 | 167 | 168 | 169 | ##################MAIN####################### 170 | 171 | #Parse args 172 | args = parser.parse_args() 173 | feature_dir = args.feature_dir[0] 174 | predict_id = args.predict_id[0] 175 | ckpt_params = np.load(args.ckpt_params[0] , allow_pickle=True) 176 | num_recycles = args.num_recycles[0] 177 | num_samples_per_cluster = args.num_samples_per_cluster[0] 178 | outdir = args.outdir[0] 179 | #Predict 180 | predict(config.CONFIG, 181 | feature_dir, 182 | [predict_id], 183 | num_recycles, 184 | num_samples_per_cluster, 185 | ckpt_params=ckpt_params, 186 | outdir=outdir) 187 | -------------------------------------------------------------------------------- /src/predict_with_clusters_colab.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import warnings 4 | import pathlib 5 | import pickle 6 | import random 7 | import sys 8 | import time 9 | from typing import Dict, Optional 10 | from typing import NamedTuple 11 | import haiku as hk 12 | import jax 13 | import jax.numpy as jnp 14 | 15 | #Silence tf 16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 17 | import tensorflow.compat.v1 as tf 18 | tf.config.set_visible_devices([], 'GPU') 19 | 20 | import argparse 21 | import pandas as pd 22 | import numpy as np 23 | from collections import Counter 24 | from scipy.special import softmax 25 | import pdb 26 | 27 | #AlphaFold imports 28 | from alphafold.common import protein 29 | from alphafold.common import residue_constants 30 | from alphafold.data import templates 31 | from alphafold.model import data 32 | from alphafold.model import config 33 | from alphafold.model import features 34 | from alphafold.model import modules 35 | 36 | #JAX will preallocate 90% of currently-available GPU memory when the first JAX operation is run. 37 | #This prevents this 38 | #os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 39 | 40 | 41 | ##############FUNCTIONS############## 42 | ##########INPUT DATA######### 43 | 44 | def process_features(raw_features, config, random_seed): 45 | """Processes features to prepare for feeding them into the model. 46 | 47 | Args: 48 | raw_features: The output of the data pipeline either as a dict of NumPy 49 | arrays or as a tf.train.Example. 50 | random_seed: The random seed to use when processing the features. 51 | 52 | Returns: 53 | A dict of NumPy feature arrays suitable for feeding into the model. 54 | """ 55 | return features.np_example_to_features(np_example=raw_features, 56 | config=config, 57 | random_seed=random_seed) 58 | 59 | 60 | def load_input_feats(id, feature_dir, config, num_clusters): 61 | """ 62 | Load all input feats. 63 | """ 64 | 65 | #Load raw features 66 | msa_feature_dict = np.load(feature_dir+id+'/msa_features.pkl', allow_pickle=True) 67 | #Process the features on CPU (sample MSA) 68 | #Set the config to determine the number of clusters 69 | config.data.eval.max_msa_clusters = num_clusters 70 | #processed_feature_dict['msa_feat'].shape = num_clusts, L, 49 71 | processed_feature_dict = process_features(msa_feature_dict, config, np.random.choice(sys.maxsize)) 72 | 73 | return processed_feature_dict 74 | 75 | 76 | 77 | ##########MODEL######### 78 | 79 | def predict(feature_dir, 80 | predict_id, 81 | num_recycles, 82 | num_samples_per_cluster, 83 | clusters, 84 | ckpt_params=None, 85 | outdir=None): 86 | """Predict a structure 87 | """ 88 | 89 | 90 | #Does the config have to be updated here? 91 | #No - the clusters can be changed. 92 | #Define the forward function 93 | def _forward_fn(batch): 94 | '''Define the forward function - has to be a function for JAX 95 | ''' 96 | model = modules.AlphaFold(config.CONFIG.model) 97 | 98 | return model(batch, 99 | is_training=False, 100 | compute_loss=False, 101 | ensemble_representations=False, 102 | return_representations=True) 103 | 104 | #The forward function is here transformed to apply and init functions which 105 | #can be called during training and initialisation (JAX needs functions) 106 | forward = hk.transform(_forward_fn) 107 | apply_fwd = forward.apply 108 | #Get a random key 109 | rng = jax.random.PRNGKey(42) 110 | 111 | for num_clusts in clusters: 112 | for i in range(num_samples_per_cluster): 113 | print('Predicting: number of clusters',num_clusts, 'Sample', i+1, '...') 114 | if os.path.exists(outdir+'/'+predict_id+'_'+str(num_clusts)+'_'+str(i)+'_pred.pdb'): 115 | print('Prediction',num_clusts, i+1, 'exists...') 116 | continue 117 | 118 | #Load input feats 119 | batch = load_input_feats(predict_id, feature_dir, config.CONFIG, num_clusts) 120 | for key in batch: 121 | batch[key] = np.reshape(batch[key], (1, *batch[key].shape)) 122 | 123 | batch['num_iter_recycling'] = [num_recycles] 124 | ret = apply_fwd(ckpt_params, rng, batch) 125 | #Save structure 126 | save_feats = {'aatype':batch['aatype'], 'residue_index':batch['residue_index']} 127 | result = {'predicted_lddt':ret['predicted_lddt'], 128 | 'structure_module':{'final_atom_positions':ret['structure_module']['final_atom_positions'], 129 | 'final_atom_mask': ret['structure_module']['final_atom_mask'] 130 | }} 131 | save_structure(save_feats, result, predict_id+'_'+str(num_clusts)+'_'+str(i), outdir) 132 | 133 | 134 | 135 | def save_structure(save_feats, result, id, outdir): 136 | """Save prediction 137 | 138 | save_feats = {'aatype':batch['aatype'][0][0], 'residue_index':batch['residue_index'][0][0]} 139 | result = {'predicted_lddt':aux['predicted_lddt'], 140 | 'structure_module':{'final_atom_positions':aux['structure_module']['final_atom_positions'][0], 141 | 'final_atom_mask': aux['structure_module']['final_atom_mask'][0] 142 | }} 143 | save_structure(save_feats, result, step_num, outdir) 144 | 145 | """ 146 | #Define the plDDT bins 147 | bin_width = 1.0 / 50 148 | bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width) 149 | 150 | # Add the predicted LDDT in the b-factor column. 151 | plddt_per_pos = jnp.sum(jax.nn.softmax(result['predicted_lddt']['logits']) * bin_centers[None, :], axis=-1) 152 | plddt_b_factors = np.repeat(plddt_per_pos[:, None], residue_constants.atom_type_num, axis=-1) 153 | unrelaxed_protein = protein.from_prediction(features=save_feats, result=result, b_factors=plddt_b_factors) 154 | unrelaxed_pdb = protein.to_pdb(unrelaxed_protein) 155 | unrelaxed_pdb_path = os.path.join(outdir+'/', id+'_pred.pdb') 156 | with open(unrelaxed_pdb_path, 'w') as f: 157 | f.write(unrelaxed_pdb) 158 | 159 | 160 | 161 | ##################MAIN####################### 162 | 163 | # #Predict 164 | # predict(feature_dir, 165 | # [predict_id], 166 | # num_recycles, 167 | # num_samples_per_cluster, 168 | # ckpt_params=ckpt_params, 169 | # outdir=outdir) 170 | -------------------------------------------------------------------------------- /src/score_closest_ca.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pdb 4 | 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | import numpy as np 8 | from collections import defaultdict 9 | from Bio.SVDSuperimposer import SVDSuperimposer 10 | import shutil 11 | import glob 12 | import argparse 13 | 14 | 15 | parser = argparse.ArgumentParser(description = '''Structurally align all files in a directory to a reference and write new roto-translated pdb files''') 16 | parser.add_argument('--pdbdir', nargs=1, type= str, required=True, help = "Path to directory with PDB files") 17 | parser.add_argument('--outdir', nargs=1, type= str, help = 'Outdir.') 18 | 19 | 20 | #################FUNCTIONS################# 21 | 22 | 23 | def parse_atm_record(line): 24 | '''Get the atm record 25 | ''' 26 | record = defaultdict() 27 | record['name'] = line[0:6].strip() 28 | record['atm_no'] = int(line[6:11]) 29 | record['atm_name'] = line[12:16].strip() 30 | record['atm_alt'] = line[17] 31 | record['res_name'] = line[17:20].strip() 32 | record['chain'] = line[21] 33 | record['res_no'] = int(line[22:26]) 34 | record['insert'] = line[26].strip() 35 | record['resid'] = line[22:29] 36 | record['x'] = float(line[30:38]) 37 | record['y'] = float(line[38:46]) 38 | record['z'] = float(line[46:54]) 39 | record['occ'] = float(line[54:60]) 40 | record['B'] = float(line[60:66]) 41 | 42 | return record 43 | 44 | def read_pdb(pdbfile): 45 | '''Read a pdb file per chain 46 | ''' 47 | pdb_file_info = [] 48 | all_coords = [] 49 | CA_coords = [] 50 | 51 | with open(pdbfile) as file: 52 | for line in file: 53 | if line.startswith('ATOM'): 54 | #Parse line 55 | record = parse_atm_record(line) 56 | #Save line 57 | pdb_file_info.append(line.rstrip()) 58 | #Save coords 59 | all_coords.append([record['x'],record['y'],record['z']]) 60 | if record['atm_name']=='CA': 61 | CA_coords.append([record['x'],record['y'],record['z']]) 62 | 63 | 64 | return pdb_file_info, np.array(all_coords), np.array(CA_coords) 65 | 66 | def score_ca_diff(all_ca_coords): 67 | """Score all CA coords against each other 68 | """ 69 | 70 | num_samples = len(all_ca_coords) 71 | score_mat = np.zeros((num_samples, num_samples)) 72 | for i in range(num_samples): 73 | dmat_i = np.sqrt(np.sum((all_ca_coords[i][:,None]-all_ca_coords[i][None,:])**2,axis=-1)) 74 | score_mat[i,i]=1000 75 | for j in range(i+1, num_samples): 76 | dmat_j = np.sqrt(np.sum((all_ca_coords[j][:,None]-all_ca_coords[j][None,:])**2,axis=-1)) 77 | #Diff 78 | diff = np.mean(np.sqrt((dmat_i-dmat_j)**2)) 79 | score_mat[i,j] = diff 80 | score_mat[j,i] = diff 81 | 82 | #Establish the order from the score mat 83 | order = [] 84 | min_ind = np.unravel_index(score_mat.argmin(), score_mat.shape) 85 | all_inds = np.arange(num_samples) 86 | order.append(min_ind[0]) 87 | ci = min_ind[1] 88 | while len(order)