├── .gitignore ├── 6nnw.gif ├── 6nnw.svg ├── AFP.svg ├── LICENSE ├── README.md ├── T1123.svg ├── afprofile.yml ├── data ├── H1141 │ ├── H1141.fasta │ └── features.pkl ├── H1144 │ ├── H1144.fasta │ └── features.pkl └── T1123 │ ├── T1123.fasta │ └── features.pkl ├── pip_pkgs.txt └── src ├── .DS_Store ├── AFP.sh ├── alphafold ├── __init__.py ├── common │ ├── __init__.py │ ├── confidence.py │ ├── protein.py │ ├── protein_test.py │ ├── residue_constants.py │ ├── residue_constants_test.py │ └── testdata │ │ └── 2rbg.pdb ├── data │ ├── __init__.py │ ├── feature_processing.py │ ├── mmcif_parsing.py │ ├── msa_identifiers.py │ ├── msa_pairing.py │ ├── parsers.py │ ├── pipeline.py │ ├── pipeline_multimer.py │ ├── templates.py │ └── tools │ │ ├── __init__.py │ │ ├── hhblits.py │ │ ├── hhsearch.py │ │ ├── hmmbuild.py │ │ ├── hmmsearch.py │ │ ├── jackhmmer.py │ │ ├── kalign.py │ │ └── utils.py ├── model │ ├── __init__.py │ ├── all_atom.py │ ├── all_atom_multimer.py │ ├── all_atom_test.py │ ├── common_modules.py │ ├── config.py │ ├── data.py │ ├── features.py │ ├── folding.py │ ├── folding_multimer.py │ ├── geometry │ │ ├── __init__.py │ │ ├── rigid_matrix_vector.py │ │ ├── rotation_matrix.py │ │ ├── struct_of_array.py │ │ ├── test_utils.py │ │ ├── utils.py │ │ └── vector.py │ ├── layer_stack.py │ ├── layer_stack_test.py │ ├── lddt.py │ ├── lddt_test.py │ ├── mapping.py │ ├── model.py │ ├── modules.py │ ├── modules_multimer.py │ ├── prng.py │ ├── prng_test.py │ ├── quat_affine.py │ ├── quat_affine_test.py │ ├── r3.py │ ├── tf │ │ ├── __init__.py │ │ ├── data_transforms.py │ │ ├── input_pipeline.py │ │ ├── protein_features.py │ │ ├── protein_features_test.py │ │ ├── proteins_dataset.py │ │ ├── shape_helpers.py │ │ ├── shape_helpers_test.py │ │ ├── shape_placeholders.py │ │ └── utils.py │ └── utils.py ├── notebooks │ ├── __init__.py │ ├── notebook_utils.py │ └── notebook_utils_test.py └── relax │ ├── __init__.py │ ├── amber_minimize.py │ ├── amber_minimize_test.py │ ├── cleanup.py │ ├── cleanup_test.py │ ├── relax.py │ ├── relax_test.py │ ├── testdata │ ├── model_output.pdb │ ├── multiple_disulfides_target.pdb │ ├── with_violations.pdb │ └── with_violations_casp14.pdb │ ├── utils.py │ └── utils_test.py ├── generate_msas.sh ├── obsolete.txt ├── run_AFP.py └── run_alphafold_msa_template_only.py /.gitignore: -------------------------------------------------------------------------------- 1 | src/alphafold/model/tf/__pycache__ 2 | src/alphafold/model/geometry/__pycache__ 3 | src/alphafold/model/__pycache__ 4 | src/alphafold/data/tools/__pycache__ 5 | src/alphafold/data/__pycache__ 6 | src/alphafold/common/__pycache__ 7 | src/alphafold/__pycache__/ 8 | data/params 9 | data/T1123/*.pdb 10 | data/T1123/metrics_0.0001_20.csv 11 | .DS_Store 12 | -------------------------------------------------------------------------------- /6nnw.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/AFProfile/a6aaaca40c8598b70c0f8f2768edae7e1847792e/6nnw.gif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AFProfile 2 | Improved protein complex prediction with AlphaFold-multimer by denoising the MSA profile. 3 | \ 4 | \ 5 | AFProfile learns a bias to the MSA representation that **improves the predictions** by performing **gradient descent through the AlphaFold-multimer network**. \ 6 | We effectively denoise the MSA profile, similar to how a blurry image would be sharpened to become more clear. \ 7 | This proves to be a highly efficient process, resulting in a 60-fold speedup compared to AFsample and as efficient as AFM v2.3. \ 8 | Read more about it in the paper [here](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1012253). 9 | 10 | \ 11 | 12 | \ 13 | \ 14 | AlphaFold2 (including AlphaFold-multimer) is available under the [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0) and so is AFProfile, which is a derivative thereof. \ 15 | The AlphaFold2 parameters are made available under the terms of the [CC BY 4.0 license](https://creativecommons.org/licenses/by/4.0/legalcode) and have not been modified. 16 | \ 17 | **You may not use these files except in compliance with the licenses.** 18 | 19 | ## Optimisation for 6nnw 20 | - Here is an example trajectory for PDBID 6nnw sorted by confidence. 21 | 22 | 23 | 24 | - The final prediction has an MMscore of 0.96 compared to 0.44 using AF-multimer. The [native structure](https://www.rcsb.org/structure/6NNW) is in grey. 25 | 26 | 27 | 28 | The confidence used to denoise the MSA is defined as: \ 29 | Confidence = 0.8 iptm + 0.2 ptm \ 30 | Where iptm is the predicted [TM-score](https://zhanggroup.org/TM-score/) in the interface and ptm that of the entire complex. 31 | 32 | # Setup 33 | 34 | ## Clone this repository 35 | ``` 36 | git clone https://github.com/patrickbryant1/AFProfile.git 37 | ``` 38 | 39 | ## Get the AlphaFold-multimer parameters 40 | ``` 41 | cd AFProfile 42 | mkdir data/params 43 | wget https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar 44 | tar -xvf alphafold_params_2022-03-02.tar 45 | mv params_model_1_multimer_v2.npz data/params/ 46 | rm *.npz 47 | rm alphafold_params_2022-03-02.tar 48 | ``` 49 | 50 | ## Install the AlphaFold requirements 51 | 52 | Install all packages into a conda environment (requires https://docs.conda.io/en/latest/miniconda.html) 53 | ``` 54 | conda env create -f afprofile.yml 55 | wait 56 | conda activate afprofile 57 | pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 58 | pip install --upgrade numpy 59 | ``` 60 | If the conda doesn't work for you - see "pip_pkgs.txt" 61 | 62 | ## Try the test case 63 | Now when you have installed the required packages - you can run a test case on CASP15 target T1123o 64 | \ 65 | 66 | ``` 67 | cd src 68 | bash AFP.sh 69 | ``` 70 | 71 | ## Install the genetic search programs 72 | - We install the genetic search programs from source. This will make the searches faster. 73 | 74 | *hh-suite* 75 | ``` 76 | cd src 77 | mkdir hh-suite 78 | cd hh-suite 79 | wget https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz 80 | tar xvfz hhsuite-3.3.0-SSE2-Linux.tar.gz 81 | cd .. 82 | ``` 83 | 84 | *hmmer* 85 | ``` 86 | cd src 87 | wget http://eddylab.org/software/hmmer/hmmer.tar.gz 88 | tar -xvzf hmmer.tar.gz 89 | rm hmmer.tar.gz 90 | cd hmmer-* 91 | ./configure 92 | make 93 | cd .. 94 | ``` 95 | 96 | *kalign* 97 | ``` 98 | wget https://github.com/TimoLassmann/kalign/archive/refs/tags/v3.3.2.tar.gz 99 | tar -zxvf v3.3.2.tar.gz 100 | rm v3.3.2.tar.gz 101 | cd kalign-3.3.2/ 102 | ./autogen.sh 103 | bash configure 104 | make 105 | make check 106 | make install 107 | cd .. 108 | ``` 109 | 110 | 111 | ## Download all databases for AlphaFold 112 | - If you have already installed AlphaFold, you don't need to do this. Then you can simply 113 | provide the paths for the databases in the runscript. 114 | 115 | *Small BFD: 17 GB* 116 | ``` 117 | wget https://storage.googleapis.com/alphafold-databases/reduced_dbs/bfd-first_non_consensus_sequences.fasta.gz 118 | gunzip bfd-first_non_consensus_sequences.fasta.gz 119 | mkdir data/small_bfd 120 | mv bfd-first_non_consensus_sequences.fasta data/small_bfd 121 | rm bfd-first_non_consensus_sequences.fasta.gz 122 | ``` 123 | 124 | *UNIREF90: 67 GB* 125 | ``` 126 | wget https://ftp.ebi.ac.uk/pub/databases/uniprot/uniref/uniref90/uniref90.fasta.gz 127 | gunzip uniref90.fasta.gz 128 | mkdir data/uniref90 129 | mv uniref90.fasta data/uniref90/ 130 | rm uniref90.fasta.gz 131 | ``` 132 | 133 | *UNIPROT: 105 GB* 134 | ``` 135 | wget https://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_trembl.fasta.gz 136 | wget https://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz 137 | gunzip uniprot_trembl.fasta.gz 138 | gunzip uniprot_sprot.fasta.gz 139 | mkdir data/uniprot 140 | cat uniprot_sprot.fasta >> uniprot_trembl.fasta 141 | mv uniprot_trembl.fasta data/uniprot/uniprot.fasta 142 | rm *.gz 143 | rm uniprot_sprot.fasta 144 | ``` 145 | 146 | - The following template databases are not used for the predictions, but needed due to the feature processing. 147 | 148 | *PDB SEQRES: 0.2 GB* 149 | ``` 150 | wget https://files.rcsb.org/pub/pdb/derived_data/pdb_seqres.txt.gz 151 | gunzip pdb_seqres.txt.gz 152 | mkdir pdb_seqres 153 | mv pdb_seqres.txt pdb_seqres/ 154 | ``` 155 | 156 | *MGNIFY: 120 GB* 157 | ``` 158 | wget https://storage.googleapis.com/alphafold-databases/v2.3/mgy_clusters_2022_05.fa.gz 159 | gunzip mgy_clusters_2022_05.fa.gz 160 | mkdir mgnify 161 | mv mgy_clusters_2022_05.fa mgnify/ 162 | rm mgy_clusters_2022_05.fa.gz 163 | ``` 164 | 165 | *MMCIF: 238 GB* 166 | - This may take a while... 167 | ``` 168 | mkdir -p data/pdb_mmcif/raw 169 | mkdir data/pdb_mmcif/mmcif_files 170 | rsync --recursive --links --perms --times --compress --info=progress2 --delete --port=33444 rsync.rcsb.org::ftp_data/structures/divided/mmCIF/ data/pdb_mmcif/raw 171 | 172 | find data/pdb_mmcif/raw -type f -iname "*.gz" -exec gunzip 173 | find data/pdb_mmcif/raw -type d -empty -delete 174 | for subdir in data/pdb_mmcif/raw/* 175 | do 176 | mv "${subdir}/"*.cif data/pdb_mmcif/mmcif_files/ 177 | done 178 | find data/pdb_mmcif/raw -type d -empty -delete 179 | ``` 180 | 181 | # Citation 182 | Bryant P, Noé F. Improved protein complex prediction with AlphaFold-multimer by denoising the MSA profile. PLoS Comput Biol. 2024;20: e1012253. 183 | -------------------------------------------------------------------------------- /data/H1141/H1141.fasta: -------------------------------------------------------------------------------- 1 | >H1141,subunit1|_0 2 | GLEKDFLPLYFGWFLTKKSSETLRKAGQVFLEELGNHKAFKKELRHFISGDEPKEKLELVSYFGKRPPGVLHCTTKFCDYKAAGAEEYAQQEVVKRSYGKAFKLSISALFVTPKTAGAQVVLTDQELQLWPSDLDKPSASEGLPPGSRAHVTLGCAADVQPVQTGLDLLDILQQVKGGSQGEAVGELPRGKLYSLGKGRWMLSLTKKMEVKAIFTGYYG 3 | >H1141,subunit2|_0 4 | EVQLEESGGGWVHPGGSLRLSCAASGNVFGVNTMAWYRQAPGKQREQRELVASITDYGTTEYADSVKGRFTISGDNAKATVYLQMNSLKPEDTAVYYCNMDLTVMTATSSLYAYDYWGQGTQVTVSS 5 | -------------------------------------------------------------------------------- /data/H1141/features.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/AFProfile/a6aaaca40c8598b70c0f8f2768edae7e1847792e/data/H1141/features.pkl -------------------------------------------------------------------------------- /data/H1144/H1144.fasta: -------------------------------------------------------------------------------- 1 | >H1144,subunit1|_0 2 | GLEKDFLPLYFGWFLTKKSSETLRKAGQVFLEELGNHKAFKKELRHFISGDEPKEKLELVSYFGKRPPGVLHCTTKFCDYKAAGAEEYAQQEVVKRSYGKAFKLSISALFVTPKTAGAQVVLTDQELQLWPSDLDKPSASEGLPPGSRAHVTLGCAADVQPVQTGLDLLDILQQVKGGSQGEAVGELPRGKLYSLGKGRWMLSLTKKMEVKAIFTGYYG 3 | >H1144,subunit2|_0 4 | EVQLEESGGGLVQPGGSLRLSCAASGFTFSSYVMSWVRQAPGKGLEWVSDINSGGSRTYYTDSVKGRFTISRDNAKNTLYLQMNSLKPEDTAVYYCARDSLLSTRYLHTSERGQGTQVTVSS 5 | -------------------------------------------------------------------------------- /data/H1144/features.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/AFProfile/a6aaaca40c8598b70c0f8f2768edae7e1847792e/data/H1144/features.pkl -------------------------------------------------------------------------------- /data/T1123/T1123.fasta: -------------------------------------------------------------------------------- 1 | > T1123_0 2 | MHHHHHHHHHHSETTYTGPRSIVTPETPIGPSSYPMTPSSLVLMAGYFSGPEISDNFGKYMPLLFQQNTSKVTFRSGSHTIKIVSMVLVDRLMWLDKHFNQYTNEPDGVFGDVGNVFVDNDNVAKVITMSGSSAPANRGATLMLCRATKNIQTFNFAATVYIPAYKVKDGAGGKDVVLNVAQWEANKTLTYPAIPKDTYFMVVTMGGASFTIQRYVVYNEGIGDGLELPAFWGKYLSQLYGFSWSSPTYACVTWEPIYAEEGIPHR 3 | > T1123_1 4 | MHHHHHHHHHHSETTYTGPRSIVTPETPIGPSSYPMTPSSLVLMAGYFSGPEISDNFGKYMPLLFQQNTSKVTFRSGSHTIKIVSMVLVDRLMWLDKHFNQYTNEPDGVFGDVGNVFVDNDNVAKVITMSGSSAPANRGATLMLCRATKNIQTFNFAATVYIPAYKVKDGAGGKDVVLNVAQWEANKTLTYPAIPKDTYFMVVTMGGASFTIQRYVVYNEGIGDGLELPAFWGKYLSQLYGFSWSSPTYACVTWEPIYAEEGIPHR 5 | -------------------------------------------------------------------------------- /data/T1123/features.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/AFProfile/a6aaaca40c8598b70c0f8f2768edae7e1847792e/data/T1123/features.pkl -------------------------------------------------------------------------------- /pip_pkgs.txt: -------------------------------------------------------------------------------- 1 | pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 2 | pip install ml-collections==0.1.1 3 | pip install dm-haiku==0.0.11 4 | pip install pandas==1.3.5 5 | pip install biopython==1.81 6 | pip install chex==0.1.5 7 | pip install dm-tree==0.1.8 8 | pip install immutabledict==2.0.0 9 | pip install scipy 10 | pip install tensorflow 11 | pip install numpy==1.21.6 12 | -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/AFProfile/a6aaaca40c8598b70c0f8f2768edae7e1847792e/src/.DS_Store -------------------------------------------------------------------------------- /src/AFP.sh: -------------------------------------------------------------------------------- 1 | 2 | # Script for predicting with AlphaFold-multimer using directed sampling (=dropout, This is the procedure implemented by Wallner in CASP15: https://www.biorxiv.org/content/10.1101/2022.12.20.521205v3) 3 | # using a denoising process over the profile. The MSA is denoised by doing gradient descent through AlphaFold-multimer. 4 | # Fill in all the variables below to run the predictions. 5 | # This script assumes that all python packages necessary are in the current path. 6 | #The fasta conventions are the same as for AlphaFold-multimer. See the example files in ../data/ 7 | 8 | #Get ID 9 | ID=T1123 10 | FASTA_PATHS=../data/T1123/T1123.fasta 11 | PARAMDIR=../data/ #If v2 is used: Change the _v3 to _v2 in the multimer MODEL_PRESETS in config.py 12 | OUTDIR=../data/ 13 | AFDIR=./ 14 | 15 | #1. Get MSAs: run generate_msas.sh which runs: run_alphafold_msa_template_only.py - this produces the feats as well (saved as pickle) 16 | #For this test case - the features have already been generated and are available here: ../data/T1123/features.pkl 17 | 18 | #2. Learn residuals to improve the confidence: run_AFP.py 19 | #Run AFM 20 | MODEL_PRESET='multimer' 21 | NUM_RECYCLES=20 #Number of recycles 22 | CONFIDENCE_T=0.95 #At what confidence to stop the search 23 | MAX_ITER=500 #Max number of iterations 24 | LR=0.0001 #Learning rate for ADAM optimizer 25 | 26 | #Run 27 | python3 $AFDIR/run_AFP.py --fasta_paths=$FASTA_PATHS \ 28 | --data_dir=$PARAMDIR --model_preset=$MODEL_PRESET \ 29 | --num_recycles=$NUM_RECYCLES \ 30 | --confidence_threshold=$CONFIDENCE_T \ 31 | --max_iter=$MAX_ITER \ 32 | --learning_rate=$LR \ 33 | --output_dir=$OUTDIR \ 34 | --feature_dir=$OUTDIR/$ID/ 35 | -------------------------------------------------------------------------------- /src/alphafold/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """An implementation of the inference pipeline of AlphaFold v2.0.""" 15 | -------------------------------------------------------------------------------- /src/alphafold/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common data types and constants used within Alphafold.""" 15 | -------------------------------------------------------------------------------- /src/alphafold/common/confidence.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for processing confidence metrics.""" 16 | 17 | from typing import Dict, Optional, Tuple 18 | import numpy as np 19 | import jax.numpy as jnp 20 | import jax 21 | import scipy.special 22 | 23 | 24 | def compute_plddt(logits: np.ndarray) -> np.ndarray: 25 | """Computes per-residue pLDDT from logits. 26 | 27 | Args: 28 | logits: [num_res, num_bins] output from the PredictedLDDTHead. 29 | 30 | Returns: 31 | plddt: [num_res] per-residue pLDDT. 32 | """ 33 | num_bins = logits.shape[-1] 34 | bin_width = 1.0 / num_bins 35 | bin_centers = jnp.arange(start=0.5 * bin_width, stop=1.0, step=bin_width) 36 | probs = jax.nn.softmax(logits, axis=-1) 37 | predicted_lddt_ca = jnp.sum(probs * bin_centers[None, :], axis=-1) 38 | return predicted_lddt_ca * 100 39 | 40 | 41 | def _calculate_bin_centers(breaks: np.ndarray): 42 | """Gets the bin centers from the bin edges. 43 | 44 | Args: 45 | breaks: [num_bins - 1] the error bin edges. 46 | 47 | Returns: 48 | bin_centers: [num_bins] the error bin centers. 49 | """ 50 | step = (breaks[1] - breaks[0]) 51 | 52 | # Add half-step to get the center 53 | bin_centers = breaks + step / 2 54 | # Add a catch-all bin at the end. 55 | bin_centers = jnp.concatenate([bin_centers, jnp.expand_dims(jnp.array(bin_centers[-1] + step),axis=0)],axis=0) 56 | return bin_centers 57 | 58 | 59 | def _calculate_expected_aligned_error( 60 | alignment_confidence_breaks: np.ndarray, 61 | aligned_distance_error_probs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 62 | """Calculates expected aligned distance errors for every pair of residues. 63 | 64 | Args: 65 | alignment_confidence_breaks: [num_bins - 1] the error bin edges. 66 | aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted 67 | probs for each error bin, for each pair of residues. 68 | 69 | Returns: 70 | predicted_aligned_error: [num_res, num_res] the expected aligned distance 71 | error for each pair of residues. 72 | max_predicted_aligned_error: The maximum predicted error possible. 73 | """ 74 | bin_centers = _calculate_bin_centers(alignment_confidence_breaks) 75 | 76 | # Tuple of expected aligned distance error and max possible error. 77 | return (jnp.sum(aligned_distance_error_probs * bin_centers, axis=-1), 78 | jnp.asarray(bin_centers[-1])) 79 | 80 | 81 | def compute_predicted_aligned_error( 82 | logits: np.ndarray, 83 | breaks: np.ndarray) -> Dict[str, np.ndarray]: 84 | """Computes aligned confidence metrics from logits. 85 | 86 | Args: 87 | logits: [num_res, num_res, num_bins] the logits output from 88 | PredictedAlignedErrorHead. 89 | breaks: [num_bins - 1] the error bin edges. 90 | 91 | Returns: 92 | aligned_confidence_probs: [num_res, num_res, num_bins] the predicted 93 | aligned error probabilities over bins for each residue pair. 94 | predicted_aligned_error: [num_res, num_res] the expected aligned distance 95 | error for each pair of residues. 96 | max_predicted_aligned_error: The maximum predicted error possible. 97 | """ 98 | aligned_confidence_probs = jax.nn.softmax( 99 | logits, 100 | axis=-1) 101 | predicted_aligned_error, max_predicted_aligned_error = ( 102 | _calculate_expected_aligned_error( 103 | alignment_confidence_breaks=breaks, 104 | aligned_distance_error_probs=aligned_confidence_probs)) 105 | return { 106 | 'aligned_confidence_probs': aligned_confidence_probs, 107 | 'predicted_aligned_error': predicted_aligned_error, 108 | 'max_predicted_aligned_error': max_predicted_aligned_error, 109 | } 110 | 111 | 112 | def predicted_tm_score( 113 | logits: np.ndarray, 114 | breaks: np.ndarray, 115 | residue_weights: Optional[np.ndarray] = None, 116 | asym_id: Optional[np.ndarray] = None, 117 | interface: bool = False) -> np.ndarray: 118 | """Computes predicted TM alignment or predicted interface TM alignment score. 119 | 120 | Args: 121 | logits: [num_res, num_res, num_bins] the logits output from 122 | PredictedAlignedErrorHead. 123 | breaks: [num_bins] the error bins. 124 | residue_weights: [num_res] the per residue weights to use for the 125 | expectation. 126 | asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for 127 | ipTM calculation, i.e. when interface=True. 128 | interface: If True, interface predicted TM score is computed. 129 | 130 | Returns: 131 | ptm_score: The predicted TM alignment or the predicted iTM score. 132 | """ 133 | 134 | # residue_weights has to be in [0, 1], but can be floating-point, i.e. the 135 | # exp. resolved head's probability. 136 | if residue_weights is None: 137 | residue_weights = np.ones(logits.shape[0]) 138 | 139 | bin_centers = _calculate_bin_centers(breaks) 140 | 141 | num_res = int(jnp.sum(residue_weights)) 142 | # Clip num_res to avoid negative/undefined d0. 143 | clipped_num_res = max(num_res, 19) 144 | 145 | # Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick 146 | # "Scoring function for automated assessment of protein structure template 147 | # quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf 148 | d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8 149 | 150 | # Convert logits to probs. 151 | probs = jax.nn.softmax(logits, axis=-1) 152 | 153 | # TM-Score term for every bin. 154 | tm_per_bin = 1. / (1 + jnp.square(bin_centers) / jnp.square(d0)) 155 | # E_distances tm(distance). 156 | predicted_tm_term = jnp.sum(probs * tm_per_bin, axis=-1) 157 | 158 | pair_mask = jnp.ones(shape=(num_res, num_res), dtype=bool) 159 | if interface: 160 | pair_mask *= asym_id[:, None] != asym_id[None, :] 161 | 162 | predicted_tm_term *= pair_mask 163 | 164 | pair_residue_weights = pair_mask * ( 165 | residue_weights[None, :] * residue_weights[:, None]) 166 | normed_residue_mask = pair_residue_weights / (1e-8 + jnp.sum( 167 | pair_residue_weights, axis=-1, keepdims=True)) 168 | per_alignment = jnp.sum(predicted_tm_term * normed_residue_mask, axis=-1) 169 | return jnp.asarray(per_alignment[(per_alignment * residue_weights).argmax()]) 170 | -------------------------------------------------------------------------------- /src/alphafold/common/protein_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for protein.""" 16 | 17 | import os 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from alphafold.common import protein 22 | from alphafold.common import residue_constants 23 | import numpy as np 24 | # Internal import (7716). 25 | 26 | TEST_DATA_DIR = 'alphafold/common/testdata/' 27 | 28 | 29 | class ProteinTest(parameterized.TestCase): 30 | 31 | def _check_shapes(self, prot, num_res): 32 | """Check that the processed shapes are correct.""" 33 | num_atoms = residue_constants.atom_type_num 34 | self.assertEqual((num_res, num_atoms, 3), prot.atom_positions.shape) 35 | self.assertEqual((num_res,), prot.aatype.shape) 36 | self.assertEqual((num_res, num_atoms), prot.atom_mask.shape) 37 | self.assertEqual((num_res,), prot.residue_index.shape) 38 | self.assertEqual((num_res,), prot.chain_index.shape) 39 | self.assertEqual((num_res, num_atoms), prot.b_factors.shape) 40 | 41 | @parameterized.named_parameters( 42 | dict(testcase_name='chain_A', 43 | pdb_file='2rbg.pdb', chain_id='A', num_res=282, num_chains=1), 44 | dict(testcase_name='chain_B', 45 | pdb_file='2rbg.pdb', chain_id='B', num_res=282, num_chains=1), 46 | dict(testcase_name='multichain', 47 | pdb_file='2rbg.pdb', chain_id=None, num_res=564, num_chains=2)) 48 | def test_from_pdb_str(self, pdb_file, chain_id, num_res, num_chains): 49 | pdb_file = os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, 50 | pdb_file) 51 | with open(pdb_file) as f: 52 | pdb_string = f.read() 53 | prot = protein.from_pdb_string(pdb_string, chain_id) 54 | self._check_shapes(prot, num_res) 55 | self.assertGreaterEqual(prot.aatype.min(), 0) 56 | # Allow equal since unknown restypes have index equal to restype_num. 57 | self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num) 58 | self.assertLen(np.unique(prot.chain_index), num_chains) 59 | 60 | def test_to_pdb(self): 61 | with open( 62 | os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, 63 | '2rbg.pdb')) as f: 64 | pdb_string = f.read() 65 | prot = protein.from_pdb_string(pdb_string) 66 | pdb_string_reconstr = protein.to_pdb(prot) 67 | 68 | for line in pdb_string_reconstr.splitlines(): 69 | self.assertLen(line, 80) 70 | 71 | prot_reconstr = protein.from_pdb_string(pdb_string_reconstr) 72 | 73 | np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype) 74 | np.testing.assert_array_almost_equal( 75 | prot_reconstr.atom_positions, prot.atom_positions) 76 | np.testing.assert_array_almost_equal( 77 | prot_reconstr.atom_mask, prot.atom_mask) 78 | np.testing.assert_array_equal( 79 | prot_reconstr.residue_index, prot.residue_index) 80 | np.testing.assert_array_equal( 81 | prot_reconstr.chain_index, prot.chain_index) 82 | np.testing.assert_array_almost_equal( 83 | prot_reconstr.b_factors, prot.b_factors) 84 | 85 | def test_ideal_atom_mask(self): 86 | with open( 87 | os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR, 88 | '2rbg.pdb')) as f: 89 | pdb_string = f.read() 90 | prot = protein.from_pdb_string(pdb_string) 91 | ideal_mask = protein.ideal_atom_mask(prot) 92 | non_ideal_residues = set([102] + list(range(127, 286))) 93 | for i, (res, atom_mask) in enumerate( 94 | zip(prot.residue_index, prot.atom_mask)): 95 | if res in non_ideal_residues: 96 | self.assertFalse(np.all(atom_mask == ideal_mask[i]), msg=f'{res}') 97 | else: 98 | self.assertTrue(np.all(atom_mask == ideal_mask[i]), msg=f'{res}') 99 | 100 | def test_too_many_chains(self): 101 | num_res = protein.PDB_MAX_CHAINS + 1 102 | num_atom_type = residue_constants.atom_type_num 103 | with self.assertRaises(ValueError): 104 | _ = protein.Protein( 105 | atom_positions=np.random.random([num_res, num_atom_type, 3]), 106 | aatype=np.random.randint(0, 21, [num_res]), 107 | atom_mask=np.random.randint(0, 2, [num_res]).astype(np.float32), 108 | residue_index=np.arange(1, num_res+1), 109 | chain_index=np.arange(num_res), 110 | b_factors=np.random.uniform(1, 100, [num_res])) 111 | 112 | 113 | if __name__ == '__main__': 114 | absltest.main() 115 | -------------------------------------------------------------------------------- /src/alphafold/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Data pipeline for model features.""" 15 | -------------------------------------------------------------------------------- /src/alphafold/data/msa_identifiers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for extracting identifiers from MSA sequence descriptions.""" 16 | 17 | import dataclasses 18 | import re 19 | from typing import Optional 20 | 21 | 22 | # Sequences coming from UniProtKB database come in the 23 | # `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE` 24 | # or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively). 25 | _UNIPROT_PATTERN = re.compile( 26 | r""" 27 | ^ 28 | # UniProtKB/TrEMBL or UniProtKB/Swiss-Prot 29 | (?:tr|sp) 30 | \| 31 | # A primary accession number of the UniProtKB entry. 32 | (?P[A-Za-z0-9]{6,10}) 33 | # Occasionally there is a _0 or _1 isoform suffix, which we ignore. 34 | (?:_\d)? 35 | \| 36 | # TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic 37 | # protein ID code. 38 | (?:[A-Za-z0-9]+) 39 | _ 40 | # A mnemonic species identification code. 41 | (?P([A-Za-z0-9]){1,5}) 42 | # Small BFD uses a final value after an underscore, which we ignore. 43 | (?:_\d+)? 44 | $ 45 | """, 46 | re.VERBOSE) 47 | 48 | 49 | @dataclasses.dataclass(frozen=True) 50 | class Identifiers: 51 | species_id: str = '' 52 | 53 | 54 | def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: 55 | """Gets species from an msa sequence identifier. 56 | 57 | The sequence identifier has the format specified by 58 | _UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN. 59 | An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE` 60 | 61 | Args: 62 | msa_sequence_identifier: a sequence identifier. 63 | 64 | Returns: 65 | An `Identifiers` instance with species_id. These 66 | can be empty in the case where no identifier was found. 67 | """ 68 | matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip()) 69 | if matches: 70 | return Identifiers( 71 | species_id=matches.group('SpeciesIdentifier')) 72 | return Identifiers() 73 | 74 | 75 | def _extract_sequence_identifier(description: str) -> Optional[str]: 76 | """Extracts sequence identifier from description. Returns None if no match.""" 77 | split_description = description.split() 78 | if split_description: 79 | return split_description[0].partition('/')[0] 80 | else: 81 | return None 82 | 83 | 84 | def get_identifiers(description: str) -> Identifiers: 85 | """Computes extra MSA features from the description.""" 86 | sequence_identifier = _extract_sequence_identifier(description) 87 | if sequence_identifier is None: 88 | return Identifiers() 89 | else: 90 | return _parse_sequence_identifier(sequence_identifier) 91 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Python wrappers for third party tools.""" 15 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/hhblits.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library to run HHblits from Python.""" 16 | 17 | import glob 18 | import os 19 | import subprocess 20 | from typing import Any, List, Mapping, Optional, Sequence 21 | 22 | from absl import logging 23 | from alphafold.data.tools import utils 24 | # Internal import (7716). 25 | 26 | 27 | _HHBLITS_DEFAULT_P = 20 28 | _HHBLITS_DEFAULT_Z = 500 29 | 30 | 31 | class HHBlits: 32 | """Python wrapper of the HHblits binary.""" 33 | 34 | def __init__(self, 35 | *, 36 | binary_path: str, 37 | databases: Sequence[str], 38 | n_cpu: int = 4, 39 | n_iter: int = 3, 40 | e_value: float = 0.001, 41 | maxseq: int = 1_000_000, 42 | realign_max: int = 100_000, 43 | maxfilt: int = 100_000, 44 | min_prefilter_hits: int = 1000, 45 | all_seqs: bool = False, 46 | alt: Optional[int] = None, 47 | p: int = _HHBLITS_DEFAULT_P, 48 | z: int = _HHBLITS_DEFAULT_Z): 49 | """Initializes the Python HHblits wrapper. 50 | 51 | Args: 52 | binary_path: The path to the HHblits executable. 53 | databases: A sequence of HHblits database paths. This should be the 54 | common prefix for the database files (i.e. up to but not including 55 | _hhm.ffindex etc.) 56 | n_cpu: The number of CPUs to give HHblits. 57 | n_iter: The number of HHblits iterations. 58 | e_value: The E-value, see HHblits docs for more details. 59 | maxseq: The maximum number of rows in an input alignment. Note that this 60 | parameter is only supported in HHBlits version 3.1 and higher. 61 | realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. 62 | maxfilt: Max number of hits allowed to pass the 2nd prefilter. 63 | HHblits default: 20000. 64 | min_prefilter_hits: Min number of hits to pass prefilter. 65 | HHblits default: 100. 66 | all_seqs: Return all sequences in the MSA / Do not filter the result MSA. 67 | HHblits default: False. 68 | alt: Show up to this many alternative alignments. 69 | p: Minimum Prob for a hit to be included in the output hhr file. 70 | HHblits default: 20. 71 | z: Hard cap on number of hits reported in the hhr file. 72 | HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. 73 | 74 | Raises: 75 | RuntimeError: If HHblits binary not found within the path. 76 | """ 77 | self.binary_path = binary_path 78 | self.databases = databases 79 | 80 | for database_path in self.databases: 81 | if not glob.glob(database_path + '_*'): 82 | logging.error('Could not find HHBlits database %s', database_path) 83 | raise ValueError(f'Could not find HHBlits database {database_path}') 84 | 85 | self.n_cpu = n_cpu 86 | self.n_iter = n_iter 87 | self.e_value = e_value 88 | self.maxseq = maxseq 89 | self.realign_max = realign_max 90 | self.maxfilt = maxfilt 91 | self.min_prefilter_hits = min_prefilter_hits 92 | self.all_seqs = all_seqs 93 | self.alt = alt 94 | self.p = p 95 | self.z = z 96 | 97 | def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]: 98 | """Queries the database using HHblits.""" 99 | with utils.tmpdir_manager() as query_tmp_dir: 100 | a3m_path = os.path.join(query_tmp_dir, 'output.a3m') 101 | 102 | db_cmd = [] 103 | for db_path in self.databases: 104 | db_cmd.append('-d') 105 | db_cmd.append(db_path) 106 | cmd = [ 107 | self.binary_path, 108 | '-i', input_fasta_path, 109 | '-cpu', str(self.n_cpu), 110 | '-oa3m', a3m_path, 111 | '-o', '/dev/null', 112 | '-n', str(self.n_iter), 113 | '-e', str(self.e_value), 114 | '-maxseq', str(self.maxseq), 115 | '-realign_max', str(self.realign_max), 116 | '-maxfilt', str(self.maxfilt), 117 | '-min_prefilter_hits', str(self.min_prefilter_hits)] 118 | if self.all_seqs: 119 | cmd += ['-all'] 120 | if self.alt: 121 | cmd += ['-alt', str(self.alt)] 122 | if self.p != _HHBLITS_DEFAULT_P: 123 | cmd += ['-p', str(self.p)] 124 | if self.z != _HHBLITS_DEFAULT_Z: 125 | cmd += ['-Z', str(self.z)] 126 | cmd += db_cmd 127 | 128 | logging.info('Launching subprocess "%s"', ' '.join(cmd)) 129 | process = subprocess.Popen( 130 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 131 | 132 | with utils.timing('HHblits query'): 133 | stdout, stderr = process.communicate() 134 | retcode = process.wait() 135 | 136 | if retcode: 137 | # Logs have a 15k character limit, so log HHblits error line by line. 138 | logging.error('HHblits failed. HHblits stderr begin:') 139 | for error_line in stderr.decode('utf-8').splitlines(): 140 | if error_line.strip(): 141 | logging.error(error_line.strip()) 142 | logging.error('HHblits stderr end') 143 | raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % ( 144 | stdout.decode('utf-8'), stderr[:500_000].decode('utf-8'))) 145 | 146 | with open(a3m_path) as f: 147 | a3m = f.read() 148 | 149 | raw_output = dict( 150 | a3m=a3m, 151 | output=stdout, 152 | stderr=stderr, 153 | n_iter=self.n_iter, 154 | e_value=self.e_value) 155 | return [raw_output] 156 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/hhsearch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library to run HHsearch from Python.""" 16 | 17 | import glob 18 | import os 19 | import subprocess 20 | from typing import Sequence 21 | 22 | from absl import logging 23 | 24 | from alphafold.data import parsers 25 | from alphafold.data.tools import utils 26 | # Internal import (7716). 27 | 28 | 29 | class HHSearch: 30 | """Python wrapper of the HHsearch binary.""" 31 | 32 | def __init__(self, 33 | *, 34 | binary_path: str, 35 | databases: Sequence[str], 36 | maxseq: int = 1_000_000): 37 | """Initializes the Python HHsearch wrapper. 38 | 39 | Args: 40 | binary_path: The path to the HHsearch executable. 41 | databases: A sequence of HHsearch database paths. This should be the 42 | common prefix for the database files (i.e. up to but not including 43 | _hhm.ffindex etc.) 44 | maxseq: The maximum number of rows in an input alignment. Note that this 45 | parameter is only supported in HHBlits version 3.1 and higher. 46 | 47 | Raises: 48 | RuntimeError: If HHsearch binary not found within the path. 49 | """ 50 | self.binary_path = binary_path 51 | self.databases = databases 52 | self.maxseq = maxseq 53 | 54 | for database_path in self.databases: 55 | if not glob.glob(database_path + '_*'): 56 | logging.error('Could not find HHsearch database %s', database_path) 57 | raise ValueError(f'Could not find HHsearch database {database_path}') 58 | 59 | @property 60 | def output_format(self) -> str: 61 | return 'hhr' 62 | 63 | @property 64 | def input_format(self) -> str: 65 | return 'a3m' 66 | 67 | def query(self, a3m: str) -> str: 68 | """Queries the database using HHsearch using a given a3m.""" 69 | with utils.tmpdir_manager() as query_tmp_dir: 70 | input_path = os.path.join(query_tmp_dir, 'query.a3m') 71 | hhr_path = os.path.join(query_tmp_dir, 'output.hhr') 72 | with open(input_path, 'w') as f: 73 | f.write(a3m) 74 | 75 | db_cmd = [] 76 | for db_path in self.databases: 77 | db_cmd.append('-d') 78 | db_cmd.append(db_path) 79 | cmd = [self.binary_path, 80 | '-i', input_path, 81 | '-o', hhr_path, 82 | '-maxseq', str(self.maxseq) 83 | ] + db_cmd 84 | 85 | logging.info('Launching subprocess "%s"', ' '.join(cmd)) 86 | process = subprocess.Popen( 87 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 88 | with utils.timing('HHsearch query'): 89 | stdout, stderr = process.communicate() 90 | retcode = process.wait() 91 | 92 | if retcode: 93 | # Stderr is truncated to prevent proto size errors in Beam. 94 | raise RuntimeError( 95 | 'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( 96 | stdout.decode('utf-8'), stderr[:100_000].decode('utf-8'))) 97 | 98 | with open(hhr_path) as f: 99 | hhr = f.read() 100 | return hhr 101 | 102 | def get_template_hits(self, 103 | output_string: str, 104 | input_sequence: str) -> Sequence[parsers.TemplateHit]: 105 | """Gets parsed template hits from the raw string output by the tool.""" 106 | del input_sequence # Used by hmmseach but not needed for hhsearch. 107 | return parsers.parse_hhr(output_string) 108 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/hmmbuild.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A Python wrapper for hmmbuild - construct HMM profiles from MSA.""" 16 | 17 | import os 18 | import re 19 | import subprocess 20 | 21 | from absl import logging 22 | from alphafold.data.tools import utils 23 | # Internal import (7716). 24 | 25 | 26 | class Hmmbuild(object): 27 | """Python wrapper of the hmmbuild binary.""" 28 | 29 | def __init__(self, 30 | *, 31 | binary_path: str, 32 | singlemx: bool = False): 33 | """Initializes the Python hmmbuild wrapper. 34 | 35 | Args: 36 | binary_path: The path to the hmmbuild executable. 37 | singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to 38 | just use a common substitution score matrix. 39 | 40 | Raises: 41 | RuntimeError: If hmmbuild binary not found within the path. 42 | """ 43 | self.binary_path = binary_path 44 | self.singlemx = singlemx 45 | 46 | def build_profile_from_sto(self, sto: str, model_construction='fast') -> str: 47 | """Builds a HHM for the aligned sequences given as an A3M string. 48 | 49 | Args: 50 | sto: A string with the aligned sequences in the Stockholm format. 51 | model_construction: Whether to use reference annotation in the msa to 52 | determine consensus columns ('hand') or default ('fast'). 53 | 54 | Returns: 55 | A string with the profile in the HMM format. 56 | 57 | Raises: 58 | RuntimeError: If hmmbuild fails. 59 | """ 60 | return self._build_profile(sto, model_construction=model_construction) 61 | 62 | def build_profile_from_a3m(self, a3m: str) -> str: 63 | """Builds a HHM for the aligned sequences given as an A3M string. 64 | 65 | Args: 66 | a3m: A string with the aligned sequences in the A3M format. 67 | 68 | Returns: 69 | A string with the profile in the HMM format. 70 | 71 | Raises: 72 | RuntimeError: If hmmbuild fails. 73 | """ 74 | lines = [] 75 | for line in a3m.splitlines(): 76 | if not line.startswith('>'): 77 | line = re.sub('[a-z]+', '', line) # Remove inserted residues. 78 | lines.append(line + '\n') 79 | msa = ''.join(lines) 80 | return self._build_profile(msa, model_construction='fast') 81 | 82 | def _build_profile(self, msa: str, model_construction: str = 'fast') -> str: 83 | """Builds a HMM for the aligned sequences given as an MSA string. 84 | 85 | Args: 86 | msa: A string with the aligned sequences, in A3M or STO format. 87 | model_construction: Whether to use reference annotation in the msa to 88 | determine consensus columns ('hand') or default ('fast'). 89 | 90 | Returns: 91 | A string with the profile in the HMM format. 92 | 93 | Raises: 94 | RuntimeError: If hmmbuild fails. 95 | ValueError: If unspecified arguments are provided. 96 | """ 97 | if model_construction not in {'hand', 'fast'}: 98 | raise ValueError(f'Invalid model_construction {model_construction} - only' 99 | 'hand and fast supported.') 100 | 101 | with utils.tmpdir_manager() as query_tmp_dir: 102 | input_query = os.path.join(query_tmp_dir, 'query.msa') 103 | output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm') 104 | 105 | with open(input_query, 'w') as f: 106 | f.write(msa) 107 | 108 | cmd = [self.binary_path] 109 | # If adding flags, we have to do so before the output and input: 110 | 111 | if model_construction == 'hand': 112 | cmd.append(f'--{model_construction}') 113 | if self.singlemx: 114 | cmd.append('--singlemx') 115 | cmd.extend([ 116 | '--amino', 117 | output_hmm_path, 118 | input_query, 119 | ]) 120 | 121 | logging.info('Launching subprocess %s', cmd) 122 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, 123 | stderr=subprocess.PIPE) 124 | 125 | with utils.timing('hmmbuild query'): 126 | stdout, stderr = process.communicate() 127 | retcode = process.wait() 128 | logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n', 129 | stdout.decode('utf-8'), stderr.decode('utf-8')) 130 | 131 | if retcode: 132 | raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n' 133 | % (stdout.decode('utf-8'), stderr.decode('utf-8'))) 134 | 135 | with open(output_hmm_path, encoding='utf-8') as f: 136 | hmm = f.read() 137 | 138 | return hmm 139 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/hmmsearch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A Python wrapper for hmmsearch - search profile against a sequence db.""" 16 | 17 | import os 18 | import subprocess 19 | from typing import Optional, Sequence 20 | 21 | from absl import logging 22 | from alphafold.data import parsers 23 | from alphafold.data.tools import hmmbuild 24 | from alphafold.data.tools import utils 25 | # Internal import (7716). 26 | 27 | 28 | class Hmmsearch(object): 29 | """Python wrapper of the hmmsearch binary.""" 30 | 31 | def __init__(self, 32 | *, 33 | binary_path: str, 34 | hmmbuild_binary_path: str, 35 | database_path: str, 36 | flags: Optional[Sequence[str]] = None): 37 | """Initializes the Python hmmsearch wrapper. 38 | 39 | Args: 40 | binary_path: The path to the hmmsearch executable. 41 | hmmbuild_binary_path: The path to the hmmbuild executable. Used to build 42 | an hmm from an input a3m. 43 | database_path: The path to the hmmsearch database (FASTA format). 44 | flags: List of flags to be used by hmmsearch. 45 | 46 | Raises: 47 | RuntimeError: If hmmsearch binary not found within the path. 48 | """ 49 | self.binary_path = binary_path 50 | self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path) 51 | self.database_path = database_path 52 | if flags is None: 53 | # Default hmmsearch run settings. 54 | flags = ['--F1', '0.1', 55 | '--F2', '0.1', 56 | '--F3', '0.1', 57 | '--incE', '100', 58 | '-E', '100', 59 | '--domE', '100', 60 | '--incdomE', '100'] 61 | self.flags = flags 62 | 63 | if not os.path.exists(self.database_path): 64 | logging.error('Could not find hmmsearch database %s', database_path) 65 | raise ValueError(f'Could not find hmmsearch database {database_path}') 66 | 67 | @property 68 | def output_format(self) -> str: 69 | return 'sto' 70 | 71 | @property 72 | def input_format(self) -> str: 73 | return 'sto' 74 | 75 | def query(self, msa_sto: str) -> str: 76 | """Queries the database using hmmsearch using a given stockholm msa.""" 77 | hmm = self.hmmbuild_runner.build_profile_from_sto(msa_sto, 78 | model_construction='hand') 79 | return self.query_with_hmm(hmm) 80 | 81 | def query_with_hmm(self, hmm: str) -> str: 82 | """Queries the database using hmmsearch using a given hmm.""" 83 | with utils.tmpdir_manager() as query_tmp_dir: 84 | hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') 85 | out_path = os.path.join(query_tmp_dir, 'output.sto') 86 | with open(hmm_input_path, 'w') as f: 87 | f.write(hmm) 88 | 89 | cmd = [ 90 | self.binary_path, 91 | '--noali', # Don't include the alignment in stdout. 92 | '--cpu', '8' 93 | ] 94 | # If adding flags, we have to do so before the output and input: 95 | if self.flags: 96 | cmd.extend(self.flags) 97 | cmd.extend([ 98 | '-A', out_path, 99 | hmm_input_path, 100 | self.database_path, 101 | ]) 102 | 103 | logging.info('Launching sub-process %s', cmd) 104 | process = subprocess.Popen( 105 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 106 | with utils.timing( 107 | f'hmmsearch ({os.path.basename(self.database_path)}) query'): 108 | stdout, stderr = process.communicate() 109 | retcode = process.wait() 110 | 111 | if retcode: 112 | raise RuntimeError( 113 | 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( 114 | stdout.decode('utf-8'), stderr.decode('utf-8'))) 115 | 116 | with open(out_path) as f: 117 | out_msa = f.read() 118 | 119 | return out_msa 120 | 121 | def get_template_hits(self, 122 | output_string: str, 123 | input_sequence: str) -> Sequence[parsers.TemplateHit]: 124 | """Gets parsed template hits from the raw string output by the tool.""" 125 | a3m_string = parsers.convert_stockholm_to_a3m(output_string, 126 | remove_first_row_gaps=False) 127 | template_hits = parsers.parse_hmmsearch_a3m( 128 | query_sequence=input_sequence, 129 | a3m_string=a3m_string, 130 | skip_first=False) 131 | return template_hits 132 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/jackhmmer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Library to run Jackhmmer from Python.""" 16 | 17 | from concurrent import futures 18 | import glob 19 | import os 20 | import subprocess 21 | from typing import Any, Callable, Mapping, Optional, Sequence 22 | from urllib import request 23 | 24 | from absl import logging 25 | 26 | from alphafold.data import parsers 27 | from alphafold.data.tools import utils 28 | # Internal import (7716). 29 | 30 | 31 | class Jackhmmer: 32 | """Python wrapper of the Jackhmmer binary.""" 33 | 34 | def __init__(self, 35 | *, 36 | binary_path: str, 37 | database_path: str, 38 | n_cpu: int = 8, 39 | n_iter: int = 1, 40 | e_value: float = 0.0001, 41 | z_value: Optional[int] = None, 42 | get_tblout: bool = False, 43 | filter_f1: float = 0.0005, 44 | filter_f2: float = 0.00005, 45 | filter_f3: float = 0.0000005, 46 | incdom_e: Optional[float] = None, 47 | dom_e: Optional[float] = None, 48 | num_streamed_chunks: Optional[int] = None, 49 | streaming_callback: Optional[Callable[[int], None]] = None): 50 | """Initializes the Python Jackhmmer wrapper. 51 | 52 | Args: 53 | binary_path: The path to the jackhmmer executable. 54 | database_path: The path to the jackhmmer database (FASTA format). 55 | n_cpu: The number of CPUs to give Jackhmmer. 56 | n_iter: The number of Jackhmmer iterations. 57 | e_value: The E-value, see Jackhmmer docs for more details. 58 | z_value: The Z-value, see Jackhmmer docs for more details. 59 | get_tblout: Whether to save tblout string. 60 | filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. 61 | filter_f2: Viterbi pre-filter, set to >1.0 to turn off. 62 | filter_f3: Forward pre-filter, set to >1.0 to turn off. 63 | incdom_e: Domain e-value criteria for inclusion of domains in MSA/next 64 | round. 65 | dom_e: Domain e-value criteria for inclusion in tblout. 66 | num_streamed_chunks: Number of database chunks to stream over. 67 | streaming_callback: Callback function run after each chunk iteration with 68 | the iteration number as argument. 69 | """ 70 | self.binary_path = binary_path 71 | self.database_path = database_path 72 | self.num_streamed_chunks = num_streamed_chunks 73 | 74 | if not os.path.exists(self.database_path) and num_streamed_chunks is None: 75 | logging.error('Could not find Jackhmmer database %s', database_path) 76 | raise ValueError(f'Could not find Jackhmmer database {database_path}') 77 | 78 | self.n_cpu = n_cpu 79 | self.n_iter = n_iter 80 | self.e_value = e_value 81 | self.z_value = z_value 82 | self.filter_f1 = filter_f1 83 | self.filter_f2 = filter_f2 84 | self.filter_f3 = filter_f3 85 | self.incdom_e = incdom_e 86 | self.dom_e = dom_e 87 | self.get_tblout = get_tblout 88 | self.streaming_callback = streaming_callback 89 | 90 | def _query_chunk(self, 91 | input_fasta_path: str, 92 | database_path: str, 93 | max_sequences: Optional[int] = None) -> Mapping[str, Any]: 94 | """Queries the database chunk using Jackhmmer.""" 95 | with utils.tmpdir_manager() as query_tmp_dir: 96 | sto_path = os.path.join(query_tmp_dir, 'output.sto') 97 | 98 | # The F1/F2/F3 are the expected proportion to pass each of the filtering 99 | # stages (which get progressively more expensive), reducing these 100 | # speeds up the pipeline at the expensive of sensitivity. They are 101 | # currently set very low to make querying Mgnify run in a reasonable 102 | # amount of time. 103 | cmd_flags = [ 104 | # Don't pollute stdout with Jackhmmer output. 105 | '-o', '/dev/null', 106 | '-A', sto_path, 107 | '--noali', 108 | '--F1', str(self.filter_f1), 109 | '--F2', str(self.filter_f2), 110 | '--F3', str(self.filter_f3), 111 | '--incE', str(self.e_value), 112 | # Report only sequences with E-values <= x in per-sequence output. 113 | '-E', str(self.e_value), 114 | '--cpu', str(self.n_cpu), 115 | '-N', str(self.n_iter) 116 | ] 117 | if self.get_tblout: 118 | tblout_path = os.path.join(query_tmp_dir, 'tblout.txt') 119 | cmd_flags.extend(['--tblout', tblout_path]) 120 | 121 | if self.z_value: 122 | cmd_flags.extend(['-Z', str(self.z_value)]) 123 | 124 | if self.dom_e is not None: 125 | cmd_flags.extend(['--domE', str(self.dom_e)]) 126 | 127 | if self.incdom_e is not None: 128 | cmd_flags.extend(['--incdomE', str(self.incdom_e)]) 129 | 130 | cmd = [self.binary_path] + cmd_flags + [input_fasta_path, 131 | database_path] 132 | 133 | logging.info('Launching subprocess "%s"', ' '.join(cmd)) 134 | process = subprocess.Popen( 135 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 136 | with utils.timing( 137 | f'Jackhmmer ({os.path.basename(database_path)}) query'): 138 | _, stderr = process.communicate() 139 | retcode = process.wait() 140 | 141 | if retcode: 142 | raise RuntimeError( 143 | 'Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8')) 144 | 145 | # Get e-values for each target name 146 | tbl = '' 147 | if self.get_tblout: 148 | with open(tblout_path) as f: 149 | tbl = f.read() 150 | 151 | if max_sequences is None: 152 | with open(sto_path) as f: 153 | sto = f.read() 154 | else: 155 | sto = parsers.truncate_stockholm_msa(sto_path, max_sequences) 156 | 157 | raw_output = dict( 158 | sto=sto, 159 | tbl=tbl, 160 | stderr=stderr, 161 | n_iter=self.n_iter, 162 | e_value=self.e_value) 163 | 164 | return raw_output 165 | 166 | def query(self, 167 | input_fasta_path: str, 168 | max_sequences: Optional[int] = None) -> Sequence[Mapping[str, Any]]: 169 | """Queries the database using Jackhmmer.""" 170 | if self.num_streamed_chunks is None: 171 | single_chunk_result = self._query_chunk( 172 | input_fasta_path, self.database_path, max_sequences) 173 | return [single_chunk_result] 174 | 175 | db_basename = os.path.basename(self.database_path) 176 | db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' 177 | db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}' 178 | 179 | # Remove existing files to prevent OOM 180 | for f in glob.glob(db_local_chunk('[0-9]*')): 181 | try: 182 | os.remove(f) 183 | except OSError: 184 | print(f'OSError while deleting {f}') 185 | 186 | # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk 187 | with futures.ThreadPoolExecutor(max_workers=2) as executor: 188 | chunked_output = [] 189 | for i in range(1, self.num_streamed_chunks + 1): 190 | # Copy the chunk locally 191 | if i == 1: 192 | future = executor.submit( 193 | request.urlretrieve, db_remote_chunk(i), db_local_chunk(i)) 194 | if i < self.num_streamed_chunks: 195 | next_future = executor.submit( 196 | request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1)) 197 | 198 | # Run Jackhmmer with the chunk 199 | future.result() 200 | chunked_output.append(self._query_chunk( 201 | input_fasta_path, db_local_chunk(i), max_sequences)) 202 | 203 | # Remove the local copy of the chunk 204 | os.remove(db_local_chunk(i)) 205 | # Do not set next_future for the last chunk so that this works even for 206 | # databases with only 1 chunk. 207 | if i < self.num_streamed_chunks: 208 | future = next_future 209 | if self.streaming_callback: 210 | self.streaming_callback(i) 211 | return chunked_output 212 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/kalign.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A Python wrapper for Kalign.""" 16 | import os 17 | import subprocess 18 | from typing import Sequence 19 | 20 | from absl import logging 21 | 22 | from alphafold.data.tools import utils 23 | # Internal import (7716). 24 | 25 | 26 | def _to_a3m(sequences: Sequence[str]) -> str: 27 | """Converts sequences to an a3m file.""" 28 | names = ['sequence %d' % i for i in range(1, len(sequences) + 1)] 29 | a3m = [] 30 | for sequence, name in zip(sequences, names): 31 | a3m.append(u'>' + name + u'\n') 32 | a3m.append(sequence + u'\n') 33 | return ''.join(a3m) 34 | 35 | 36 | class Kalign: 37 | """Python wrapper of the Kalign binary.""" 38 | 39 | def __init__(self, *, binary_path: str): 40 | """Initializes the Python Kalign wrapper. 41 | 42 | Args: 43 | binary_path: The path to the Kalign binary. 44 | 45 | Raises: 46 | RuntimeError: If Kalign binary not found within the path. 47 | """ 48 | self.binary_path = binary_path 49 | 50 | def align(self, sequences: Sequence[str]) -> str: 51 | """Aligns the sequences and returns the alignment in A3M string. 52 | 53 | Args: 54 | sequences: A list of query sequence strings. The sequences have to be at 55 | least 6 residues long (Kalign requires this). Note that the order in 56 | which you give the sequences might alter the output slightly as 57 | different alignment tree might get constructed. 58 | 59 | Returns: 60 | A string with the alignment in a3m format. 61 | 62 | Raises: 63 | RuntimeError: If Kalign fails. 64 | ValueError: If any of the sequences is less than 6 residues long. 65 | """ 66 | logging.info('Aligning %d sequences', len(sequences)) 67 | 68 | for s in sequences: 69 | if len(s) < 6: 70 | raise ValueError('Kalign requires all sequences to be at least 6 ' 71 | 'residues long. Got %s (%d residues).' % (s, len(s))) 72 | 73 | with utils.tmpdir_manager() as query_tmp_dir: 74 | input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta') 75 | output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') 76 | 77 | with open(input_fasta_path, 'w') as f: 78 | f.write(_to_a3m(sequences)) 79 | 80 | cmd = [ 81 | self.binary_path, 82 | '-i', input_fasta_path, 83 | '-o', output_a3m_path, 84 | '-format', 'fasta', 85 | ] 86 | 87 | logging.info('Launching subprocess "%s"', ' '.join(cmd)) 88 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, 89 | stderr=subprocess.PIPE) 90 | 91 | with utils.timing('Kalign query'): 92 | stdout, stderr = process.communicate() 93 | retcode = process.wait() 94 | logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n', 95 | stdout.decode('utf-8'), stderr.decode('utf-8')) 96 | 97 | if retcode: 98 | raise RuntimeError('Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' 99 | % (stdout.decode('utf-8'), stderr.decode('utf-8'))) 100 | 101 | with open(output_a3m_path) as f: 102 | a3m = f.read() 103 | 104 | return a3m 105 | -------------------------------------------------------------------------------- /src/alphafold/data/tools/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Common utilities for data pipeline tools.""" 15 | import contextlib 16 | import shutil 17 | import tempfile 18 | import time 19 | from typing import Optional 20 | 21 | from absl import logging 22 | 23 | 24 | @contextlib.contextmanager 25 | def tmpdir_manager(base_dir: Optional[str] = None): 26 | """Context manager that deletes a temporary directory on exit.""" 27 | tmpdir = tempfile.mkdtemp(dir=base_dir) 28 | try: 29 | yield tmpdir 30 | finally: 31 | shutil.rmtree(tmpdir, ignore_errors=True) 32 | 33 | 34 | @contextlib.contextmanager 35 | def timing(msg: str): 36 | logging.info('Started %s', msg) 37 | tic = time.time() 38 | yield 39 | toc = time.time() 40 | logging.info('Finished %s in %.3f seconds', msg, toc - tic) 41 | -------------------------------------------------------------------------------- /src/alphafold/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Alphafold model.""" 15 | -------------------------------------------------------------------------------- /src/alphafold/model/all_atom_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for all_atom.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from alphafold.model import all_atom 20 | from alphafold.model import r3 21 | import numpy as np 22 | 23 | L1_CLAMP_DISTANCE = 10 24 | 25 | 26 | def get_identity_rigid(shape): 27 | """Returns identity rigid transform.""" 28 | 29 | ones = np.ones(shape) 30 | zeros = np.zeros(shape) 31 | rot = r3.Rots(ones, zeros, zeros, 32 | zeros, ones, zeros, 33 | zeros, zeros, ones) 34 | trans = r3.Vecs(zeros, zeros, zeros) 35 | return r3.Rigids(rot, trans) 36 | 37 | 38 | def get_global_rigid_transform(rot_angle, translation, bcast_dims): 39 | """Returns rigid transform that globally rotates/translates by same amount.""" 40 | 41 | rot_angle = np.asarray(rot_angle) 42 | translation = np.asarray(translation) 43 | if bcast_dims: 44 | for _ in range(bcast_dims): 45 | rot_angle = np.expand_dims(rot_angle, 0) 46 | translation = np.expand_dims(translation, 0) 47 | sin_angle = np.sin(np.deg2rad(rot_angle)) 48 | cos_angle = np.cos(np.deg2rad(rot_angle)) 49 | ones = np.ones_like(sin_angle) 50 | zeros = np.zeros_like(sin_angle) 51 | rot = r3.Rots(ones, zeros, zeros, 52 | zeros, cos_angle, -sin_angle, 53 | zeros, sin_angle, cos_angle) 54 | trans = r3.Vecs(translation[..., 0], translation[..., 1], translation[..., 2]) 55 | return r3.Rigids(rot, trans) 56 | 57 | 58 | class AllAtomTest(parameterized.TestCase, absltest.TestCase): 59 | 60 | @parameterized.named_parameters( 61 | ('identity', 0, [0, 0, 0]), 62 | ('rot_90', 90, [0, 0, 0]), 63 | ('trans_10', 0, [0, 0, 10]), 64 | ('rot_174_trans_1', 174, [1, 1, 1])) 65 | def test_frame_aligned_point_error_perfect_on_global_transform( 66 | self, rot_angle, translation): 67 | """Tests global transform between target and preds gives perfect score.""" 68 | 69 | # pylint: disable=bad-whitespace 70 | target_positions = np.array( 71 | [[ 21.182, 23.095, 19.731], 72 | [ 22.055, 20.919, 17.294], 73 | [ 24.599, 20.005, 15.041], 74 | [ 25.567, 18.214, 12.166], 75 | [ 28.063, 17.082, 10.043], 76 | [ 28.779, 15.569, 6.985], 77 | [ 30.581, 13.815, 4.612], 78 | [ 29.258, 12.193, 2.296]]) 79 | # pylint: enable=bad-whitespace 80 | global_rigid_transform = get_global_rigid_transform( 81 | rot_angle, translation, 1) 82 | 83 | target_positions = r3.vecs_from_tensor(target_positions) 84 | pred_positions = r3.rigids_mul_vecs( 85 | global_rigid_transform, target_positions) 86 | positions_mask = np.ones(target_positions.x.shape[0]) 87 | 88 | target_frames = get_identity_rigid(10) 89 | pred_frames = r3.rigids_mul_rigids(global_rigid_transform, target_frames) 90 | frames_mask = np.ones(10) 91 | 92 | fape = all_atom.frame_aligned_point_error( 93 | pred_frames, target_frames, frames_mask, pred_positions, 94 | target_positions, positions_mask, L1_CLAMP_DISTANCE, 95 | L1_CLAMP_DISTANCE, epsilon=0) 96 | self.assertAlmostEqual(fape, 0.) 97 | 98 | @parameterized.named_parameters( 99 | ('identity', 100 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 101 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 102 | 0.), 103 | ('shift_2.5', 104 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 105 | [[2.5, 0, 0], [7.5, 0, 0], [7.5, 0, 0]], 106 | 0.25), 107 | ('shift_5', 108 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 109 | [[5, 0, 0], [10, 0, 0], [15, 0, 0]], 110 | 0.5), 111 | ('shift_10', 112 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 113 | [[10, 0, 0], [15, 0, 0], [0, 0, 0]], 114 | 1.)) 115 | def test_frame_aligned_point_error_matches_expected( 116 | self, target_positions, pred_positions, expected_alddt): 117 | """Tests score matches expected.""" 118 | 119 | target_frames = get_identity_rigid(2) 120 | pred_frames = target_frames 121 | frames_mask = np.ones(2) 122 | 123 | target_positions = r3.vecs_from_tensor(np.array(target_positions)) 124 | pred_positions = r3.vecs_from_tensor(np.array(pred_positions)) 125 | positions_mask = np.ones(target_positions.x.shape[0]) 126 | 127 | alddt = all_atom.frame_aligned_point_error( 128 | pred_frames, target_frames, frames_mask, pred_positions, 129 | target_positions, positions_mask, L1_CLAMP_DISTANCE, 130 | L1_CLAMP_DISTANCE, epsilon=0) 131 | self.assertAlmostEqual(alddt, expected_alddt) 132 | 133 | 134 | if __name__ == '__main__': 135 | absltest.main() 136 | -------------------------------------------------------------------------------- /src/alphafold/model/common_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A collection of common Haiku modules for use in protein folding.""" 16 | import numbers 17 | from typing import Union, Sequence 18 | 19 | import haiku as hk 20 | import jax.numpy as jnp 21 | import numpy as np 22 | 23 | 24 | # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) 25 | TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978, 26 | dtype=np.float32) 27 | 28 | 29 | def get_initializer_scale(initializer_name, input_shape): 30 | """Get Initializer for weights and scale to multiply activations by.""" 31 | 32 | if initializer_name == 'zeros': 33 | w_init = hk.initializers.Constant(0.0) 34 | else: 35 | # fan-in scaling 36 | scale = 1. 37 | for channel_dim in input_shape: 38 | scale /= channel_dim 39 | if initializer_name == 'relu': 40 | scale *= 2 41 | 42 | noise_scale = scale 43 | 44 | stddev = np.sqrt(noise_scale) 45 | # Adjust stddev for truncation. 46 | stddev = stddev / TRUNCATED_NORMAL_STDDEV_FACTOR 47 | w_init = hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev) 48 | 49 | return w_init 50 | 51 | 52 | class Linear(hk.Module): 53 | """Protein folding specific Linear module. 54 | 55 | This differs from the standard Haiku Linear in a few ways: 56 | * It supports inputs and outputs of arbitrary rank 57 | * Initializers are specified by strings 58 | """ 59 | 60 | def __init__(self, 61 | num_output: Union[int, Sequence[int]], 62 | initializer: str = 'linear', 63 | num_input_dims: int = 1, 64 | use_bias: bool = True, 65 | bias_init: float = 0., 66 | precision = None, 67 | name: str = 'linear'): 68 | """Constructs Linear Module. 69 | 70 | Args: 71 | num_output: Number of output channels. Can be tuple when outputting 72 | multiple dimensions. 73 | initializer: What initializer to use, should be one of {'linear', 'relu', 74 | 'zeros'} 75 | num_input_dims: Number of dimensions from the end to project. 76 | use_bias: Whether to include trainable bias 77 | bias_init: Value used to initialize bias. 78 | precision: What precision to use for matrix multiplication, defaults 79 | to None. 80 | name: Name of module, used for name scopes. 81 | """ 82 | super().__init__(name=name) 83 | if isinstance(num_output, numbers.Integral): 84 | self.output_shape = (num_output,) 85 | else: 86 | self.output_shape = tuple(num_output) 87 | self.initializer = initializer 88 | self.use_bias = use_bias 89 | self.bias_init = bias_init 90 | self.num_input_dims = num_input_dims 91 | self.num_output_dims = len(self.output_shape) 92 | self.precision = precision 93 | 94 | def __call__(self, inputs): 95 | """Connects Module. 96 | 97 | Args: 98 | inputs: Tensor with at least num_input_dims dimensions. 99 | 100 | Returns: 101 | output of shape [...] + num_output. 102 | """ 103 | 104 | num_input_dims = self.num_input_dims 105 | 106 | if self.num_input_dims > 0: 107 | in_shape = inputs.shape[-self.num_input_dims:] 108 | else: 109 | in_shape = () 110 | 111 | weight_init = get_initializer_scale(self.initializer, in_shape) 112 | 113 | in_letters = 'abcde'[:self.num_input_dims] 114 | out_letters = 'hijkl'[:self.num_output_dims] 115 | 116 | weight_shape = in_shape + self.output_shape 117 | weights = hk.get_parameter('weights', weight_shape, inputs.dtype, 118 | weight_init) 119 | 120 | equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}' 121 | 122 | output = jnp.einsum(equation, inputs, weights, precision=self.precision) 123 | 124 | if self.use_bias: 125 | bias = hk.get_parameter('bias', self.output_shape, inputs.dtype, 126 | hk.initializers.Constant(self.bias_init)) 127 | output += bias 128 | 129 | return output 130 | 131 | 132 | class LayerNorm(hk.LayerNorm): 133 | """LayerNorm module. 134 | 135 | Equivalent to hk.LayerNorm but with different parameter shapes: they are 136 | always vectors rather than possibly higher-rank tensors. This makes it easier 137 | to change the layout whilst keep the model weight-compatible. 138 | """ 139 | 140 | def __init__(self, 141 | axis, 142 | create_scale: bool, 143 | create_offset: bool, 144 | eps: float = 1e-5, 145 | scale_init=None, 146 | offset_init=None, 147 | use_fast_variance: bool = False, 148 | name=None, 149 | param_axis=None): 150 | super().__init__( 151 | axis=axis, 152 | create_scale=False, 153 | create_offset=False, 154 | eps=eps, 155 | scale_init=None, 156 | offset_init=None, 157 | use_fast_variance=use_fast_variance, 158 | name=name, 159 | param_axis=param_axis) 160 | self._temp_create_scale = create_scale 161 | self._temp_create_offset = create_offset 162 | 163 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 164 | is_bf16 = (x.dtype == jnp.bfloat16) 165 | if is_bf16: 166 | x = x.astype(jnp.float32) 167 | 168 | param_axis = self.param_axis[0] if self.param_axis else -1 169 | param_shape = (x.shape[param_axis],) 170 | 171 | param_broadcast_shape = [1] * x.ndim 172 | param_broadcast_shape[param_axis] = x.shape[param_axis] 173 | scale = None 174 | offset = None 175 | if self._temp_create_scale: 176 | scale = hk.get_parameter( 177 | 'scale', param_shape, x.dtype, init=self.scale_init) 178 | scale = scale.reshape(param_broadcast_shape) 179 | 180 | if self._temp_create_offset: 181 | offset = hk.get_parameter( 182 | 'offset', param_shape, x.dtype, init=self.offset_init) 183 | offset = offset.reshape(param_broadcast_shape) 184 | 185 | out = super().__call__(x, scale=scale, offset=offset) 186 | 187 | if is_bf16: 188 | out = out.astype(jnp.bfloat16) 189 | 190 | return out 191 | -------------------------------------------------------------------------------- /src/alphafold/model/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Convenience functions for reading data.""" 16 | 17 | import io 18 | import os 19 | from alphafold.model import utils 20 | import haiku as hk 21 | import numpy as np 22 | # Internal import (7716). 23 | 24 | 25 | def get_model_haiku_params(model_name: str, data_dir: str) -> hk.Params: 26 | """Get the Haiku parameters from a model name.""" 27 | 28 | path = os.path.join(data_dir, 'params', f'params_{model_name}.npz') 29 | 30 | with open(path, 'rb') as f: 31 | params = np.load(io.BytesIO(f.read()), allow_pickle=False) 32 | 33 | return utils.flat_params_to_haiku(params) 34 | -------------------------------------------------------------------------------- /src/alphafold/model/features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Code to generate processed features.""" 16 | import copy 17 | from typing import List, Mapping, Tuple 18 | 19 | from alphafold.model.tf import input_pipeline 20 | from alphafold.model.tf import proteins_dataset 21 | 22 | import ml_collections 23 | import numpy as np 24 | import tensorflow.compat.v1 as tf 25 | 26 | FeatureDict = Mapping[str, np.ndarray] 27 | 28 | 29 | def make_data_config( 30 | config: ml_collections.ConfigDict, 31 | num_res: int, 32 | ) -> Tuple[ml_collections.ConfigDict, List[str]]: 33 | """Makes a data config for the input pipeline.""" 34 | cfg = copy.deepcopy(config.data) 35 | 36 | feature_names = cfg.common.unsupervised_features 37 | if cfg.common.use_templates: 38 | feature_names += cfg.common.template_features 39 | 40 | with cfg.unlocked(): 41 | cfg.eval.crop_size = num_res 42 | 43 | return cfg, feature_names 44 | 45 | 46 | def tf_example_to_features(tf_example: tf.train.Example, 47 | config: ml_collections.ConfigDict, 48 | random_seed: int = 0) -> FeatureDict: 49 | """Converts tf_example to numpy feature dictionary.""" 50 | num_res = int(tf_example.features.feature['seq_length'].int64_list.value[0]) 51 | cfg, feature_names = make_data_config(config, num_res=num_res) 52 | 53 | if 'deletion_matrix_int' in set(tf_example.features.feature): 54 | deletion_matrix_int = ( 55 | tf_example.features.feature['deletion_matrix_int'].int64_list.value) 56 | feat = tf.train.Feature(float_list=tf.train.FloatList( 57 | value=map(float, deletion_matrix_int))) 58 | tf_example.features.feature['deletion_matrix'].CopyFrom(feat) 59 | del tf_example.features.feature['deletion_matrix_int'] 60 | 61 | tf_graph = tf.Graph() 62 | with tf_graph.as_default(), tf.device('/device:CPU:0'): 63 | tf.compat.v1.set_random_seed(random_seed) 64 | tensor_dict = proteins_dataset.create_tensor_dict( 65 | raw_data=tf_example.SerializeToString(), 66 | features=feature_names) 67 | processed_batch = input_pipeline.process_tensors_from_config( 68 | tensor_dict, cfg) 69 | 70 | tf_graph.finalize() 71 | 72 | with tf.Session(graph=tf_graph) as sess: 73 | features = sess.run(processed_batch) 74 | 75 | return {k: v for k, v in features.items() if v.dtype != 'O'} 76 | 77 | 78 | def np_example_to_features(np_example: FeatureDict, 79 | config: ml_collections.ConfigDict, 80 | random_seed: int = 0) -> FeatureDict: 81 | """Preprocesses NumPy feature dict using TF pipeline.""" 82 | np_example = dict(np_example) 83 | num_res = int(np_example['seq_length'][0]) 84 | cfg, feature_names = make_data_config(config, num_res=num_res) 85 | 86 | if 'deletion_matrix_int' in np_example: 87 | np_example['deletion_matrix'] = ( 88 | np_example.pop('deletion_matrix_int').astype(np.float32)) 89 | 90 | tf_graph = tf.Graph() 91 | with tf_graph.as_default(), tf.device('/device:CPU:0'): 92 | tf.compat.v1.set_random_seed(random_seed) 93 | tensor_dict = proteins_dataset.np_to_tensor_dict( 94 | np_example=np_example, features=feature_names) 95 | 96 | processed_batch = input_pipeline.process_tensors_from_config( 97 | tensor_dict, cfg) 98 | 99 | tf_graph.finalize() 100 | 101 | with tf.Session(graph=tf_graph) as sess: 102 | features = sess.run(processed_batch) 103 | 104 | return {k: v for k, v in features.items() if v.dtype != 'O'} 105 | -------------------------------------------------------------------------------- /src/alphafold/model/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Geometry Module.""" 15 | 16 | from alphafold.model.geometry import rigid_matrix_vector 17 | from alphafold.model.geometry import rotation_matrix 18 | from alphafold.model.geometry import struct_of_array 19 | from alphafold.model.geometry import vector 20 | 21 | Rot3Array = rotation_matrix.Rot3Array 22 | Rigid3Array = rigid_matrix_vector.Rigid3Array 23 | 24 | StructOfArray = struct_of_array.StructOfArray 25 | 26 | Vec3Array = vector.Vec3Array 27 | square_euclidean_distance = vector.square_euclidean_distance 28 | euclidean_distance = vector.euclidean_distance 29 | dihedral_angle = vector.dihedral_angle 30 | dot = vector.dot 31 | cross = vector.cross 32 | -------------------------------------------------------------------------------- /src/alphafold/model/geometry/rigid_matrix_vector.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Rigid3Array Transformations represented by a Matrix and a Vector.""" 15 | 16 | from __future__ import annotations 17 | from typing import Union 18 | 19 | from alphafold.model.geometry import rotation_matrix 20 | from alphafold.model.geometry import struct_of_array 21 | from alphafold.model.geometry import vector 22 | import jax 23 | import jax.numpy as jnp 24 | 25 | Float = Union[float, jnp.ndarray] 26 | 27 | VERSION = '0.1' 28 | 29 | 30 | @struct_of_array.StructOfArray(same_dtype=True) 31 | class Rigid3Array: 32 | """Rigid Transformation, i.e. element of special euclidean group.""" 33 | 34 | rotation: rotation_matrix.Rot3Array 35 | translation: vector.Vec3Array 36 | 37 | def __matmul__(self, other: Rigid3Array) -> Rigid3Array: 38 | new_rotation = self.rotation @ other.rotation 39 | new_translation = self.apply_to_point(other.translation) 40 | return Rigid3Array(new_rotation, new_translation) 41 | 42 | def inverse(self) -> Rigid3Array: 43 | """Return Rigid3Array corresponding to inverse transform.""" 44 | inv_rotation = self.rotation.inverse() 45 | inv_translation = inv_rotation.apply_to_point(-self.translation) 46 | return Rigid3Array(inv_rotation, inv_translation) 47 | 48 | def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: 49 | """Apply Rigid3Array transform to point.""" 50 | return self.rotation.apply_to_point(point) + self.translation 51 | 52 | def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: 53 | """Apply inverse Rigid3Array transform to point.""" 54 | new_point = point - self.translation 55 | return self.rotation.apply_inverse_to_point(new_point) 56 | 57 | def compose_rotation(self, other_rotation): 58 | rot = self.rotation @ other_rotation 59 | trans = jax.tree_map(lambda x: jnp.broadcast_to(x, rot.shape), 60 | self.translation) 61 | return Rigid3Array(rot, trans) 62 | 63 | @classmethod 64 | def identity(cls, shape, dtype=jnp.float32) -> Rigid3Array: 65 | """Return identity Rigid3Array of given shape.""" 66 | return cls( 67 | rotation_matrix.Rot3Array.identity(shape, dtype=dtype), 68 | vector.Vec3Array.zeros(shape, dtype=dtype)) # pytype: disable=wrong-arg-count # trace-all-classes 69 | 70 | def scale_translation(self, factor: Float) -> Rigid3Array: 71 | """Scale translation in Rigid3Array by 'factor'.""" 72 | return Rigid3Array(self.rotation, self.translation * factor) 73 | 74 | def to_array(self): 75 | rot_array = self.rotation.to_array() 76 | vec_array = self.translation.to_array() 77 | return jnp.concatenate([rot_array, vec_array[..., None]], axis=-1) 78 | 79 | @classmethod 80 | def from_array(cls, array): 81 | rot = rotation_matrix.Rot3Array.from_array(array[..., :3]) 82 | vec = vector.Vec3Array.from_array(array[..., -1]) 83 | return cls(rot, vec) # pytype: disable=wrong-arg-count # trace-all-classes 84 | 85 | @classmethod 86 | def from_array4x4(cls, array: jnp.ndarray) -> Rigid3Array: 87 | """Construct Rigid3Array from homogeneous 4x4 array.""" 88 | assert array.shape[-1] == 4 89 | assert array.shape[-2] == 4 90 | rotation = rotation_matrix.Rot3Array( 91 | array[..., 0, 0], array[..., 0, 1], array[..., 0, 2], 92 | array[..., 1, 0], array[..., 1, 1], array[..., 1, 2], 93 | array[..., 2, 0], array[..., 2, 1], array[..., 2, 2] 94 | ) 95 | translation = vector.Vec3Array( 96 | array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]) 97 | return cls(rotation, translation) # pytype: disable=wrong-arg-count # trace-all-classes 98 | 99 | def __getstate__(self): 100 | return (VERSION, (self.rotation, self.translation)) 101 | 102 | def __setstate__(self, state): 103 | version, (rot, trans) = state 104 | del version 105 | object.__setattr__(self, 'rotation', rot) 106 | object.__setattr__(self, 'translation', trans) 107 | -------------------------------------------------------------------------------- /src/alphafold/model/geometry/rotation_matrix.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Rot3Array Matrix Class.""" 15 | 16 | from __future__ import annotations 17 | import dataclasses 18 | 19 | from alphafold.model.geometry import struct_of_array 20 | from alphafold.model.geometry import utils 21 | from alphafold.model.geometry import vector 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz'] 27 | 28 | VERSION = '0.1' 29 | 30 | 31 | @struct_of_array.StructOfArray(same_dtype=True) 32 | class Rot3Array: 33 | """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays.""" 34 | 35 | xx: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32}) 36 | xy: jnp.ndarray 37 | xz: jnp.ndarray 38 | yx: jnp.ndarray 39 | yy: jnp.ndarray 40 | yz: jnp.ndarray 41 | zx: jnp.ndarray 42 | zy: jnp.ndarray 43 | zz: jnp.ndarray 44 | 45 | __array_ufunc__ = None 46 | 47 | def inverse(self) -> Rot3Array: 48 | """Returns inverse of Rot3Array.""" 49 | return Rot3Array(self.xx, self.yx, self.zx, 50 | self.xy, self.yy, self.zy, 51 | self.xz, self.yz, self.zz) 52 | 53 | def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: 54 | """Applies Rot3Array to point.""" 55 | return vector.Vec3Array( 56 | self.xx * point.x + self.xy * point.y + self.xz * point.z, 57 | self.yx * point.x + self.yy * point.y + self.yz * point.z, 58 | self.zx * point.x + self.zy * point.y + self.zz * point.z) 59 | 60 | def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: 61 | """Applies inverse Rot3Array to point.""" 62 | return self.inverse().apply_to_point(point) 63 | 64 | def __matmul__(self, other: Rot3Array) -> Rot3Array: 65 | """Composes two Rot3Arrays.""" 66 | c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx)) 67 | c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy)) 68 | c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz)) 69 | return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) 70 | 71 | @classmethod 72 | def identity(cls, shape, dtype=jnp.float32) -> Rot3Array: 73 | """Returns identity of given shape.""" 74 | ones = jnp.ones(shape, dtype=dtype) 75 | zeros = jnp.zeros(shape, dtype=dtype) 76 | return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) # pytype: disable=wrong-arg-count # trace-all-classes 77 | 78 | @classmethod 79 | def from_two_vectors(cls, e0: vector.Vec3Array, 80 | e1: vector.Vec3Array) -> Rot3Array: 81 | """Construct Rot3Array from two Vectors. 82 | 83 | Rot3Array is constructed such that in the corresponding frame 'e0' lies on 84 | the positive x-Axis and 'e1' lies in the xy plane with positive sign of y. 85 | 86 | Args: 87 | e0: Vector 88 | e1: Vector 89 | Returns: 90 | Rot3Array 91 | """ 92 | # Normalize the unit vector for the x-axis, e0. 93 | e0 = e0.normalized() 94 | # make e1 perpendicular to e0. 95 | c = e1.dot(e0) 96 | e1 = (e1 - c * e0).normalized() 97 | # Compute e2 as cross product of e0 and e1. 98 | e2 = e0.cross(e1) 99 | return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) # pytype: disable=wrong-arg-count # trace-all-classes 100 | 101 | @classmethod 102 | def from_array(cls, array: jnp.ndarray) -> Rot3Array: 103 | """Construct Rot3Array Matrix from array of shape. [..., 3, 3].""" 104 | unstacked = utils.unstack(array, axis=-2) 105 | unstacked = sum([utils.unstack(x, axis=-1) for x in unstacked], []) 106 | return cls(*unstacked) 107 | 108 | def to_array(self) -> jnp.ndarray: 109 | """Convert Rot3Array to array of shape [..., 3, 3].""" 110 | return jnp.stack( 111 | [jnp.stack([self.xx, self.xy, self.xz], axis=-1), 112 | jnp.stack([self.yx, self.yy, self.yz], axis=-1), 113 | jnp.stack([self.zx, self.zy, self.zz], axis=-1)], 114 | axis=-2) 115 | 116 | @classmethod 117 | def from_quaternion(cls, 118 | w: jnp.ndarray, 119 | x: jnp.ndarray, 120 | y: jnp.ndarray, 121 | z: jnp.ndarray, 122 | normalize: bool = True, 123 | epsilon: float = 1e-6) -> Rot3Array: 124 | """Construct Rot3Array from components of quaternion.""" 125 | if normalize: 126 | inv_norm = jax.lax.rsqrt(jnp.maximum(epsilon, w**2 + x**2 + y**2 + z**2)) 127 | w *= inv_norm 128 | x *= inv_norm 129 | y *= inv_norm 130 | z *= inv_norm 131 | xx = 1 - 2 * (jnp.square(y) + jnp.square(z)) 132 | xy = 2 * (x * y - w * z) 133 | xz = 2 * (x * z + w * y) 134 | yx = 2 * (x * y + w * z) 135 | yy = 1 - 2 * (jnp.square(x) + jnp.square(z)) 136 | yz = 2 * (y * z - w * x) 137 | zx = 2 * (x * z - w * y) 138 | zy = 2 * (y * z + w * x) 139 | zz = 1 - 2 * (jnp.square(x) + jnp.square(y)) 140 | return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) # pytype: disable=wrong-arg-count # trace-all-classes 141 | 142 | @classmethod 143 | def random_uniform(cls, key, shape, dtype=jnp.float32) -> Rot3Array: 144 | """Samples uniform random Rot3Array according to Haar Measure.""" 145 | quat_array = jax.random.normal(key, tuple(shape) + (4,), dtype=dtype) 146 | quats = utils.unstack(quat_array) 147 | return cls.from_quaternion(*quats) 148 | 149 | def __getstate__(self): 150 | return (VERSION, 151 | [np.asarray(getattr(self, field)) for field in COMPONENTS]) 152 | 153 | def __setstate__(self, state): 154 | version, state = state 155 | del version 156 | for i, field in enumerate(COMPONENTS): 157 | object.__setattr__(self, field, state[i]) 158 | -------------------------------------------------------------------------------- /src/alphafold/model/geometry/struct_of_array.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Class decorator to represent (nested) struct of arrays.""" 15 | 16 | import dataclasses 17 | 18 | import jax 19 | 20 | 21 | def get_item(instance, key): 22 | sliced = {} 23 | for field in get_array_fields(instance): 24 | num_trailing_dims = field.metadata.get('num_trailing_dims', 0) 25 | this_key = key 26 | if isinstance(key, tuple) and Ellipsis in this_key: 27 | this_key += (slice(None),) * num_trailing_dims 28 | sliced[field.name] = getattr(instance, field.name)[this_key] 29 | return dataclasses.replace(instance, **sliced) 30 | 31 | 32 | @property 33 | def get_shape(instance): 34 | """Returns Shape for given instance of dataclass.""" 35 | first_field = dataclasses.fields(instance)[0] 36 | num_trailing_dims = first_field.metadata.get('num_trailing_dims', None) 37 | value = getattr(instance, first_field.name) 38 | if num_trailing_dims: 39 | return value.shape[:-num_trailing_dims] 40 | else: 41 | return value.shape 42 | 43 | 44 | def get_len(instance): 45 | """Returns length for given instance of dataclass.""" 46 | shape = instance.shape 47 | if shape: 48 | return shape[0] 49 | else: 50 | raise TypeError('len() of unsized object') # Match jax.numpy behavior. 51 | 52 | 53 | @property 54 | def get_dtype(instance): 55 | """Returns Dtype for given instance of dataclass.""" 56 | fields = dataclasses.fields(instance) 57 | sets_dtype = [ 58 | field.name for field in fields if field.metadata.get('sets_dtype', False) 59 | ] 60 | if sets_dtype: 61 | assert len(sets_dtype) == 1, 'at most field can set dtype' 62 | field_value = getattr(instance, sets_dtype[0]) 63 | elif instance.same_dtype: 64 | field_value = getattr(instance, fields[0].name) 65 | else: 66 | # Should this be Value Error? 67 | raise AttributeError('Trying to access Dtype on Struct of Array without' 68 | 'either "same_dtype" or field setting dtype') 69 | 70 | if hasattr(field_value, 'dtype'): 71 | return field_value.dtype 72 | else: 73 | # Should this be Value Error? 74 | raise AttributeError(f'field_value {field_value} does not have dtype') 75 | 76 | 77 | def replace(instance, **kwargs): 78 | return dataclasses.replace(instance, **kwargs) 79 | 80 | 81 | def post_init(instance): 82 | """Validate instance has same shapes & dtypes.""" 83 | array_fields = get_array_fields(instance) 84 | arrays = list(get_array_fields(instance, return_values=True).values()) 85 | first_field = array_fields[0] 86 | # These slightly weird constructions about checking whether the leaves are 87 | # actual arrays is since e.g. vmap internally relies on being able to 88 | # construct pytree's with object() as leaves, this would break the checking 89 | # as such we are only validating the object when the entries in the dataclass 90 | # Are arrays or other dataclasses of arrays. 91 | try: 92 | dtype = instance.dtype 93 | except AttributeError: 94 | dtype = None 95 | if dtype is not None: 96 | first_shape = instance.shape 97 | for array, field in zip(arrays, array_fields): 98 | field_shape = array.shape 99 | num_trailing_dims = field.metadata.get('num_trailing_dims', None) 100 | if num_trailing_dims: 101 | array_shape = array.shape 102 | field_shape = array_shape[:-num_trailing_dims] 103 | msg = (f'field {field} should have number of trailing dims' 104 | ' {num_trailing_dims}') 105 | assert len(array_shape) == len(first_shape) + num_trailing_dims, msg 106 | else: 107 | field_shape = array.shape 108 | 109 | shape_msg = (f"Stripped Shape {field_shape} of field {field} doesn't " 110 | f"match shape {first_shape} of field {first_field}") 111 | assert field_shape == first_shape, shape_msg 112 | 113 | field_dtype = array.dtype 114 | 115 | allowed_metadata_dtypes = field.metadata.get('allowed_dtypes', []) 116 | if allowed_metadata_dtypes: 117 | msg = f'Dtype is {field_dtype} but must be in {allowed_metadata_dtypes}' 118 | assert field_dtype in allowed_metadata_dtypes, msg 119 | 120 | if 'dtype' in field.metadata: 121 | target_dtype = field.metadata['dtype'] 122 | else: 123 | target_dtype = dtype 124 | 125 | msg = f'Dtype is {field_dtype} but must be {target_dtype}' 126 | assert field_dtype == target_dtype, msg 127 | 128 | 129 | def flatten(instance): 130 | """Flatten Struct of Array instance.""" 131 | array_likes = list(get_array_fields(instance, return_values=True).values()) 132 | flat_array_likes = [] 133 | inner_treedefs = [] 134 | num_arrays = [] 135 | for array_like in array_likes: 136 | flat_array_like, inner_treedef = jax.tree_util.tree_flatten(array_like) 137 | inner_treedefs.append(inner_treedef) 138 | flat_array_likes += flat_array_like 139 | num_arrays.append(len(flat_array_like)) 140 | metadata = get_metadata_fields(instance, return_values=True) 141 | metadata = type(instance).metadata_cls(**metadata) 142 | return flat_array_likes, (inner_treedefs, metadata, num_arrays) 143 | 144 | 145 | def make_metadata_class(cls): 146 | metadata_fields = get_fields(cls, 147 | lambda x: x.metadata.get('is_metadata', False)) 148 | metadata_cls = dataclasses.make_dataclass( 149 | cls_name='Meta' + cls.__name__, 150 | fields=[(field.name, field.type, field) for field in metadata_fields], 151 | frozen=True, 152 | eq=True) 153 | return metadata_cls 154 | 155 | 156 | def get_fields(cls_or_instance, filterfn, return_values=False): 157 | fields = dataclasses.fields(cls_or_instance) 158 | fields = [field for field in fields if filterfn(field)] 159 | if return_values: 160 | return { 161 | field.name: getattr(cls_or_instance, field.name) for field in fields 162 | } 163 | else: 164 | return fields 165 | 166 | 167 | def get_array_fields(cls, return_values=False): 168 | return get_fields( 169 | cls, 170 | lambda x: not x.metadata.get('is_metadata', False), 171 | return_values=return_values) 172 | 173 | 174 | def get_metadata_fields(cls, return_values=False): 175 | return get_fields( 176 | cls, 177 | lambda x: x.metadata.get('is_metadata', False), 178 | return_values=return_values) 179 | 180 | 181 | class StructOfArray: 182 | """Class Decorator for Struct Of Arrays.""" 183 | 184 | def __init__(self, same_dtype=True): 185 | self.same_dtype = same_dtype 186 | 187 | def __call__(self, cls): 188 | cls.__array_ufunc__ = None 189 | cls.replace = replace 190 | cls.same_dtype = self.same_dtype 191 | cls.dtype = get_dtype 192 | cls.shape = get_shape 193 | cls.__len__ = get_len 194 | cls.__getitem__ = get_item 195 | cls.__post_init__ = post_init 196 | new_cls = dataclasses.dataclass(cls, frozen=True, eq=False) # pytype: disable=wrong-keyword-args 197 | # pytree claims to require metadata to be hashable, not sure why, 198 | # But making derived dataclass that can just hold metadata 199 | new_cls.metadata_cls = make_metadata_class(new_cls) 200 | 201 | def unflatten(aux, data): 202 | inner_treedefs, metadata, num_arrays = aux 203 | array_fields = [field.name for field in get_array_fields(new_cls)] 204 | value_dict = {} 205 | array_start = 0 206 | for num_array, inner_treedef, array_field in zip(num_arrays, 207 | inner_treedefs, 208 | array_fields): 209 | value_dict[array_field] = jax.tree_util.tree_unflatten( 210 | inner_treedef, data[array_start:array_start + num_array]) 211 | array_start += num_array 212 | metadata_fields = get_metadata_fields(new_cls) 213 | for field in metadata_fields: 214 | value_dict[field.name] = getattr(metadata, field.name) 215 | 216 | return new_cls(**value_dict) 217 | 218 | jax.tree_util.register_pytree_node( 219 | nodetype=new_cls, flatten_func=flatten, unflatten_func=unflatten) 220 | return new_cls 221 | -------------------------------------------------------------------------------- /src/alphafold/model/geometry/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Shared utils for tests.""" 15 | 16 | import dataclasses 17 | 18 | from alphafold.model.geometry import rigid_matrix_vector 19 | from alphafold.model.geometry import rotation_matrix 20 | from alphafold.model.geometry import vector 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | 25 | def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array, 26 | matrix2: rotation_matrix.Rot3Array): 27 | for field in dataclasses.fields(rotation_matrix.Rot3Array): 28 | field = field.name 29 | np.testing.assert_array_equal( 30 | getattr(matrix1, field), getattr(matrix2, field)) 31 | 32 | 33 | def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array, 34 | mat2: rotation_matrix.Rot3Array): 35 | np.testing.assert_array_almost_equal(mat1.to_array(), mat2.to_array(), 6) 36 | 37 | 38 | def assert_array_equal_to_rotation_matrix(array: jnp.ndarray, 39 | matrix: rotation_matrix.Rot3Array): 40 | """Check that array and Matrix match.""" 41 | np.testing.assert_array_equal(matrix.xx, array[..., 0, 0]) 42 | np.testing.assert_array_equal(matrix.xy, array[..., 0, 1]) 43 | np.testing.assert_array_equal(matrix.xz, array[..., 0, 2]) 44 | np.testing.assert_array_equal(matrix.yx, array[..., 1, 0]) 45 | np.testing.assert_array_equal(matrix.yy, array[..., 1, 1]) 46 | np.testing.assert_array_equal(matrix.yz, array[..., 1, 2]) 47 | np.testing.assert_array_equal(matrix.zx, array[..., 2, 0]) 48 | np.testing.assert_array_equal(matrix.zy, array[..., 2, 1]) 49 | np.testing.assert_array_equal(matrix.zz, array[..., 2, 2]) 50 | 51 | 52 | def assert_array_close_to_rotation_matrix(array: jnp.ndarray, 53 | matrix: rotation_matrix.Rot3Array): 54 | np.testing.assert_array_almost_equal(matrix.to_array(), array, 6) 55 | 56 | 57 | def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array): 58 | np.testing.assert_array_equal(vec1.x, vec2.x) 59 | np.testing.assert_array_equal(vec1.y, vec2.y) 60 | np.testing.assert_array_equal(vec1.z, vec2.z) 61 | 62 | 63 | def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array): 64 | np.testing.assert_allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.) 65 | np.testing.assert_allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.) 66 | np.testing.assert_allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.) 67 | 68 | 69 | def assert_array_close_to_vector(array: jnp.ndarray, vec: vector.Vec3Array): 70 | np.testing.assert_allclose(vec.to_array(), array, atol=1e-6, rtol=0.) 71 | 72 | 73 | def assert_array_equal_to_vector(array: jnp.ndarray, vec: vector.Vec3Array): 74 | np.testing.assert_array_equal(vec.to_array(), array) 75 | 76 | 77 | def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, 78 | rigid2: rigid_matrix_vector.Rigid3Array): 79 | assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2) 80 | 81 | 82 | def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, 83 | rigid2: rigid_matrix_vector.Rigid3Array): 84 | assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2) 85 | 86 | 87 | def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array, 88 | trans: vector.Vec3Array, 89 | rigid: rigid_matrix_vector.Rigid3Array): 90 | assert_rotation_matrix_equal(rot, rigid.rotation) 91 | assert_vectors_equal(trans, rigid.translation) 92 | 93 | 94 | def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array, 95 | trans: vector.Vec3Array, 96 | rigid: rigid_matrix_vector.Rigid3Array): 97 | assert_rotation_matrix_close(rot, rigid.rotation) 98 | assert_vectors_close(trans, rigid.translation) 99 | -------------------------------------------------------------------------------- /src/alphafold/model/geometry/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for geometry library.""" 15 | 16 | from typing import List 17 | 18 | import jax.numpy as jnp 19 | 20 | 21 | def unstack(value: jnp.ndarray, axis: int = -1) -> List[jnp.ndarray]: 22 | return [jnp.squeeze(v, axis=axis) 23 | for v in jnp.split(value, value.shape[axis], axis=axis)] 24 | -------------------------------------------------------------------------------- /src/alphafold/model/geometry/vector.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Vec3Array Class.""" 15 | 16 | from __future__ import annotations 17 | import dataclasses 18 | from typing import Union 19 | 20 | from alphafold.model.geometry import struct_of_array 21 | from alphafold.model.geometry import utils 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | Float = Union[float, jnp.ndarray] 27 | 28 | VERSION = '0.1' 29 | 30 | 31 | @struct_of_array.StructOfArray(same_dtype=True) 32 | class Vec3Array: 33 | """Vec3Array in 3 dimensional Space implemented as struct of arrays. 34 | 35 | This is done in order to improve performance and precision. 36 | On TPU small matrix multiplications are very suboptimal and will waste large 37 | compute ressources, furthermore any matrix multiplication on tpu happen in 38 | mixed bfloat16/float32 precision, which is often undesirable when handling 39 | physical coordinates. 40 | In most cases this will also be faster on cpu's/gpu's since it allows for 41 | easier use of vector instructions. 42 | """ 43 | 44 | x: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32}) 45 | y: jnp.ndarray 46 | z: jnp.ndarray 47 | 48 | def __post_init__(self): 49 | if hasattr(self.x, 'dtype'): 50 | assert self.x.dtype == self.y.dtype 51 | assert self.x.dtype == self.z.dtype 52 | assert all([x == y for x, y in zip(self.x.shape, self.y.shape)]) 53 | assert all([x == z for x, z in zip(self.x.shape, self.z.shape)]) 54 | 55 | def __add__(self, other: Vec3Array) -> Vec3Array: 56 | return jax.tree_map(lambda x, y: x + y, self, other) 57 | 58 | def __sub__(self, other: Vec3Array) -> Vec3Array: 59 | return jax.tree_map(lambda x, y: x - y, self, other) 60 | 61 | def __mul__(self, other: Float) -> Vec3Array: 62 | return jax.tree_map(lambda x: x * other, self) 63 | 64 | def __rmul__(self, other: Float) -> Vec3Array: 65 | return self * other 66 | 67 | def __truediv__(self, other: Float) -> Vec3Array: 68 | return jax.tree_map(lambda x: x / other, self) 69 | 70 | def __neg__(self) -> Vec3Array: 71 | return jax.tree_map(lambda x: -x, self) 72 | 73 | def __pos__(self) -> Vec3Array: 74 | return jax.tree_map(lambda x: x, self) 75 | 76 | def cross(self, other: Vec3Array) -> Vec3Array: 77 | """Compute cross product between 'self' and 'other'.""" 78 | new_x = self.y * other.z - self.z * other.y 79 | new_y = self.z * other.x - self.x * other.z 80 | new_z = self.x * other.y - self.y * other.x 81 | return Vec3Array(new_x, new_y, new_z) 82 | 83 | def dot(self, other: Vec3Array) -> Float: 84 | """Compute dot product between 'self' and 'other'.""" 85 | return self.x * other.x + self.y * other.y + self.z * other.z 86 | 87 | def norm(self, epsilon: float = 1e-6) -> Float: 88 | """Compute Norm of Vec3Array, clipped to epsilon.""" 89 | # To avoid NaN on the backward pass, we must use maximum before the sqrt 90 | norm2 = self.dot(self) 91 | if epsilon: 92 | norm2 = jnp.maximum(norm2, epsilon**2) 93 | return jnp.sqrt(norm2) 94 | 95 | def norm2(self): 96 | return self.dot(self) 97 | 98 | def normalized(self, epsilon: float = 1e-6) -> Vec3Array: 99 | """Return unit vector with optional clipping.""" 100 | return self / self.norm(epsilon) 101 | 102 | @classmethod 103 | def zeros(cls, shape, dtype=jnp.float32): 104 | """Return Vec3Array corresponding to zeros of given shape.""" 105 | return cls( 106 | jnp.zeros(shape, dtype), jnp.zeros(shape, dtype), 107 | jnp.zeros(shape, dtype)) # pytype: disable=wrong-arg-count # trace-all-classes 108 | 109 | def to_array(self) -> jnp.ndarray: 110 | return jnp.stack([self.x, self.y, self.z], axis=-1) 111 | 112 | @classmethod 113 | def from_array(cls, array): 114 | return cls(*utils.unstack(array)) 115 | 116 | def __getstate__(self): 117 | return (VERSION, 118 | [np.asarray(self.x), 119 | np.asarray(self.y), 120 | np.asarray(self.z)]) 121 | 122 | def __setstate__(self, state): 123 | version, state = state 124 | del version 125 | for i, letter in enumerate('xyz'): 126 | object.__setattr__(self, letter, state[i]) 127 | 128 | 129 | def square_euclidean_distance(vec1: Vec3Array, 130 | vec2: Vec3Array, 131 | epsilon: float = 1e-6) -> Float: 132 | """Computes square of euclidean distance between 'vec1' and 'vec2'. 133 | 134 | Args: 135 | vec1: Vec3Array to compute distance to 136 | vec2: Vec3Array to compute distance from, should be 137 | broadcast compatible with 'vec1' 138 | epsilon: distance is clipped from below to be at least epsilon 139 | 140 | Returns: 141 | Array of square euclidean distances; 142 | shape will be result of broadcasting 'vec1' and 'vec2' 143 | """ 144 | difference = vec1 - vec2 145 | distance = difference.dot(difference) 146 | if epsilon: 147 | distance = jnp.maximum(distance, epsilon) 148 | return distance 149 | 150 | 151 | def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float: 152 | return vector1.dot(vector2) 153 | 154 | 155 | def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float: 156 | return vector1.cross(vector2) 157 | 158 | 159 | def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float: 160 | return vector.norm(epsilon) 161 | 162 | 163 | def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array: 164 | return vector.normalized(epsilon) 165 | 166 | 167 | def euclidean_distance(vec1: Vec3Array, 168 | vec2: Vec3Array, 169 | epsilon: float = 1e-6) -> Float: 170 | """Computes euclidean distance between 'vec1' and 'vec2'. 171 | 172 | Args: 173 | vec1: Vec3Array to compute euclidean distance to 174 | vec2: Vec3Array to compute euclidean distance from, should be 175 | broadcast compatible with 'vec1' 176 | epsilon: distance is clipped from below to be at least epsilon 177 | 178 | Returns: 179 | Array of euclidean distances; 180 | shape will be result of broadcasting 'vec1' and 'vec2' 181 | """ 182 | distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2) 183 | distance = jnp.sqrt(distance_sq) 184 | return distance 185 | 186 | 187 | def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array, 188 | d: Vec3Array) -> Float: 189 | """Computes torsion angle for a quadruple of points. 190 | 191 | For points (a, b, c, d), this is the angle between the planes defined by 192 | points (a, b, c) and (b, c, d). It is also known as the dihedral angle. 193 | 194 | Arguments: 195 | a: A Vec3Array of coordinates. 196 | b: A Vec3Array of coordinates. 197 | c: A Vec3Array of coordinates. 198 | d: A Vec3Array of coordinates. 199 | 200 | Returns: 201 | A tensor of angles in radians: [-pi, pi]. 202 | """ 203 | v1 = a - b 204 | v2 = b - c 205 | v3 = d - c 206 | 207 | c1 = v1.cross(v2) 208 | c2 = v3.cross(v2) 209 | c3 = c2.cross(c1) 210 | 211 | v2_mag = v2.norm() 212 | return jnp.arctan2(c3.dot(v2), v2_mag * c1.dot(c2)) 213 | 214 | 215 | def random_gaussian_vector(shape, key, dtype=jnp.float32): 216 | vec_array = jax.random.normal(key, shape + (3,), dtype) 217 | return Vec3Array.from_array(vec_array) 218 | -------------------------------------------------------------------------------- /src/alphafold/model/lddt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """lDDT protein distance score.""" 16 | import jax.numpy as jnp 17 | 18 | 19 | def lddt(predicted_points, 20 | true_points, 21 | true_points_mask, 22 | cutoff=15., 23 | per_residue=False): 24 | """Measure (approximate) lDDT for a batch of coordinates. 25 | 26 | lDDT reference: 27 | Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local 28 | superposition-free score for comparing protein structures and models using 29 | distance difference tests. Bioinformatics 29, 2722–2728 (2013). 30 | 31 | lDDT is a measure of the difference between the true distance matrix and the 32 | distance matrix of the predicted points. The difference is computed only on 33 | points closer than cutoff *in the true structure*. 34 | 35 | This function does not compute the exact lDDT value that the original paper 36 | describes because it does not include terms for physical feasibility 37 | (e.g. bond length violations). Therefore this is only an approximate 38 | lDDT score. 39 | 40 | Args: 41 | predicted_points: (batch, length, 3) array of predicted 3D points 42 | true_points: (batch, length, 3) array of true 3D points 43 | true_points_mask: (batch, length, 1) binary-valued float array. This mask 44 | should be 1 for points that exist in the true points. 45 | cutoff: Maximum distance for a pair of points to be included 46 | per_residue: If true, return score for each residue. Note that the overall 47 | lDDT is not exactly the mean of the per_residue lDDT's because some 48 | residues have more contacts than others. 49 | 50 | Returns: 51 | An (approximate, see above) lDDT score in the range 0-1. 52 | """ 53 | 54 | assert len(predicted_points.shape) == 3 55 | assert predicted_points.shape[-1] == 3 56 | assert true_points_mask.shape[-1] == 1 57 | assert len(true_points_mask.shape) == 3 58 | 59 | # Compute true and predicted distance matrices. 60 | dmat_true = jnp.sqrt(1e-10 + jnp.sum( 61 | (true_points[:, :, None] - true_points[:, None, :])**2, axis=-1)) 62 | 63 | dmat_predicted = jnp.sqrt(1e-10 + jnp.sum( 64 | (predicted_points[:, :, None] - 65 | predicted_points[:, None, :])**2, axis=-1)) 66 | 67 | dists_to_score = ( 68 | (dmat_true < cutoff).astype(jnp.float32) * true_points_mask * 69 | jnp.transpose(true_points_mask, [0, 2, 1]) * 70 | (1. - jnp.eye(dmat_true.shape[1])) # Exclude self-interaction. 71 | ) 72 | 73 | # Shift unscored distances to be far away. 74 | dist_l1 = jnp.abs(dmat_true - dmat_predicted) 75 | 76 | # True lDDT uses a number of fixed bins. 77 | # We ignore the physical plausibility correction to lDDT, though. 78 | score = 0.25 * ((dist_l1 < 0.5).astype(jnp.float32) + 79 | (dist_l1 < 1.0).astype(jnp.float32) + 80 | (dist_l1 < 2.0).astype(jnp.float32) + 81 | (dist_l1 < 4.0).astype(jnp.float32)) 82 | 83 | # Normalize over the appropriate axes. 84 | reduce_axes = (-1,) if per_residue else (-2, -1) 85 | norm = 1. / (1e-10 + jnp.sum(dists_to_score, axis=reduce_axes)) 86 | score = norm * (1e-10 + jnp.sum(dists_to_score * score, axis=reduce_axes)) 87 | 88 | return score 89 | -------------------------------------------------------------------------------- /src/alphafold/model/lddt_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for lddt.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from alphafold.model import lddt 20 | import numpy as np 21 | 22 | 23 | class LddtTest(parameterized.TestCase, absltest.TestCase): 24 | 25 | @parameterized.named_parameters( 26 | ('same', 27 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 28 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 29 | [1, 1, 1]), 30 | ('all_shifted', 31 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 32 | [[-1, 0, 0], [4, 0, 0], [9, 0, 0]], 33 | [1, 1, 1]), 34 | ('all_rotated', 35 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]], 36 | [[0, 0, 0], [0, 5, 0], [0, 10, 0]], 37 | [1, 1, 1]), 38 | ('half_a_dist', 39 | [[0, 0, 0], [5, 0, 0]], 40 | [[0, 0, 0], [5.5-1e-5, 0, 0]], 41 | [1, 1]), 42 | ('one_a_dist', 43 | [[0, 0, 0], [5, 0, 0]], 44 | [[0, 0, 0], [6-1e-5, 0, 0]], 45 | [0.75, 0.75]), 46 | ('two_a_dist', 47 | [[0, 0, 0], [5, 0, 0]], 48 | [[0, 0, 0], [7-1e-5, 0, 0]], 49 | [0.5, 0.5]), 50 | ('four_a_dist', 51 | [[0, 0, 0], [5, 0, 0]], 52 | [[0, 0, 0], [9-1e-5, 0, 0]], 53 | [0.25, 0.25],), 54 | ('five_a_dist', 55 | [[0, 0, 0], [16-1e-5, 0, 0]], 56 | [[0, 0, 0], [11, 0, 0]], 57 | [0, 0]), 58 | ('no_pairs', 59 | [[0, 0, 0], [20, 0, 0]], 60 | [[0, 0, 0], [25-1e-5, 0, 0]], 61 | [1, 1]), 62 | ) 63 | def test_lddt( 64 | self, predicted_pos, true_pos, exp_lddt): 65 | predicted_pos = np.array([predicted_pos], dtype=np.float32) 66 | true_points_mask = np.array([[[1]] * len(true_pos)], dtype=np.float32) 67 | true_pos = np.array([true_pos], dtype=np.float32) 68 | cutoff = 15.0 69 | per_residue = True 70 | 71 | result = lddt.lddt( 72 | predicted_pos, true_pos, true_points_mask, cutoff, 73 | per_residue) 74 | 75 | np.testing.assert_almost_equal(result, [exp_lddt], decimal=4) 76 | 77 | 78 | if __name__ == '__main__': 79 | absltest.main() 80 | -------------------------------------------------------------------------------- /src/alphafold/model/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Code for constructing the model.""" 16 | from typing import Any, Mapping, Optional, Union 17 | 18 | from absl import logging 19 | from alphafold.common import confidence 20 | from alphafold.model import features 21 | from alphafold.model import modules 22 | from alphafold.model import modules_multimer 23 | import haiku as hk 24 | import jax 25 | import ml_collections 26 | import numpy as np 27 | import tensorflow.compat.v1 as tf 28 | import pdb 29 | import tree 30 | 31 | 32 | def get_confidence_metrics( 33 | prediction_result: Mapping[str, Any], 34 | multimer_mode: bool) -> Mapping[str, Any]: 35 | """Post processes prediction_result to get confidence metrics.""" 36 | confidence_metrics = {} 37 | confidence_metrics['plddt'] = confidence.compute_plddt( 38 | prediction_result['predicted_lddt']['logits']) 39 | if 'predicted_aligned_error' in prediction_result: 40 | confidence_metrics.update(confidence.compute_predicted_aligned_error( 41 | logits=prediction_result['predicted_aligned_error']['logits'], 42 | breaks=prediction_result['predicted_aligned_error']['breaks'])) 43 | confidence_metrics['ptm'] = confidence.predicted_tm_score( 44 | logits=prediction_result['predicted_aligned_error']['logits'], 45 | breaks=prediction_result['predicted_aligned_error']['breaks'], 46 | asym_id=None) 47 | if multimer_mode: 48 | # Compute the ipTM only for the multimer model. 49 | confidence_metrics['iptm'] = confidence.predicted_tm_score( 50 | logits=prediction_result['predicted_aligned_error']['logits'], 51 | breaks=prediction_result['predicted_aligned_error']['breaks'], 52 | asym_id=prediction_result['predicted_aligned_error']['asym_id'], 53 | interface=True) 54 | confidence_metrics['ranking_confidence'] = ( 55 | 0.8 * confidence_metrics['iptm'] + 0.2 * confidence_metrics['ptm']) 56 | 57 | if not multimer_mode: 58 | # Monomer models use mean pLDDT for model ranking. 59 | confidence_metrics['ranking_confidence'] = np.mean( 60 | confidence_metrics['plddt']) 61 | 62 | return confidence_metrics 63 | 64 | 65 | class RunModel: 66 | """Container for JAX model.""" 67 | 68 | def __init__(self, 69 | config: ml_collections.ConfigDict, 70 | params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None, 71 | is_training=False): 72 | self.config = config 73 | self.params = params 74 | self.multimer_mode = config.model.global_config.multimer_mode 75 | self.is_training = is_training 76 | 77 | if self.multimer_mode: 78 | def _forward_fn(batch): 79 | model = modules_multimer.AlphaFold(self.config.model) 80 | return model( 81 | batch, 82 | is_training=self.is_training) 83 | else: 84 | def _forward_fn(batch): 85 | model = modules.AlphaFold(self.config.model) 86 | return model( 87 | batch, 88 | is_training=self.is_training, 89 | compute_loss=False, 90 | ensemble_representations=True) 91 | 92 | self.apply = jax.jit(hk.transform(_forward_fn).apply) 93 | self.init = jax.jit(hk.transform(_forward_fn).init) 94 | 95 | def init_params(self, feat: features.FeatureDict, random_seed: int = 0): 96 | """Initializes the model parameters. 97 | 98 | If none were provided when this class was instantiated then the parameters 99 | are randomly initialized. 100 | 101 | Args: 102 | feat: A dictionary of NumPy feature arrays as output by 103 | RunModel.process_features. 104 | random_seed: A random seed to use to initialize the parameters if none 105 | were set when this class was initialized. 106 | """ 107 | if not self.params: 108 | # Init params randomly. 109 | rng = jax.random.PRNGKey(random_seed) 110 | self.params = hk.data_structures.to_mutable_dict( 111 | self.init(rng, feat)) 112 | logging.warning('Initialized parameters randomly') 113 | 114 | def process_features( 115 | self, 116 | raw_features: Union[tf.train.Example, features.FeatureDict], 117 | random_seed: int) -> features.FeatureDict: 118 | """Processes features to prepare for feeding them into the model. 119 | 120 | Args: 121 | raw_features: The output of the data pipeline either as a dict of NumPy 122 | arrays or as a tf.train.Example. 123 | random_seed: The random seed to use when processing the features. 124 | 125 | Returns: 126 | A dict of NumPy feature arrays suitable for feeding into the model. 127 | """ 128 | 129 | if self.multimer_mode: 130 | return raw_features 131 | 132 | # Single-chain mode. 133 | if isinstance(raw_features, dict): 134 | return features.np_example_to_features( 135 | np_example=raw_features, 136 | config=self.config, 137 | random_seed=random_seed) 138 | else: 139 | return features.tf_example_to_features( 140 | tf_example=raw_features, 141 | config=self.config, 142 | random_seed=random_seed) 143 | 144 | def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct: 145 | self.init_params(feat) 146 | logging.info('Running eval_shape with shape(feat) = %s', 147 | tree.map_structure(lambda x: x.shape, feat)) 148 | shape = jax.eval_shape(self.apply, self.params, jax.random.PRNGKey(0), feat) 149 | logging.info('Output shape was %s', shape) 150 | return shape 151 | 152 | def predict(self, 153 | msa_params, 154 | feat: features.FeatureDict 155 | ) -> Mapping[str, Any]: 156 | """Makes a prediction by inferencing the model on the provided features. 157 | 158 | Args: 159 | feat: A dictionary of NumPy feature arrays as output by 160 | RunModel.process_features. 161 | random_seed: The random seed to use when running the model. In the 162 | multimer model this controls the MSA sampling. 163 | 164 | Returns: 165 | A dictionary of model outputs. 166 | """ 167 | 168 | #Add the msa params 169 | feat['msa_feat_bias'] = msa_params 170 | result = self.apply(self.params, jax.random.PRNGKey(0), feat) 171 | 172 | # This block is to ensure benchmark timings are accurate. Some blocking is 173 | # already happening when computing get_confidence_metrics, and this ensures 174 | # all outputs are blocked on. 175 | #jax.tree_map(lambda x: x.block_until_ready(), result) 176 | #result.update( 177 | # get_confidence_metrics(result, multimer_mode=self.multimer_mode)) 178 | return result 179 | -------------------------------------------------------------------------------- /src/alphafold/model/prng.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A collection of utilities surrounding PRNG usage in protein folding.""" 16 | 17 | import haiku as hk 18 | import jax 19 | 20 | 21 | def safe_dropout(*, tensor, safe_key, rate, is_deterministic, is_training): 22 | if is_training and rate != 0.0 and not is_deterministic: 23 | return hk.dropout(safe_key.get(), rate, tensor) 24 | else: 25 | return tensor 26 | 27 | 28 | class SafeKey: 29 | """Safety wrapper for PRNG keys.""" 30 | 31 | def __init__(self, key): 32 | self._key = key 33 | self._used = False 34 | 35 | def _assert_not_used(self): 36 | if self._used: 37 | raise RuntimeError('Random key has been used previously.') 38 | 39 | def get(self): 40 | self._assert_not_used() 41 | self._used = True 42 | return self._key 43 | 44 | def split(self, num_keys=2): 45 | self._assert_not_used() 46 | self._used = True 47 | new_keys = jax.random.split(self._key, num_keys) 48 | return jax.tree_map(SafeKey, tuple(new_keys)) 49 | 50 | def duplicate(self, num_keys=2): 51 | self._assert_not_used() 52 | self._used = True 53 | return tuple(SafeKey(self._key) for _ in range(num_keys)) 54 | 55 | 56 | def _safe_key_flatten(safe_key): 57 | # Flatten transfers "ownership" to the tree 58 | return (safe_key._key,), safe_key._used # pylint: disable=protected-access 59 | 60 | 61 | def _safe_key_unflatten(aux_data, children): 62 | ret = SafeKey(children[0]) 63 | ret._used = aux_data # pylint: disable=protected-access 64 | return ret 65 | 66 | 67 | jax.tree_util.register_pytree_node( 68 | SafeKey, _safe_key_flatten, _safe_key_unflatten) 69 | 70 | -------------------------------------------------------------------------------- /src/alphafold/model/prng_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for prng.""" 16 | 17 | from absl.testing import absltest 18 | from alphafold.model import prng 19 | import jax 20 | 21 | 22 | class PrngTest(absltest.TestCase): 23 | 24 | def test_key_reuse(self): 25 | 26 | init_key = jax.random.PRNGKey(42) 27 | safe_key = prng.SafeKey(init_key) 28 | _, safe_key = safe_key.split() 29 | 30 | raw_key = safe_key.get() 31 | 32 | self.assertNotEqual(raw_key[0], init_key[0]) 33 | self.assertNotEqual(raw_key[1], init_key[1]) 34 | 35 | with self.assertRaises(RuntimeError): 36 | safe_key.get() 37 | 38 | with self.assertRaises(RuntimeError): 39 | safe_key.split() 40 | 41 | with self.assertRaises(RuntimeError): 42 | safe_key.duplicate() 43 | 44 | 45 | if __name__ == '__main__': 46 | absltest.main() 47 | -------------------------------------------------------------------------------- /src/alphafold/model/quat_affine_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for quat_affine.""" 16 | 17 | from absl import logging 18 | from absl.testing import absltest 19 | from alphafold.model import quat_affine 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | VERBOSE = False 25 | np.set_printoptions(precision=3, suppress=True) 26 | 27 | r2t = quat_affine.rot_list_to_tensor 28 | v2t = quat_affine.vec_list_to_tensor 29 | 30 | q2r = lambda q: r2t(quat_affine.quat_to_rot(q)) 31 | 32 | 33 | class QuatAffineTest(absltest.TestCase): 34 | 35 | def _assert_check(self, to_check, tol=1e-5): 36 | for k, (correct, generated) in to_check.items(): 37 | if VERBOSE: 38 | logging.info(k) 39 | logging.info('Correct %s', correct) 40 | logging.info('Predicted %s', generated) 41 | self.assertLess(np.max(np.abs(correct - generated)), tol) 42 | 43 | def test_conversion(self): 44 | quat = jnp.array([-2., 5., -1., 4.]) 45 | 46 | rotation = jnp.array([ 47 | [0.26087, 0.130435, 0.956522], 48 | [-0.565217, -0.782609, 0.26087], 49 | [0.782609, -0.608696, -0.130435]]) 50 | 51 | translation = jnp.array([1., -3., 4.]) 52 | point = jnp.array([0.7, 3.2, -2.9]) 53 | 54 | a = quat_affine.QuatAffine(quat, translation, unstack_inputs=True) 55 | true_new_point = jnp.matmul(rotation, point[:, None])[:, 0] + translation 56 | 57 | self._assert_check({ 58 | 'rot': (rotation, r2t(a.rotation)), 59 | 'trans': (translation, v2t(a.translation)), 60 | 'point': (true_new_point, 61 | v2t(a.apply_to_point(jnp.moveaxis(point, -1, 0)))), 62 | # Because of the double cover, we must be careful and compare rotations 63 | 'quat': (q2r(a.quaternion), 64 | q2r(quat_affine.rot_to_quat(a.rotation))), 65 | 66 | }) 67 | 68 | def test_double_cover(self): 69 | """Test that -q is the same rotation as q.""" 70 | rng = jax.random.PRNGKey(42) 71 | keys = jax.random.split(rng) 72 | q = jax.random.normal(keys[0], (2, 4)) 73 | trans = jax.random.normal(keys[1], (2, 3)) 74 | a1 = quat_affine.QuatAffine(q, trans, unstack_inputs=True) 75 | a2 = quat_affine.QuatAffine(-q, trans, unstack_inputs=True) 76 | 77 | self._assert_check({ 78 | 'rot': (r2t(a1.rotation), 79 | r2t(a2.rotation)), 80 | 'trans': (v2t(a1.translation), 81 | v2t(a2.translation)), 82 | }) 83 | 84 | def test_homomorphism(self): 85 | rng = jax.random.PRNGKey(42) 86 | keys = jax.random.split(rng, 4) 87 | vec_q1 = jax.random.normal(keys[0], (2, 3)) 88 | 89 | q1 = jnp.concatenate([ 90 | jnp.ones_like(vec_q1)[:, :1], 91 | vec_q1], axis=-1) 92 | 93 | q2 = jax.random.normal(keys[1], (2, 4)) 94 | t1 = jax.random.normal(keys[2], (2, 3)) 95 | t2 = jax.random.normal(keys[3], (2, 3)) 96 | 97 | a1 = quat_affine.QuatAffine(q1, t1, unstack_inputs=True) 98 | a2 = quat_affine.QuatAffine(q2, t2, unstack_inputs=True) 99 | a21 = a2.pre_compose(jnp.concatenate([vec_q1, t1], axis=-1)) 100 | 101 | rng, key = jax.random.split(rng) 102 | x = jax.random.normal(key, (2, 3)) 103 | new_x = a21.apply_to_point(jnp.moveaxis(x, -1, 0)) 104 | new_x_apply2 = a2.apply_to_point(a1.apply_to_point(jnp.moveaxis(x, -1, 0))) 105 | 106 | self._assert_check({ 107 | 'quat': (q2r(quat_affine.quat_multiply(a2.quaternion, a1.quaternion)), 108 | q2r(a21.quaternion)), 109 | 'rot': (jnp.matmul(r2t(a2.rotation), r2t(a1.rotation)), 110 | r2t(a21.rotation)), 111 | 'point': (v2t(new_x_apply2), 112 | v2t(new_x)), 113 | 'inverse': (x, v2t(a21.invert_point(new_x))), 114 | }) 115 | 116 | def test_batching(self): 117 | """Test that affine applies batchwise.""" 118 | rng = jax.random.PRNGKey(42) 119 | keys = jax.random.split(rng, 3) 120 | q = jax.random.uniform(keys[0], (5, 2, 4)) 121 | t = jax.random.uniform(keys[1], (2, 3)) 122 | x = jax.random.uniform(keys[2], (5, 1, 3)) 123 | 124 | a = quat_affine.QuatAffine(q, t, unstack_inputs=True) 125 | y = v2t(a.apply_to_point(jnp.moveaxis(x, -1, 0))) 126 | 127 | y_list = [] 128 | for i in range(5): 129 | for j in range(2): 130 | a_local = quat_affine.QuatAffine(q[i, j], t[j], 131 | unstack_inputs=True) 132 | y_local = v2t(a_local.apply_to_point(jnp.moveaxis(x[i, 0], -1, 0))) 133 | y_list.append(y_local) 134 | y_combine = jnp.reshape(jnp.stack(y_list, axis=0), (5, 2, 3)) 135 | 136 | self._assert_check({ 137 | 'batch': (y_combine, y), 138 | 'quat': (q2r(a.quaternion), 139 | q2r(quat_affine.rot_to_quat(a.rotation))), 140 | }) 141 | 142 | def assertAllClose(self, a, b, rtol=1e-06, atol=1e-06): 143 | self.assertTrue(np.allclose(a, b, rtol=rtol, atol=atol)) 144 | 145 | def assertAllEqual(self, a, b): 146 | self.assertTrue(np.all(np.array(a) == np.array(b))) 147 | 148 | 149 | if __name__ == '__main__': 150 | absltest.main() 151 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Alphafold model TensorFlow code.""" 15 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/input_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Feature pre-processing input pipeline for AlphaFold.""" 16 | 17 | from alphafold.model.tf import data_transforms 18 | from alphafold.model.tf import shape_placeholders 19 | import tensorflow.compat.v1 as tf 20 | import tree 21 | 22 | # Pylint gets confused by the curry1 decorator because it changes the number 23 | # of arguments to the function. 24 | # pylint:disable=no-value-for-parameter 25 | 26 | 27 | NUM_RES = shape_placeholders.NUM_RES 28 | NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ 29 | NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ 30 | NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES 31 | 32 | 33 | def nonensembled_map_fns(data_config): 34 | """Input pipeline functions which are not ensembled.""" 35 | common_cfg = data_config.common 36 | 37 | map_fns = [ 38 | data_transforms.correct_msa_restypes, 39 | data_transforms.add_distillation_flag(False), 40 | data_transforms.cast_64bit_ints, 41 | data_transforms.squeeze_features, 42 | # Keep to not disrupt RNG. 43 | data_transforms.randomly_replace_msa_with_unknown(0.0), 44 | data_transforms.make_seq_mask, 45 | data_transforms.make_msa_mask, 46 | # Compute the HHblits profile if it's not set. This has to be run before 47 | # sampling the MSA. 48 | data_transforms.make_hhblits_profile, 49 | data_transforms.make_random_crop_to_size_seed, 50 | ] 51 | if common_cfg.use_templates: 52 | map_fns.extend([ 53 | data_transforms.fix_templates_aatype, 54 | data_transforms.make_template_mask, 55 | data_transforms.make_pseudo_beta('template_') 56 | ]) 57 | map_fns.extend([ 58 | data_transforms.make_atom14_masks, 59 | ]) 60 | 61 | return map_fns 62 | 63 | 64 | def ensembled_map_fns(data_config): 65 | """Input pipeline functions that can be ensembled and averaged.""" 66 | common_cfg = data_config.common 67 | eval_cfg = data_config.eval 68 | 69 | map_fns = [] 70 | 71 | if common_cfg.reduce_msa_clusters_by_max_templates: 72 | pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates 73 | else: 74 | pad_msa_clusters = eval_cfg.max_msa_clusters 75 | 76 | max_msa_clusters = pad_msa_clusters 77 | max_extra_msa = common_cfg.max_extra_msa 78 | 79 | map_fns.append( 80 | data_transforms.sample_msa( 81 | max_msa_clusters, 82 | keep_extra=True)) 83 | 84 | if 'masked_msa' in common_cfg: 85 | # Masked MSA should come *before* MSA clustering so that 86 | # the clustering and full MSA profile do not leak information about 87 | # the masked locations and secret corrupted locations. 88 | map_fns.append( 89 | data_transforms.make_masked_msa(common_cfg.masked_msa, 90 | eval_cfg.masked_msa_replace_fraction)) 91 | 92 | if common_cfg.msa_cluster_features: 93 | map_fns.append(data_transforms.nearest_neighbor_clusters()) 94 | map_fns.append(data_transforms.summarize_clusters()) 95 | 96 | # Crop after creating the cluster profiles. 97 | if max_extra_msa: 98 | map_fns.append(data_transforms.crop_extra_msa(max_extra_msa)) 99 | else: 100 | map_fns.append(data_transforms.delete_extra_msa) 101 | 102 | map_fns.append(data_transforms.make_msa_feat()) 103 | 104 | crop_feats = dict(eval_cfg.feat) 105 | 106 | if eval_cfg.fixed_size: 107 | map_fns.append(data_transforms.select_feat(list(crop_feats))) 108 | map_fns.append(data_transforms.random_crop_to_size( 109 | eval_cfg.crop_size, 110 | eval_cfg.max_templates, 111 | crop_feats, 112 | eval_cfg.subsample_templates)) 113 | map_fns.append(data_transforms.make_fixed_size( 114 | crop_feats, 115 | pad_msa_clusters, 116 | common_cfg.max_extra_msa, 117 | eval_cfg.crop_size, 118 | eval_cfg.max_templates)) 119 | else: 120 | map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates)) 121 | 122 | return map_fns 123 | 124 | 125 | def process_tensors_from_config(tensors, data_config): 126 | """Apply filters and maps to an existing dataset, based on the config.""" 127 | 128 | def wrap_ensemble_fn(data, i): 129 | """Function to be mapped over the ensemble dimension.""" 130 | d = data.copy() 131 | fns = ensembled_map_fns(data_config) 132 | fn = compose(fns) 133 | d['ensemble_index'] = i 134 | return fn(d) 135 | 136 | eval_cfg = data_config.eval 137 | tensors = compose( 138 | nonensembled_map_fns( 139 | data_config))( 140 | tensors) 141 | 142 | tensors_0 = wrap_ensemble_fn(tensors, tf.constant(0)) 143 | num_ensemble = eval_cfg.num_ensemble 144 | if data_config.common.resample_msa_in_recycling: 145 | # Separate batch per ensembling & recycling step. 146 | num_ensemble *= data_config.common.num_recycle + 1 147 | 148 | if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1: 149 | fn_output_signature = tree.map_structure( 150 | tf.TensorSpec.from_tensor, tensors_0) 151 | tensors = tf.map_fn( 152 | lambda x: wrap_ensemble_fn(tensors, x), 153 | tf.range(num_ensemble), 154 | parallel_iterations=1, 155 | fn_output_signature=fn_output_signature) 156 | else: 157 | tensors = tree.map_structure(lambda x: x[None], 158 | tensors_0) 159 | return tensors 160 | 161 | 162 | @data_transforms.curry1 163 | def compose(x, fs): 164 | for f in fs: 165 | x = f(x) 166 | return x 167 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/protein_features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Contains descriptions of various protein features.""" 16 | import enum 17 | from typing import Dict, Optional, Sequence, Tuple, Union 18 | from alphafold.common import residue_constants 19 | import tensorflow.compat.v1 as tf 20 | 21 | # Type aliases. 22 | FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]] 23 | 24 | 25 | class FeatureType(enum.Enum): 26 | ZERO_DIM = 0 # Shape [x] 27 | ONE_DIM = 1 # Shape [num_res, x] 28 | TWO_DIM = 2 # Shape [num_res, num_res, x] 29 | MSA = 3 # Shape [msa_length, num_res, x] 30 | 31 | 32 | # Placeholder values that will be replaced with their true value at runtime. 33 | NUM_RES = "num residues placeholder" 34 | NUM_SEQ = "length msa placeholder" 35 | NUM_TEMPLATES = "num templates placeholder" 36 | # Sizes of the protein features, NUM_RES and NUM_SEQ are allowed as placeholders 37 | # to be replaced with the number of residues and the number of sequences in the 38 | # multiple sequence alignment, respectively. 39 | 40 | 41 | FEATURES = { 42 | #### Static features of a protein sequence #### 43 | "aatype": (tf.float32, [NUM_RES, 21]), 44 | "between_segment_residues": (tf.int64, [NUM_RES, 1]), 45 | "deletion_matrix": (tf.float32, [NUM_SEQ, NUM_RES, 1]), 46 | "domain_name": (tf.string, [1]), 47 | "msa": (tf.int64, [NUM_SEQ, NUM_RES, 1]), 48 | "num_alignments": (tf.int64, [NUM_RES, 1]), 49 | "residue_index": (tf.int64, [NUM_RES, 1]), 50 | "seq_length": (tf.int64, [NUM_RES, 1]), 51 | "sequence": (tf.string, [1]), 52 | "all_atom_positions": (tf.float32, 53 | [NUM_RES, residue_constants.atom_type_num, 3]), 54 | "all_atom_mask": (tf.int64, [NUM_RES, residue_constants.atom_type_num]), 55 | "resolution": (tf.float32, [1]), 56 | "template_domain_names": (tf.string, [NUM_TEMPLATES]), 57 | "template_sum_probs": (tf.float32, [NUM_TEMPLATES, 1]), 58 | "template_aatype": (tf.float32, [NUM_TEMPLATES, NUM_RES, 22]), 59 | "template_all_atom_positions": (tf.float32, [ 60 | NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3 61 | ]), 62 | "template_all_atom_masks": (tf.float32, [ 63 | NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1 64 | ]), 65 | } 66 | 67 | FEATURE_TYPES = {k: v[0] for k, v in FEATURES.items()} 68 | FEATURE_SIZES = {k: v[1] for k, v in FEATURES.items()} 69 | 70 | 71 | def register_feature(name: str, 72 | type_: tf.dtypes.DType, 73 | shape_: Tuple[Union[str, int]]): 74 | """Register extra features used in custom datasets.""" 75 | FEATURES[name] = (type_, shape_) 76 | FEATURE_TYPES[name] = type_ 77 | FEATURE_SIZES[name] = shape_ 78 | 79 | 80 | def shape(feature_name: str, 81 | num_residues: int, 82 | msa_length: int, 83 | num_templates: Optional[int] = None, 84 | features: Optional[FeaturesMetadata] = None): 85 | """Get the shape for the given feature name. 86 | 87 | This is near identical to _get_tf_shape_no_placeholders() but with 2 88 | differences: 89 | * This method does not calculate a single placeholder from the total number of 90 | elements (eg given and size := 12, this won't deduce NUM_RES 91 | must be 4) 92 | * This method will work with tensors 93 | 94 | Args: 95 | feature_name: String identifier for the feature. If the feature name ends 96 | with "_unnormalized", this suffix is stripped off. 97 | num_residues: The number of residues in the current domain - some elements 98 | of the shape can be dynamic and will be replaced by this value. 99 | msa_length: The number of sequences in the multiple sequence alignment, some 100 | elements of the shape can be dynamic and will be replaced by this value. 101 | If the number of alignments is unknown / not read, please pass None for 102 | msa_length. 103 | num_templates (optional): The number of templates in this tfexample. 104 | features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES. 105 | 106 | Returns: 107 | List of ints representation the tensor size. 108 | 109 | Raises: 110 | ValueError: If a feature is requested but no concrete placeholder value is 111 | given. 112 | """ 113 | features = features or FEATURES 114 | if feature_name.endswith("_unnormalized"): 115 | feature_name = feature_name[:-13] 116 | 117 | unused_dtype, raw_sizes = features[feature_name] 118 | replacements = {NUM_RES: num_residues, 119 | NUM_SEQ: msa_length} 120 | 121 | if num_templates is not None: 122 | replacements[NUM_TEMPLATES] = num_templates 123 | 124 | sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes] 125 | for dimension in sizes: 126 | if isinstance(dimension, str): 127 | raise ValueError("Could not parse %s (shape: %s) with values: %s" % ( 128 | feature_name, raw_sizes, replacements)) 129 | return sizes 130 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/protein_features_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for protein_features.""" 16 | import uuid 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from alphafold.model.tf import protein_features 21 | import tensorflow.compat.v1 as tf 22 | 23 | 24 | def _random_bytes(): 25 | return str(uuid.uuid4()).encode('utf-8') 26 | 27 | 28 | class FeaturesTest(parameterized.TestCase, tf.test.TestCase): 29 | 30 | def setUp(self): 31 | super().setUp() 32 | tf.disable_v2_behavior() 33 | 34 | def testFeatureNames(self): 35 | self.assertEqual(len(protein_features.FEATURE_SIZES), 36 | len(protein_features.FEATURE_TYPES)) 37 | sorted_size_names = sorted(protein_features.FEATURE_SIZES.keys()) 38 | sorted_type_names = sorted(protein_features.FEATURE_TYPES.keys()) 39 | for i, size_name in enumerate(sorted_size_names): 40 | self.assertEqual(size_name, sorted_type_names[i]) 41 | 42 | def testReplacement(self): 43 | for name in protein_features.FEATURE_SIZES.keys(): 44 | sizes = protein_features.shape(name, 45 | num_residues=12, 46 | msa_length=24, 47 | num_templates=3) 48 | for x in sizes: 49 | self.assertEqual(type(x), int) 50 | self.assertGreater(x, 0) 51 | 52 | 53 | if __name__ == '__main__': 54 | absltest.main() 55 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/proteins_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Datasets consisting of proteins.""" 16 | from typing import Dict, Mapping, Optional, Sequence 17 | from alphafold.model.tf import protein_features 18 | import numpy as np 19 | import tensorflow.compat.v1 as tf 20 | 21 | TensorDict = Dict[str, tf.Tensor] 22 | 23 | 24 | def parse_tfexample( 25 | raw_data: bytes, 26 | features: protein_features.FeaturesMetadata, 27 | key: Optional[str] = None) -> Dict[str, tf.train.Feature]: 28 | """Read a single TF Example proto and return a subset of its features. 29 | 30 | Args: 31 | raw_data: A serialized tf.Example proto. 32 | features: A dictionary of features, mapping string feature names to a tuple 33 | (dtype, shape). This dictionary should be a subset of 34 | protein_features.FEATURES (or the dictionary itself for all features). 35 | key: Optional string with the SSTable key of that tf.Example. This will be 36 | added into features as a 'key' but only if requested in features. 37 | 38 | Returns: 39 | A dictionary of features mapping feature names to features. Only the given 40 | features are returned, all other ones are filtered out. 41 | """ 42 | feature_map = { 43 | k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True) 44 | for k, v in features.items() 45 | } 46 | parsed_features = tf.io.parse_single_example(raw_data, feature_map) 47 | reshaped_features = parse_reshape_logic(parsed_features, features, key=key) 48 | 49 | return reshaped_features 50 | 51 | 52 | def _first(tensor: tf.Tensor) -> tf.Tensor: 53 | """Returns the 1st element - the input can be a tensor or a scalar.""" 54 | return tf.reshape(tensor, shape=(-1,))[0] 55 | 56 | 57 | def parse_reshape_logic( 58 | parsed_features: TensorDict, 59 | features: protein_features.FeaturesMetadata, 60 | key: Optional[str] = None) -> TensorDict: 61 | """Transforms parsed serial features to the correct shape.""" 62 | # Find out what is the number of sequences and the number of alignments. 63 | num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32) 64 | 65 | if "num_alignments" in parsed_features: 66 | num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32) 67 | else: 68 | num_msa = 0 69 | 70 | if "template_domain_names" in parsed_features: 71 | num_templates = tf.cast( 72 | tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32) 73 | else: 74 | num_templates = 0 75 | 76 | if key is not None and "key" in features: 77 | parsed_features["key"] = [key] # Expand dims from () to (1,). 78 | 79 | # Reshape the tensors according to the sequence length and num alignments. 80 | for k, v in parsed_features.items(): 81 | new_shape = protein_features.shape( 82 | feature_name=k, 83 | num_residues=num_residues, 84 | msa_length=num_msa, 85 | num_templates=num_templates, 86 | features=features) 87 | new_shape_size = tf.constant(1, dtype=tf.int32) 88 | for dim in new_shape: 89 | new_shape_size *= tf.cast(dim, tf.int32) 90 | 91 | assert_equal = tf.assert_equal( 92 | tf.size(v), new_shape_size, 93 | name="assert_%s_shape_correct" % k, 94 | message="The size of feature %s (%s) could not be reshaped " 95 | "into %s" % (k, tf.size(v), new_shape)) 96 | if "template" not in k: 97 | # Make sure the feature we are reshaping is not empty. 98 | assert_non_empty = tf.assert_greater( 99 | tf.size(v), 0, name="assert_%s_non_empty" % k, 100 | message="The feature %s is not set in the tf.Example. Either do not " 101 | "request the feature or use a tf.Example that has the " 102 | "feature set." % k) 103 | with tf.control_dependencies([assert_non_empty, assert_equal]): 104 | parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) 105 | else: 106 | with tf.control_dependencies([assert_equal]): 107 | parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k) 108 | 109 | return parsed_features 110 | 111 | 112 | def _make_features_metadata( 113 | feature_names: Sequence[str]) -> protein_features.FeaturesMetadata: 114 | """Makes a feature name to type and shape mapping from a list of names.""" 115 | # Make sure these features are always read. 116 | required_features = ["aatype", "sequence", "seq_length"] 117 | feature_names = list(set(feature_names) | set(required_features)) 118 | 119 | features_metadata = {name: protein_features.FEATURES[name] 120 | for name in feature_names} 121 | return features_metadata 122 | 123 | 124 | def create_tensor_dict( 125 | raw_data: bytes, 126 | features: Sequence[str], 127 | key: Optional[str] = None, 128 | ) -> TensorDict: 129 | """Creates a dictionary of tensor features. 130 | 131 | Args: 132 | raw_data: A serialized tf.Example proto. 133 | features: A list of strings of feature names to be returned in the dataset. 134 | key: Optional string with the SSTable key of that tf.Example. This will be 135 | added into features as a 'key' but only if requested in features. 136 | 137 | Returns: 138 | A dictionary of features mapping feature names to features. Only the given 139 | features are returned, all other ones are filtered out. 140 | """ 141 | features_metadata = _make_features_metadata(features) 142 | return parse_tfexample(raw_data, features_metadata, key) 143 | 144 | 145 | def np_to_tensor_dict( 146 | np_example: Mapping[str, np.ndarray], 147 | features: Sequence[str], 148 | ) -> TensorDict: 149 | """Creates dict of tensors from a dict of NumPy arrays. 150 | 151 | Args: 152 | np_example: A dict of NumPy feature arrays. 153 | features: A list of strings of feature names to be returned in the dataset. 154 | 155 | Returns: 156 | A dictionary of features mapping feature names to features. Only the given 157 | features are returned, all other ones are filtered out. 158 | """ 159 | features_metadata = _make_features_metadata(features) 160 | tensor_dict = {k: tf.constant(v) for k, v in np_example.items() 161 | if k in features_metadata} 162 | 163 | # Ensures shapes are as expected. Needed for setting size of empty features 164 | # e.g. when no template hits were found. 165 | tensor_dict = parse_reshape_logic(tensor_dict, features_metadata) 166 | return tensor_dict 167 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/shape_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for dealing with shapes of TensorFlow tensors.""" 16 | import tensorflow.compat.v1 as tf 17 | 18 | 19 | def shape_list(x): 20 | """Return list of dimensions of a tensor, statically where possible. 21 | 22 | Like `x.shape.as_list()` but with tensors instead of `None`s. 23 | 24 | Args: 25 | x: A tensor. 26 | Returns: 27 | A list with length equal to the rank of the tensor. The n-th element of the 28 | list is an integer when that dimension is statically known otherwise it is 29 | the n-th element of `tf.shape(x)`. 30 | """ 31 | x = tf.convert_to_tensor(x) 32 | 33 | # If unknown rank, return dynamic shape 34 | if x.get_shape().dims is None: 35 | return tf.shape(x) 36 | 37 | static = x.get_shape().as_list() 38 | shape = tf.shape(x) 39 | 40 | ret = [] 41 | for i in range(len(static)): 42 | dim = static[i] 43 | if dim is None: 44 | dim = shape[i] 45 | ret.append(dim) 46 | return ret 47 | 48 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/shape_helpers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for shape_helpers.""" 16 | 17 | from alphafold.model.tf import shape_helpers 18 | import numpy as np 19 | import tensorflow.compat.v1 as tf 20 | 21 | 22 | class ShapeTest(tf.test.TestCase): 23 | 24 | def setUp(self): 25 | super().setUp() 26 | tf.disable_v2_behavior() 27 | 28 | def test_shape_list(self): 29 | """Test that shape_list can allow for reshaping to dynamic shapes.""" 30 | a = tf.zeros([10, 4, 4, 2]) 31 | p = tf.placeholder(tf.float32, shape=[None, None, 1, 4, 4]) 32 | shape_dyn = shape_helpers.shape_list(p)[:2] + [4, 4] 33 | 34 | b = tf.reshape(a, shape_dyn) 35 | with self.session() as sess: 36 | out = sess.run(b, feed_dict={p: np.ones((20, 1, 1, 4, 4))}) 37 | 38 | self.assertAllEqual(out.shape, (20, 1, 4, 4)) 39 | 40 | 41 | if __name__ == '__main__': 42 | tf.test.main() 43 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/shape_placeholders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Placeholder values for run-time varying dimension sizes.""" 16 | 17 | NUM_RES = 'num residues placeholder' 18 | NUM_MSA_SEQ = 'msa placeholder' 19 | NUM_EXTRA_SEQ = 'extra msa placeholder' 20 | NUM_TEMPLATES = 'num templates placeholder' 21 | -------------------------------------------------------------------------------- /src/alphafold/model/tf/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Shared utilities for various components.""" 16 | import tensorflow.compat.v1 as tf 17 | 18 | 19 | def tf_combine_mask(*masks): 20 | """Take the intersection of float-valued masks.""" 21 | ret = 1 22 | for m in masks: 23 | ret *= m 24 | return ret 25 | 26 | 27 | class SeedMaker(object): 28 | """Return unique seeds.""" 29 | 30 | def __init__(self, initial_seed=0): 31 | self.next_seed = initial_seed 32 | 33 | def __call__(self): 34 | i = self.next_seed 35 | self.next_seed += 1 36 | return i 37 | 38 | seed_maker = SeedMaker() 39 | 40 | 41 | def make_random_seed(): 42 | return tf.random.uniform([2], 43 | tf.int32.min, 44 | tf.int32.max, 45 | tf.int32, 46 | seed=seed_maker()) 47 | 48 | -------------------------------------------------------------------------------- /src/alphafold/model/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A collection of JAX utility functions for use in protein folding.""" 16 | 17 | import collections 18 | import contextlib 19 | import functools 20 | import numbers 21 | from typing import Mapping 22 | 23 | import haiku as hk 24 | import jax 25 | import jax.numpy as jnp 26 | import numpy as np 27 | 28 | 29 | def bfloat16_creator(next_creator, shape, dtype, init, context): 30 | """Creates float32 variables when bfloat16 is requested.""" 31 | if context.original_dtype == jnp.bfloat16: 32 | dtype = jnp.float32 33 | return next_creator(shape, dtype, init) 34 | 35 | 36 | def bfloat16_getter(next_getter, value, context): 37 | """Casts float32 to bfloat16 when bfloat16 was originally requested.""" 38 | if context.original_dtype == jnp.bfloat16: 39 | assert value.dtype == jnp.float32 40 | value = value.astype(jnp.bfloat16) 41 | return next_getter(value) 42 | 43 | 44 | @contextlib.contextmanager 45 | def bfloat16_context(): 46 | with hk.custom_creator(bfloat16_creator), hk.custom_getter(bfloat16_getter): 47 | yield 48 | 49 | 50 | def final_init(config): 51 | if config.zero_init: 52 | return 'zeros' 53 | else: 54 | return 'linear' 55 | 56 | 57 | def batched_gather(params, indices, axis=0, batch_dims=0): 58 | """Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`.""" 59 | take_fn = lambda p, i: jnp.take(p, i, axis=axis, mode='clip') 60 | for _ in range(batch_dims): 61 | take_fn = jax.vmap(take_fn) 62 | return take_fn(params, indices) 63 | 64 | 65 | def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10): 66 | """Masked mean.""" 67 | if drop_mask_channel: 68 | mask = mask[..., 0] 69 | 70 | mask_shape = mask.shape 71 | value_shape = value.shape 72 | 73 | assert len(mask_shape) == len(value_shape) 74 | 75 | if isinstance(axis, numbers.Integral): 76 | axis = [axis] 77 | elif axis is None: 78 | axis = list(range(len(mask_shape))) 79 | assert isinstance(axis, collections.abc.Iterable), ( 80 | 'axis needs to be either an iterable, integer or "None"') 81 | 82 | broadcast_factor = 1. 83 | for axis_ in axis: 84 | value_size = value_shape[axis_] 85 | mask_size = mask_shape[axis_] 86 | if mask_size == 1: 87 | broadcast_factor *= value_size 88 | else: 89 | assert mask_size == value_size 90 | 91 | return (jnp.sum(mask * value, axis=axis) / 92 | (jnp.sum(mask, axis=axis) * broadcast_factor + eps)) 93 | 94 | 95 | def flat_params_to_haiku(params: Mapping[str, np.ndarray]) -> hk.Params: 96 | """Convert a dictionary of NumPy arrays to Haiku parameters.""" 97 | hk_params = {} 98 | for path, array in params.items(): 99 | scope, name = path.split('//') 100 | if scope not in hk_params: 101 | hk_params[scope] = {} 102 | hk_params[scope][name] = jnp.array(array) 103 | 104 | return hk_params 105 | 106 | 107 | def padding_consistent_rng(f): 108 | """Modify any element-wise random function to be consistent with padding. 109 | 110 | Normally if you take a function like jax.random.normal and generate an array, 111 | say of size (10,10), you will get a different set of random numbers to if you 112 | add padding and take the first (10,10) sub-array. 113 | 114 | This function makes a random function that is consistent regardless of the 115 | amount of padding added. 116 | 117 | Note: The padding-consistent function is likely to be slower to compile and 118 | run than the function it is wrapping, but these slowdowns are likely to be 119 | negligible in a large network. 120 | 121 | Args: 122 | f: Any element-wise function that takes (PRNG key, shape) as the first 2 123 | arguments. 124 | 125 | Returns: 126 | An equivalent function to f, that is now consistent for different amounts of 127 | padding. 128 | """ 129 | def grid_keys(key, shape): 130 | """Generate a grid of rng keys that is consistent with different padding. 131 | 132 | Generate random keys such that the keys will be identical, regardless of 133 | how much padding is added to any dimension. 134 | 135 | Args: 136 | key: A PRNG key. 137 | shape: The shape of the output array of keys that will be generated. 138 | 139 | Returns: 140 | An array of shape `shape` consisting of random keys. 141 | """ 142 | if not shape: 143 | return key 144 | new_keys = jax.vmap(functools.partial(jax.random.fold_in, key))( 145 | jnp.arange(shape[0])) 146 | return jax.vmap(functools.partial(grid_keys, shape=shape[1:]))(new_keys) 147 | 148 | def inner(key, shape, **kwargs): 149 | return jnp.vectorize( 150 | lambda key: f(key, shape=(), **kwargs), 151 | signature='(2)->()')( 152 | grid_keys(key, shape)) 153 | return inner 154 | -------------------------------------------------------------------------------- /src/alphafold/notebooks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """AlphaFold Colab notebook.""" 15 | -------------------------------------------------------------------------------- /src/alphafold/notebooks/notebook_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper methods for the AlphaFold Colab notebook.""" 16 | import json 17 | from typing import Any, Mapping, Optional, Sequence, Tuple 18 | 19 | from alphafold.common import residue_constants 20 | from alphafold.data import parsers 21 | from matplotlib import pyplot as plt 22 | import numpy as np 23 | 24 | 25 | def clean_and_validate_single_sequence( 26 | input_sequence: str, min_length: int, max_length: int) -> str: 27 | """Checks that the input sequence is ok and returns a clean version of it.""" 28 | # Remove all whitespaces, tabs and end lines; upper-case. 29 | clean_sequence = input_sequence.translate( 30 | str.maketrans('', '', ' \n\t')).upper() 31 | aatypes = set(residue_constants.restypes) # 20 standard aatypes. 32 | if not set(clean_sequence).issubset(aatypes): 33 | raise ValueError( 34 | f'Input sequence contains non-amino acid letters: ' 35 | f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard ' 36 | 'amino acids as inputs.') 37 | if len(clean_sequence) < min_length: 38 | raise ValueError( 39 | f'Input sequence is too short: {len(clean_sequence)} amino acids, ' 40 | f'while the minimum is {min_length}') 41 | if len(clean_sequence) > max_length: 42 | raise ValueError( 43 | f'Input sequence is too long: {len(clean_sequence)} amino acids, while ' 44 | f'the maximum is {max_length}. You may be able to run it with the full ' 45 | f'AlphaFold system depending on your resources (system memory, ' 46 | f'GPU memory).') 47 | return clean_sequence 48 | 49 | 50 | def clean_and_validate_input_sequences( 51 | input_sequences: Sequence[str], 52 | min_sequence_length: int, 53 | max_sequence_length: int) -> Sequence[str]: 54 | """Validates and cleans input sequences.""" 55 | sequences = [] 56 | 57 | for input_sequence in input_sequences: 58 | if input_sequence.strip(): 59 | input_sequence = clean_and_validate_single_sequence( 60 | input_sequence=input_sequence, 61 | min_length=min_sequence_length, 62 | max_length=max_sequence_length) 63 | sequences.append(input_sequence) 64 | 65 | if sequences: 66 | return sequences 67 | else: 68 | raise ValueError('No input amino acid sequence provided, please provide at ' 69 | 'least one sequence.') 70 | 71 | 72 | def merge_chunked_msa( 73 | results: Sequence[Mapping[str, Any]], 74 | max_hits: Optional[int] = None 75 | ) -> parsers.Msa: 76 | """Merges chunked database hits together into hits for the full database.""" 77 | unsorted_results = [] 78 | for chunk_index, chunk in enumerate(results): 79 | msa = parsers.parse_stockholm(chunk['sto']) 80 | e_values_dict = parsers.parse_e_values_from_tblout(chunk['tbl']) 81 | # Jackhmmer lists sequences as /-. 82 | e_values = [e_values_dict[t.partition('/')[0]] for t in msa.descriptions] 83 | chunk_results = zip( 84 | msa.sequences, msa.deletion_matrix, msa.descriptions, e_values) 85 | if chunk_index != 0: 86 | next(chunk_results) # Only take query (first hit) from the first chunk. 87 | unsorted_results.extend(chunk_results) 88 | 89 | sorted_by_evalue = sorted(unsorted_results, key=lambda x: x[-1]) 90 | merged_sequences, merged_deletion_matrix, merged_descriptions, _ = zip( 91 | *sorted_by_evalue) 92 | merged_msa = parsers.Msa(sequences=merged_sequences, 93 | deletion_matrix=merged_deletion_matrix, 94 | descriptions=merged_descriptions) 95 | if max_hits is not None: 96 | merged_msa = merged_msa.truncate(max_seqs=max_hits) 97 | 98 | return merged_msa 99 | 100 | 101 | def show_msa_info( 102 | single_chain_msas: Sequence[parsers.Msa], 103 | sequence_index: int): 104 | """Prints info and shows a plot of the deduplicated single chain MSA.""" 105 | full_single_chain_msa = [] 106 | for single_chain_msa in single_chain_msas: 107 | full_single_chain_msa.extend(single_chain_msa.sequences) 108 | 109 | # Deduplicate but preserve order (hence can't use set). 110 | deduped_full_single_chain_msa = list(dict.fromkeys(full_single_chain_msa)) 111 | total_msa_size = len(deduped_full_single_chain_msa) 112 | print(f'\n{total_msa_size} unique sequences found in total for sequence ' 113 | f'{sequence_index}\n') 114 | 115 | aa_map = {res: i for i, res in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')} 116 | msa_arr = np.array( 117 | [[aa_map[aa] for aa in seq] for seq in deduped_full_single_chain_msa]) 118 | 119 | plt.figure(figsize=(12, 3)) 120 | plt.title(f'Per-Residue Count of Non-Gap Amino Acids in the MSA for Sequence ' 121 | f'{sequence_index}') 122 | plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), color='black') 123 | plt.ylabel('Non-Gap Count') 124 | plt.yticks(range(0, total_msa_size + 1, max(1, int(total_msa_size / 3)))) 125 | plt.show() 126 | 127 | 128 | def empty_placeholder_template_features( 129 | num_templates: int, num_res: int) -> Mapping[str, np.ndarray]: 130 | return { 131 | 'template_aatype': np.zeros( 132 | (num_templates, num_res, 133 | len(residue_constants.restypes_with_x_and_gap)), dtype=np.float32), 134 | 'template_all_atom_masks': np.zeros( 135 | (num_templates, num_res, residue_constants.atom_type_num), 136 | dtype=np.float32), 137 | 'template_all_atom_positions': np.zeros( 138 | (num_templates, num_res, residue_constants.atom_type_num, 3), 139 | dtype=np.float32), 140 | 'template_domain_names': np.zeros([num_templates], dtype=np.object), 141 | 'template_sequence': np.zeros([num_templates], dtype=np.object), 142 | 'template_sum_probs': np.zeros([num_templates], dtype=np.float32), 143 | } 144 | 145 | 146 | def get_pae_json(pae: np.ndarray, max_pae: float) -> str: 147 | """Returns the PAE in the same format as is used in the AFDB. 148 | 149 | Note that the values are presented as floats to 1 decimal place, 150 | whereas AFDB returns integer values. 151 | 152 | Args: 153 | pae: The n_res x n_res PAE array. 154 | max_pae: The maximum possible PAE value. 155 | Returns: 156 | PAE output format as a JSON string. 157 | """ 158 | # Check the PAE array is the correct shape. 159 | if (pae.ndim != 2 or pae.shape[0] != pae.shape[1]): 160 | raise ValueError(f'PAE must be a square matrix, got {pae.shape}') 161 | 162 | # Round the predicted aligned errors to 1 decimal place. 163 | rounded_errors = np.round(pae.astype(np.float64), decimals=1) 164 | formatted_output = [{ 165 | 'predicted_aligned_error': rounded_errors.tolist(), 166 | 'max_predicted_aligned_error': max_pae 167 | }] 168 | return json.dumps(formatted_output, indent=None, separators=(',', ':')) 169 | -------------------------------------------------------------------------------- /src/alphafold/relax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Amber relaxation.""" 15 | -------------------------------------------------------------------------------- /src/alphafold/relax/amber_minimize_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for amber_minimize.""" 16 | import os 17 | 18 | from absl.testing import absltest 19 | from alphafold.common import protein 20 | from alphafold.relax import amber_minimize 21 | import numpy as np 22 | # Internal import (7716). 23 | 24 | _USE_GPU = False 25 | 26 | 27 | def _load_test_protein(data_path): 28 | pdb_path = os.path.join(absltest.get_default_test_srcdir(), data_path) 29 | with open(pdb_path, 'r') as f: 30 | return protein.from_pdb_string(f.read()) 31 | 32 | 33 | class AmberMinimizeTest(absltest.TestCase): 34 | 35 | def test_multiple_disulfides_target(self): 36 | prot = _load_test_protein( 37 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb' 38 | ) 39 | ret = amber_minimize.run_pipeline(prot, max_iterations=10, max_attempts=1, 40 | stiffness=10., use_gpu=_USE_GPU) 41 | self.assertIn('opt_time', ret) 42 | self.assertIn('min_attempts', ret) 43 | 44 | def test_raises_invalid_protein_assertion(self): 45 | prot = _load_test_protein( 46 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb' 47 | ) 48 | prot.atom_mask[4, :] = 0 49 | with self.assertRaisesRegex( 50 | ValueError, 51 | 'Amber minimization can only be performed on proteins with well-defined' 52 | ' residues. This protein contains at least one residue with no atoms.'): 53 | amber_minimize.run_pipeline(prot, max_iterations=10, 54 | stiffness=1., 55 | max_attempts=1, 56 | use_gpu=_USE_GPU) 57 | 58 | def test_iterative_relax(self): 59 | prot = _load_test_protein( 60 | 'alphafold/relax/testdata/with_violations.pdb' 61 | ) 62 | violations = amber_minimize.get_violation_metrics(prot) 63 | self.assertGreater(violations['num_residue_violations'], 0) 64 | out = amber_minimize.run_pipeline( 65 | prot=prot, max_outer_iterations=10, stiffness=10., use_gpu=_USE_GPU) 66 | self.assertLess(out['efinal'], out['einit']) 67 | self.assertEqual(0, out['num_residue_violations']) 68 | 69 | def test_find_violations(self): 70 | prot = _load_test_protein( 71 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb' 72 | ) 73 | viols, _ = amber_minimize.find_violations(prot) 74 | 75 | expected_between_residues_connection_mask = np.zeros((191,), np.float32) 76 | for residue in (42, 43, 59, 60, 135, 136): 77 | expected_between_residues_connection_mask[residue] = 1.0 78 | 79 | expected_clash_indices = np.array([ 80 | [8, 4], 81 | [8, 5], 82 | [13, 3], 83 | [14, 1], 84 | [14, 4], 85 | [26, 4], 86 | [26, 5], 87 | [31, 8], 88 | [31, 10], 89 | [39, 0], 90 | [39, 1], 91 | [39, 2], 92 | [39, 3], 93 | [39, 4], 94 | [42, 5], 95 | [42, 6], 96 | [42, 7], 97 | [42, 8], 98 | [47, 7], 99 | [47, 8], 100 | [47, 9], 101 | [47, 10], 102 | [64, 4], 103 | [85, 5], 104 | [102, 4], 105 | [102, 5], 106 | [109, 13], 107 | [111, 5], 108 | [118, 6], 109 | [118, 7], 110 | [118, 8], 111 | [124, 4], 112 | [124, 5], 113 | [131, 5], 114 | [139, 7], 115 | [147, 4], 116 | [152, 7]], dtype=np.int32) 117 | expected_between_residues_clash_mask = np.zeros([191, 14]) 118 | expected_between_residues_clash_mask[expected_clash_indices[:, 0], 119 | expected_clash_indices[:, 1]] += 1 120 | expected_per_atom_violations = np.zeros([191, 14]) 121 | np.testing.assert_array_equal( 122 | viols['between_residues']['connections_per_residue_violation_mask'], 123 | expected_between_residues_connection_mask) 124 | np.testing.assert_array_equal( 125 | viols['between_residues']['clashes_per_atom_clash_mask'], 126 | expected_between_residues_clash_mask) 127 | np.testing.assert_array_equal( 128 | viols['within_residues']['per_atom_violations'], 129 | expected_per_atom_violations) 130 | 131 | 132 | if __name__ == '__main__': 133 | absltest.main() 134 | -------------------------------------------------------------------------------- /src/alphafold/relax/cleanup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations. 16 | 17 | fix_pdb uses a third-party tool. We also support fixing some additional edge 18 | cases like removing chains of length one (see clean_structure). 19 | """ 20 | import io 21 | 22 | import pdbfixer 23 | from simtk.openmm import app 24 | from simtk.openmm.app import element 25 | 26 | 27 | def fix_pdb(pdbfile, alterations_info): 28 | """Apply pdbfixer to the contents of a PDB file; return a PDB string result. 29 | 30 | 1) Replaces nonstandard residues. 31 | 2) Removes heterogens (non protein residues) including water. 32 | 3) Adds missing residues and missing atoms within existing residues. 33 | 4) Adds hydrogens assuming pH=7.0. 34 | 5) KeepIds is currently true, so the fixer must keep the existing chain and 35 | residue identifiers. This will fail for some files in wider PDB that have 36 | invalid IDs. 37 | 38 | Args: 39 | pdbfile: Input PDB file handle. 40 | alterations_info: A dict that will store details of changes made. 41 | 42 | Returns: 43 | A PDB string representing the fixed structure. 44 | """ 45 | fixer = pdbfixer.PDBFixer(pdbfile=pdbfile) 46 | fixer.findNonstandardResidues() 47 | alterations_info['nonstandard_residues'] = fixer.nonstandardResidues 48 | fixer.replaceNonstandardResidues() 49 | _remove_heterogens(fixer, alterations_info, keep_water=False) 50 | fixer.findMissingResidues() 51 | alterations_info['missing_residues'] = fixer.missingResidues 52 | fixer.findMissingAtoms() 53 | alterations_info['missing_heavy_atoms'] = fixer.missingAtoms 54 | alterations_info['missing_terminals'] = fixer.missingTerminals 55 | fixer.addMissingAtoms(seed=0) 56 | fixer.addMissingHydrogens() 57 | out_handle = io.StringIO() 58 | app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle, 59 | keepIds=True) 60 | return out_handle.getvalue() 61 | 62 | 63 | def clean_structure(pdb_structure, alterations_info): 64 | """Applies additional fixes to an OpenMM structure, to handle edge cases. 65 | 66 | Args: 67 | pdb_structure: An OpenMM structure to modify and fix. 68 | alterations_info: A dict that will store details of changes made. 69 | """ 70 | _replace_met_se(pdb_structure, alterations_info) 71 | _remove_chains_of_length_one(pdb_structure, alterations_info) 72 | 73 | 74 | def _remove_heterogens(fixer, alterations_info, keep_water): 75 | """Removes the residues that Pdbfixer considers to be heterogens. 76 | 77 | Args: 78 | fixer: A Pdbfixer instance. 79 | alterations_info: A dict that will store details of changes made. 80 | keep_water: If True, water (HOH) is not considered to be a heterogen. 81 | """ 82 | initial_resnames = set() 83 | for chain in fixer.topology.chains(): 84 | for residue in chain.residues(): 85 | initial_resnames.add(residue.name) 86 | fixer.removeHeterogens(keepWater=keep_water) 87 | final_resnames = set() 88 | for chain in fixer.topology.chains(): 89 | for residue in chain.residues(): 90 | final_resnames.add(residue.name) 91 | alterations_info['removed_heterogens'] = ( 92 | initial_resnames.difference(final_resnames)) 93 | 94 | 95 | def _replace_met_se(pdb_structure, alterations_info): 96 | """Replace the Se in any MET residues that were not marked as modified.""" 97 | modified_met_residues = [] 98 | for res in pdb_structure.iter_residues(): 99 | name = res.get_name_with_spaces().strip() 100 | if name == 'MET': 101 | s_atom = res.get_atom('SD') 102 | if s_atom.element_symbol == 'Se': 103 | s_atom.element_symbol = 'S' 104 | s_atom.element = element.get_by_symbol('S') 105 | modified_met_residues.append(s_atom.residue_number) 106 | alterations_info['Se_in_MET'] = modified_met_residues 107 | 108 | 109 | def _remove_chains_of_length_one(pdb_structure, alterations_info): 110 | """Removes chains that correspond to a single amino acid. 111 | 112 | A single amino acid in a chain is both N and C terminus. There is no force 113 | template for this case. 114 | 115 | Args: 116 | pdb_structure: An OpenMM pdb_structure to modify and fix. 117 | alterations_info: A dict that will store details of changes made. 118 | """ 119 | removed_chains = {} 120 | for model in pdb_structure.iter_models(): 121 | valid_chains = [c for c in model.iter_chains() if len(c) > 1] 122 | invalid_chain_ids = [c.chain_id for c in model.iter_chains() if len(c) <= 1] 123 | model.chains = valid_chains 124 | for chain_id in invalid_chain_ids: 125 | model.chains_by_id.pop(chain_id) 126 | removed_chains[model.number] = invalid_chain_ids 127 | alterations_info['removed_chains'] = removed_chains 128 | -------------------------------------------------------------------------------- /src/alphafold/relax/cleanup_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for relax.cleanup.""" 16 | import io 17 | 18 | from absl.testing import absltest 19 | from alphafold.relax import cleanup 20 | from simtk.openmm.app.internal import pdbstructure 21 | 22 | 23 | def _pdb_to_structure(pdb_str): 24 | handle = io.StringIO(pdb_str) 25 | return pdbstructure.PdbStructure(handle) 26 | 27 | 28 | def _lines_to_structure(pdb_lines): 29 | return _pdb_to_structure('\n'.join(pdb_lines)) 30 | 31 | 32 | class CleanupTest(absltest.TestCase): 33 | 34 | def test_missing_residues(self): 35 | pdb_lines = ['SEQRES 1 C 3 CYS GLY LEU', 36 | 'ATOM 1 N CYS C 1 -12.262 20.115 60.959 1.00 ' 37 | '19.08 N', 38 | 'ATOM 2 CA CYS C 1 -11.065 20.934 60.773 1.00 ' 39 | '17.23 C', 40 | 'ATOM 3 C CYS C 1 -10.002 20.742 61.844 1.00 ' 41 | '15.38 C', 42 | 'ATOM 4 O CYS C 1 -10.284 20.225 62.929 1.00 ' 43 | '16.04 O', 44 | 'ATOM 5 N LEU C 3 -7.688 18.700 62.045 1.00 ' 45 | '14.75 N', 46 | 'ATOM 6 CA LEU C 3 -7.256 17.320 62.234 1.00 ' 47 | '16.81 C', 48 | 'ATOM 7 C LEU C 3 -6.380 16.864 61.070 1.00 ' 49 | '16.95 C', 50 | 'ATOM 8 O LEU C 3 -6.551 17.332 59.947 1.00 ' 51 | '16.97 O'] 52 | input_handle = io.StringIO('\n'.join(pdb_lines)) 53 | alterations = {} 54 | result = cleanup.fix_pdb(input_handle, alterations) 55 | structure = _pdb_to_structure(result) 56 | residue_names = [r.get_name() for r in structure.iter_residues()] 57 | self.assertCountEqual(residue_names, ['CYS', 'GLY', 'LEU']) 58 | self.assertCountEqual(alterations['missing_residues'].values(), [['GLY']]) 59 | 60 | def test_missing_atoms(self): 61 | pdb_lines = ['SEQRES 1 A 1 PRO', 62 | 'ATOM 1 CA PRO A 1 1.000 1.000 1.000 1.00 ' 63 | ' 0.00 C'] 64 | input_handle = io.StringIO('\n'.join(pdb_lines)) 65 | alterations = {} 66 | result = cleanup.fix_pdb(input_handle, alterations) 67 | structure = _pdb_to_structure(result) 68 | atom_names = [a.get_name() for a in structure.iter_atoms()] 69 | self.assertCountEqual(atom_names, ['N', 'CD', 'HD2', 'HD3', 'CG', 'HG2', 70 | 'HG3', 'CB', 'HB2', 'HB3', 'CA', 'HA', 71 | 'C', 'O', 'H2', 'H3', 'OXT']) 72 | missing_atoms_by_residue = list(alterations['missing_heavy_atoms'].values()) 73 | self.assertLen(missing_atoms_by_residue, 1) 74 | atoms_added = [a.name for a in missing_atoms_by_residue[0]] 75 | self.assertCountEqual(atoms_added, ['N', 'CD', 'CG', 'CB', 'C', 'O']) 76 | missing_terminals_by_residue = alterations['missing_terminals'] 77 | self.assertLen(missing_terminals_by_residue, 1) 78 | has_missing_terminal = [r.name for r in missing_terminals_by_residue.keys()] 79 | self.assertCountEqual(has_missing_terminal, ['PRO']) 80 | self.assertCountEqual([t for t in missing_terminals_by_residue.values()], 81 | [['OXT']]) 82 | 83 | def test_remove_heterogens(self): 84 | pdb_lines = ['SEQRES 1 A 1 GLY', 85 | 'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 ' 86 | ' 0.00 C', 87 | 'ATOM 2 O HOH A 2 0.000 0.000 0.000 1.00 ' 88 | ' 0.00 O'] 89 | input_handle = io.StringIO('\n'.join(pdb_lines)) 90 | alterations = {} 91 | result = cleanup.fix_pdb(input_handle, alterations) 92 | structure = _pdb_to_structure(result) 93 | self.assertCountEqual([res.get_name() for res in structure.iter_residues()], 94 | ['GLY']) 95 | self.assertEqual(alterations['removed_heterogens'], set(['HOH'])) 96 | 97 | def test_fix_nonstandard_residues(self): 98 | pdb_lines = ['SEQRES 1 A 1 DAL', 99 | 'ATOM 1 CA DAL A 1 0.000 0.000 0.000 1.00 ' 100 | ' 0.00 C'] 101 | input_handle = io.StringIO('\n'.join(pdb_lines)) 102 | alterations = {} 103 | result = cleanup.fix_pdb(input_handle, alterations) 104 | structure = _pdb_to_structure(result) 105 | residue_names = [res.get_name() for res in structure.iter_residues()] 106 | self.assertCountEqual(residue_names, ['ALA']) 107 | self.assertLen(alterations['nonstandard_residues'], 1) 108 | original_res, new_name = alterations['nonstandard_residues'][0] 109 | self.assertEqual(original_res.id, '1') 110 | self.assertEqual(new_name, 'ALA') 111 | 112 | def test_replace_met_se(self): 113 | pdb_lines = ['SEQRES 1 A 1 MET', 114 | 'ATOM 1 SD MET A 1 0.000 0.000 0.000 1.00 ' 115 | ' 0.00 Se'] 116 | structure = _lines_to_structure(pdb_lines) 117 | alterations = {} 118 | cleanup._replace_met_se(structure, alterations) 119 | sd = [a for a in structure.iter_atoms() if a.get_name() == 'SD'] 120 | self.assertLen(sd, 1) 121 | self.assertEqual(sd[0].element_symbol, 'S') 122 | self.assertCountEqual(alterations['Se_in_MET'], [sd[0].residue_number]) 123 | 124 | def test_remove_chains_of_length_one(self): 125 | pdb_lines = ['SEQRES 1 A 1 GLY', 126 | 'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 ' 127 | ' 0.00 C'] 128 | structure = _lines_to_structure(pdb_lines) 129 | alterations = {} 130 | cleanup._remove_chains_of_length_one(structure, alterations) 131 | chains = list(structure.iter_chains()) 132 | self.assertEmpty(chains) 133 | self.assertCountEqual(alterations['removed_chains'].values(), [['A']]) 134 | 135 | 136 | if __name__ == '__main__': 137 | absltest.main() 138 | -------------------------------------------------------------------------------- /src/alphafold/relax/relax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Amber relaxation.""" 16 | from typing import Any, Dict, Sequence, Tuple 17 | from alphafold.common import protein 18 | from alphafold.relax import amber_minimize 19 | from alphafold.relax import utils 20 | import numpy as np 21 | 22 | 23 | class AmberRelaxation(object): 24 | """Amber relaxation.""" 25 | 26 | def __init__(self, 27 | *, 28 | max_iterations: int, 29 | tolerance: float, 30 | stiffness: float, 31 | exclude_residues: Sequence[int], 32 | max_outer_iterations: int, 33 | use_gpu: bool): 34 | """Initialize Amber Relaxer. 35 | 36 | Args: 37 | max_iterations: Maximum number of L-BFGS iterations. 0 means no max. 38 | tolerance: kcal/mol, the energy tolerance of L-BFGS. 39 | stiffness: kcal/mol A**2, spring constant of heavy atom restraining 40 | potential. 41 | exclude_residues: Residues to exclude from per-atom restraining. 42 | Zero-indexed. 43 | max_outer_iterations: Maximum number of violation-informed relax 44 | iterations. A value of 1 will run the non-iterative procedure used in 45 | CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes 46 | as soon as there are no violations, hence in most cases this causes no 47 | slowdown. In the worst case we do 20 outer iterations. 48 | use_gpu: Whether to run on GPU. 49 | """ 50 | 51 | self._max_iterations = max_iterations 52 | self._tolerance = tolerance 53 | self._stiffness = stiffness 54 | self._exclude_residues = exclude_residues 55 | self._max_outer_iterations = max_outer_iterations 56 | self._use_gpu = use_gpu 57 | 58 | def process(self, *, 59 | prot: protein.Protein 60 | ) -> Tuple[str, Dict[str, Any], Sequence[float]]: 61 | """Runs Amber relax on a prediction, adds hydrogens, returns PDB string.""" 62 | out = amber_minimize.run_pipeline( 63 | prot=prot, max_iterations=self._max_iterations, 64 | tolerance=self._tolerance, stiffness=self._stiffness, 65 | exclude_residues=self._exclude_residues, 66 | max_outer_iterations=self._max_outer_iterations, 67 | use_gpu=self._use_gpu) 68 | min_pos = out['pos'] 69 | start_pos = out['posinit'] 70 | rmsd = np.sqrt(np.sum((start_pos - min_pos)**2) / start_pos.shape[0]) 71 | debug_data = { 72 | 'initial_energy': out['einit'], 73 | 'final_energy': out['efinal'], 74 | 'attempts': out['min_attempts'], 75 | 'rmsd': rmsd 76 | } 77 | min_pdb = out['min_pdb'] 78 | min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors) 79 | utils.assert_equal_nonterminal_atom_types( 80 | protein.from_pdb_string(min_pdb).atom_mask, 81 | prot.atom_mask) 82 | violations = out['structural_violations'][ 83 | 'total_per_residue_violations_mask'].tolist() 84 | return min_pdb, debug_data, violations 85 | -------------------------------------------------------------------------------- /src/alphafold/relax/relax_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for relax.""" 16 | import os 17 | 18 | from absl.testing import absltest 19 | from alphafold.common import protein 20 | from alphafold.relax import relax 21 | import numpy as np 22 | # Internal import (7716). 23 | 24 | 25 | class RunAmberRelaxTest(absltest.TestCase): 26 | 27 | def setUp(self): 28 | super().setUp() 29 | self.test_dir = os.path.join( 30 | absltest.get_default_test_srcdir(), 31 | 'alphafold/relax/testdata/') 32 | self.test_config = { 33 | 'max_iterations': 1, 34 | 'tolerance': 2.39, 35 | 'stiffness': 10.0, 36 | 'exclude_residues': [], 37 | 'max_outer_iterations': 1, 38 | 'use_gpu': False} 39 | 40 | def test_process(self): 41 | amber_relax = relax.AmberRelaxation(**self.test_config) 42 | 43 | with open(os.path.join(self.test_dir, 'model_output.pdb')) as f: 44 | test_prot = protein.from_pdb_string(f.read()) 45 | pdb_min, debug_info, num_violations = amber_relax.process(prot=test_prot) 46 | 47 | self.assertCountEqual(debug_info.keys(), 48 | set({'initial_energy', 'final_energy', 49 | 'attempts', 'rmsd'})) 50 | self.assertLess(debug_info['final_energy'], debug_info['initial_energy']) 51 | self.assertGreater(debug_info['rmsd'], 0) 52 | 53 | prot_min = protein.from_pdb_string(pdb_min) 54 | # Most protein properties should be unchanged. 55 | np.testing.assert_almost_equal(test_prot.aatype, prot_min.aatype) 56 | np.testing.assert_almost_equal(test_prot.residue_index, 57 | prot_min.residue_index) 58 | # Atom mask and bfactors identical except for terminal OXT of last residue. 59 | np.testing.assert_almost_equal(test_prot.atom_mask[:-1, :], 60 | prot_min.atom_mask[:-1, :]) 61 | np.testing.assert_almost_equal(test_prot.b_factors[:-1, :], 62 | prot_min.b_factors[:-1, :]) 63 | np.testing.assert_almost_equal(test_prot.atom_mask[:, :-1], 64 | prot_min.atom_mask[:, :-1]) 65 | np.testing.assert_almost_equal(test_prot.b_factors[:, :-1], 66 | prot_min.b_factors[:, :-1]) 67 | # There are no residues with violations. 68 | np.testing.assert_equal(num_violations, np.zeros_like(num_violations)) 69 | 70 | def test_unresolved_violations(self): 71 | amber_relax = relax.AmberRelaxation(**self.test_config) 72 | with open(os.path.join(self.test_dir, 73 | 'with_violations_casp14.pdb')) as f: 74 | test_prot = protein.from_pdb_string(f.read()) 75 | _, _, num_violations = amber_relax.process(prot=test_prot) 76 | exp_num_violations = np.array( 77 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 78 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 79 | 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 80 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 81 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 82 | 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 83 | 0, 0, 0, 0]) 84 | # Check no violations were added. Can't check exactly due to stochasticity. 85 | self.assertTrue(np.all(np.array(num_violations) <= exp_num_violations)) 86 | 87 | 88 | if __name__ == '__main__': 89 | absltest.main() 90 | -------------------------------------------------------------------------------- /src/alphafold/relax/testdata/model_output.pdb: -------------------------------------------------------------------------------- 1 | ATOM 1 C MET A 1 1.921 -46.152 7.786 1.00 4.39 C 2 | ATOM 2 CA MET A 1 1.631 -46.829 9.131 1.00 4.39 C 3 | ATOM 3 CB MET A 1 2.759 -47.768 9.578 1.00 4.39 C 4 | ATOM 4 CE MET A 1 3.466 -49.770 13.198 1.00 4.39 C 5 | ATOM 5 CG MET A 1 2.581 -48.221 11.034 1.00 4.39 C 6 | ATOM 6 H MET A 1 0.234 -48.249 8.549 1.00 4.39 H 7 | ATOM 7 H2 MET A 1 -0.424 -46.789 8.952 1.00 4.39 H 8 | ATOM 8 H3 MET A 1 0.111 -47.796 10.118 1.00 4.39 H 9 | ATOM 9 HA MET A 1 1.628 -46.009 9.849 1.00 4.39 H 10 | ATOM 10 HB2 MET A 1 3.701 -47.225 9.500 1.00 4.39 H 11 | ATOM 11 HB3 MET A 1 2.807 -48.640 8.926 1.00 4.39 H 12 | ATOM 12 HE1 MET A 1 2.747 -50.537 12.910 1.00 4.39 H 13 | ATOM 13 HE2 MET A 1 4.296 -50.241 13.725 1.00 4.39 H 14 | ATOM 14 HE3 MET A 1 2.988 -49.052 13.864 1.00 4.39 H 15 | ATOM 15 HG2 MET A 1 1.791 -48.971 11.083 1.00 4.39 H 16 | ATOM 16 HG3 MET A 1 2.295 -47.368 11.650 1.00 4.39 H 17 | ATOM 17 N MET A 1 0.291 -47.464 9.182 1.00 4.39 N 18 | ATOM 18 O MET A 1 2.091 -44.945 7.799 1.00 4.39 O 19 | ATOM 19 SD MET A 1 4.096 -48.921 11.725 1.00 4.39 S 20 | ATOM 20 C LYS A 2 1.366 -45.033 4.898 1.00 2.92 C 21 | ATOM 21 CA LYS A 2 2.235 -46.242 5.308 1.00 2.92 C 22 | ATOM 22 CB LYS A 2 2.206 -47.314 4.196 1.00 2.92 C 23 | ATOM 23 CD LYS A 2 3.331 -49.342 3.134 1.00 2.92 C 24 | ATOM 24 CE LYS A 2 4.434 -50.403 3.293 1.00 2.92 C 25 | ATOM 25 CG LYS A 2 3.294 -48.395 4.349 1.00 2.92 C 26 | ATOM 26 H LYS A 2 1.832 -47.853 6.656 1.00 2.92 H 27 | ATOM 27 HA LYS A 2 3.248 -45.841 5.355 1.00 2.92 H 28 | ATOM 28 HB2 LYS A 2 1.223 -47.785 4.167 1.00 2.92 H 29 | ATOM 29 HB3 LYS A 2 2.363 -46.812 3.241 1.00 2.92 H 30 | ATOM 30 HD2 LYS A 2 3.524 -48.754 2.237 1.00 2.92 H 31 | ATOM 31 HD3 LYS A 2 2.364 -49.833 3.031 1.00 2.92 H 32 | ATOM 32 HE2 LYS A 2 5.383 -49.891 3.455 1.00 2.92 H 33 | ATOM 33 HE3 LYS A 2 4.225 -51.000 4.180 1.00 2.92 H 34 | ATOM 34 HG2 LYS A 2 3.102 -48.977 5.250 1.00 2.92 H 35 | ATOM 35 HG3 LYS A 2 4.264 -47.909 4.446 1.00 2.92 H 36 | ATOM 36 HZ1 LYS A 2 4.763 -50.747 1.274 1.00 2.92 H 37 | ATOM 37 HZ2 LYS A 2 3.681 -51.785 1.931 1.00 2.92 H 38 | ATOM 38 HZ3 LYS A 2 5.280 -51.965 2.224 1.00 2.92 H 39 | ATOM 39 N LYS A 2 1.907 -46.846 6.629 1.00 2.92 N 40 | ATOM 40 NZ LYS A 2 4.542 -51.286 2.100 1.00 2.92 N 41 | ATOM 41 O LYS A 2 1.882 -44.093 4.312 1.00 2.92 O 42 | ATOM 42 C PHE A 3 -0.511 -42.597 5.624 1.00 4.39 C 43 | ATOM 43 CA PHE A 3 -0.853 -43.933 4.929 1.00 4.39 C 44 | ATOM 44 CB PHE A 3 -2.271 -44.408 5.285 1.00 4.39 C 45 | ATOM 45 CD1 PHE A 3 -3.760 -43.542 3.432 1.00 4.39 C 46 | ATOM 46 CD2 PHE A 3 -4.050 -42.638 5.675 1.00 4.39 C 47 | ATOM 47 CE1 PHE A 3 -4.797 -42.715 2.965 1.00 4.39 C 48 | ATOM 48 CE2 PHE A 3 -5.091 -41.818 5.207 1.00 4.39 C 49 | ATOM 49 CG PHE A 3 -3.382 -43.505 4.788 1.00 4.39 C 50 | ATOM 50 CZ PHE A 3 -5.463 -41.853 3.853 1.00 4.39 C 51 | ATOM 51 H PHE A 3 -0.311 -45.868 5.655 1.00 4.39 H 52 | ATOM 52 HA PHE A 3 -0.817 -43.746 3.856 1.00 4.39 H 53 | ATOM 53 HB2 PHE A 3 -2.353 -44.512 6.367 1.00 4.39 H 54 | ATOM 54 HB3 PHE A 3 -2.432 -45.393 4.848 1.00 4.39 H 55 | ATOM 55 HD1 PHE A 3 -3.255 -44.198 2.739 1.00 4.39 H 56 | ATOM 56 HD2 PHE A 3 -3.768 -42.590 6.716 1.00 4.39 H 57 | ATOM 57 HE1 PHE A 3 -5.083 -42.735 1.923 1.00 4.39 H 58 | ATOM 58 HE2 PHE A 3 -5.604 -41.151 5.885 1.00 4.39 H 59 | ATOM 59 HZ PHE A 3 -6.257 -41.215 3.493 1.00 4.39 H 60 | ATOM 60 N PHE A 3 0.079 -45.027 5.253 1.00 4.39 N 61 | ATOM 61 O PHE A 3 -0.633 -41.541 5.014 1.00 4.39 O 62 | ATOM 62 C LEU A 4 1.598 -40.732 7.042 1.00 4.39 C 63 | ATOM 63 CA LEU A 4 0.367 -41.437 7.633 1.00 4.39 C 64 | ATOM 64 CB LEU A 4 0.628 -41.823 9.104 1.00 4.39 C 65 | ATOM 65 CD1 LEU A 4 -0.319 -42.778 11.228 1.00 4.39 C 66 | ATOM 66 CD2 LEU A 4 -1.300 -40.694 10.309 1.00 4.39 C 67 | ATOM 67 CG LEU A 4 -0.650 -42.027 9.937 1.00 4.39 C 68 | ATOM 68 H LEU A 4 0.163 -43.538 7.292 1.00 4.39 H 69 | ATOM 69 HA LEU A 4 -0.445 -40.712 7.588 1.00 4.39 H 70 | ATOM 70 HB2 LEU A 4 1.213 -41.034 9.576 1.00 4.39 H 71 | ATOM 71 HB3 LEU A 4 1.235 -42.728 9.127 1.00 4.39 H 72 | ATOM 72 HD11 LEU A 4 0.380 -42.191 11.824 1.00 4.39 H 73 | ATOM 73 HD12 LEU A 4 0.127 -43.747 11.002 1.00 4.39 H 74 | ATOM 74 HD13 LEU A 4 -1.230 -42.927 11.808 1.00 4.39 H 75 | ATOM 75 HD21 LEU A 4 -0.606 -40.080 10.883 1.00 4.39 H 76 | ATOM 76 HD22 LEU A 4 -2.193 -40.869 10.909 1.00 4.39 H 77 | ATOM 77 HD23 LEU A 4 -1.593 -40.147 9.413 1.00 4.39 H 78 | ATOM 78 HG LEU A 4 -1.359 -42.630 9.370 1.00 4.39 H 79 | ATOM 79 N LEU A 4 -0.012 -42.638 6.869 1.00 4.39 N 80 | ATOM 80 O LEU A 4 1.655 -39.508 7.028 1.00 4.39 O 81 | ATOM 81 C VAL A 5 3.372 -40.190 4.573 1.00 4.39 C 82 | ATOM 82 CA VAL A 5 3.752 -40.956 5.845 1.00 4.39 C 83 | ATOM 83 CB VAL A 5 4.757 -42.083 5.528 1.00 4.39 C 84 | ATOM 84 CG1 VAL A 5 6.019 -41.568 4.827 1.00 4.39 C 85 | ATOM 85 CG2 VAL A 5 5.199 -42.807 6.810 1.00 4.39 C 86 | ATOM 86 H VAL A 5 2.440 -42.503 6.548 1.00 4.39 H 87 | ATOM 87 HA VAL A 5 4.234 -40.242 6.512 1.00 4.39 H 88 | ATOM 88 HB VAL A 5 4.279 -42.813 4.875 1.00 4.39 H 89 | ATOM 89 HG11 VAL A 5 6.494 -40.795 5.431 1.00 4.39 H 90 | ATOM 90 HG12 VAL A 5 5.770 -41.145 3.853 1.00 4.39 H 91 | ATOM 91 HG13 VAL A 5 6.725 -42.383 4.670 1.00 4.39 H 92 | ATOM 92 HG21 VAL A 5 4.347 -43.283 7.297 1.00 4.39 H 93 | ATOM 93 HG22 VAL A 5 5.933 -43.575 6.568 1.00 4.39 H 94 | ATOM 94 HG23 VAL A 5 5.651 -42.093 7.498 1.00 4.39 H 95 | ATOM 95 N VAL A 5 2.554 -41.501 6.509 1.00 4.39 N 96 | ATOM 96 O VAL A 5 3.937 -39.138 4.297 1.00 4.39 O 97 | TER 96 VAL A 5 98 | END 99 | -------------------------------------------------------------------------------- /src/alphafold/relax/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utils for minimization.""" 16 | import io 17 | from alphafold.common import residue_constants 18 | from Bio import PDB 19 | import numpy as np 20 | 21 | 22 | def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str: 23 | """Overwrites the B-factors in pdb_str with contents of bfactors array. 24 | 25 | Args: 26 | pdb_str: An input PDB string. 27 | bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the 28 | B-factors are per residue; i.e. that the nonzero entries are identical in 29 | [0, i, :]. 30 | 31 | Returns: 32 | A new PDB string with the B-factors replaced. 33 | """ 34 | if bfactors.shape[-1] != residue_constants.atom_type_num: 35 | raise ValueError( 36 | f'Invalid final dimension size for bfactors: {bfactors.shape[-1]}.') 37 | 38 | parser = PDB.PDBParser(QUIET=True) 39 | handle = io.StringIO(pdb_str) 40 | structure = parser.get_structure('', handle) 41 | 42 | curr_resid = ('', '', '') 43 | idx = -1 44 | for atom in structure.get_atoms(): 45 | atom_resid = atom.parent.get_id() 46 | if atom_resid != curr_resid: 47 | idx += 1 48 | if idx >= bfactors.shape[0]: 49 | raise ValueError('Index into bfactors exceeds number of residues. ' 50 | 'B-factors shape: {shape}, idx: {idx}.') 51 | curr_resid = atom_resid 52 | atom.bfactor = bfactors[idx, residue_constants.atom_order['CA']] 53 | 54 | new_pdb = io.StringIO() 55 | pdb_io = PDB.PDBIO() 56 | pdb_io.set_structure(structure) 57 | pdb_io.save(new_pdb) 58 | return new_pdb.getvalue() 59 | 60 | 61 | def assert_equal_nonterminal_atom_types( 62 | atom_mask: np.ndarray, ref_atom_mask: np.ndarray): 63 | """Checks that pre- and post-minimized proteins have same atom set.""" 64 | # Ignore any terminal OXT atoms which may have been added by minimization. 65 | oxt = residue_constants.atom_order['OXT'] 66 | no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool) 67 | no_oxt_mask[..., oxt] = False 68 | np.testing.assert_almost_equal(ref_atom_mask[no_oxt_mask], 69 | atom_mask[no_oxt_mask]) 70 | -------------------------------------------------------------------------------- /src/alphafold/relax/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for utils.""" 16 | 17 | import os 18 | 19 | from absl.testing import absltest 20 | from alphafold.common import protein 21 | from alphafold.relax import utils 22 | import numpy as np 23 | # Internal import (7716). 24 | 25 | 26 | class UtilsTest(absltest.TestCase): 27 | 28 | def test_overwrite_b_factors(self): 29 | testdir = os.path.join( 30 | absltest.get_default_test_srcdir(), 31 | 'alphafold/relax/testdata/' 32 | 'multiple_disulfides_target.pdb') 33 | with open(testdir) as f: 34 | test_pdb = f.read() 35 | n_residues = 191 36 | bfactors = np.stack([np.arange(0, n_residues)] * 37, axis=-1) 37 | 38 | output_pdb = utils.overwrite_b_factors(test_pdb, bfactors) 39 | 40 | # Check that the atom lines are unchanged apart from the B-factors. 41 | atom_lines_original = [l for l in test_pdb.split('\n') if l[:4] == ('ATOM')] 42 | atom_lines_new = [l for l in output_pdb.split('\n') if l[:4] == ('ATOM')] 43 | for line_original, line_new in zip(atom_lines_original, atom_lines_new): 44 | self.assertEqual(line_original[:60].strip(), line_new[:60].strip()) 45 | self.assertEqual(line_original[66:].strip(), line_new[66:].strip()) 46 | 47 | # Check B-factors are correctly set for all atoms present. 48 | as_protein = protein.from_pdb_string(output_pdb) 49 | np.testing.assert_almost_equal( 50 | np.where(as_protein.atom_mask > 0, as_protein.b_factors, 0), 51 | np.where(as_protein.atom_mask > 0, bfactors, 0)) 52 | 53 | 54 | if __name__ == '__main__': 55 | absltest.main() 56 | -------------------------------------------------------------------------------- /src/generate_msas.sh: -------------------------------------------------------------------------------- 1 | """ 2 | Script for generating input features with AlphaFold-multimer 3 | 4 | Fill in all the variables below to run the feature generation. 5 | The reduced database version is used here. 6 | 7 | This script assumes that all python packages necessary are in the current path. 8 | """ 9 | 10 | #Get ID 11 | ID=$1 12 | echo $ID 13 | #Generate input MSAs and templates for AFM 14 | FASTA_DIR=$2 15 | FASTA_PATHS=$FASTA_DIR/$ID'.fasta' 16 | ls $FASTA_PATHS 17 | OUTDIR=$3 18 | #Genetic search 19 | JACKHMMER=./hmmer-3.3.2/src/jackhmmer 20 | HHBLITS=./hh-suite/build/bin/hhblits 21 | HHSEARCH=./hh-suite/build/bin/hhsearch 22 | HMMSEARCH=./hmmer-3.3.2/src/hmmsearch 23 | HMMBUILD=./hmmer-3.3.2/src/hmmbuild 24 | KALIGN=./kalign-3.3.2/src/kalign 25 | #Dbs 26 | DATADIR=$10 27 | UNIREF90=$DATADIR/uniref90/uniref90.fasta 28 | MGNIFY=$DATADIR/mgnify/mgy_clusters_2022_05.fa 29 | #BFD=$DATADIR/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt 30 | SMALL_BFD=$DATADIR/small_bfd/bfd-first_non_consensus_sequences.fasta 31 | UNIREF30=$DATADIR/uniref30/UniRef30_2021_03 32 | UNIPROT=$DATADIR/uniprot/uniprot.fasta 33 | PDB70=$DATADIR/pdb70_from_mmcif_220313 34 | PDBSEQRES=$DATADIR/pdb_seqres/pdb_seqres.txt 35 | MMCIFDIR=$DATADIR/pdb_mmcif/ 36 | #Settings 37 | DB_PRESET='reduced_dbs' 38 | MODEL_PRESET='multimer' 39 | MAX_DATE='2034-01-01' #Max template date 40 | OBS_PDBS=./obsolete.txt 41 | AFDIR=./ 42 | #Run 43 | python3 $AFDIR/run_alphafold_msa_template_only.py --fasta_paths=$FASTA_PATHS \ 44 | --output_dir=$OUTDIR --jackhmmer_binary_path=$JACKHMMER \ 45 | --hhblits_binary_path=$HHBLITS --hhsearch_binary_path=$HHSEARCH \ 46 | --hmmsearch_binary_path=$HMMSEARCH --hmmbuild_binary_path=$HMMBUILD \ 47 | --kalign_binary_path=$KALIGN --uniref90_database_path=$UNIREF90 \ 48 | --mgnify_database_path=$MGNIFY --small_bfd_database_path=$SMALL_BFD \ 49 | --uniprot_database_path=$UNIPROT \ 50 | --pdb_seqres_database_path=$PDBSEQRES \ 51 | --template_mmcif_dir=$MMCIFDIR --db_preset=$DB_PRESET --model_preset=$MODEL_PRESET \ 52 | --max_template_date=$MAX_DATE --obsolete_pdbs_path=$OBS_PDBS 53 | -------------------------------------------------------------------------------- /src/obsolete.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/patrickbryant1/AFProfile/a6aaaca40c8598b70c0f8f2768edae7e1847792e/src/obsolete.txt --------------------------------------------------------------------------------