├── .gitignore
├── 6nnw.gif
├── 6nnw.svg
├── AFP.svg
├── LICENSE
├── README.md
├── T1123.svg
├── afprofile.yml
├── data
├── H1141
│ ├── H1141.fasta
│ └── features.pkl
├── H1144
│ ├── H1144.fasta
│ └── features.pkl
└── T1123
│ ├── T1123.fasta
│ └── features.pkl
├── pip_pkgs.txt
└── src
├── .DS_Store
├── AFP.sh
├── alphafold
├── __init__.py
├── common
│ ├── __init__.py
│ ├── confidence.py
│ ├── protein.py
│ ├── protein_test.py
│ ├── residue_constants.py
│ ├── residue_constants_test.py
│ └── testdata
│ │ └── 2rbg.pdb
├── data
│ ├── __init__.py
│ ├── feature_processing.py
│ ├── mmcif_parsing.py
│ ├── msa_identifiers.py
│ ├── msa_pairing.py
│ ├── parsers.py
│ ├── pipeline.py
│ ├── pipeline_multimer.py
│ ├── templates.py
│ └── tools
│ │ ├── __init__.py
│ │ ├── hhblits.py
│ │ ├── hhsearch.py
│ │ ├── hmmbuild.py
│ │ ├── hmmsearch.py
│ │ ├── jackhmmer.py
│ │ ├── kalign.py
│ │ └── utils.py
├── model
│ ├── __init__.py
│ ├── all_atom.py
│ ├── all_atom_multimer.py
│ ├── all_atom_test.py
│ ├── common_modules.py
│ ├── config.py
│ ├── data.py
│ ├── features.py
│ ├── folding.py
│ ├── folding_multimer.py
│ ├── geometry
│ │ ├── __init__.py
│ │ ├── rigid_matrix_vector.py
│ │ ├── rotation_matrix.py
│ │ ├── struct_of_array.py
│ │ ├── test_utils.py
│ │ ├── utils.py
│ │ └── vector.py
│ ├── layer_stack.py
│ ├── layer_stack_test.py
│ ├── lddt.py
│ ├── lddt_test.py
│ ├── mapping.py
│ ├── model.py
│ ├── modules.py
│ ├── modules_multimer.py
│ ├── prng.py
│ ├── prng_test.py
│ ├── quat_affine.py
│ ├── quat_affine_test.py
│ ├── r3.py
│ ├── tf
│ │ ├── __init__.py
│ │ ├── data_transforms.py
│ │ ├── input_pipeline.py
│ │ ├── protein_features.py
│ │ ├── protein_features_test.py
│ │ ├── proteins_dataset.py
│ │ ├── shape_helpers.py
│ │ ├── shape_helpers_test.py
│ │ ├── shape_placeholders.py
│ │ └── utils.py
│ └── utils.py
├── notebooks
│ ├── __init__.py
│ ├── notebook_utils.py
│ └── notebook_utils_test.py
└── relax
│ ├── __init__.py
│ ├── amber_minimize.py
│ ├── amber_minimize_test.py
│ ├── cleanup.py
│ ├── cleanup_test.py
│ ├── relax.py
│ ├── relax_test.py
│ ├── testdata
│ ├── model_output.pdb
│ ├── multiple_disulfides_target.pdb
│ ├── with_violations.pdb
│ └── with_violations_casp14.pdb
│ ├── utils.py
│ └── utils_test.py
├── generate_msas.sh
├── obsolete.txt
├── run_AFP.py
└── run_alphafold_msa_template_only.py
/.gitignore:
--------------------------------------------------------------------------------
1 | src/alphafold/model/tf/__pycache__
2 | src/alphafold/model/geometry/__pycache__
3 | src/alphafold/model/__pycache__
4 | src/alphafold/data/tools/__pycache__
5 | src/alphafold/data/__pycache__
6 | src/alphafold/common/__pycache__
7 | src/alphafold/__pycache__/
8 | data/params
9 | data/T1123/*.pdb
10 | data/T1123/metrics_0.0001_20.csv
11 | .DS_Store
12 |
--------------------------------------------------------------------------------
/6nnw.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/AFProfile/a6aaaca40c8598b70c0f8f2768edae7e1847792e/6nnw.gif
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AFProfile
2 | Improved protein complex prediction with AlphaFold-multimer by denoising the MSA profile.
3 | \
4 | \
5 | AFProfile learns a bias to the MSA representation that **improves the predictions** by performing **gradient descent through the AlphaFold-multimer network**. \
6 | We effectively denoise the MSA profile, similar to how a blurry image would be sharpened to become more clear. \
7 | This proves to be a highly efficient process, resulting in a 60-fold speedup compared to AFsample and as efficient as AFM v2.3. \
8 | Read more about it in the paper [here](https://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1012253).
9 |
10 | \
11 |
12 | \
13 | \
14 | AlphaFold2 (including AlphaFold-multimer) is available under the [Apache License, Version 2.0](http://www.apache.org/licenses/LICENSE-2.0) and so is AFProfile, which is a derivative thereof. \
15 | The AlphaFold2 parameters are made available under the terms of the [CC BY 4.0 license](https://creativecommons.org/licenses/by/4.0/legalcode) and have not been modified.
16 | \
17 | **You may not use these files except in compliance with the licenses.**
18 |
19 | ## Optimisation for 6nnw
20 | - Here is an example trajectory for PDBID 6nnw sorted by confidence.
21 |
22 |
23 |
24 | - The final prediction has an MMscore of 0.96 compared to 0.44 using AF-multimer. The [native structure](https://www.rcsb.org/structure/6NNW) is in grey.
25 |
26 |
27 |
28 | The confidence used to denoise the MSA is defined as: \
29 | Confidence = 0.8 iptm + 0.2 ptm \
30 | Where iptm is the predicted [TM-score](https://zhanggroup.org/TM-score/) in the interface and ptm that of the entire complex.
31 |
32 | # Setup
33 |
34 | ## Clone this repository
35 | ```
36 | git clone https://github.com/patrickbryant1/AFProfile.git
37 | ```
38 |
39 | ## Get the AlphaFold-multimer parameters
40 | ```
41 | cd AFProfile
42 | mkdir data/params
43 | wget https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar
44 | tar -xvf alphafold_params_2022-03-02.tar
45 | mv params_model_1_multimer_v2.npz data/params/
46 | rm *.npz
47 | rm alphafold_params_2022-03-02.tar
48 | ```
49 |
50 | ## Install the AlphaFold requirements
51 |
52 | Install all packages into a conda environment (requires https://docs.conda.io/en/latest/miniconda.html)
53 | ```
54 | conda env create -f afprofile.yml
55 | wait
56 | conda activate afprofile
57 | pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
58 | pip install --upgrade numpy
59 | ```
60 | If the conda doesn't work for you - see "pip_pkgs.txt"
61 |
62 | ## Try the test case
63 | Now when you have installed the required packages - you can run a test case on CASP15 target T1123o
64 | \
65 |
66 | ```
67 | cd src
68 | bash AFP.sh
69 | ```
70 |
71 | ## Install the genetic search programs
72 | - We install the genetic search programs from source. This will make the searches faster.
73 |
74 | *hh-suite*
75 | ```
76 | cd src
77 | mkdir hh-suite
78 | cd hh-suite
79 | wget https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz
80 | tar xvfz hhsuite-3.3.0-SSE2-Linux.tar.gz
81 | cd ..
82 | ```
83 |
84 | *hmmer*
85 | ```
86 | cd src
87 | wget http://eddylab.org/software/hmmer/hmmer.tar.gz
88 | tar -xvzf hmmer.tar.gz
89 | rm hmmer.tar.gz
90 | cd hmmer-*
91 | ./configure
92 | make
93 | cd ..
94 | ```
95 |
96 | *kalign*
97 | ```
98 | wget https://github.com/TimoLassmann/kalign/archive/refs/tags/v3.3.2.tar.gz
99 | tar -zxvf v3.3.2.tar.gz
100 | rm v3.3.2.tar.gz
101 | cd kalign-3.3.2/
102 | ./autogen.sh
103 | bash configure
104 | make
105 | make check
106 | make install
107 | cd ..
108 | ```
109 |
110 |
111 | ## Download all databases for AlphaFold
112 | - If you have already installed AlphaFold, you don't need to do this. Then you can simply
113 | provide the paths for the databases in the runscript.
114 |
115 | *Small BFD: 17 GB*
116 | ```
117 | wget https://storage.googleapis.com/alphafold-databases/reduced_dbs/bfd-first_non_consensus_sequences.fasta.gz
118 | gunzip bfd-first_non_consensus_sequences.fasta.gz
119 | mkdir data/small_bfd
120 | mv bfd-first_non_consensus_sequences.fasta data/small_bfd
121 | rm bfd-first_non_consensus_sequences.fasta.gz
122 | ```
123 |
124 | *UNIREF90: 67 GB*
125 | ```
126 | wget https://ftp.ebi.ac.uk/pub/databases/uniprot/uniref/uniref90/uniref90.fasta.gz
127 | gunzip uniref90.fasta.gz
128 | mkdir data/uniref90
129 | mv uniref90.fasta data/uniref90/
130 | rm uniref90.fasta.gz
131 | ```
132 |
133 | *UNIPROT: 105 GB*
134 | ```
135 | wget https://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_trembl.fasta.gz
136 | wget https://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/knowledgebase/complete/uniprot_sprot.fasta.gz
137 | gunzip uniprot_trembl.fasta.gz
138 | gunzip uniprot_sprot.fasta.gz
139 | mkdir data/uniprot
140 | cat uniprot_sprot.fasta >> uniprot_trembl.fasta
141 | mv uniprot_trembl.fasta data/uniprot/uniprot.fasta
142 | rm *.gz
143 | rm uniprot_sprot.fasta
144 | ```
145 |
146 | - The following template databases are not used for the predictions, but needed due to the feature processing.
147 |
148 | *PDB SEQRES: 0.2 GB*
149 | ```
150 | wget https://files.rcsb.org/pub/pdb/derived_data/pdb_seqres.txt.gz
151 | gunzip pdb_seqres.txt.gz
152 | mkdir pdb_seqres
153 | mv pdb_seqres.txt pdb_seqres/
154 | ```
155 |
156 | *MGNIFY: 120 GB*
157 | ```
158 | wget https://storage.googleapis.com/alphafold-databases/v2.3/mgy_clusters_2022_05.fa.gz
159 | gunzip mgy_clusters_2022_05.fa.gz
160 | mkdir mgnify
161 | mv mgy_clusters_2022_05.fa mgnify/
162 | rm mgy_clusters_2022_05.fa.gz
163 | ```
164 |
165 | *MMCIF: 238 GB*
166 | - This may take a while...
167 | ```
168 | mkdir -p data/pdb_mmcif/raw
169 | mkdir data/pdb_mmcif/mmcif_files
170 | rsync --recursive --links --perms --times --compress --info=progress2 --delete --port=33444 rsync.rcsb.org::ftp_data/structures/divided/mmCIF/ data/pdb_mmcif/raw
171 |
172 | find data/pdb_mmcif/raw -type f -iname "*.gz" -exec gunzip
173 | find data/pdb_mmcif/raw -type d -empty -delete
174 | for subdir in data/pdb_mmcif/raw/*
175 | do
176 | mv "${subdir}/"*.cif data/pdb_mmcif/mmcif_files/
177 | done
178 | find data/pdb_mmcif/raw -type d -empty -delete
179 | ```
180 |
181 | # Citation
182 | Bryant P, Noé F. Improved protein complex prediction with AlphaFold-multimer by denoising the MSA profile. PLoS Comput Biol. 2024;20: e1012253.
183 |
--------------------------------------------------------------------------------
/data/H1141/H1141.fasta:
--------------------------------------------------------------------------------
1 | >H1141,subunit1|_0
2 | GLEKDFLPLYFGWFLTKKSSETLRKAGQVFLEELGNHKAFKKELRHFISGDEPKEKLELVSYFGKRPPGVLHCTTKFCDYKAAGAEEYAQQEVVKRSYGKAFKLSISALFVTPKTAGAQVVLTDQELQLWPSDLDKPSASEGLPPGSRAHVTLGCAADVQPVQTGLDLLDILQQVKGGSQGEAVGELPRGKLYSLGKGRWMLSLTKKMEVKAIFTGYYG
3 | >H1141,subunit2|_0
4 | EVQLEESGGGWVHPGGSLRLSCAASGNVFGVNTMAWYRQAPGKQREQRELVASITDYGTTEYADSVKGRFTISGDNAKATVYLQMNSLKPEDTAVYYCNMDLTVMTATSSLYAYDYWGQGTQVTVSS
5 |
--------------------------------------------------------------------------------
/data/H1141/features.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/AFProfile/a6aaaca40c8598b70c0f8f2768edae7e1847792e/data/H1141/features.pkl
--------------------------------------------------------------------------------
/data/H1144/H1144.fasta:
--------------------------------------------------------------------------------
1 | >H1144,subunit1|_0
2 | GLEKDFLPLYFGWFLTKKSSETLRKAGQVFLEELGNHKAFKKELRHFISGDEPKEKLELVSYFGKRPPGVLHCTTKFCDYKAAGAEEYAQQEVVKRSYGKAFKLSISALFVTPKTAGAQVVLTDQELQLWPSDLDKPSASEGLPPGSRAHVTLGCAADVQPVQTGLDLLDILQQVKGGSQGEAVGELPRGKLYSLGKGRWMLSLTKKMEVKAIFTGYYG
3 | >H1144,subunit2|_0
4 | EVQLEESGGGLVQPGGSLRLSCAASGFTFSSYVMSWVRQAPGKGLEWVSDINSGGSRTYYTDSVKGRFTISRDNAKNTLYLQMNSLKPEDTAVYYCARDSLLSTRYLHTSERGQGTQVTVSS
5 |
--------------------------------------------------------------------------------
/data/H1144/features.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/AFProfile/a6aaaca40c8598b70c0f8f2768edae7e1847792e/data/H1144/features.pkl
--------------------------------------------------------------------------------
/data/T1123/T1123.fasta:
--------------------------------------------------------------------------------
1 | > T1123_0
2 | MHHHHHHHHHHSETTYTGPRSIVTPETPIGPSSYPMTPSSLVLMAGYFSGPEISDNFGKYMPLLFQQNTSKVTFRSGSHTIKIVSMVLVDRLMWLDKHFNQYTNEPDGVFGDVGNVFVDNDNVAKVITMSGSSAPANRGATLMLCRATKNIQTFNFAATVYIPAYKVKDGAGGKDVVLNVAQWEANKTLTYPAIPKDTYFMVVTMGGASFTIQRYVVYNEGIGDGLELPAFWGKYLSQLYGFSWSSPTYACVTWEPIYAEEGIPHR
3 | > T1123_1
4 | MHHHHHHHHHHSETTYTGPRSIVTPETPIGPSSYPMTPSSLVLMAGYFSGPEISDNFGKYMPLLFQQNTSKVTFRSGSHTIKIVSMVLVDRLMWLDKHFNQYTNEPDGVFGDVGNVFVDNDNVAKVITMSGSSAPANRGATLMLCRATKNIQTFNFAATVYIPAYKVKDGAGGKDVVLNVAQWEANKTLTYPAIPKDTYFMVVTMGGASFTIQRYVVYNEGIGDGLELPAFWGKYLSQLYGFSWSSPTYACVTWEPIYAEEGIPHR
5 |
--------------------------------------------------------------------------------
/data/T1123/features.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/AFProfile/a6aaaca40c8598b70c0f8f2768edae7e1847792e/data/T1123/features.pkl
--------------------------------------------------------------------------------
/pip_pkgs.txt:
--------------------------------------------------------------------------------
1 | pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
2 | pip install ml-collections==0.1.1
3 | pip install dm-haiku==0.0.11
4 | pip install pandas==1.3.5
5 | pip install biopython==1.81
6 | pip install chex==0.1.5
7 | pip install dm-tree==0.1.8
8 | pip install immutabledict==2.0.0
9 | pip install scipy
10 | pip install tensorflow
11 | pip install numpy==1.21.6
12 |
--------------------------------------------------------------------------------
/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/AFProfile/a6aaaca40c8598b70c0f8f2768edae7e1847792e/src/.DS_Store
--------------------------------------------------------------------------------
/src/AFP.sh:
--------------------------------------------------------------------------------
1 |
2 | # Script for predicting with AlphaFold-multimer using directed sampling (=dropout, This is the procedure implemented by Wallner in CASP15: https://www.biorxiv.org/content/10.1101/2022.12.20.521205v3)
3 | # using a denoising process over the profile. The MSA is denoised by doing gradient descent through AlphaFold-multimer.
4 | # Fill in all the variables below to run the predictions.
5 | # This script assumes that all python packages necessary are in the current path.
6 | #The fasta conventions are the same as for AlphaFold-multimer. See the example files in ../data/
7 |
8 | #Get ID
9 | ID=T1123
10 | FASTA_PATHS=../data/T1123/T1123.fasta
11 | PARAMDIR=../data/ #If v2 is used: Change the _v3 to _v2 in the multimer MODEL_PRESETS in config.py
12 | OUTDIR=../data/
13 | AFDIR=./
14 |
15 | #1. Get MSAs: run generate_msas.sh which runs: run_alphafold_msa_template_only.py - this produces the feats as well (saved as pickle)
16 | #For this test case - the features have already been generated and are available here: ../data/T1123/features.pkl
17 |
18 | #2. Learn residuals to improve the confidence: run_AFP.py
19 | #Run AFM
20 | MODEL_PRESET='multimer'
21 | NUM_RECYCLES=20 #Number of recycles
22 | CONFIDENCE_T=0.95 #At what confidence to stop the search
23 | MAX_ITER=500 #Max number of iterations
24 | LR=0.0001 #Learning rate for ADAM optimizer
25 |
26 | #Run
27 | python3 $AFDIR/run_AFP.py --fasta_paths=$FASTA_PATHS \
28 | --data_dir=$PARAMDIR --model_preset=$MODEL_PRESET \
29 | --num_recycles=$NUM_RECYCLES \
30 | --confidence_threshold=$CONFIDENCE_T \
31 | --max_iter=$MAX_ITER \
32 | --learning_rate=$LR \
33 | --output_dir=$OUTDIR \
34 | --feature_dir=$OUTDIR/$ID/
35 |
--------------------------------------------------------------------------------
/src/alphafold/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """An implementation of the inference pipeline of AlphaFold v2.0."""
15 |
--------------------------------------------------------------------------------
/src/alphafold/common/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Common data types and constants used within Alphafold."""
15 |
--------------------------------------------------------------------------------
/src/alphafold/common/confidence.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Functions for processing confidence metrics."""
16 |
17 | from typing import Dict, Optional, Tuple
18 | import numpy as np
19 | import jax.numpy as jnp
20 | import jax
21 | import scipy.special
22 |
23 |
24 | def compute_plddt(logits: np.ndarray) -> np.ndarray:
25 | """Computes per-residue pLDDT from logits.
26 |
27 | Args:
28 | logits: [num_res, num_bins] output from the PredictedLDDTHead.
29 |
30 | Returns:
31 | plddt: [num_res] per-residue pLDDT.
32 | """
33 | num_bins = logits.shape[-1]
34 | bin_width = 1.0 / num_bins
35 | bin_centers = jnp.arange(start=0.5 * bin_width, stop=1.0, step=bin_width)
36 | probs = jax.nn.softmax(logits, axis=-1)
37 | predicted_lddt_ca = jnp.sum(probs * bin_centers[None, :], axis=-1)
38 | return predicted_lddt_ca * 100
39 |
40 |
41 | def _calculate_bin_centers(breaks: np.ndarray):
42 | """Gets the bin centers from the bin edges.
43 |
44 | Args:
45 | breaks: [num_bins - 1] the error bin edges.
46 |
47 | Returns:
48 | bin_centers: [num_bins] the error bin centers.
49 | """
50 | step = (breaks[1] - breaks[0])
51 |
52 | # Add half-step to get the center
53 | bin_centers = breaks + step / 2
54 | # Add a catch-all bin at the end.
55 | bin_centers = jnp.concatenate([bin_centers, jnp.expand_dims(jnp.array(bin_centers[-1] + step),axis=0)],axis=0)
56 | return bin_centers
57 |
58 |
59 | def _calculate_expected_aligned_error(
60 | alignment_confidence_breaks: np.ndarray,
61 | aligned_distance_error_probs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
62 | """Calculates expected aligned distance errors for every pair of residues.
63 |
64 | Args:
65 | alignment_confidence_breaks: [num_bins - 1] the error bin edges.
66 | aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted
67 | probs for each error bin, for each pair of residues.
68 |
69 | Returns:
70 | predicted_aligned_error: [num_res, num_res] the expected aligned distance
71 | error for each pair of residues.
72 | max_predicted_aligned_error: The maximum predicted error possible.
73 | """
74 | bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
75 |
76 | # Tuple of expected aligned distance error and max possible error.
77 | return (jnp.sum(aligned_distance_error_probs * bin_centers, axis=-1),
78 | jnp.asarray(bin_centers[-1]))
79 |
80 |
81 | def compute_predicted_aligned_error(
82 | logits: np.ndarray,
83 | breaks: np.ndarray) -> Dict[str, np.ndarray]:
84 | """Computes aligned confidence metrics from logits.
85 |
86 | Args:
87 | logits: [num_res, num_res, num_bins] the logits output from
88 | PredictedAlignedErrorHead.
89 | breaks: [num_bins - 1] the error bin edges.
90 |
91 | Returns:
92 | aligned_confidence_probs: [num_res, num_res, num_bins] the predicted
93 | aligned error probabilities over bins for each residue pair.
94 | predicted_aligned_error: [num_res, num_res] the expected aligned distance
95 | error for each pair of residues.
96 | max_predicted_aligned_error: The maximum predicted error possible.
97 | """
98 | aligned_confidence_probs = jax.nn.softmax(
99 | logits,
100 | axis=-1)
101 | predicted_aligned_error, max_predicted_aligned_error = (
102 | _calculate_expected_aligned_error(
103 | alignment_confidence_breaks=breaks,
104 | aligned_distance_error_probs=aligned_confidence_probs))
105 | return {
106 | 'aligned_confidence_probs': aligned_confidence_probs,
107 | 'predicted_aligned_error': predicted_aligned_error,
108 | 'max_predicted_aligned_error': max_predicted_aligned_error,
109 | }
110 |
111 |
112 | def predicted_tm_score(
113 | logits: np.ndarray,
114 | breaks: np.ndarray,
115 | residue_weights: Optional[np.ndarray] = None,
116 | asym_id: Optional[np.ndarray] = None,
117 | interface: bool = False) -> np.ndarray:
118 | """Computes predicted TM alignment or predicted interface TM alignment score.
119 |
120 | Args:
121 | logits: [num_res, num_res, num_bins] the logits output from
122 | PredictedAlignedErrorHead.
123 | breaks: [num_bins] the error bins.
124 | residue_weights: [num_res] the per residue weights to use for the
125 | expectation.
126 | asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
127 | ipTM calculation, i.e. when interface=True.
128 | interface: If True, interface predicted TM score is computed.
129 |
130 | Returns:
131 | ptm_score: The predicted TM alignment or the predicted iTM score.
132 | """
133 |
134 | # residue_weights has to be in [0, 1], but can be floating-point, i.e. the
135 | # exp. resolved head's probability.
136 | if residue_weights is None:
137 | residue_weights = np.ones(logits.shape[0])
138 |
139 | bin_centers = _calculate_bin_centers(breaks)
140 |
141 | num_res = int(jnp.sum(residue_weights))
142 | # Clip num_res to avoid negative/undefined d0.
143 | clipped_num_res = max(num_res, 19)
144 |
145 | # Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick
146 | # "Scoring function for automated assessment of protein structure template
147 | # quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
148 | d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8
149 |
150 | # Convert logits to probs.
151 | probs = jax.nn.softmax(logits, axis=-1)
152 |
153 | # TM-Score term for every bin.
154 | tm_per_bin = 1. / (1 + jnp.square(bin_centers) / jnp.square(d0))
155 | # E_distances tm(distance).
156 | predicted_tm_term = jnp.sum(probs * tm_per_bin, axis=-1)
157 |
158 | pair_mask = jnp.ones(shape=(num_res, num_res), dtype=bool)
159 | if interface:
160 | pair_mask *= asym_id[:, None] != asym_id[None, :]
161 |
162 | predicted_tm_term *= pair_mask
163 |
164 | pair_residue_weights = pair_mask * (
165 | residue_weights[None, :] * residue_weights[:, None])
166 | normed_residue_mask = pair_residue_weights / (1e-8 + jnp.sum(
167 | pair_residue_weights, axis=-1, keepdims=True))
168 | per_alignment = jnp.sum(predicted_tm_term * normed_residue_mask, axis=-1)
169 | return jnp.asarray(per_alignment[(per_alignment * residue_weights).argmax()])
170 |
--------------------------------------------------------------------------------
/src/alphafold/common/protein_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for protein."""
16 |
17 | import os
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | from alphafold.common import protein
22 | from alphafold.common import residue_constants
23 | import numpy as np
24 | # Internal import (7716).
25 |
26 | TEST_DATA_DIR = 'alphafold/common/testdata/'
27 |
28 |
29 | class ProteinTest(parameterized.TestCase):
30 |
31 | def _check_shapes(self, prot, num_res):
32 | """Check that the processed shapes are correct."""
33 | num_atoms = residue_constants.atom_type_num
34 | self.assertEqual((num_res, num_atoms, 3), prot.atom_positions.shape)
35 | self.assertEqual((num_res,), prot.aatype.shape)
36 | self.assertEqual((num_res, num_atoms), prot.atom_mask.shape)
37 | self.assertEqual((num_res,), prot.residue_index.shape)
38 | self.assertEqual((num_res,), prot.chain_index.shape)
39 | self.assertEqual((num_res, num_atoms), prot.b_factors.shape)
40 |
41 | @parameterized.named_parameters(
42 | dict(testcase_name='chain_A',
43 | pdb_file='2rbg.pdb', chain_id='A', num_res=282, num_chains=1),
44 | dict(testcase_name='chain_B',
45 | pdb_file='2rbg.pdb', chain_id='B', num_res=282, num_chains=1),
46 | dict(testcase_name='multichain',
47 | pdb_file='2rbg.pdb', chain_id=None, num_res=564, num_chains=2))
48 | def test_from_pdb_str(self, pdb_file, chain_id, num_res, num_chains):
49 | pdb_file = os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
50 | pdb_file)
51 | with open(pdb_file) as f:
52 | pdb_string = f.read()
53 | prot = protein.from_pdb_string(pdb_string, chain_id)
54 | self._check_shapes(prot, num_res)
55 | self.assertGreaterEqual(prot.aatype.min(), 0)
56 | # Allow equal since unknown restypes have index equal to restype_num.
57 | self.assertLessEqual(prot.aatype.max(), residue_constants.restype_num)
58 | self.assertLen(np.unique(prot.chain_index), num_chains)
59 |
60 | def test_to_pdb(self):
61 | with open(
62 | os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
63 | '2rbg.pdb')) as f:
64 | pdb_string = f.read()
65 | prot = protein.from_pdb_string(pdb_string)
66 | pdb_string_reconstr = protein.to_pdb(prot)
67 |
68 | for line in pdb_string_reconstr.splitlines():
69 | self.assertLen(line, 80)
70 |
71 | prot_reconstr = protein.from_pdb_string(pdb_string_reconstr)
72 |
73 | np.testing.assert_array_equal(prot_reconstr.aatype, prot.aatype)
74 | np.testing.assert_array_almost_equal(
75 | prot_reconstr.atom_positions, prot.atom_positions)
76 | np.testing.assert_array_almost_equal(
77 | prot_reconstr.atom_mask, prot.atom_mask)
78 | np.testing.assert_array_equal(
79 | prot_reconstr.residue_index, prot.residue_index)
80 | np.testing.assert_array_equal(
81 | prot_reconstr.chain_index, prot.chain_index)
82 | np.testing.assert_array_almost_equal(
83 | prot_reconstr.b_factors, prot.b_factors)
84 |
85 | def test_ideal_atom_mask(self):
86 | with open(
87 | os.path.join(absltest.get_default_test_srcdir(), TEST_DATA_DIR,
88 | '2rbg.pdb')) as f:
89 | pdb_string = f.read()
90 | prot = protein.from_pdb_string(pdb_string)
91 | ideal_mask = protein.ideal_atom_mask(prot)
92 | non_ideal_residues = set([102] + list(range(127, 286)))
93 | for i, (res, atom_mask) in enumerate(
94 | zip(prot.residue_index, prot.atom_mask)):
95 | if res in non_ideal_residues:
96 | self.assertFalse(np.all(atom_mask == ideal_mask[i]), msg=f'{res}')
97 | else:
98 | self.assertTrue(np.all(atom_mask == ideal_mask[i]), msg=f'{res}')
99 |
100 | def test_too_many_chains(self):
101 | num_res = protein.PDB_MAX_CHAINS + 1
102 | num_atom_type = residue_constants.atom_type_num
103 | with self.assertRaises(ValueError):
104 | _ = protein.Protein(
105 | atom_positions=np.random.random([num_res, num_atom_type, 3]),
106 | aatype=np.random.randint(0, 21, [num_res]),
107 | atom_mask=np.random.randint(0, 2, [num_res]).astype(np.float32),
108 | residue_index=np.arange(1, num_res+1),
109 | chain_index=np.arange(num_res),
110 | b_factors=np.random.uniform(1, 100, [num_res]))
111 |
112 |
113 | if __name__ == '__main__':
114 | absltest.main()
115 |
--------------------------------------------------------------------------------
/src/alphafold/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Data pipeline for model features."""
15 |
--------------------------------------------------------------------------------
/src/alphafold/data/msa_identifiers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for extracting identifiers from MSA sequence descriptions."""
16 |
17 | import dataclasses
18 | import re
19 | from typing import Optional
20 |
21 |
22 | # Sequences coming from UniProtKB database come in the
23 | # `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
24 | # or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
25 | _UNIPROT_PATTERN = re.compile(
26 | r"""
27 | ^
28 | # UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
29 | (?:tr|sp)
30 | \|
31 | # A primary accession number of the UniProtKB entry.
32 | (?P[A-Za-z0-9]{6,10})
33 | # Occasionally there is a _0 or _1 isoform suffix, which we ignore.
34 | (?:_\d)?
35 | \|
36 | # TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
37 | # protein ID code.
38 | (?:[A-Za-z0-9]+)
39 | _
40 | # A mnemonic species identification code.
41 | (?P([A-Za-z0-9]){1,5})
42 | # Small BFD uses a final value after an underscore, which we ignore.
43 | (?:_\d+)?
44 | $
45 | """,
46 | re.VERBOSE)
47 |
48 |
49 | @dataclasses.dataclass(frozen=True)
50 | class Identifiers:
51 | species_id: str = ''
52 |
53 |
54 | def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
55 | """Gets species from an msa sequence identifier.
56 |
57 | The sequence identifier has the format specified by
58 | _UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
59 | An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
60 |
61 | Args:
62 | msa_sequence_identifier: a sequence identifier.
63 |
64 | Returns:
65 | An `Identifiers` instance with species_id. These
66 | can be empty in the case where no identifier was found.
67 | """
68 | matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
69 | if matches:
70 | return Identifiers(
71 | species_id=matches.group('SpeciesIdentifier'))
72 | return Identifiers()
73 |
74 |
75 | def _extract_sequence_identifier(description: str) -> Optional[str]:
76 | """Extracts sequence identifier from description. Returns None if no match."""
77 | split_description = description.split()
78 | if split_description:
79 | return split_description[0].partition('/')[0]
80 | else:
81 | return None
82 |
83 |
84 | def get_identifiers(description: str) -> Identifiers:
85 | """Computes extra MSA features from the description."""
86 | sequence_identifier = _extract_sequence_identifier(description)
87 | if sequence_identifier is None:
88 | return Identifiers()
89 | else:
90 | return _parse_sequence_identifier(sequence_identifier)
91 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Python wrappers for third party tools."""
15 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/hhblits.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library to run HHblits from Python."""
16 |
17 | import glob
18 | import os
19 | import subprocess
20 | from typing import Any, List, Mapping, Optional, Sequence
21 |
22 | from absl import logging
23 | from alphafold.data.tools import utils
24 | # Internal import (7716).
25 |
26 |
27 | _HHBLITS_DEFAULT_P = 20
28 | _HHBLITS_DEFAULT_Z = 500
29 |
30 |
31 | class HHBlits:
32 | """Python wrapper of the HHblits binary."""
33 |
34 | def __init__(self,
35 | *,
36 | binary_path: str,
37 | databases: Sequence[str],
38 | n_cpu: int = 4,
39 | n_iter: int = 3,
40 | e_value: float = 0.001,
41 | maxseq: int = 1_000_000,
42 | realign_max: int = 100_000,
43 | maxfilt: int = 100_000,
44 | min_prefilter_hits: int = 1000,
45 | all_seqs: bool = False,
46 | alt: Optional[int] = None,
47 | p: int = _HHBLITS_DEFAULT_P,
48 | z: int = _HHBLITS_DEFAULT_Z):
49 | """Initializes the Python HHblits wrapper.
50 |
51 | Args:
52 | binary_path: The path to the HHblits executable.
53 | databases: A sequence of HHblits database paths. This should be the
54 | common prefix for the database files (i.e. up to but not including
55 | _hhm.ffindex etc.)
56 | n_cpu: The number of CPUs to give HHblits.
57 | n_iter: The number of HHblits iterations.
58 | e_value: The E-value, see HHblits docs for more details.
59 | maxseq: The maximum number of rows in an input alignment. Note that this
60 | parameter is only supported in HHBlits version 3.1 and higher.
61 | realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
62 | maxfilt: Max number of hits allowed to pass the 2nd prefilter.
63 | HHblits default: 20000.
64 | min_prefilter_hits: Min number of hits to pass prefilter.
65 | HHblits default: 100.
66 | all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
67 | HHblits default: False.
68 | alt: Show up to this many alternative alignments.
69 | p: Minimum Prob for a hit to be included in the output hhr file.
70 | HHblits default: 20.
71 | z: Hard cap on number of hits reported in the hhr file.
72 | HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
73 |
74 | Raises:
75 | RuntimeError: If HHblits binary not found within the path.
76 | """
77 | self.binary_path = binary_path
78 | self.databases = databases
79 |
80 | for database_path in self.databases:
81 | if not glob.glob(database_path + '_*'):
82 | logging.error('Could not find HHBlits database %s', database_path)
83 | raise ValueError(f'Could not find HHBlits database {database_path}')
84 |
85 | self.n_cpu = n_cpu
86 | self.n_iter = n_iter
87 | self.e_value = e_value
88 | self.maxseq = maxseq
89 | self.realign_max = realign_max
90 | self.maxfilt = maxfilt
91 | self.min_prefilter_hits = min_prefilter_hits
92 | self.all_seqs = all_seqs
93 | self.alt = alt
94 | self.p = p
95 | self.z = z
96 |
97 | def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]:
98 | """Queries the database using HHblits."""
99 | with utils.tmpdir_manager() as query_tmp_dir:
100 | a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
101 |
102 | db_cmd = []
103 | for db_path in self.databases:
104 | db_cmd.append('-d')
105 | db_cmd.append(db_path)
106 | cmd = [
107 | self.binary_path,
108 | '-i', input_fasta_path,
109 | '-cpu', str(self.n_cpu),
110 | '-oa3m', a3m_path,
111 | '-o', '/dev/null',
112 | '-n', str(self.n_iter),
113 | '-e', str(self.e_value),
114 | '-maxseq', str(self.maxseq),
115 | '-realign_max', str(self.realign_max),
116 | '-maxfilt', str(self.maxfilt),
117 | '-min_prefilter_hits', str(self.min_prefilter_hits)]
118 | if self.all_seqs:
119 | cmd += ['-all']
120 | if self.alt:
121 | cmd += ['-alt', str(self.alt)]
122 | if self.p != _HHBLITS_DEFAULT_P:
123 | cmd += ['-p', str(self.p)]
124 | if self.z != _HHBLITS_DEFAULT_Z:
125 | cmd += ['-Z', str(self.z)]
126 | cmd += db_cmd
127 |
128 | logging.info('Launching subprocess "%s"', ' '.join(cmd))
129 | process = subprocess.Popen(
130 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
131 |
132 | with utils.timing('HHblits query'):
133 | stdout, stderr = process.communicate()
134 | retcode = process.wait()
135 |
136 | if retcode:
137 | # Logs have a 15k character limit, so log HHblits error line by line.
138 | logging.error('HHblits failed. HHblits stderr begin:')
139 | for error_line in stderr.decode('utf-8').splitlines():
140 | if error_line.strip():
141 | logging.error(error_line.strip())
142 | logging.error('HHblits stderr end')
143 | raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % (
144 | stdout.decode('utf-8'), stderr[:500_000].decode('utf-8')))
145 |
146 | with open(a3m_path) as f:
147 | a3m = f.read()
148 |
149 | raw_output = dict(
150 | a3m=a3m,
151 | output=stdout,
152 | stderr=stderr,
153 | n_iter=self.n_iter,
154 | e_value=self.e_value)
155 | return [raw_output]
156 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/hhsearch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library to run HHsearch from Python."""
16 |
17 | import glob
18 | import os
19 | import subprocess
20 | from typing import Sequence
21 |
22 | from absl import logging
23 |
24 | from alphafold.data import parsers
25 | from alphafold.data.tools import utils
26 | # Internal import (7716).
27 |
28 |
29 | class HHSearch:
30 | """Python wrapper of the HHsearch binary."""
31 |
32 | def __init__(self,
33 | *,
34 | binary_path: str,
35 | databases: Sequence[str],
36 | maxseq: int = 1_000_000):
37 | """Initializes the Python HHsearch wrapper.
38 |
39 | Args:
40 | binary_path: The path to the HHsearch executable.
41 | databases: A sequence of HHsearch database paths. This should be the
42 | common prefix for the database files (i.e. up to but not including
43 | _hhm.ffindex etc.)
44 | maxseq: The maximum number of rows in an input alignment. Note that this
45 | parameter is only supported in HHBlits version 3.1 and higher.
46 |
47 | Raises:
48 | RuntimeError: If HHsearch binary not found within the path.
49 | """
50 | self.binary_path = binary_path
51 | self.databases = databases
52 | self.maxseq = maxseq
53 |
54 | for database_path in self.databases:
55 | if not glob.glob(database_path + '_*'):
56 | logging.error('Could not find HHsearch database %s', database_path)
57 | raise ValueError(f'Could not find HHsearch database {database_path}')
58 |
59 | @property
60 | def output_format(self) -> str:
61 | return 'hhr'
62 |
63 | @property
64 | def input_format(self) -> str:
65 | return 'a3m'
66 |
67 | def query(self, a3m: str) -> str:
68 | """Queries the database using HHsearch using a given a3m."""
69 | with utils.tmpdir_manager() as query_tmp_dir:
70 | input_path = os.path.join(query_tmp_dir, 'query.a3m')
71 | hhr_path = os.path.join(query_tmp_dir, 'output.hhr')
72 | with open(input_path, 'w') as f:
73 | f.write(a3m)
74 |
75 | db_cmd = []
76 | for db_path in self.databases:
77 | db_cmd.append('-d')
78 | db_cmd.append(db_path)
79 | cmd = [self.binary_path,
80 | '-i', input_path,
81 | '-o', hhr_path,
82 | '-maxseq', str(self.maxseq)
83 | ] + db_cmd
84 |
85 | logging.info('Launching subprocess "%s"', ' '.join(cmd))
86 | process = subprocess.Popen(
87 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
88 | with utils.timing('HHsearch query'):
89 | stdout, stderr = process.communicate()
90 | retcode = process.wait()
91 |
92 | if retcode:
93 | # Stderr is truncated to prevent proto size errors in Beam.
94 | raise RuntimeError(
95 | 'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
96 | stdout.decode('utf-8'), stderr[:100_000].decode('utf-8')))
97 |
98 | with open(hhr_path) as f:
99 | hhr = f.read()
100 | return hhr
101 |
102 | def get_template_hits(self,
103 | output_string: str,
104 | input_sequence: str) -> Sequence[parsers.TemplateHit]:
105 | """Gets parsed template hits from the raw string output by the tool."""
106 | del input_sequence # Used by hmmseach but not needed for hhsearch.
107 | return parsers.parse_hhr(output_string)
108 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/hmmbuild.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
16 |
17 | import os
18 | import re
19 | import subprocess
20 |
21 | from absl import logging
22 | from alphafold.data.tools import utils
23 | # Internal import (7716).
24 |
25 |
26 | class Hmmbuild(object):
27 | """Python wrapper of the hmmbuild binary."""
28 |
29 | def __init__(self,
30 | *,
31 | binary_path: str,
32 | singlemx: bool = False):
33 | """Initializes the Python hmmbuild wrapper.
34 |
35 | Args:
36 | binary_path: The path to the hmmbuild executable.
37 | singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
38 | just use a common substitution score matrix.
39 |
40 | Raises:
41 | RuntimeError: If hmmbuild binary not found within the path.
42 | """
43 | self.binary_path = binary_path
44 | self.singlemx = singlemx
45 |
46 | def build_profile_from_sto(self, sto: str, model_construction='fast') -> str:
47 | """Builds a HHM for the aligned sequences given as an A3M string.
48 |
49 | Args:
50 | sto: A string with the aligned sequences in the Stockholm format.
51 | model_construction: Whether to use reference annotation in the msa to
52 | determine consensus columns ('hand') or default ('fast').
53 |
54 | Returns:
55 | A string with the profile in the HMM format.
56 |
57 | Raises:
58 | RuntimeError: If hmmbuild fails.
59 | """
60 | return self._build_profile(sto, model_construction=model_construction)
61 |
62 | def build_profile_from_a3m(self, a3m: str) -> str:
63 | """Builds a HHM for the aligned sequences given as an A3M string.
64 |
65 | Args:
66 | a3m: A string with the aligned sequences in the A3M format.
67 |
68 | Returns:
69 | A string with the profile in the HMM format.
70 |
71 | Raises:
72 | RuntimeError: If hmmbuild fails.
73 | """
74 | lines = []
75 | for line in a3m.splitlines():
76 | if not line.startswith('>'):
77 | line = re.sub('[a-z]+', '', line) # Remove inserted residues.
78 | lines.append(line + '\n')
79 | msa = ''.join(lines)
80 | return self._build_profile(msa, model_construction='fast')
81 |
82 | def _build_profile(self, msa: str, model_construction: str = 'fast') -> str:
83 | """Builds a HMM for the aligned sequences given as an MSA string.
84 |
85 | Args:
86 | msa: A string with the aligned sequences, in A3M or STO format.
87 | model_construction: Whether to use reference annotation in the msa to
88 | determine consensus columns ('hand') or default ('fast').
89 |
90 | Returns:
91 | A string with the profile in the HMM format.
92 |
93 | Raises:
94 | RuntimeError: If hmmbuild fails.
95 | ValueError: If unspecified arguments are provided.
96 | """
97 | if model_construction not in {'hand', 'fast'}:
98 | raise ValueError(f'Invalid model_construction {model_construction} - only'
99 | 'hand and fast supported.')
100 |
101 | with utils.tmpdir_manager() as query_tmp_dir:
102 | input_query = os.path.join(query_tmp_dir, 'query.msa')
103 | output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm')
104 |
105 | with open(input_query, 'w') as f:
106 | f.write(msa)
107 |
108 | cmd = [self.binary_path]
109 | # If adding flags, we have to do so before the output and input:
110 |
111 | if model_construction == 'hand':
112 | cmd.append(f'--{model_construction}')
113 | if self.singlemx:
114 | cmd.append('--singlemx')
115 | cmd.extend([
116 | '--amino',
117 | output_hmm_path,
118 | input_query,
119 | ])
120 |
121 | logging.info('Launching subprocess %s', cmd)
122 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
123 | stderr=subprocess.PIPE)
124 |
125 | with utils.timing('hmmbuild query'):
126 | stdout, stderr = process.communicate()
127 | retcode = process.wait()
128 | logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n',
129 | stdout.decode('utf-8'), stderr.decode('utf-8'))
130 |
131 | if retcode:
132 | raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n'
133 | % (stdout.decode('utf-8'), stderr.decode('utf-8')))
134 |
135 | with open(output_hmm_path, encoding='utf-8') as f:
136 | hmm = f.read()
137 |
138 | return hmm
139 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/hmmsearch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A Python wrapper for hmmsearch - search profile against a sequence db."""
16 |
17 | import os
18 | import subprocess
19 | from typing import Optional, Sequence
20 |
21 | from absl import logging
22 | from alphafold.data import parsers
23 | from alphafold.data.tools import hmmbuild
24 | from alphafold.data.tools import utils
25 | # Internal import (7716).
26 |
27 |
28 | class Hmmsearch(object):
29 | """Python wrapper of the hmmsearch binary."""
30 |
31 | def __init__(self,
32 | *,
33 | binary_path: str,
34 | hmmbuild_binary_path: str,
35 | database_path: str,
36 | flags: Optional[Sequence[str]] = None):
37 | """Initializes the Python hmmsearch wrapper.
38 |
39 | Args:
40 | binary_path: The path to the hmmsearch executable.
41 | hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
42 | an hmm from an input a3m.
43 | database_path: The path to the hmmsearch database (FASTA format).
44 | flags: List of flags to be used by hmmsearch.
45 |
46 | Raises:
47 | RuntimeError: If hmmsearch binary not found within the path.
48 | """
49 | self.binary_path = binary_path
50 | self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
51 | self.database_path = database_path
52 | if flags is None:
53 | # Default hmmsearch run settings.
54 | flags = ['--F1', '0.1',
55 | '--F2', '0.1',
56 | '--F3', '0.1',
57 | '--incE', '100',
58 | '-E', '100',
59 | '--domE', '100',
60 | '--incdomE', '100']
61 | self.flags = flags
62 |
63 | if not os.path.exists(self.database_path):
64 | logging.error('Could not find hmmsearch database %s', database_path)
65 | raise ValueError(f'Could not find hmmsearch database {database_path}')
66 |
67 | @property
68 | def output_format(self) -> str:
69 | return 'sto'
70 |
71 | @property
72 | def input_format(self) -> str:
73 | return 'sto'
74 |
75 | def query(self, msa_sto: str) -> str:
76 | """Queries the database using hmmsearch using a given stockholm msa."""
77 | hmm = self.hmmbuild_runner.build_profile_from_sto(msa_sto,
78 | model_construction='hand')
79 | return self.query_with_hmm(hmm)
80 |
81 | def query_with_hmm(self, hmm: str) -> str:
82 | """Queries the database using hmmsearch using a given hmm."""
83 | with utils.tmpdir_manager() as query_tmp_dir:
84 | hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
85 | out_path = os.path.join(query_tmp_dir, 'output.sto')
86 | with open(hmm_input_path, 'w') as f:
87 | f.write(hmm)
88 |
89 | cmd = [
90 | self.binary_path,
91 | '--noali', # Don't include the alignment in stdout.
92 | '--cpu', '8'
93 | ]
94 | # If adding flags, we have to do so before the output and input:
95 | if self.flags:
96 | cmd.extend(self.flags)
97 | cmd.extend([
98 | '-A', out_path,
99 | hmm_input_path,
100 | self.database_path,
101 | ])
102 |
103 | logging.info('Launching sub-process %s', cmd)
104 | process = subprocess.Popen(
105 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
106 | with utils.timing(
107 | f'hmmsearch ({os.path.basename(self.database_path)}) query'):
108 | stdout, stderr = process.communicate()
109 | retcode = process.wait()
110 |
111 | if retcode:
112 | raise RuntimeError(
113 | 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
114 | stdout.decode('utf-8'), stderr.decode('utf-8')))
115 |
116 | with open(out_path) as f:
117 | out_msa = f.read()
118 |
119 | return out_msa
120 |
121 | def get_template_hits(self,
122 | output_string: str,
123 | input_sequence: str) -> Sequence[parsers.TemplateHit]:
124 | """Gets parsed template hits from the raw string output by the tool."""
125 | a3m_string = parsers.convert_stockholm_to_a3m(output_string,
126 | remove_first_row_gaps=False)
127 | template_hits = parsers.parse_hmmsearch_a3m(
128 | query_sequence=input_sequence,
129 | a3m_string=a3m_string,
130 | skip_first=False)
131 | return template_hits
132 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/jackhmmer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Library to run Jackhmmer from Python."""
16 |
17 | from concurrent import futures
18 | import glob
19 | import os
20 | import subprocess
21 | from typing import Any, Callable, Mapping, Optional, Sequence
22 | from urllib import request
23 |
24 | from absl import logging
25 |
26 | from alphafold.data import parsers
27 | from alphafold.data.tools import utils
28 | # Internal import (7716).
29 |
30 |
31 | class Jackhmmer:
32 | """Python wrapper of the Jackhmmer binary."""
33 |
34 | def __init__(self,
35 | *,
36 | binary_path: str,
37 | database_path: str,
38 | n_cpu: int = 8,
39 | n_iter: int = 1,
40 | e_value: float = 0.0001,
41 | z_value: Optional[int] = None,
42 | get_tblout: bool = False,
43 | filter_f1: float = 0.0005,
44 | filter_f2: float = 0.00005,
45 | filter_f3: float = 0.0000005,
46 | incdom_e: Optional[float] = None,
47 | dom_e: Optional[float] = None,
48 | num_streamed_chunks: Optional[int] = None,
49 | streaming_callback: Optional[Callable[[int], None]] = None):
50 | """Initializes the Python Jackhmmer wrapper.
51 |
52 | Args:
53 | binary_path: The path to the jackhmmer executable.
54 | database_path: The path to the jackhmmer database (FASTA format).
55 | n_cpu: The number of CPUs to give Jackhmmer.
56 | n_iter: The number of Jackhmmer iterations.
57 | e_value: The E-value, see Jackhmmer docs for more details.
58 | z_value: The Z-value, see Jackhmmer docs for more details.
59 | get_tblout: Whether to save tblout string.
60 | filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
61 | filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
62 | filter_f3: Forward pre-filter, set to >1.0 to turn off.
63 | incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
64 | round.
65 | dom_e: Domain e-value criteria for inclusion in tblout.
66 | num_streamed_chunks: Number of database chunks to stream over.
67 | streaming_callback: Callback function run after each chunk iteration with
68 | the iteration number as argument.
69 | """
70 | self.binary_path = binary_path
71 | self.database_path = database_path
72 | self.num_streamed_chunks = num_streamed_chunks
73 |
74 | if not os.path.exists(self.database_path) and num_streamed_chunks is None:
75 | logging.error('Could not find Jackhmmer database %s', database_path)
76 | raise ValueError(f'Could not find Jackhmmer database {database_path}')
77 |
78 | self.n_cpu = n_cpu
79 | self.n_iter = n_iter
80 | self.e_value = e_value
81 | self.z_value = z_value
82 | self.filter_f1 = filter_f1
83 | self.filter_f2 = filter_f2
84 | self.filter_f3 = filter_f3
85 | self.incdom_e = incdom_e
86 | self.dom_e = dom_e
87 | self.get_tblout = get_tblout
88 | self.streaming_callback = streaming_callback
89 |
90 | def _query_chunk(self,
91 | input_fasta_path: str,
92 | database_path: str,
93 | max_sequences: Optional[int] = None) -> Mapping[str, Any]:
94 | """Queries the database chunk using Jackhmmer."""
95 | with utils.tmpdir_manager() as query_tmp_dir:
96 | sto_path = os.path.join(query_tmp_dir, 'output.sto')
97 |
98 | # The F1/F2/F3 are the expected proportion to pass each of the filtering
99 | # stages (which get progressively more expensive), reducing these
100 | # speeds up the pipeline at the expensive of sensitivity. They are
101 | # currently set very low to make querying Mgnify run in a reasonable
102 | # amount of time.
103 | cmd_flags = [
104 | # Don't pollute stdout with Jackhmmer output.
105 | '-o', '/dev/null',
106 | '-A', sto_path,
107 | '--noali',
108 | '--F1', str(self.filter_f1),
109 | '--F2', str(self.filter_f2),
110 | '--F3', str(self.filter_f3),
111 | '--incE', str(self.e_value),
112 | # Report only sequences with E-values <= x in per-sequence output.
113 | '-E', str(self.e_value),
114 | '--cpu', str(self.n_cpu),
115 | '-N', str(self.n_iter)
116 | ]
117 | if self.get_tblout:
118 | tblout_path = os.path.join(query_tmp_dir, 'tblout.txt')
119 | cmd_flags.extend(['--tblout', tblout_path])
120 |
121 | if self.z_value:
122 | cmd_flags.extend(['-Z', str(self.z_value)])
123 |
124 | if self.dom_e is not None:
125 | cmd_flags.extend(['--domE', str(self.dom_e)])
126 |
127 | if self.incdom_e is not None:
128 | cmd_flags.extend(['--incdomE', str(self.incdom_e)])
129 |
130 | cmd = [self.binary_path] + cmd_flags + [input_fasta_path,
131 | database_path]
132 |
133 | logging.info('Launching subprocess "%s"', ' '.join(cmd))
134 | process = subprocess.Popen(
135 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
136 | with utils.timing(
137 | f'Jackhmmer ({os.path.basename(database_path)}) query'):
138 | _, stderr = process.communicate()
139 | retcode = process.wait()
140 |
141 | if retcode:
142 | raise RuntimeError(
143 | 'Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8'))
144 |
145 | # Get e-values for each target name
146 | tbl = ''
147 | if self.get_tblout:
148 | with open(tblout_path) as f:
149 | tbl = f.read()
150 |
151 | if max_sequences is None:
152 | with open(sto_path) as f:
153 | sto = f.read()
154 | else:
155 | sto = parsers.truncate_stockholm_msa(sto_path, max_sequences)
156 |
157 | raw_output = dict(
158 | sto=sto,
159 | tbl=tbl,
160 | stderr=stderr,
161 | n_iter=self.n_iter,
162 | e_value=self.e_value)
163 |
164 | return raw_output
165 |
166 | def query(self,
167 | input_fasta_path: str,
168 | max_sequences: Optional[int] = None) -> Sequence[Mapping[str, Any]]:
169 | """Queries the database using Jackhmmer."""
170 | if self.num_streamed_chunks is None:
171 | single_chunk_result = self._query_chunk(
172 | input_fasta_path, self.database_path, max_sequences)
173 | return [single_chunk_result]
174 |
175 | db_basename = os.path.basename(self.database_path)
176 | db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}'
177 | db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}'
178 |
179 | # Remove existing files to prevent OOM
180 | for f in glob.glob(db_local_chunk('[0-9]*')):
181 | try:
182 | os.remove(f)
183 | except OSError:
184 | print(f'OSError while deleting {f}')
185 |
186 | # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
187 | with futures.ThreadPoolExecutor(max_workers=2) as executor:
188 | chunked_output = []
189 | for i in range(1, self.num_streamed_chunks + 1):
190 | # Copy the chunk locally
191 | if i == 1:
192 | future = executor.submit(
193 | request.urlretrieve, db_remote_chunk(i), db_local_chunk(i))
194 | if i < self.num_streamed_chunks:
195 | next_future = executor.submit(
196 | request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1))
197 |
198 | # Run Jackhmmer with the chunk
199 | future.result()
200 | chunked_output.append(self._query_chunk(
201 | input_fasta_path, db_local_chunk(i), max_sequences))
202 |
203 | # Remove the local copy of the chunk
204 | os.remove(db_local_chunk(i))
205 | # Do not set next_future for the last chunk so that this works even for
206 | # databases with only 1 chunk.
207 | if i < self.num_streamed_chunks:
208 | future = next_future
209 | if self.streaming_callback:
210 | self.streaming_callback(i)
211 | return chunked_output
212 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/kalign.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A Python wrapper for Kalign."""
16 | import os
17 | import subprocess
18 | from typing import Sequence
19 |
20 | from absl import logging
21 |
22 | from alphafold.data.tools import utils
23 | # Internal import (7716).
24 |
25 |
26 | def _to_a3m(sequences: Sequence[str]) -> str:
27 | """Converts sequences to an a3m file."""
28 | names = ['sequence %d' % i for i in range(1, len(sequences) + 1)]
29 | a3m = []
30 | for sequence, name in zip(sequences, names):
31 | a3m.append(u'>' + name + u'\n')
32 | a3m.append(sequence + u'\n')
33 | return ''.join(a3m)
34 |
35 |
36 | class Kalign:
37 | """Python wrapper of the Kalign binary."""
38 |
39 | def __init__(self, *, binary_path: str):
40 | """Initializes the Python Kalign wrapper.
41 |
42 | Args:
43 | binary_path: The path to the Kalign binary.
44 |
45 | Raises:
46 | RuntimeError: If Kalign binary not found within the path.
47 | """
48 | self.binary_path = binary_path
49 |
50 | def align(self, sequences: Sequence[str]) -> str:
51 | """Aligns the sequences and returns the alignment in A3M string.
52 |
53 | Args:
54 | sequences: A list of query sequence strings. The sequences have to be at
55 | least 6 residues long (Kalign requires this). Note that the order in
56 | which you give the sequences might alter the output slightly as
57 | different alignment tree might get constructed.
58 |
59 | Returns:
60 | A string with the alignment in a3m format.
61 |
62 | Raises:
63 | RuntimeError: If Kalign fails.
64 | ValueError: If any of the sequences is less than 6 residues long.
65 | """
66 | logging.info('Aligning %d sequences', len(sequences))
67 |
68 | for s in sequences:
69 | if len(s) < 6:
70 | raise ValueError('Kalign requires all sequences to be at least 6 '
71 | 'residues long. Got %s (%d residues).' % (s, len(s)))
72 |
73 | with utils.tmpdir_manager() as query_tmp_dir:
74 | input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta')
75 | output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
76 |
77 | with open(input_fasta_path, 'w') as f:
78 | f.write(_to_a3m(sequences))
79 |
80 | cmd = [
81 | self.binary_path,
82 | '-i', input_fasta_path,
83 | '-o', output_a3m_path,
84 | '-format', 'fasta',
85 | ]
86 |
87 | logging.info('Launching subprocess "%s"', ' '.join(cmd))
88 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
89 | stderr=subprocess.PIPE)
90 |
91 | with utils.timing('Kalign query'):
92 | stdout, stderr = process.communicate()
93 | retcode = process.wait()
94 | logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n',
95 | stdout.decode('utf-8'), stderr.decode('utf-8'))
96 |
97 | if retcode:
98 | raise RuntimeError('Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n'
99 | % (stdout.decode('utf-8'), stderr.decode('utf-8')))
100 |
101 | with open(output_a3m_path) as f:
102 | a3m = f.read()
103 |
104 | return a3m
105 |
--------------------------------------------------------------------------------
/src/alphafold/data/tools/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Common utilities for data pipeline tools."""
15 | import contextlib
16 | import shutil
17 | import tempfile
18 | import time
19 | from typing import Optional
20 |
21 | from absl import logging
22 |
23 |
24 | @contextlib.contextmanager
25 | def tmpdir_manager(base_dir: Optional[str] = None):
26 | """Context manager that deletes a temporary directory on exit."""
27 | tmpdir = tempfile.mkdtemp(dir=base_dir)
28 | try:
29 | yield tmpdir
30 | finally:
31 | shutil.rmtree(tmpdir, ignore_errors=True)
32 |
33 |
34 | @contextlib.contextmanager
35 | def timing(msg: str):
36 | logging.info('Started %s', msg)
37 | tic = time.time()
38 | yield
39 | toc = time.time()
40 | logging.info('Finished %s in %.3f seconds', msg, toc - tic)
41 |
--------------------------------------------------------------------------------
/src/alphafold/model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Alphafold model."""
15 |
--------------------------------------------------------------------------------
/src/alphafold/model/all_atom_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for all_atom."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from alphafold.model import all_atom
20 | from alphafold.model import r3
21 | import numpy as np
22 |
23 | L1_CLAMP_DISTANCE = 10
24 |
25 |
26 | def get_identity_rigid(shape):
27 | """Returns identity rigid transform."""
28 |
29 | ones = np.ones(shape)
30 | zeros = np.zeros(shape)
31 | rot = r3.Rots(ones, zeros, zeros,
32 | zeros, ones, zeros,
33 | zeros, zeros, ones)
34 | trans = r3.Vecs(zeros, zeros, zeros)
35 | return r3.Rigids(rot, trans)
36 |
37 |
38 | def get_global_rigid_transform(rot_angle, translation, bcast_dims):
39 | """Returns rigid transform that globally rotates/translates by same amount."""
40 |
41 | rot_angle = np.asarray(rot_angle)
42 | translation = np.asarray(translation)
43 | if bcast_dims:
44 | for _ in range(bcast_dims):
45 | rot_angle = np.expand_dims(rot_angle, 0)
46 | translation = np.expand_dims(translation, 0)
47 | sin_angle = np.sin(np.deg2rad(rot_angle))
48 | cos_angle = np.cos(np.deg2rad(rot_angle))
49 | ones = np.ones_like(sin_angle)
50 | zeros = np.zeros_like(sin_angle)
51 | rot = r3.Rots(ones, zeros, zeros,
52 | zeros, cos_angle, -sin_angle,
53 | zeros, sin_angle, cos_angle)
54 | trans = r3.Vecs(translation[..., 0], translation[..., 1], translation[..., 2])
55 | return r3.Rigids(rot, trans)
56 |
57 |
58 | class AllAtomTest(parameterized.TestCase, absltest.TestCase):
59 |
60 | @parameterized.named_parameters(
61 | ('identity', 0, [0, 0, 0]),
62 | ('rot_90', 90, [0, 0, 0]),
63 | ('trans_10', 0, [0, 0, 10]),
64 | ('rot_174_trans_1', 174, [1, 1, 1]))
65 | def test_frame_aligned_point_error_perfect_on_global_transform(
66 | self, rot_angle, translation):
67 | """Tests global transform between target and preds gives perfect score."""
68 |
69 | # pylint: disable=bad-whitespace
70 | target_positions = np.array(
71 | [[ 21.182, 23.095, 19.731],
72 | [ 22.055, 20.919, 17.294],
73 | [ 24.599, 20.005, 15.041],
74 | [ 25.567, 18.214, 12.166],
75 | [ 28.063, 17.082, 10.043],
76 | [ 28.779, 15.569, 6.985],
77 | [ 30.581, 13.815, 4.612],
78 | [ 29.258, 12.193, 2.296]])
79 | # pylint: enable=bad-whitespace
80 | global_rigid_transform = get_global_rigid_transform(
81 | rot_angle, translation, 1)
82 |
83 | target_positions = r3.vecs_from_tensor(target_positions)
84 | pred_positions = r3.rigids_mul_vecs(
85 | global_rigid_transform, target_positions)
86 | positions_mask = np.ones(target_positions.x.shape[0])
87 |
88 | target_frames = get_identity_rigid(10)
89 | pred_frames = r3.rigids_mul_rigids(global_rigid_transform, target_frames)
90 | frames_mask = np.ones(10)
91 |
92 | fape = all_atom.frame_aligned_point_error(
93 | pred_frames, target_frames, frames_mask, pred_positions,
94 | target_positions, positions_mask, L1_CLAMP_DISTANCE,
95 | L1_CLAMP_DISTANCE, epsilon=0)
96 | self.assertAlmostEqual(fape, 0.)
97 |
98 | @parameterized.named_parameters(
99 | ('identity',
100 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
101 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
102 | 0.),
103 | ('shift_2.5',
104 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
105 | [[2.5, 0, 0], [7.5, 0, 0], [7.5, 0, 0]],
106 | 0.25),
107 | ('shift_5',
108 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
109 | [[5, 0, 0], [10, 0, 0], [15, 0, 0]],
110 | 0.5),
111 | ('shift_10',
112 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
113 | [[10, 0, 0], [15, 0, 0], [0, 0, 0]],
114 | 1.))
115 | def test_frame_aligned_point_error_matches_expected(
116 | self, target_positions, pred_positions, expected_alddt):
117 | """Tests score matches expected."""
118 |
119 | target_frames = get_identity_rigid(2)
120 | pred_frames = target_frames
121 | frames_mask = np.ones(2)
122 |
123 | target_positions = r3.vecs_from_tensor(np.array(target_positions))
124 | pred_positions = r3.vecs_from_tensor(np.array(pred_positions))
125 | positions_mask = np.ones(target_positions.x.shape[0])
126 |
127 | alddt = all_atom.frame_aligned_point_error(
128 | pred_frames, target_frames, frames_mask, pred_positions,
129 | target_positions, positions_mask, L1_CLAMP_DISTANCE,
130 | L1_CLAMP_DISTANCE, epsilon=0)
131 | self.assertAlmostEqual(alddt, expected_alddt)
132 |
133 |
134 | if __name__ == '__main__':
135 | absltest.main()
136 |
--------------------------------------------------------------------------------
/src/alphafold/model/common_modules.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A collection of common Haiku modules for use in protein folding."""
16 | import numbers
17 | from typing import Union, Sequence
18 |
19 | import haiku as hk
20 | import jax.numpy as jnp
21 | import numpy as np
22 |
23 |
24 | # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
25 | TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray(.87962566103423978,
26 | dtype=np.float32)
27 |
28 |
29 | def get_initializer_scale(initializer_name, input_shape):
30 | """Get Initializer for weights and scale to multiply activations by."""
31 |
32 | if initializer_name == 'zeros':
33 | w_init = hk.initializers.Constant(0.0)
34 | else:
35 | # fan-in scaling
36 | scale = 1.
37 | for channel_dim in input_shape:
38 | scale /= channel_dim
39 | if initializer_name == 'relu':
40 | scale *= 2
41 |
42 | noise_scale = scale
43 |
44 | stddev = np.sqrt(noise_scale)
45 | # Adjust stddev for truncation.
46 | stddev = stddev / TRUNCATED_NORMAL_STDDEV_FACTOR
47 | w_init = hk.initializers.TruncatedNormal(mean=0.0, stddev=stddev)
48 |
49 | return w_init
50 |
51 |
52 | class Linear(hk.Module):
53 | """Protein folding specific Linear module.
54 |
55 | This differs from the standard Haiku Linear in a few ways:
56 | * It supports inputs and outputs of arbitrary rank
57 | * Initializers are specified by strings
58 | """
59 |
60 | def __init__(self,
61 | num_output: Union[int, Sequence[int]],
62 | initializer: str = 'linear',
63 | num_input_dims: int = 1,
64 | use_bias: bool = True,
65 | bias_init: float = 0.,
66 | precision = None,
67 | name: str = 'linear'):
68 | """Constructs Linear Module.
69 |
70 | Args:
71 | num_output: Number of output channels. Can be tuple when outputting
72 | multiple dimensions.
73 | initializer: What initializer to use, should be one of {'linear', 'relu',
74 | 'zeros'}
75 | num_input_dims: Number of dimensions from the end to project.
76 | use_bias: Whether to include trainable bias
77 | bias_init: Value used to initialize bias.
78 | precision: What precision to use for matrix multiplication, defaults
79 | to None.
80 | name: Name of module, used for name scopes.
81 | """
82 | super().__init__(name=name)
83 | if isinstance(num_output, numbers.Integral):
84 | self.output_shape = (num_output,)
85 | else:
86 | self.output_shape = tuple(num_output)
87 | self.initializer = initializer
88 | self.use_bias = use_bias
89 | self.bias_init = bias_init
90 | self.num_input_dims = num_input_dims
91 | self.num_output_dims = len(self.output_shape)
92 | self.precision = precision
93 |
94 | def __call__(self, inputs):
95 | """Connects Module.
96 |
97 | Args:
98 | inputs: Tensor with at least num_input_dims dimensions.
99 |
100 | Returns:
101 | output of shape [...] + num_output.
102 | """
103 |
104 | num_input_dims = self.num_input_dims
105 |
106 | if self.num_input_dims > 0:
107 | in_shape = inputs.shape[-self.num_input_dims:]
108 | else:
109 | in_shape = ()
110 |
111 | weight_init = get_initializer_scale(self.initializer, in_shape)
112 |
113 | in_letters = 'abcde'[:self.num_input_dims]
114 | out_letters = 'hijkl'[:self.num_output_dims]
115 |
116 | weight_shape = in_shape + self.output_shape
117 | weights = hk.get_parameter('weights', weight_shape, inputs.dtype,
118 | weight_init)
119 |
120 | equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}'
121 |
122 | output = jnp.einsum(equation, inputs, weights, precision=self.precision)
123 |
124 | if self.use_bias:
125 | bias = hk.get_parameter('bias', self.output_shape, inputs.dtype,
126 | hk.initializers.Constant(self.bias_init))
127 | output += bias
128 |
129 | return output
130 |
131 |
132 | class LayerNorm(hk.LayerNorm):
133 | """LayerNorm module.
134 |
135 | Equivalent to hk.LayerNorm but with different parameter shapes: they are
136 | always vectors rather than possibly higher-rank tensors. This makes it easier
137 | to change the layout whilst keep the model weight-compatible.
138 | """
139 |
140 | def __init__(self,
141 | axis,
142 | create_scale: bool,
143 | create_offset: bool,
144 | eps: float = 1e-5,
145 | scale_init=None,
146 | offset_init=None,
147 | use_fast_variance: bool = False,
148 | name=None,
149 | param_axis=None):
150 | super().__init__(
151 | axis=axis,
152 | create_scale=False,
153 | create_offset=False,
154 | eps=eps,
155 | scale_init=None,
156 | offset_init=None,
157 | use_fast_variance=use_fast_variance,
158 | name=name,
159 | param_axis=param_axis)
160 | self._temp_create_scale = create_scale
161 | self._temp_create_offset = create_offset
162 |
163 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
164 | is_bf16 = (x.dtype == jnp.bfloat16)
165 | if is_bf16:
166 | x = x.astype(jnp.float32)
167 |
168 | param_axis = self.param_axis[0] if self.param_axis else -1
169 | param_shape = (x.shape[param_axis],)
170 |
171 | param_broadcast_shape = [1] * x.ndim
172 | param_broadcast_shape[param_axis] = x.shape[param_axis]
173 | scale = None
174 | offset = None
175 | if self._temp_create_scale:
176 | scale = hk.get_parameter(
177 | 'scale', param_shape, x.dtype, init=self.scale_init)
178 | scale = scale.reshape(param_broadcast_shape)
179 |
180 | if self._temp_create_offset:
181 | offset = hk.get_parameter(
182 | 'offset', param_shape, x.dtype, init=self.offset_init)
183 | offset = offset.reshape(param_broadcast_shape)
184 |
185 | out = super().__call__(x, scale=scale, offset=offset)
186 |
187 | if is_bf16:
188 | out = out.astype(jnp.bfloat16)
189 |
190 | return out
191 |
--------------------------------------------------------------------------------
/src/alphafold/model/data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Convenience functions for reading data."""
16 |
17 | import io
18 | import os
19 | from alphafold.model import utils
20 | import haiku as hk
21 | import numpy as np
22 | # Internal import (7716).
23 |
24 |
25 | def get_model_haiku_params(model_name: str, data_dir: str) -> hk.Params:
26 | """Get the Haiku parameters from a model name."""
27 |
28 | path = os.path.join(data_dir, 'params', f'params_{model_name}.npz')
29 |
30 | with open(path, 'rb') as f:
31 | params = np.load(io.BytesIO(f.read()), allow_pickle=False)
32 |
33 | return utils.flat_params_to_haiku(params)
34 |
--------------------------------------------------------------------------------
/src/alphafold/model/features.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Code to generate processed features."""
16 | import copy
17 | from typing import List, Mapping, Tuple
18 |
19 | from alphafold.model.tf import input_pipeline
20 | from alphafold.model.tf import proteins_dataset
21 |
22 | import ml_collections
23 | import numpy as np
24 | import tensorflow.compat.v1 as tf
25 |
26 | FeatureDict = Mapping[str, np.ndarray]
27 |
28 |
29 | def make_data_config(
30 | config: ml_collections.ConfigDict,
31 | num_res: int,
32 | ) -> Tuple[ml_collections.ConfigDict, List[str]]:
33 | """Makes a data config for the input pipeline."""
34 | cfg = copy.deepcopy(config.data)
35 |
36 | feature_names = cfg.common.unsupervised_features
37 | if cfg.common.use_templates:
38 | feature_names += cfg.common.template_features
39 |
40 | with cfg.unlocked():
41 | cfg.eval.crop_size = num_res
42 |
43 | return cfg, feature_names
44 |
45 |
46 | def tf_example_to_features(tf_example: tf.train.Example,
47 | config: ml_collections.ConfigDict,
48 | random_seed: int = 0) -> FeatureDict:
49 | """Converts tf_example to numpy feature dictionary."""
50 | num_res = int(tf_example.features.feature['seq_length'].int64_list.value[0])
51 | cfg, feature_names = make_data_config(config, num_res=num_res)
52 |
53 | if 'deletion_matrix_int' in set(tf_example.features.feature):
54 | deletion_matrix_int = (
55 | tf_example.features.feature['deletion_matrix_int'].int64_list.value)
56 | feat = tf.train.Feature(float_list=tf.train.FloatList(
57 | value=map(float, deletion_matrix_int)))
58 | tf_example.features.feature['deletion_matrix'].CopyFrom(feat)
59 | del tf_example.features.feature['deletion_matrix_int']
60 |
61 | tf_graph = tf.Graph()
62 | with tf_graph.as_default(), tf.device('/device:CPU:0'):
63 | tf.compat.v1.set_random_seed(random_seed)
64 | tensor_dict = proteins_dataset.create_tensor_dict(
65 | raw_data=tf_example.SerializeToString(),
66 | features=feature_names)
67 | processed_batch = input_pipeline.process_tensors_from_config(
68 | tensor_dict, cfg)
69 |
70 | tf_graph.finalize()
71 |
72 | with tf.Session(graph=tf_graph) as sess:
73 | features = sess.run(processed_batch)
74 |
75 | return {k: v for k, v in features.items() if v.dtype != 'O'}
76 |
77 |
78 | def np_example_to_features(np_example: FeatureDict,
79 | config: ml_collections.ConfigDict,
80 | random_seed: int = 0) -> FeatureDict:
81 | """Preprocesses NumPy feature dict using TF pipeline."""
82 | np_example = dict(np_example)
83 | num_res = int(np_example['seq_length'][0])
84 | cfg, feature_names = make_data_config(config, num_res=num_res)
85 |
86 | if 'deletion_matrix_int' in np_example:
87 | np_example['deletion_matrix'] = (
88 | np_example.pop('deletion_matrix_int').astype(np.float32))
89 |
90 | tf_graph = tf.Graph()
91 | with tf_graph.as_default(), tf.device('/device:CPU:0'):
92 | tf.compat.v1.set_random_seed(random_seed)
93 | tensor_dict = proteins_dataset.np_to_tensor_dict(
94 | np_example=np_example, features=feature_names)
95 |
96 | processed_batch = input_pipeline.process_tensors_from_config(
97 | tensor_dict, cfg)
98 |
99 | tf_graph.finalize()
100 |
101 | with tf.Session(graph=tf_graph) as sess:
102 | features = sess.run(processed_batch)
103 |
104 | return {k: v for k, v in features.items() if v.dtype != 'O'}
105 |
--------------------------------------------------------------------------------
/src/alphafold/model/geometry/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Geometry Module."""
15 |
16 | from alphafold.model.geometry import rigid_matrix_vector
17 | from alphafold.model.geometry import rotation_matrix
18 | from alphafold.model.geometry import struct_of_array
19 | from alphafold.model.geometry import vector
20 |
21 | Rot3Array = rotation_matrix.Rot3Array
22 | Rigid3Array = rigid_matrix_vector.Rigid3Array
23 |
24 | StructOfArray = struct_of_array.StructOfArray
25 |
26 | Vec3Array = vector.Vec3Array
27 | square_euclidean_distance = vector.square_euclidean_distance
28 | euclidean_distance = vector.euclidean_distance
29 | dihedral_angle = vector.dihedral_angle
30 | dot = vector.dot
31 | cross = vector.cross
32 |
--------------------------------------------------------------------------------
/src/alphafold/model/geometry/rigid_matrix_vector.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Rigid3Array Transformations represented by a Matrix and a Vector."""
15 |
16 | from __future__ import annotations
17 | from typing import Union
18 |
19 | from alphafold.model.geometry import rotation_matrix
20 | from alphafold.model.geometry import struct_of_array
21 | from alphafold.model.geometry import vector
22 | import jax
23 | import jax.numpy as jnp
24 |
25 | Float = Union[float, jnp.ndarray]
26 |
27 | VERSION = '0.1'
28 |
29 |
30 | @struct_of_array.StructOfArray(same_dtype=True)
31 | class Rigid3Array:
32 | """Rigid Transformation, i.e. element of special euclidean group."""
33 |
34 | rotation: rotation_matrix.Rot3Array
35 | translation: vector.Vec3Array
36 |
37 | def __matmul__(self, other: Rigid3Array) -> Rigid3Array:
38 | new_rotation = self.rotation @ other.rotation
39 | new_translation = self.apply_to_point(other.translation)
40 | return Rigid3Array(new_rotation, new_translation)
41 |
42 | def inverse(self) -> Rigid3Array:
43 | """Return Rigid3Array corresponding to inverse transform."""
44 | inv_rotation = self.rotation.inverse()
45 | inv_translation = inv_rotation.apply_to_point(-self.translation)
46 | return Rigid3Array(inv_rotation, inv_translation)
47 |
48 | def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
49 | """Apply Rigid3Array transform to point."""
50 | return self.rotation.apply_to_point(point) + self.translation
51 |
52 | def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
53 | """Apply inverse Rigid3Array transform to point."""
54 | new_point = point - self.translation
55 | return self.rotation.apply_inverse_to_point(new_point)
56 |
57 | def compose_rotation(self, other_rotation):
58 | rot = self.rotation @ other_rotation
59 | trans = jax.tree_map(lambda x: jnp.broadcast_to(x, rot.shape),
60 | self.translation)
61 | return Rigid3Array(rot, trans)
62 |
63 | @classmethod
64 | def identity(cls, shape, dtype=jnp.float32) -> Rigid3Array:
65 | """Return identity Rigid3Array of given shape."""
66 | return cls(
67 | rotation_matrix.Rot3Array.identity(shape, dtype=dtype),
68 | vector.Vec3Array.zeros(shape, dtype=dtype)) # pytype: disable=wrong-arg-count # trace-all-classes
69 |
70 | def scale_translation(self, factor: Float) -> Rigid3Array:
71 | """Scale translation in Rigid3Array by 'factor'."""
72 | return Rigid3Array(self.rotation, self.translation * factor)
73 |
74 | def to_array(self):
75 | rot_array = self.rotation.to_array()
76 | vec_array = self.translation.to_array()
77 | return jnp.concatenate([rot_array, vec_array[..., None]], axis=-1)
78 |
79 | @classmethod
80 | def from_array(cls, array):
81 | rot = rotation_matrix.Rot3Array.from_array(array[..., :3])
82 | vec = vector.Vec3Array.from_array(array[..., -1])
83 | return cls(rot, vec) # pytype: disable=wrong-arg-count # trace-all-classes
84 |
85 | @classmethod
86 | def from_array4x4(cls, array: jnp.ndarray) -> Rigid3Array:
87 | """Construct Rigid3Array from homogeneous 4x4 array."""
88 | assert array.shape[-1] == 4
89 | assert array.shape[-2] == 4
90 | rotation = rotation_matrix.Rot3Array(
91 | array[..., 0, 0], array[..., 0, 1], array[..., 0, 2],
92 | array[..., 1, 0], array[..., 1, 1], array[..., 1, 2],
93 | array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]
94 | )
95 | translation = vector.Vec3Array(
96 | array[..., 0, 3], array[..., 1, 3], array[..., 2, 3])
97 | return cls(rotation, translation) # pytype: disable=wrong-arg-count # trace-all-classes
98 |
99 | def __getstate__(self):
100 | return (VERSION, (self.rotation, self.translation))
101 |
102 | def __setstate__(self, state):
103 | version, (rot, trans) = state
104 | del version
105 | object.__setattr__(self, 'rotation', rot)
106 | object.__setattr__(self, 'translation', trans)
107 |
--------------------------------------------------------------------------------
/src/alphafold/model/geometry/rotation_matrix.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Rot3Array Matrix Class."""
15 |
16 | from __future__ import annotations
17 | import dataclasses
18 |
19 | from alphafold.model.geometry import struct_of_array
20 | from alphafold.model.geometry import utils
21 | from alphafold.model.geometry import vector
22 | import jax
23 | import jax.numpy as jnp
24 | import numpy as np
25 |
26 | COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz']
27 |
28 | VERSION = '0.1'
29 |
30 |
31 | @struct_of_array.StructOfArray(same_dtype=True)
32 | class Rot3Array:
33 | """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
34 |
35 | xx: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32})
36 | xy: jnp.ndarray
37 | xz: jnp.ndarray
38 | yx: jnp.ndarray
39 | yy: jnp.ndarray
40 | yz: jnp.ndarray
41 | zx: jnp.ndarray
42 | zy: jnp.ndarray
43 | zz: jnp.ndarray
44 |
45 | __array_ufunc__ = None
46 |
47 | def inverse(self) -> Rot3Array:
48 | """Returns inverse of Rot3Array."""
49 | return Rot3Array(self.xx, self.yx, self.zx,
50 | self.xy, self.yy, self.zy,
51 | self.xz, self.yz, self.zz)
52 |
53 | def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
54 | """Applies Rot3Array to point."""
55 | return vector.Vec3Array(
56 | self.xx * point.x + self.xy * point.y + self.xz * point.z,
57 | self.yx * point.x + self.yy * point.y + self.yz * point.z,
58 | self.zx * point.x + self.zy * point.y + self.zz * point.z)
59 |
60 | def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
61 | """Applies inverse Rot3Array to point."""
62 | return self.inverse().apply_to_point(point)
63 |
64 | def __matmul__(self, other: Rot3Array) -> Rot3Array:
65 | """Composes two Rot3Arrays."""
66 | c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx))
67 | c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy))
68 | c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz))
69 | return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
70 |
71 | @classmethod
72 | def identity(cls, shape, dtype=jnp.float32) -> Rot3Array:
73 | """Returns identity of given shape."""
74 | ones = jnp.ones(shape, dtype=dtype)
75 | zeros = jnp.zeros(shape, dtype=dtype)
76 | return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) # pytype: disable=wrong-arg-count # trace-all-classes
77 |
78 | @classmethod
79 | def from_two_vectors(cls, e0: vector.Vec3Array,
80 | e1: vector.Vec3Array) -> Rot3Array:
81 | """Construct Rot3Array from two Vectors.
82 |
83 | Rot3Array is constructed such that in the corresponding frame 'e0' lies on
84 | the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.
85 |
86 | Args:
87 | e0: Vector
88 | e1: Vector
89 | Returns:
90 | Rot3Array
91 | """
92 | # Normalize the unit vector for the x-axis, e0.
93 | e0 = e0.normalized()
94 | # make e1 perpendicular to e0.
95 | c = e1.dot(e0)
96 | e1 = (e1 - c * e0).normalized()
97 | # Compute e2 as cross product of e0 and e1.
98 | e2 = e0.cross(e1)
99 | return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) # pytype: disable=wrong-arg-count # trace-all-classes
100 |
101 | @classmethod
102 | def from_array(cls, array: jnp.ndarray) -> Rot3Array:
103 | """Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
104 | unstacked = utils.unstack(array, axis=-2)
105 | unstacked = sum([utils.unstack(x, axis=-1) for x in unstacked], [])
106 | return cls(*unstacked)
107 |
108 | def to_array(self) -> jnp.ndarray:
109 | """Convert Rot3Array to array of shape [..., 3, 3]."""
110 | return jnp.stack(
111 | [jnp.stack([self.xx, self.xy, self.xz], axis=-1),
112 | jnp.stack([self.yx, self.yy, self.yz], axis=-1),
113 | jnp.stack([self.zx, self.zy, self.zz], axis=-1)],
114 | axis=-2)
115 |
116 | @classmethod
117 | def from_quaternion(cls,
118 | w: jnp.ndarray,
119 | x: jnp.ndarray,
120 | y: jnp.ndarray,
121 | z: jnp.ndarray,
122 | normalize: bool = True,
123 | epsilon: float = 1e-6) -> Rot3Array:
124 | """Construct Rot3Array from components of quaternion."""
125 | if normalize:
126 | inv_norm = jax.lax.rsqrt(jnp.maximum(epsilon, w**2 + x**2 + y**2 + z**2))
127 | w *= inv_norm
128 | x *= inv_norm
129 | y *= inv_norm
130 | z *= inv_norm
131 | xx = 1 - 2 * (jnp.square(y) + jnp.square(z))
132 | xy = 2 * (x * y - w * z)
133 | xz = 2 * (x * z + w * y)
134 | yx = 2 * (x * y + w * z)
135 | yy = 1 - 2 * (jnp.square(x) + jnp.square(z))
136 | yz = 2 * (y * z - w * x)
137 | zx = 2 * (x * z - w * y)
138 | zy = 2 * (y * z + w * x)
139 | zz = 1 - 2 * (jnp.square(x) + jnp.square(y))
140 | return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) # pytype: disable=wrong-arg-count # trace-all-classes
141 |
142 | @classmethod
143 | def random_uniform(cls, key, shape, dtype=jnp.float32) -> Rot3Array:
144 | """Samples uniform random Rot3Array according to Haar Measure."""
145 | quat_array = jax.random.normal(key, tuple(shape) + (4,), dtype=dtype)
146 | quats = utils.unstack(quat_array)
147 | return cls.from_quaternion(*quats)
148 |
149 | def __getstate__(self):
150 | return (VERSION,
151 | [np.asarray(getattr(self, field)) for field in COMPONENTS])
152 |
153 | def __setstate__(self, state):
154 | version, state = state
155 | del version
156 | for i, field in enumerate(COMPONENTS):
157 | object.__setattr__(self, field, state[i])
158 |
--------------------------------------------------------------------------------
/src/alphafold/model/geometry/struct_of_array.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Class decorator to represent (nested) struct of arrays."""
15 |
16 | import dataclasses
17 |
18 | import jax
19 |
20 |
21 | def get_item(instance, key):
22 | sliced = {}
23 | for field in get_array_fields(instance):
24 | num_trailing_dims = field.metadata.get('num_trailing_dims', 0)
25 | this_key = key
26 | if isinstance(key, tuple) and Ellipsis in this_key:
27 | this_key += (slice(None),) * num_trailing_dims
28 | sliced[field.name] = getattr(instance, field.name)[this_key]
29 | return dataclasses.replace(instance, **sliced)
30 |
31 |
32 | @property
33 | def get_shape(instance):
34 | """Returns Shape for given instance of dataclass."""
35 | first_field = dataclasses.fields(instance)[0]
36 | num_trailing_dims = first_field.metadata.get('num_trailing_dims', None)
37 | value = getattr(instance, first_field.name)
38 | if num_trailing_dims:
39 | return value.shape[:-num_trailing_dims]
40 | else:
41 | return value.shape
42 |
43 |
44 | def get_len(instance):
45 | """Returns length for given instance of dataclass."""
46 | shape = instance.shape
47 | if shape:
48 | return shape[0]
49 | else:
50 | raise TypeError('len() of unsized object') # Match jax.numpy behavior.
51 |
52 |
53 | @property
54 | def get_dtype(instance):
55 | """Returns Dtype for given instance of dataclass."""
56 | fields = dataclasses.fields(instance)
57 | sets_dtype = [
58 | field.name for field in fields if field.metadata.get('sets_dtype', False)
59 | ]
60 | if sets_dtype:
61 | assert len(sets_dtype) == 1, 'at most field can set dtype'
62 | field_value = getattr(instance, sets_dtype[0])
63 | elif instance.same_dtype:
64 | field_value = getattr(instance, fields[0].name)
65 | else:
66 | # Should this be Value Error?
67 | raise AttributeError('Trying to access Dtype on Struct of Array without'
68 | 'either "same_dtype" or field setting dtype')
69 |
70 | if hasattr(field_value, 'dtype'):
71 | return field_value.dtype
72 | else:
73 | # Should this be Value Error?
74 | raise AttributeError(f'field_value {field_value} does not have dtype')
75 |
76 |
77 | def replace(instance, **kwargs):
78 | return dataclasses.replace(instance, **kwargs)
79 |
80 |
81 | def post_init(instance):
82 | """Validate instance has same shapes & dtypes."""
83 | array_fields = get_array_fields(instance)
84 | arrays = list(get_array_fields(instance, return_values=True).values())
85 | first_field = array_fields[0]
86 | # These slightly weird constructions about checking whether the leaves are
87 | # actual arrays is since e.g. vmap internally relies on being able to
88 | # construct pytree's with object() as leaves, this would break the checking
89 | # as such we are only validating the object when the entries in the dataclass
90 | # Are arrays or other dataclasses of arrays.
91 | try:
92 | dtype = instance.dtype
93 | except AttributeError:
94 | dtype = None
95 | if dtype is not None:
96 | first_shape = instance.shape
97 | for array, field in zip(arrays, array_fields):
98 | field_shape = array.shape
99 | num_trailing_dims = field.metadata.get('num_trailing_dims', None)
100 | if num_trailing_dims:
101 | array_shape = array.shape
102 | field_shape = array_shape[:-num_trailing_dims]
103 | msg = (f'field {field} should have number of trailing dims'
104 | ' {num_trailing_dims}')
105 | assert len(array_shape) == len(first_shape) + num_trailing_dims, msg
106 | else:
107 | field_shape = array.shape
108 |
109 | shape_msg = (f"Stripped Shape {field_shape} of field {field} doesn't "
110 | f"match shape {first_shape} of field {first_field}")
111 | assert field_shape == first_shape, shape_msg
112 |
113 | field_dtype = array.dtype
114 |
115 | allowed_metadata_dtypes = field.metadata.get('allowed_dtypes', [])
116 | if allowed_metadata_dtypes:
117 | msg = f'Dtype is {field_dtype} but must be in {allowed_metadata_dtypes}'
118 | assert field_dtype in allowed_metadata_dtypes, msg
119 |
120 | if 'dtype' in field.metadata:
121 | target_dtype = field.metadata['dtype']
122 | else:
123 | target_dtype = dtype
124 |
125 | msg = f'Dtype is {field_dtype} but must be {target_dtype}'
126 | assert field_dtype == target_dtype, msg
127 |
128 |
129 | def flatten(instance):
130 | """Flatten Struct of Array instance."""
131 | array_likes = list(get_array_fields(instance, return_values=True).values())
132 | flat_array_likes = []
133 | inner_treedefs = []
134 | num_arrays = []
135 | for array_like in array_likes:
136 | flat_array_like, inner_treedef = jax.tree_util.tree_flatten(array_like)
137 | inner_treedefs.append(inner_treedef)
138 | flat_array_likes += flat_array_like
139 | num_arrays.append(len(flat_array_like))
140 | metadata = get_metadata_fields(instance, return_values=True)
141 | metadata = type(instance).metadata_cls(**metadata)
142 | return flat_array_likes, (inner_treedefs, metadata, num_arrays)
143 |
144 |
145 | def make_metadata_class(cls):
146 | metadata_fields = get_fields(cls,
147 | lambda x: x.metadata.get('is_metadata', False))
148 | metadata_cls = dataclasses.make_dataclass(
149 | cls_name='Meta' + cls.__name__,
150 | fields=[(field.name, field.type, field) for field in metadata_fields],
151 | frozen=True,
152 | eq=True)
153 | return metadata_cls
154 |
155 |
156 | def get_fields(cls_or_instance, filterfn, return_values=False):
157 | fields = dataclasses.fields(cls_or_instance)
158 | fields = [field for field in fields if filterfn(field)]
159 | if return_values:
160 | return {
161 | field.name: getattr(cls_or_instance, field.name) for field in fields
162 | }
163 | else:
164 | return fields
165 |
166 |
167 | def get_array_fields(cls, return_values=False):
168 | return get_fields(
169 | cls,
170 | lambda x: not x.metadata.get('is_metadata', False),
171 | return_values=return_values)
172 |
173 |
174 | def get_metadata_fields(cls, return_values=False):
175 | return get_fields(
176 | cls,
177 | lambda x: x.metadata.get('is_metadata', False),
178 | return_values=return_values)
179 |
180 |
181 | class StructOfArray:
182 | """Class Decorator for Struct Of Arrays."""
183 |
184 | def __init__(self, same_dtype=True):
185 | self.same_dtype = same_dtype
186 |
187 | def __call__(self, cls):
188 | cls.__array_ufunc__ = None
189 | cls.replace = replace
190 | cls.same_dtype = self.same_dtype
191 | cls.dtype = get_dtype
192 | cls.shape = get_shape
193 | cls.__len__ = get_len
194 | cls.__getitem__ = get_item
195 | cls.__post_init__ = post_init
196 | new_cls = dataclasses.dataclass(cls, frozen=True, eq=False) # pytype: disable=wrong-keyword-args
197 | # pytree claims to require metadata to be hashable, not sure why,
198 | # But making derived dataclass that can just hold metadata
199 | new_cls.metadata_cls = make_metadata_class(new_cls)
200 |
201 | def unflatten(aux, data):
202 | inner_treedefs, metadata, num_arrays = aux
203 | array_fields = [field.name for field in get_array_fields(new_cls)]
204 | value_dict = {}
205 | array_start = 0
206 | for num_array, inner_treedef, array_field in zip(num_arrays,
207 | inner_treedefs,
208 | array_fields):
209 | value_dict[array_field] = jax.tree_util.tree_unflatten(
210 | inner_treedef, data[array_start:array_start + num_array])
211 | array_start += num_array
212 | metadata_fields = get_metadata_fields(new_cls)
213 | for field in metadata_fields:
214 | value_dict[field.name] = getattr(metadata, field.name)
215 |
216 | return new_cls(**value_dict)
217 |
218 | jax.tree_util.register_pytree_node(
219 | nodetype=new_cls, flatten_func=flatten, unflatten_func=unflatten)
220 | return new_cls
221 |
--------------------------------------------------------------------------------
/src/alphafold/model/geometry/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Shared utils for tests."""
15 |
16 | import dataclasses
17 |
18 | from alphafold.model.geometry import rigid_matrix_vector
19 | from alphafold.model.geometry import rotation_matrix
20 | from alphafold.model.geometry import vector
21 | import jax.numpy as jnp
22 | import numpy as np
23 |
24 |
25 | def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array,
26 | matrix2: rotation_matrix.Rot3Array):
27 | for field in dataclasses.fields(rotation_matrix.Rot3Array):
28 | field = field.name
29 | np.testing.assert_array_equal(
30 | getattr(matrix1, field), getattr(matrix2, field))
31 |
32 |
33 | def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array,
34 | mat2: rotation_matrix.Rot3Array):
35 | np.testing.assert_array_almost_equal(mat1.to_array(), mat2.to_array(), 6)
36 |
37 |
38 | def assert_array_equal_to_rotation_matrix(array: jnp.ndarray,
39 | matrix: rotation_matrix.Rot3Array):
40 | """Check that array and Matrix match."""
41 | np.testing.assert_array_equal(matrix.xx, array[..., 0, 0])
42 | np.testing.assert_array_equal(matrix.xy, array[..., 0, 1])
43 | np.testing.assert_array_equal(matrix.xz, array[..., 0, 2])
44 | np.testing.assert_array_equal(matrix.yx, array[..., 1, 0])
45 | np.testing.assert_array_equal(matrix.yy, array[..., 1, 1])
46 | np.testing.assert_array_equal(matrix.yz, array[..., 1, 2])
47 | np.testing.assert_array_equal(matrix.zx, array[..., 2, 0])
48 | np.testing.assert_array_equal(matrix.zy, array[..., 2, 1])
49 | np.testing.assert_array_equal(matrix.zz, array[..., 2, 2])
50 |
51 |
52 | def assert_array_close_to_rotation_matrix(array: jnp.ndarray,
53 | matrix: rotation_matrix.Rot3Array):
54 | np.testing.assert_array_almost_equal(matrix.to_array(), array, 6)
55 |
56 |
57 | def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
58 | np.testing.assert_array_equal(vec1.x, vec2.x)
59 | np.testing.assert_array_equal(vec1.y, vec2.y)
60 | np.testing.assert_array_equal(vec1.z, vec2.z)
61 |
62 |
63 | def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array):
64 | np.testing.assert_allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.)
65 | np.testing.assert_allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.)
66 | np.testing.assert_allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.)
67 |
68 |
69 | def assert_array_close_to_vector(array: jnp.ndarray, vec: vector.Vec3Array):
70 | np.testing.assert_allclose(vec.to_array(), array, atol=1e-6, rtol=0.)
71 |
72 |
73 | def assert_array_equal_to_vector(array: jnp.ndarray, vec: vector.Vec3Array):
74 | np.testing.assert_array_equal(vec.to_array(), array)
75 |
76 |
77 | def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
78 | rigid2: rigid_matrix_vector.Rigid3Array):
79 | assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
80 |
81 |
82 | def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array,
83 | rigid2: rigid_matrix_vector.Rigid3Array):
84 | assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2)
85 |
86 |
87 | def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array,
88 | trans: vector.Vec3Array,
89 | rigid: rigid_matrix_vector.Rigid3Array):
90 | assert_rotation_matrix_equal(rot, rigid.rotation)
91 | assert_vectors_equal(trans, rigid.translation)
92 |
93 |
94 | def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array,
95 | trans: vector.Vec3Array,
96 | rigid: rigid_matrix_vector.Rigid3Array):
97 | assert_rotation_matrix_close(rot, rigid.rotation)
98 | assert_vectors_close(trans, rigid.translation)
99 |
--------------------------------------------------------------------------------
/src/alphafold/model/geometry/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utils for geometry library."""
15 |
16 | from typing import List
17 |
18 | import jax.numpy as jnp
19 |
20 |
21 | def unstack(value: jnp.ndarray, axis: int = -1) -> List[jnp.ndarray]:
22 | return [jnp.squeeze(v, axis=axis)
23 | for v in jnp.split(value, value.shape[axis], axis=axis)]
24 |
--------------------------------------------------------------------------------
/src/alphafold/model/geometry/vector.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Vec3Array Class."""
15 |
16 | from __future__ import annotations
17 | import dataclasses
18 | from typing import Union
19 |
20 | from alphafold.model.geometry import struct_of_array
21 | from alphafold.model.geometry import utils
22 | import jax
23 | import jax.numpy as jnp
24 | import numpy as np
25 |
26 | Float = Union[float, jnp.ndarray]
27 |
28 | VERSION = '0.1'
29 |
30 |
31 | @struct_of_array.StructOfArray(same_dtype=True)
32 | class Vec3Array:
33 | """Vec3Array in 3 dimensional Space implemented as struct of arrays.
34 |
35 | This is done in order to improve performance and precision.
36 | On TPU small matrix multiplications are very suboptimal and will waste large
37 | compute ressources, furthermore any matrix multiplication on tpu happen in
38 | mixed bfloat16/float32 precision, which is often undesirable when handling
39 | physical coordinates.
40 | In most cases this will also be faster on cpu's/gpu's since it allows for
41 | easier use of vector instructions.
42 | """
43 |
44 | x: jnp.ndarray = dataclasses.field(metadata={'dtype': jnp.float32})
45 | y: jnp.ndarray
46 | z: jnp.ndarray
47 |
48 | def __post_init__(self):
49 | if hasattr(self.x, 'dtype'):
50 | assert self.x.dtype == self.y.dtype
51 | assert self.x.dtype == self.z.dtype
52 | assert all([x == y for x, y in zip(self.x.shape, self.y.shape)])
53 | assert all([x == z for x, z in zip(self.x.shape, self.z.shape)])
54 |
55 | def __add__(self, other: Vec3Array) -> Vec3Array:
56 | return jax.tree_map(lambda x, y: x + y, self, other)
57 |
58 | def __sub__(self, other: Vec3Array) -> Vec3Array:
59 | return jax.tree_map(lambda x, y: x - y, self, other)
60 |
61 | def __mul__(self, other: Float) -> Vec3Array:
62 | return jax.tree_map(lambda x: x * other, self)
63 |
64 | def __rmul__(self, other: Float) -> Vec3Array:
65 | return self * other
66 |
67 | def __truediv__(self, other: Float) -> Vec3Array:
68 | return jax.tree_map(lambda x: x / other, self)
69 |
70 | def __neg__(self) -> Vec3Array:
71 | return jax.tree_map(lambda x: -x, self)
72 |
73 | def __pos__(self) -> Vec3Array:
74 | return jax.tree_map(lambda x: x, self)
75 |
76 | def cross(self, other: Vec3Array) -> Vec3Array:
77 | """Compute cross product between 'self' and 'other'."""
78 | new_x = self.y * other.z - self.z * other.y
79 | new_y = self.z * other.x - self.x * other.z
80 | new_z = self.x * other.y - self.y * other.x
81 | return Vec3Array(new_x, new_y, new_z)
82 |
83 | def dot(self, other: Vec3Array) -> Float:
84 | """Compute dot product between 'self' and 'other'."""
85 | return self.x * other.x + self.y * other.y + self.z * other.z
86 |
87 | def norm(self, epsilon: float = 1e-6) -> Float:
88 | """Compute Norm of Vec3Array, clipped to epsilon."""
89 | # To avoid NaN on the backward pass, we must use maximum before the sqrt
90 | norm2 = self.dot(self)
91 | if epsilon:
92 | norm2 = jnp.maximum(norm2, epsilon**2)
93 | return jnp.sqrt(norm2)
94 |
95 | def norm2(self):
96 | return self.dot(self)
97 |
98 | def normalized(self, epsilon: float = 1e-6) -> Vec3Array:
99 | """Return unit vector with optional clipping."""
100 | return self / self.norm(epsilon)
101 |
102 | @classmethod
103 | def zeros(cls, shape, dtype=jnp.float32):
104 | """Return Vec3Array corresponding to zeros of given shape."""
105 | return cls(
106 | jnp.zeros(shape, dtype), jnp.zeros(shape, dtype),
107 | jnp.zeros(shape, dtype)) # pytype: disable=wrong-arg-count # trace-all-classes
108 |
109 | def to_array(self) -> jnp.ndarray:
110 | return jnp.stack([self.x, self.y, self.z], axis=-1)
111 |
112 | @classmethod
113 | def from_array(cls, array):
114 | return cls(*utils.unstack(array))
115 |
116 | def __getstate__(self):
117 | return (VERSION,
118 | [np.asarray(self.x),
119 | np.asarray(self.y),
120 | np.asarray(self.z)])
121 |
122 | def __setstate__(self, state):
123 | version, state = state
124 | del version
125 | for i, letter in enumerate('xyz'):
126 | object.__setattr__(self, letter, state[i])
127 |
128 |
129 | def square_euclidean_distance(vec1: Vec3Array,
130 | vec2: Vec3Array,
131 | epsilon: float = 1e-6) -> Float:
132 | """Computes square of euclidean distance between 'vec1' and 'vec2'.
133 |
134 | Args:
135 | vec1: Vec3Array to compute distance to
136 | vec2: Vec3Array to compute distance from, should be
137 | broadcast compatible with 'vec1'
138 | epsilon: distance is clipped from below to be at least epsilon
139 |
140 | Returns:
141 | Array of square euclidean distances;
142 | shape will be result of broadcasting 'vec1' and 'vec2'
143 | """
144 | difference = vec1 - vec2
145 | distance = difference.dot(difference)
146 | if epsilon:
147 | distance = jnp.maximum(distance, epsilon)
148 | return distance
149 |
150 |
151 | def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float:
152 | return vector1.dot(vector2)
153 |
154 |
155 | def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float:
156 | return vector1.cross(vector2)
157 |
158 |
159 | def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float:
160 | return vector.norm(epsilon)
161 |
162 |
163 | def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array:
164 | return vector.normalized(epsilon)
165 |
166 |
167 | def euclidean_distance(vec1: Vec3Array,
168 | vec2: Vec3Array,
169 | epsilon: float = 1e-6) -> Float:
170 | """Computes euclidean distance between 'vec1' and 'vec2'.
171 |
172 | Args:
173 | vec1: Vec3Array to compute euclidean distance to
174 | vec2: Vec3Array to compute euclidean distance from, should be
175 | broadcast compatible with 'vec1'
176 | epsilon: distance is clipped from below to be at least epsilon
177 |
178 | Returns:
179 | Array of euclidean distances;
180 | shape will be result of broadcasting 'vec1' and 'vec2'
181 | """
182 | distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2)
183 | distance = jnp.sqrt(distance_sq)
184 | return distance
185 |
186 |
187 | def dihedral_angle(a: Vec3Array, b: Vec3Array, c: Vec3Array,
188 | d: Vec3Array) -> Float:
189 | """Computes torsion angle for a quadruple of points.
190 |
191 | For points (a, b, c, d), this is the angle between the planes defined by
192 | points (a, b, c) and (b, c, d). It is also known as the dihedral angle.
193 |
194 | Arguments:
195 | a: A Vec3Array of coordinates.
196 | b: A Vec3Array of coordinates.
197 | c: A Vec3Array of coordinates.
198 | d: A Vec3Array of coordinates.
199 |
200 | Returns:
201 | A tensor of angles in radians: [-pi, pi].
202 | """
203 | v1 = a - b
204 | v2 = b - c
205 | v3 = d - c
206 |
207 | c1 = v1.cross(v2)
208 | c2 = v3.cross(v2)
209 | c3 = c2.cross(c1)
210 |
211 | v2_mag = v2.norm()
212 | return jnp.arctan2(c3.dot(v2), v2_mag * c1.dot(c2))
213 |
214 |
215 | def random_gaussian_vector(shape, key, dtype=jnp.float32):
216 | vec_array = jax.random.normal(key, shape + (3,), dtype)
217 | return Vec3Array.from_array(vec_array)
218 |
--------------------------------------------------------------------------------
/src/alphafold/model/lddt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """lDDT protein distance score."""
16 | import jax.numpy as jnp
17 |
18 |
19 | def lddt(predicted_points,
20 | true_points,
21 | true_points_mask,
22 | cutoff=15.,
23 | per_residue=False):
24 | """Measure (approximate) lDDT for a batch of coordinates.
25 |
26 | lDDT reference:
27 | Mariani, V., Biasini, M., Barbato, A. & Schwede, T. lDDT: A local
28 | superposition-free score for comparing protein structures and models using
29 | distance difference tests. Bioinformatics 29, 2722–2728 (2013).
30 |
31 | lDDT is a measure of the difference between the true distance matrix and the
32 | distance matrix of the predicted points. The difference is computed only on
33 | points closer than cutoff *in the true structure*.
34 |
35 | This function does not compute the exact lDDT value that the original paper
36 | describes because it does not include terms for physical feasibility
37 | (e.g. bond length violations). Therefore this is only an approximate
38 | lDDT score.
39 |
40 | Args:
41 | predicted_points: (batch, length, 3) array of predicted 3D points
42 | true_points: (batch, length, 3) array of true 3D points
43 | true_points_mask: (batch, length, 1) binary-valued float array. This mask
44 | should be 1 for points that exist in the true points.
45 | cutoff: Maximum distance for a pair of points to be included
46 | per_residue: If true, return score for each residue. Note that the overall
47 | lDDT is not exactly the mean of the per_residue lDDT's because some
48 | residues have more contacts than others.
49 |
50 | Returns:
51 | An (approximate, see above) lDDT score in the range 0-1.
52 | """
53 |
54 | assert len(predicted_points.shape) == 3
55 | assert predicted_points.shape[-1] == 3
56 | assert true_points_mask.shape[-1] == 1
57 | assert len(true_points_mask.shape) == 3
58 |
59 | # Compute true and predicted distance matrices.
60 | dmat_true = jnp.sqrt(1e-10 + jnp.sum(
61 | (true_points[:, :, None] - true_points[:, None, :])**2, axis=-1))
62 |
63 | dmat_predicted = jnp.sqrt(1e-10 + jnp.sum(
64 | (predicted_points[:, :, None] -
65 | predicted_points[:, None, :])**2, axis=-1))
66 |
67 | dists_to_score = (
68 | (dmat_true < cutoff).astype(jnp.float32) * true_points_mask *
69 | jnp.transpose(true_points_mask, [0, 2, 1]) *
70 | (1. - jnp.eye(dmat_true.shape[1])) # Exclude self-interaction.
71 | )
72 |
73 | # Shift unscored distances to be far away.
74 | dist_l1 = jnp.abs(dmat_true - dmat_predicted)
75 |
76 | # True lDDT uses a number of fixed bins.
77 | # We ignore the physical plausibility correction to lDDT, though.
78 | score = 0.25 * ((dist_l1 < 0.5).astype(jnp.float32) +
79 | (dist_l1 < 1.0).astype(jnp.float32) +
80 | (dist_l1 < 2.0).astype(jnp.float32) +
81 | (dist_l1 < 4.0).astype(jnp.float32))
82 |
83 | # Normalize over the appropriate axes.
84 | reduce_axes = (-1,) if per_residue else (-2, -1)
85 | norm = 1. / (1e-10 + jnp.sum(dists_to_score, axis=reduce_axes))
86 | score = norm * (1e-10 + jnp.sum(dists_to_score * score, axis=reduce_axes))
87 |
88 | return score
89 |
--------------------------------------------------------------------------------
/src/alphafold/model/lddt_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for lddt."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from alphafold.model import lddt
20 | import numpy as np
21 |
22 |
23 | class LddtTest(parameterized.TestCase, absltest.TestCase):
24 |
25 | @parameterized.named_parameters(
26 | ('same',
27 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
28 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
29 | [1, 1, 1]),
30 | ('all_shifted',
31 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
32 | [[-1, 0, 0], [4, 0, 0], [9, 0, 0]],
33 | [1, 1, 1]),
34 | ('all_rotated',
35 | [[0, 0, 0], [5, 0, 0], [10, 0, 0]],
36 | [[0, 0, 0], [0, 5, 0], [0, 10, 0]],
37 | [1, 1, 1]),
38 | ('half_a_dist',
39 | [[0, 0, 0], [5, 0, 0]],
40 | [[0, 0, 0], [5.5-1e-5, 0, 0]],
41 | [1, 1]),
42 | ('one_a_dist',
43 | [[0, 0, 0], [5, 0, 0]],
44 | [[0, 0, 0], [6-1e-5, 0, 0]],
45 | [0.75, 0.75]),
46 | ('two_a_dist',
47 | [[0, 0, 0], [5, 0, 0]],
48 | [[0, 0, 0], [7-1e-5, 0, 0]],
49 | [0.5, 0.5]),
50 | ('four_a_dist',
51 | [[0, 0, 0], [5, 0, 0]],
52 | [[0, 0, 0], [9-1e-5, 0, 0]],
53 | [0.25, 0.25],),
54 | ('five_a_dist',
55 | [[0, 0, 0], [16-1e-5, 0, 0]],
56 | [[0, 0, 0], [11, 0, 0]],
57 | [0, 0]),
58 | ('no_pairs',
59 | [[0, 0, 0], [20, 0, 0]],
60 | [[0, 0, 0], [25-1e-5, 0, 0]],
61 | [1, 1]),
62 | )
63 | def test_lddt(
64 | self, predicted_pos, true_pos, exp_lddt):
65 | predicted_pos = np.array([predicted_pos], dtype=np.float32)
66 | true_points_mask = np.array([[[1]] * len(true_pos)], dtype=np.float32)
67 | true_pos = np.array([true_pos], dtype=np.float32)
68 | cutoff = 15.0
69 | per_residue = True
70 |
71 | result = lddt.lddt(
72 | predicted_pos, true_pos, true_points_mask, cutoff,
73 | per_residue)
74 |
75 | np.testing.assert_almost_equal(result, [exp_lddt], decimal=4)
76 |
77 |
78 | if __name__ == '__main__':
79 | absltest.main()
80 |
--------------------------------------------------------------------------------
/src/alphafold/model/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Code for constructing the model."""
16 | from typing import Any, Mapping, Optional, Union
17 |
18 | from absl import logging
19 | from alphafold.common import confidence
20 | from alphafold.model import features
21 | from alphafold.model import modules
22 | from alphafold.model import modules_multimer
23 | import haiku as hk
24 | import jax
25 | import ml_collections
26 | import numpy as np
27 | import tensorflow.compat.v1 as tf
28 | import pdb
29 | import tree
30 |
31 |
32 | def get_confidence_metrics(
33 | prediction_result: Mapping[str, Any],
34 | multimer_mode: bool) -> Mapping[str, Any]:
35 | """Post processes prediction_result to get confidence metrics."""
36 | confidence_metrics = {}
37 | confidence_metrics['plddt'] = confidence.compute_plddt(
38 | prediction_result['predicted_lddt']['logits'])
39 | if 'predicted_aligned_error' in prediction_result:
40 | confidence_metrics.update(confidence.compute_predicted_aligned_error(
41 | logits=prediction_result['predicted_aligned_error']['logits'],
42 | breaks=prediction_result['predicted_aligned_error']['breaks']))
43 | confidence_metrics['ptm'] = confidence.predicted_tm_score(
44 | logits=prediction_result['predicted_aligned_error']['logits'],
45 | breaks=prediction_result['predicted_aligned_error']['breaks'],
46 | asym_id=None)
47 | if multimer_mode:
48 | # Compute the ipTM only for the multimer model.
49 | confidence_metrics['iptm'] = confidence.predicted_tm_score(
50 | logits=prediction_result['predicted_aligned_error']['logits'],
51 | breaks=prediction_result['predicted_aligned_error']['breaks'],
52 | asym_id=prediction_result['predicted_aligned_error']['asym_id'],
53 | interface=True)
54 | confidence_metrics['ranking_confidence'] = (
55 | 0.8 * confidence_metrics['iptm'] + 0.2 * confidence_metrics['ptm'])
56 |
57 | if not multimer_mode:
58 | # Monomer models use mean pLDDT for model ranking.
59 | confidence_metrics['ranking_confidence'] = np.mean(
60 | confidence_metrics['plddt'])
61 |
62 | return confidence_metrics
63 |
64 |
65 | class RunModel:
66 | """Container for JAX model."""
67 |
68 | def __init__(self,
69 | config: ml_collections.ConfigDict,
70 | params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None,
71 | is_training=False):
72 | self.config = config
73 | self.params = params
74 | self.multimer_mode = config.model.global_config.multimer_mode
75 | self.is_training = is_training
76 |
77 | if self.multimer_mode:
78 | def _forward_fn(batch):
79 | model = modules_multimer.AlphaFold(self.config.model)
80 | return model(
81 | batch,
82 | is_training=self.is_training)
83 | else:
84 | def _forward_fn(batch):
85 | model = modules.AlphaFold(self.config.model)
86 | return model(
87 | batch,
88 | is_training=self.is_training,
89 | compute_loss=False,
90 | ensemble_representations=True)
91 |
92 | self.apply = jax.jit(hk.transform(_forward_fn).apply)
93 | self.init = jax.jit(hk.transform(_forward_fn).init)
94 |
95 | def init_params(self, feat: features.FeatureDict, random_seed: int = 0):
96 | """Initializes the model parameters.
97 |
98 | If none were provided when this class was instantiated then the parameters
99 | are randomly initialized.
100 |
101 | Args:
102 | feat: A dictionary of NumPy feature arrays as output by
103 | RunModel.process_features.
104 | random_seed: A random seed to use to initialize the parameters if none
105 | were set when this class was initialized.
106 | """
107 | if not self.params:
108 | # Init params randomly.
109 | rng = jax.random.PRNGKey(random_seed)
110 | self.params = hk.data_structures.to_mutable_dict(
111 | self.init(rng, feat))
112 | logging.warning('Initialized parameters randomly')
113 |
114 | def process_features(
115 | self,
116 | raw_features: Union[tf.train.Example, features.FeatureDict],
117 | random_seed: int) -> features.FeatureDict:
118 | """Processes features to prepare for feeding them into the model.
119 |
120 | Args:
121 | raw_features: The output of the data pipeline either as a dict of NumPy
122 | arrays or as a tf.train.Example.
123 | random_seed: The random seed to use when processing the features.
124 |
125 | Returns:
126 | A dict of NumPy feature arrays suitable for feeding into the model.
127 | """
128 |
129 | if self.multimer_mode:
130 | return raw_features
131 |
132 | # Single-chain mode.
133 | if isinstance(raw_features, dict):
134 | return features.np_example_to_features(
135 | np_example=raw_features,
136 | config=self.config,
137 | random_seed=random_seed)
138 | else:
139 | return features.tf_example_to_features(
140 | tf_example=raw_features,
141 | config=self.config,
142 | random_seed=random_seed)
143 |
144 | def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct:
145 | self.init_params(feat)
146 | logging.info('Running eval_shape with shape(feat) = %s',
147 | tree.map_structure(lambda x: x.shape, feat))
148 | shape = jax.eval_shape(self.apply, self.params, jax.random.PRNGKey(0), feat)
149 | logging.info('Output shape was %s', shape)
150 | return shape
151 |
152 | def predict(self,
153 | msa_params,
154 | feat: features.FeatureDict
155 | ) -> Mapping[str, Any]:
156 | """Makes a prediction by inferencing the model on the provided features.
157 |
158 | Args:
159 | feat: A dictionary of NumPy feature arrays as output by
160 | RunModel.process_features.
161 | random_seed: The random seed to use when running the model. In the
162 | multimer model this controls the MSA sampling.
163 |
164 | Returns:
165 | A dictionary of model outputs.
166 | """
167 |
168 | #Add the msa params
169 | feat['msa_feat_bias'] = msa_params
170 | result = self.apply(self.params, jax.random.PRNGKey(0), feat)
171 |
172 | # This block is to ensure benchmark timings are accurate. Some blocking is
173 | # already happening when computing get_confidence_metrics, and this ensures
174 | # all outputs are blocked on.
175 | #jax.tree_map(lambda x: x.block_until_ready(), result)
176 | #result.update(
177 | # get_confidence_metrics(result, multimer_mode=self.multimer_mode))
178 | return result
179 |
--------------------------------------------------------------------------------
/src/alphafold/model/prng.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A collection of utilities surrounding PRNG usage in protein folding."""
16 |
17 | import haiku as hk
18 | import jax
19 |
20 |
21 | def safe_dropout(*, tensor, safe_key, rate, is_deterministic, is_training):
22 | if is_training and rate != 0.0 and not is_deterministic:
23 | return hk.dropout(safe_key.get(), rate, tensor)
24 | else:
25 | return tensor
26 |
27 |
28 | class SafeKey:
29 | """Safety wrapper for PRNG keys."""
30 |
31 | def __init__(self, key):
32 | self._key = key
33 | self._used = False
34 |
35 | def _assert_not_used(self):
36 | if self._used:
37 | raise RuntimeError('Random key has been used previously.')
38 |
39 | def get(self):
40 | self._assert_not_used()
41 | self._used = True
42 | return self._key
43 |
44 | def split(self, num_keys=2):
45 | self._assert_not_used()
46 | self._used = True
47 | new_keys = jax.random.split(self._key, num_keys)
48 | return jax.tree_map(SafeKey, tuple(new_keys))
49 |
50 | def duplicate(self, num_keys=2):
51 | self._assert_not_used()
52 | self._used = True
53 | return tuple(SafeKey(self._key) for _ in range(num_keys))
54 |
55 |
56 | def _safe_key_flatten(safe_key):
57 | # Flatten transfers "ownership" to the tree
58 | return (safe_key._key,), safe_key._used # pylint: disable=protected-access
59 |
60 |
61 | def _safe_key_unflatten(aux_data, children):
62 | ret = SafeKey(children[0])
63 | ret._used = aux_data # pylint: disable=protected-access
64 | return ret
65 |
66 |
67 | jax.tree_util.register_pytree_node(
68 | SafeKey, _safe_key_flatten, _safe_key_unflatten)
69 |
70 |
--------------------------------------------------------------------------------
/src/alphafold/model/prng_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for prng."""
16 |
17 | from absl.testing import absltest
18 | from alphafold.model import prng
19 | import jax
20 |
21 |
22 | class PrngTest(absltest.TestCase):
23 |
24 | def test_key_reuse(self):
25 |
26 | init_key = jax.random.PRNGKey(42)
27 | safe_key = prng.SafeKey(init_key)
28 | _, safe_key = safe_key.split()
29 |
30 | raw_key = safe_key.get()
31 |
32 | self.assertNotEqual(raw_key[0], init_key[0])
33 | self.assertNotEqual(raw_key[1], init_key[1])
34 |
35 | with self.assertRaises(RuntimeError):
36 | safe_key.get()
37 |
38 | with self.assertRaises(RuntimeError):
39 | safe_key.split()
40 |
41 | with self.assertRaises(RuntimeError):
42 | safe_key.duplicate()
43 |
44 |
45 | if __name__ == '__main__':
46 | absltest.main()
47 |
--------------------------------------------------------------------------------
/src/alphafold/model/quat_affine_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for quat_affine."""
16 |
17 | from absl import logging
18 | from absl.testing import absltest
19 | from alphafold.model import quat_affine
20 | import jax
21 | import jax.numpy as jnp
22 | import numpy as np
23 |
24 | VERBOSE = False
25 | np.set_printoptions(precision=3, suppress=True)
26 |
27 | r2t = quat_affine.rot_list_to_tensor
28 | v2t = quat_affine.vec_list_to_tensor
29 |
30 | q2r = lambda q: r2t(quat_affine.quat_to_rot(q))
31 |
32 |
33 | class QuatAffineTest(absltest.TestCase):
34 |
35 | def _assert_check(self, to_check, tol=1e-5):
36 | for k, (correct, generated) in to_check.items():
37 | if VERBOSE:
38 | logging.info(k)
39 | logging.info('Correct %s', correct)
40 | logging.info('Predicted %s', generated)
41 | self.assertLess(np.max(np.abs(correct - generated)), tol)
42 |
43 | def test_conversion(self):
44 | quat = jnp.array([-2., 5., -1., 4.])
45 |
46 | rotation = jnp.array([
47 | [0.26087, 0.130435, 0.956522],
48 | [-0.565217, -0.782609, 0.26087],
49 | [0.782609, -0.608696, -0.130435]])
50 |
51 | translation = jnp.array([1., -3., 4.])
52 | point = jnp.array([0.7, 3.2, -2.9])
53 |
54 | a = quat_affine.QuatAffine(quat, translation, unstack_inputs=True)
55 | true_new_point = jnp.matmul(rotation, point[:, None])[:, 0] + translation
56 |
57 | self._assert_check({
58 | 'rot': (rotation, r2t(a.rotation)),
59 | 'trans': (translation, v2t(a.translation)),
60 | 'point': (true_new_point,
61 | v2t(a.apply_to_point(jnp.moveaxis(point, -1, 0)))),
62 | # Because of the double cover, we must be careful and compare rotations
63 | 'quat': (q2r(a.quaternion),
64 | q2r(quat_affine.rot_to_quat(a.rotation))),
65 |
66 | })
67 |
68 | def test_double_cover(self):
69 | """Test that -q is the same rotation as q."""
70 | rng = jax.random.PRNGKey(42)
71 | keys = jax.random.split(rng)
72 | q = jax.random.normal(keys[0], (2, 4))
73 | trans = jax.random.normal(keys[1], (2, 3))
74 | a1 = quat_affine.QuatAffine(q, trans, unstack_inputs=True)
75 | a2 = quat_affine.QuatAffine(-q, trans, unstack_inputs=True)
76 |
77 | self._assert_check({
78 | 'rot': (r2t(a1.rotation),
79 | r2t(a2.rotation)),
80 | 'trans': (v2t(a1.translation),
81 | v2t(a2.translation)),
82 | })
83 |
84 | def test_homomorphism(self):
85 | rng = jax.random.PRNGKey(42)
86 | keys = jax.random.split(rng, 4)
87 | vec_q1 = jax.random.normal(keys[0], (2, 3))
88 |
89 | q1 = jnp.concatenate([
90 | jnp.ones_like(vec_q1)[:, :1],
91 | vec_q1], axis=-1)
92 |
93 | q2 = jax.random.normal(keys[1], (2, 4))
94 | t1 = jax.random.normal(keys[2], (2, 3))
95 | t2 = jax.random.normal(keys[3], (2, 3))
96 |
97 | a1 = quat_affine.QuatAffine(q1, t1, unstack_inputs=True)
98 | a2 = quat_affine.QuatAffine(q2, t2, unstack_inputs=True)
99 | a21 = a2.pre_compose(jnp.concatenate([vec_q1, t1], axis=-1))
100 |
101 | rng, key = jax.random.split(rng)
102 | x = jax.random.normal(key, (2, 3))
103 | new_x = a21.apply_to_point(jnp.moveaxis(x, -1, 0))
104 | new_x_apply2 = a2.apply_to_point(a1.apply_to_point(jnp.moveaxis(x, -1, 0)))
105 |
106 | self._assert_check({
107 | 'quat': (q2r(quat_affine.quat_multiply(a2.quaternion, a1.quaternion)),
108 | q2r(a21.quaternion)),
109 | 'rot': (jnp.matmul(r2t(a2.rotation), r2t(a1.rotation)),
110 | r2t(a21.rotation)),
111 | 'point': (v2t(new_x_apply2),
112 | v2t(new_x)),
113 | 'inverse': (x, v2t(a21.invert_point(new_x))),
114 | })
115 |
116 | def test_batching(self):
117 | """Test that affine applies batchwise."""
118 | rng = jax.random.PRNGKey(42)
119 | keys = jax.random.split(rng, 3)
120 | q = jax.random.uniform(keys[0], (5, 2, 4))
121 | t = jax.random.uniform(keys[1], (2, 3))
122 | x = jax.random.uniform(keys[2], (5, 1, 3))
123 |
124 | a = quat_affine.QuatAffine(q, t, unstack_inputs=True)
125 | y = v2t(a.apply_to_point(jnp.moveaxis(x, -1, 0)))
126 |
127 | y_list = []
128 | for i in range(5):
129 | for j in range(2):
130 | a_local = quat_affine.QuatAffine(q[i, j], t[j],
131 | unstack_inputs=True)
132 | y_local = v2t(a_local.apply_to_point(jnp.moveaxis(x[i, 0], -1, 0)))
133 | y_list.append(y_local)
134 | y_combine = jnp.reshape(jnp.stack(y_list, axis=0), (5, 2, 3))
135 |
136 | self._assert_check({
137 | 'batch': (y_combine, y),
138 | 'quat': (q2r(a.quaternion),
139 | q2r(quat_affine.rot_to_quat(a.rotation))),
140 | })
141 |
142 | def assertAllClose(self, a, b, rtol=1e-06, atol=1e-06):
143 | self.assertTrue(np.allclose(a, b, rtol=rtol, atol=atol))
144 |
145 | def assertAllEqual(self, a, b):
146 | self.assertTrue(np.all(np.array(a) == np.array(b)))
147 |
148 |
149 | if __name__ == '__main__':
150 | absltest.main()
151 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Alphafold model TensorFlow code."""
15 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/input_pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Feature pre-processing input pipeline for AlphaFold."""
16 |
17 | from alphafold.model.tf import data_transforms
18 | from alphafold.model.tf import shape_placeholders
19 | import tensorflow.compat.v1 as tf
20 | import tree
21 |
22 | # Pylint gets confused by the curry1 decorator because it changes the number
23 | # of arguments to the function.
24 | # pylint:disable=no-value-for-parameter
25 |
26 |
27 | NUM_RES = shape_placeholders.NUM_RES
28 | NUM_MSA_SEQ = shape_placeholders.NUM_MSA_SEQ
29 | NUM_EXTRA_SEQ = shape_placeholders.NUM_EXTRA_SEQ
30 | NUM_TEMPLATES = shape_placeholders.NUM_TEMPLATES
31 |
32 |
33 | def nonensembled_map_fns(data_config):
34 | """Input pipeline functions which are not ensembled."""
35 | common_cfg = data_config.common
36 |
37 | map_fns = [
38 | data_transforms.correct_msa_restypes,
39 | data_transforms.add_distillation_flag(False),
40 | data_transforms.cast_64bit_ints,
41 | data_transforms.squeeze_features,
42 | # Keep to not disrupt RNG.
43 | data_transforms.randomly_replace_msa_with_unknown(0.0),
44 | data_transforms.make_seq_mask,
45 | data_transforms.make_msa_mask,
46 | # Compute the HHblits profile if it's not set. This has to be run before
47 | # sampling the MSA.
48 | data_transforms.make_hhblits_profile,
49 | data_transforms.make_random_crop_to_size_seed,
50 | ]
51 | if common_cfg.use_templates:
52 | map_fns.extend([
53 | data_transforms.fix_templates_aatype,
54 | data_transforms.make_template_mask,
55 | data_transforms.make_pseudo_beta('template_')
56 | ])
57 | map_fns.extend([
58 | data_transforms.make_atom14_masks,
59 | ])
60 |
61 | return map_fns
62 |
63 |
64 | def ensembled_map_fns(data_config):
65 | """Input pipeline functions that can be ensembled and averaged."""
66 | common_cfg = data_config.common
67 | eval_cfg = data_config.eval
68 |
69 | map_fns = []
70 |
71 | if common_cfg.reduce_msa_clusters_by_max_templates:
72 | pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates
73 | else:
74 | pad_msa_clusters = eval_cfg.max_msa_clusters
75 |
76 | max_msa_clusters = pad_msa_clusters
77 | max_extra_msa = common_cfg.max_extra_msa
78 |
79 | map_fns.append(
80 | data_transforms.sample_msa(
81 | max_msa_clusters,
82 | keep_extra=True))
83 |
84 | if 'masked_msa' in common_cfg:
85 | # Masked MSA should come *before* MSA clustering so that
86 | # the clustering and full MSA profile do not leak information about
87 | # the masked locations and secret corrupted locations.
88 | map_fns.append(
89 | data_transforms.make_masked_msa(common_cfg.masked_msa,
90 | eval_cfg.masked_msa_replace_fraction))
91 |
92 | if common_cfg.msa_cluster_features:
93 | map_fns.append(data_transforms.nearest_neighbor_clusters())
94 | map_fns.append(data_transforms.summarize_clusters())
95 |
96 | # Crop after creating the cluster profiles.
97 | if max_extra_msa:
98 | map_fns.append(data_transforms.crop_extra_msa(max_extra_msa))
99 | else:
100 | map_fns.append(data_transforms.delete_extra_msa)
101 |
102 | map_fns.append(data_transforms.make_msa_feat())
103 |
104 | crop_feats = dict(eval_cfg.feat)
105 |
106 | if eval_cfg.fixed_size:
107 | map_fns.append(data_transforms.select_feat(list(crop_feats)))
108 | map_fns.append(data_transforms.random_crop_to_size(
109 | eval_cfg.crop_size,
110 | eval_cfg.max_templates,
111 | crop_feats,
112 | eval_cfg.subsample_templates))
113 | map_fns.append(data_transforms.make_fixed_size(
114 | crop_feats,
115 | pad_msa_clusters,
116 | common_cfg.max_extra_msa,
117 | eval_cfg.crop_size,
118 | eval_cfg.max_templates))
119 | else:
120 | map_fns.append(data_transforms.crop_templates(eval_cfg.max_templates))
121 |
122 | return map_fns
123 |
124 |
125 | def process_tensors_from_config(tensors, data_config):
126 | """Apply filters and maps to an existing dataset, based on the config."""
127 |
128 | def wrap_ensemble_fn(data, i):
129 | """Function to be mapped over the ensemble dimension."""
130 | d = data.copy()
131 | fns = ensembled_map_fns(data_config)
132 | fn = compose(fns)
133 | d['ensemble_index'] = i
134 | return fn(d)
135 |
136 | eval_cfg = data_config.eval
137 | tensors = compose(
138 | nonensembled_map_fns(
139 | data_config))(
140 | tensors)
141 |
142 | tensors_0 = wrap_ensemble_fn(tensors, tf.constant(0))
143 | num_ensemble = eval_cfg.num_ensemble
144 | if data_config.common.resample_msa_in_recycling:
145 | # Separate batch per ensembling & recycling step.
146 | num_ensemble *= data_config.common.num_recycle + 1
147 |
148 | if isinstance(num_ensemble, tf.Tensor) or num_ensemble > 1:
149 | fn_output_signature = tree.map_structure(
150 | tf.TensorSpec.from_tensor, tensors_0)
151 | tensors = tf.map_fn(
152 | lambda x: wrap_ensemble_fn(tensors, x),
153 | tf.range(num_ensemble),
154 | parallel_iterations=1,
155 | fn_output_signature=fn_output_signature)
156 | else:
157 | tensors = tree.map_structure(lambda x: x[None],
158 | tensors_0)
159 | return tensors
160 |
161 |
162 | @data_transforms.curry1
163 | def compose(x, fs):
164 | for f in fs:
165 | x = f(x)
166 | return x
167 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/protein_features.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Contains descriptions of various protein features."""
16 | import enum
17 | from typing import Dict, Optional, Sequence, Tuple, Union
18 | from alphafold.common import residue_constants
19 | import tensorflow.compat.v1 as tf
20 |
21 | # Type aliases.
22 | FeaturesMetadata = Dict[str, Tuple[tf.dtypes.DType, Sequence[Union[str, int]]]]
23 |
24 |
25 | class FeatureType(enum.Enum):
26 | ZERO_DIM = 0 # Shape [x]
27 | ONE_DIM = 1 # Shape [num_res, x]
28 | TWO_DIM = 2 # Shape [num_res, num_res, x]
29 | MSA = 3 # Shape [msa_length, num_res, x]
30 |
31 |
32 | # Placeholder values that will be replaced with their true value at runtime.
33 | NUM_RES = "num residues placeholder"
34 | NUM_SEQ = "length msa placeholder"
35 | NUM_TEMPLATES = "num templates placeholder"
36 | # Sizes of the protein features, NUM_RES and NUM_SEQ are allowed as placeholders
37 | # to be replaced with the number of residues and the number of sequences in the
38 | # multiple sequence alignment, respectively.
39 |
40 |
41 | FEATURES = {
42 | #### Static features of a protein sequence ####
43 | "aatype": (tf.float32, [NUM_RES, 21]),
44 | "between_segment_residues": (tf.int64, [NUM_RES, 1]),
45 | "deletion_matrix": (tf.float32, [NUM_SEQ, NUM_RES, 1]),
46 | "domain_name": (tf.string, [1]),
47 | "msa": (tf.int64, [NUM_SEQ, NUM_RES, 1]),
48 | "num_alignments": (tf.int64, [NUM_RES, 1]),
49 | "residue_index": (tf.int64, [NUM_RES, 1]),
50 | "seq_length": (tf.int64, [NUM_RES, 1]),
51 | "sequence": (tf.string, [1]),
52 | "all_atom_positions": (tf.float32,
53 | [NUM_RES, residue_constants.atom_type_num, 3]),
54 | "all_atom_mask": (tf.int64, [NUM_RES, residue_constants.atom_type_num]),
55 | "resolution": (tf.float32, [1]),
56 | "template_domain_names": (tf.string, [NUM_TEMPLATES]),
57 | "template_sum_probs": (tf.float32, [NUM_TEMPLATES, 1]),
58 | "template_aatype": (tf.float32, [NUM_TEMPLATES, NUM_RES, 22]),
59 | "template_all_atom_positions": (tf.float32, [
60 | NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 3
61 | ]),
62 | "template_all_atom_masks": (tf.float32, [
63 | NUM_TEMPLATES, NUM_RES, residue_constants.atom_type_num, 1
64 | ]),
65 | }
66 |
67 | FEATURE_TYPES = {k: v[0] for k, v in FEATURES.items()}
68 | FEATURE_SIZES = {k: v[1] for k, v in FEATURES.items()}
69 |
70 |
71 | def register_feature(name: str,
72 | type_: tf.dtypes.DType,
73 | shape_: Tuple[Union[str, int]]):
74 | """Register extra features used in custom datasets."""
75 | FEATURES[name] = (type_, shape_)
76 | FEATURE_TYPES[name] = type_
77 | FEATURE_SIZES[name] = shape_
78 |
79 |
80 | def shape(feature_name: str,
81 | num_residues: int,
82 | msa_length: int,
83 | num_templates: Optional[int] = None,
84 | features: Optional[FeaturesMetadata] = None):
85 | """Get the shape for the given feature name.
86 |
87 | This is near identical to _get_tf_shape_no_placeholders() but with 2
88 | differences:
89 | * This method does not calculate a single placeholder from the total number of
90 | elements (eg given and size := 12, this won't deduce NUM_RES
91 | must be 4)
92 | * This method will work with tensors
93 |
94 | Args:
95 | feature_name: String identifier for the feature. If the feature name ends
96 | with "_unnormalized", this suffix is stripped off.
97 | num_residues: The number of residues in the current domain - some elements
98 | of the shape can be dynamic and will be replaced by this value.
99 | msa_length: The number of sequences in the multiple sequence alignment, some
100 | elements of the shape can be dynamic and will be replaced by this value.
101 | If the number of alignments is unknown / not read, please pass None for
102 | msa_length.
103 | num_templates (optional): The number of templates in this tfexample.
104 | features: A feature_name to (tf_dtype, shape) lookup; defaults to FEATURES.
105 |
106 | Returns:
107 | List of ints representation the tensor size.
108 |
109 | Raises:
110 | ValueError: If a feature is requested but no concrete placeholder value is
111 | given.
112 | """
113 | features = features or FEATURES
114 | if feature_name.endswith("_unnormalized"):
115 | feature_name = feature_name[:-13]
116 |
117 | unused_dtype, raw_sizes = features[feature_name]
118 | replacements = {NUM_RES: num_residues,
119 | NUM_SEQ: msa_length}
120 |
121 | if num_templates is not None:
122 | replacements[NUM_TEMPLATES] = num_templates
123 |
124 | sizes = [replacements.get(dimension, dimension) for dimension in raw_sizes]
125 | for dimension in sizes:
126 | if isinstance(dimension, str):
127 | raise ValueError("Could not parse %s (shape: %s) with values: %s" % (
128 | feature_name, raw_sizes, replacements))
129 | return sizes
130 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/protein_features_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for protein_features."""
16 | import uuid
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from alphafold.model.tf import protein_features
21 | import tensorflow.compat.v1 as tf
22 |
23 |
24 | def _random_bytes():
25 | return str(uuid.uuid4()).encode('utf-8')
26 |
27 |
28 | class FeaturesTest(parameterized.TestCase, tf.test.TestCase):
29 |
30 | def setUp(self):
31 | super().setUp()
32 | tf.disable_v2_behavior()
33 |
34 | def testFeatureNames(self):
35 | self.assertEqual(len(protein_features.FEATURE_SIZES),
36 | len(protein_features.FEATURE_TYPES))
37 | sorted_size_names = sorted(protein_features.FEATURE_SIZES.keys())
38 | sorted_type_names = sorted(protein_features.FEATURE_TYPES.keys())
39 | for i, size_name in enumerate(sorted_size_names):
40 | self.assertEqual(size_name, sorted_type_names[i])
41 |
42 | def testReplacement(self):
43 | for name in protein_features.FEATURE_SIZES.keys():
44 | sizes = protein_features.shape(name,
45 | num_residues=12,
46 | msa_length=24,
47 | num_templates=3)
48 | for x in sizes:
49 | self.assertEqual(type(x), int)
50 | self.assertGreater(x, 0)
51 |
52 |
53 | if __name__ == '__main__':
54 | absltest.main()
55 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/proteins_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Datasets consisting of proteins."""
16 | from typing import Dict, Mapping, Optional, Sequence
17 | from alphafold.model.tf import protein_features
18 | import numpy as np
19 | import tensorflow.compat.v1 as tf
20 |
21 | TensorDict = Dict[str, tf.Tensor]
22 |
23 |
24 | def parse_tfexample(
25 | raw_data: bytes,
26 | features: protein_features.FeaturesMetadata,
27 | key: Optional[str] = None) -> Dict[str, tf.train.Feature]:
28 | """Read a single TF Example proto and return a subset of its features.
29 |
30 | Args:
31 | raw_data: A serialized tf.Example proto.
32 | features: A dictionary of features, mapping string feature names to a tuple
33 | (dtype, shape). This dictionary should be a subset of
34 | protein_features.FEATURES (or the dictionary itself for all features).
35 | key: Optional string with the SSTable key of that tf.Example. This will be
36 | added into features as a 'key' but only if requested in features.
37 |
38 | Returns:
39 | A dictionary of features mapping feature names to features. Only the given
40 | features are returned, all other ones are filtered out.
41 | """
42 | feature_map = {
43 | k: tf.io.FixedLenSequenceFeature(shape=(), dtype=v[0], allow_missing=True)
44 | for k, v in features.items()
45 | }
46 | parsed_features = tf.io.parse_single_example(raw_data, feature_map)
47 | reshaped_features = parse_reshape_logic(parsed_features, features, key=key)
48 |
49 | return reshaped_features
50 |
51 |
52 | def _first(tensor: tf.Tensor) -> tf.Tensor:
53 | """Returns the 1st element - the input can be a tensor or a scalar."""
54 | return tf.reshape(tensor, shape=(-1,))[0]
55 |
56 |
57 | def parse_reshape_logic(
58 | parsed_features: TensorDict,
59 | features: protein_features.FeaturesMetadata,
60 | key: Optional[str] = None) -> TensorDict:
61 | """Transforms parsed serial features to the correct shape."""
62 | # Find out what is the number of sequences and the number of alignments.
63 | num_residues = tf.cast(_first(parsed_features["seq_length"]), dtype=tf.int32)
64 |
65 | if "num_alignments" in parsed_features:
66 | num_msa = tf.cast(_first(parsed_features["num_alignments"]), dtype=tf.int32)
67 | else:
68 | num_msa = 0
69 |
70 | if "template_domain_names" in parsed_features:
71 | num_templates = tf.cast(
72 | tf.shape(parsed_features["template_domain_names"])[0], dtype=tf.int32)
73 | else:
74 | num_templates = 0
75 |
76 | if key is not None and "key" in features:
77 | parsed_features["key"] = [key] # Expand dims from () to (1,).
78 |
79 | # Reshape the tensors according to the sequence length and num alignments.
80 | for k, v in parsed_features.items():
81 | new_shape = protein_features.shape(
82 | feature_name=k,
83 | num_residues=num_residues,
84 | msa_length=num_msa,
85 | num_templates=num_templates,
86 | features=features)
87 | new_shape_size = tf.constant(1, dtype=tf.int32)
88 | for dim in new_shape:
89 | new_shape_size *= tf.cast(dim, tf.int32)
90 |
91 | assert_equal = tf.assert_equal(
92 | tf.size(v), new_shape_size,
93 | name="assert_%s_shape_correct" % k,
94 | message="The size of feature %s (%s) could not be reshaped "
95 | "into %s" % (k, tf.size(v), new_shape))
96 | if "template" not in k:
97 | # Make sure the feature we are reshaping is not empty.
98 | assert_non_empty = tf.assert_greater(
99 | tf.size(v), 0, name="assert_%s_non_empty" % k,
100 | message="The feature %s is not set in the tf.Example. Either do not "
101 | "request the feature or use a tf.Example that has the "
102 | "feature set." % k)
103 | with tf.control_dependencies([assert_non_empty, assert_equal]):
104 | parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)
105 | else:
106 | with tf.control_dependencies([assert_equal]):
107 | parsed_features[k] = tf.reshape(v, new_shape, name="reshape_%s" % k)
108 |
109 | return parsed_features
110 |
111 |
112 | def _make_features_metadata(
113 | feature_names: Sequence[str]) -> protein_features.FeaturesMetadata:
114 | """Makes a feature name to type and shape mapping from a list of names."""
115 | # Make sure these features are always read.
116 | required_features = ["aatype", "sequence", "seq_length"]
117 | feature_names = list(set(feature_names) | set(required_features))
118 |
119 | features_metadata = {name: protein_features.FEATURES[name]
120 | for name in feature_names}
121 | return features_metadata
122 |
123 |
124 | def create_tensor_dict(
125 | raw_data: bytes,
126 | features: Sequence[str],
127 | key: Optional[str] = None,
128 | ) -> TensorDict:
129 | """Creates a dictionary of tensor features.
130 |
131 | Args:
132 | raw_data: A serialized tf.Example proto.
133 | features: A list of strings of feature names to be returned in the dataset.
134 | key: Optional string with the SSTable key of that tf.Example. This will be
135 | added into features as a 'key' but only if requested in features.
136 |
137 | Returns:
138 | A dictionary of features mapping feature names to features. Only the given
139 | features are returned, all other ones are filtered out.
140 | """
141 | features_metadata = _make_features_metadata(features)
142 | return parse_tfexample(raw_data, features_metadata, key)
143 |
144 |
145 | def np_to_tensor_dict(
146 | np_example: Mapping[str, np.ndarray],
147 | features: Sequence[str],
148 | ) -> TensorDict:
149 | """Creates dict of tensors from a dict of NumPy arrays.
150 |
151 | Args:
152 | np_example: A dict of NumPy feature arrays.
153 | features: A list of strings of feature names to be returned in the dataset.
154 |
155 | Returns:
156 | A dictionary of features mapping feature names to features. Only the given
157 | features are returned, all other ones are filtered out.
158 | """
159 | features_metadata = _make_features_metadata(features)
160 | tensor_dict = {k: tf.constant(v) for k, v in np_example.items()
161 | if k in features_metadata}
162 |
163 | # Ensures shapes are as expected. Needed for setting size of empty features
164 | # e.g. when no template hits were found.
165 | tensor_dict = parse_reshape_logic(tensor_dict, features_metadata)
166 | return tensor_dict
167 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/shape_helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utilities for dealing with shapes of TensorFlow tensors."""
16 | import tensorflow.compat.v1 as tf
17 |
18 |
19 | def shape_list(x):
20 | """Return list of dimensions of a tensor, statically where possible.
21 |
22 | Like `x.shape.as_list()` but with tensors instead of `None`s.
23 |
24 | Args:
25 | x: A tensor.
26 | Returns:
27 | A list with length equal to the rank of the tensor. The n-th element of the
28 | list is an integer when that dimension is statically known otherwise it is
29 | the n-th element of `tf.shape(x)`.
30 | """
31 | x = tf.convert_to_tensor(x)
32 |
33 | # If unknown rank, return dynamic shape
34 | if x.get_shape().dims is None:
35 | return tf.shape(x)
36 |
37 | static = x.get_shape().as_list()
38 | shape = tf.shape(x)
39 |
40 | ret = []
41 | for i in range(len(static)):
42 | dim = static[i]
43 | if dim is None:
44 | dim = shape[i]
45 | ret.append(dim)
46 | return ret
47 |
48 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/shape_helpers_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for shape_helpers."""
16 |
17 | from alphafold.model.tf import shape_helpers
18 | import numpy as np
19 | import tensorflow.compat.v1 as tf
20 |
21 |
22 | class ShapeTest(tf.test.TestCase):
23 |
24 | def setUp(self):
25 | super().setUp()
26 | tf.disable_v2_behavior()
27 |
28 | def test_shape_list(self):
29 | """Test that shape_list can allow for reshaping to dynamic shapes."""
30 | a = tf.zeros([10, 4, 4, 2])
31 | p = tf.placeholder(tf.float32, shape=[None, None, 1, 4, 4])
32 | shape_dyn = shape_helpers.shape_list(p)[:2] + [4, 4]
33 |
34 | b = tf.reshape(a, shape_dyn)
35 | with self.session() as sess:
36 | out = sess.run(b, feed_dict={p: np.ones((20, 1, 1, 4, 4))})
37 |
38 | self.assertAllEqual(out.shape, (20, 1, 4, 4))
39 |
40 |
41 | if __name__ == '__main__':
42 | tf.test.main()
43 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/shape_placeholders.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Placeholder values for run-time varying dimension sizes."""
16 |
17 | NUM_RES = 'num residues placeholder'
18 | NUM_MSA_SEQ = 'msa placeholder'
19 | NUM_EXTRA_SEQ = 'extra msa placeholder'
20 | NUM_TEMPLATES = 'num templates placeholder'
21 |
--------------------------------------------------------------------------------
/src/alphafold/model/tf/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Shared utilities for various components."""
16 | import tensorflow.compat.v1 as tf
17 |
18 |
19 | def tf_combine_mask(*masks):
20 | """Take the intersection of float-valued masks."""
21 | ret = 1
22 | for m in masks:
23 | ret *= m
24 | return ret
25 |
26 |
27 | class SeedMaker(object):
28 | """Return unique seeds."""
29 |
30 | def __init__(self, initial_seed=0):
31 | self.next_seed = initial_seed
32 |
33 | def __call__(self):
34 | i = self.next_seed
35 | self.next_seed += 1
36 | return i
37 |
38 | seed_maker = SeedMaker()
39 |
40 |
41 | def make_random_seed():
42 | return tf.random.uniform([2],
43 | tf.int32.min,
44 | tf.int32.max,
45 | tf.int32,
46 | seed=seed_maker())
47 |
48 |
--------------------------------------------------------------------------------
/src/alphafold/model/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """A collection of JAX utility functions for use in protein folding."""
16 |
17 | import collections
18 | import contextlib
19 | import functools
20 | import numbers
21 | from typing import Mapping
22 |
23 | import haiku as hk
24 | import jax
25 | import jax.numpy as jnp
26 | import numpy as np
27 |
28 |
29 | def bfloat16_creator(next_creator, shape, dtype, init, context):
30 | """Creates float32 variables when bfloat16 is requested."""
31 | if context.original_dtype == jnp.bfloat16:
32 | dtype = jnp.float32
33 | return next_creator(shape, dtype, init)
34 |
35 |
36 | def bfloat16_getter(next_getter, value, context):
37 | """Casts float32 to bfloat16 when bfloat16 was originally requested."""
38 | if context.original_dtype == jnp.bfloat16:
39 | assert value.dtype == jnp.float32
40 | value = value.astype(jnp.bfloat16)
41 | return next_getter(value)
42 |
43 |
44 | @contextlib.contextmanager
45 | def bfloat16_context():
46 | with hk.custom_creator(bfloat16_creator), hk.custom_getter(bfloat16_getter):
47 | yield
48 |
49 |
50 | def final_init(config):
51 | if config.zero_init:
52 | return 'zeros'
53 | else:
54 | return 'linear'
55 |
56 |
57 | def batched_gather(params, indices, axis=0, batch_dims=0):
58 | """Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`."""
59 | take_fn = lambda p, i: jnp.take(p, i, axis=axis, mode='clip')
60 | for _ in range(batch_dims):
61 | take_fn = jax.vmap(take_fn)
62 | return take_fn(params, indices)
63 |
64 |
65 | def mask_mean(mask, value, axis=None, drop_mask_channel=False, eps=1e-10):
66 | """Masked mean."""
67 | if drop_mask_channel:
68 | mask = mask[..., 0]
69 |
70 | mask_shape = mask.shape
71 | value_shape = value.shape
72 |
73 | assert len(mask_shape) == len(value_shape)
74 |
75 | if isinstance(axis, numbers.Integral):
76 | axis = [axis]
77 | elif axis is None:
78 | axis = list(range(len(mask_shape)))
79 | assert isinstance(axis, collections.abc.Iterable), (
80 | 'axis needs to be either an iterable, integer or "None"')
81 |
82 | broadcast_factor = 1.
83 | for axis_ in axis:
84 | value_size = value_shape[axis_]
85 | mask_size = mask_shape[axis_]
86 | if mask_size == 1:
87 | broadcast_factor *= value_size
88 | else:
89 | assert mask_size == value_size
90 |
91 | return (jnp.sum(mask * value, axis=axis) /
92 | (jnp.sum(mask, axis=axis) * broadcast_factor + eps))
93 |
94 |
95 | def flat_params_to_haiku(params: Mapping[str, np.ndarray]) -> hk.Params:
96 | """Convert a dictionary of NumPy arrays to Haiku parameters."""
97 | hk_params = {}
98 | for path, array in params.items():
99 | scope, name = path.split('//')
100 | if scope not in hk_params:
101 | hk_params[scope] = {}
102 | hk_params[scope][name] = jnp.array(array)
103 |
104 | return hk_params
105 |
106 |
107 | def padding_consistent_rng(f):
108 | """Modify any element-wise random function to be consistent with padding.
109 |
110 | Normally if you take a function like jax.random.normal and generate an array,
111 | say of size (10,10), you will get a different set of random numbers to if you
112 | add padding and take the first (10,10) sub-array.
113 |
114 | This function makes a random function that is consistent regardless of the
115 | amount of padding added.
116 |
117 | Note: The padding-consistent function is likely to be slower to compile and
118 | run than the function it is wrapping, but these slowdowns are likely to be
119 | negligible in a large network.
120 |
121 | Args:
122 | f: Any element-wise function that takes (PRNG key, shape) as the first 2
123 | arguments.
124 |
125 | Returns:
126 | An equivalent function to f, that is now consistent for different amounts of
127 | padding.
128 | """
129 | def grid_keys(key, shape):
130 | """Generate a grid of rng keys that is consistent with different padding.
131 |
132 | Generate random keys such that the keys will be identical, regardless of
133 | how much padding is added to any dimension.
134 |
135 | Args:
136 | key: A PRNG key.
137 | shape: The shape of the output array of keys that will be generated.
138 |
139 | Returns:
140 | An array of shape `shape` consisting of random keys.
141 | """
142 | if not shape:
143 | return key
144 | new_keys = jax.vmap(functools.partial(jax.random.fold_in, key))(
145 | jnp.arange(shape[0]))
146 | return jax.vmap(functools.partial(grid_keys, shape=shape[1:]))(new_keys)
147 |
148 | def inner(key, shape, **kwargs):
149 | return jnp.vectorize(
150 | lambda key: f(key, shape=(), **kwargs),
151 | signature='(2)->()')(
152 | grid_keys(key, shape))
153 | return inner
154 |
--------------------------------------------------------------------------------
/src/alphafold/notebooks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """AlphaFold Colab notebook."""
15 |
--------------------------------------------------------------------------------
/src/alphafold/notebooks/notebook_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Helper methods for the AlphaFold Colab notebook."""
16 | import json
17 | from typing import Any, Mapping, Optional, Sequence, Tuple
18 |
19 | from alphafold.common import residue_constants
20 | from alphafold.data import parsers
21 | from matplotlib import pyplot as plt
22 | import numpy as np
23 |
24 |
25 | def clean_and_validate_single_sequence(
26 | input_sequence: str, min_length: int, max_length: int) -> str:
27 | """Checks that the input sequence is ok and returns a clean version of it."""
28 | # Remove all whitespaces, tabs and end lines; upper-case.
29 | clean_sequence = input_sequence.translate(
30 | str.maketrans('', '', ' \n\t')).upper()
31 | aatypes = set(residue_constants.restypes) # 20 standard aatypes.
32 | if not set(clean_sequence).issubset(aatypes):
33 | raise ValueError(
34 | f'Input sequence contains non-amino acid letters: '
35 | f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard '
36 | 'amino acids as inputs.')
37 | if len(clean_sequence) < min_length:
38 | raise ValueError(
39 | f'Input sequence is too short: {len(clean_sequence)} amino acids, '
40 | f'while the minimum is {min_length}')
41 | if len(clean_sequence) > max_length:
42 | raise ValueError(
43 | f'Input sequence is too long: {len(clean_sequence)} amino acids, while '
44 | f'the maximum is {max_length}. You may be able to run it with the full '
45 | f'AlphaFold system depending on your resources (system memory, '
46 | f'GPU memory).')
47 | return clean_sequence
48 |
49 |
50 | def clean_and_validate_input_sequences(
51 | input_sequences: Sequence[str],
52 | min_sequence_length: int,
53 | max_sequence_length: int) -> Sequence[str]:
54 | """Validates and cleans input sequences."""
55 | sequences = []
56 |
57 | for input_sequence in input_sequences:
58 | if input_sequence.strip():
59 | input_sequence = clean_and_validate_single_sequence(
60 | input_sequence=input_sequence,
61 | min_length=min_sequence_length,
62 | max_length=max_sequence_length)
63 | sequences.append(input_sequence)
64 |
65 | if sequences:
66 | return sequences
67 | else:
68 | raise ValueError('No input amino acid sequence provided, please provide at '
69 | 'least one sequence.')
70 |
71 |
72 | def merge_chunked_msa(
73 | results: Sequence[Mapping[str, Any]],
74 | max_hits: Optional[int] = None
75 | ) -> parsers.Msa:
76 | """Merges chunked database hits together into hits for the full database."""
77 | unsorted_results = []
78 | for chunk_index, chunk in enumerate(results):
79 | msa = parsers.parse_stockholm(chunk['sto'])
80 | e_values_dict = parsers.parse_e_values_from_tblout(chunk['tbl'])
81 | # Jackhmmer lists sequences as /-.
82 | e_values = [e_values_dict[t.partition('/')[0]] for t in msa.descriptions]
83 | chunk_results = zip(
84 | msa.sequences, msa.deletion_matrix, msa.descriptions, e_values)
85 | if chunk_index != 0:
86 | next(chunk_results) # Only take query (first hit) from the first chunk.
87 | unsorted_results.extend(chunk_results)
88 |
89 | sorted_by_evalue = sorted(unsorted_results, key=lambda x: x[-1])
90 | merged_sequences, merged_deletion_matrix, merged_descriptions, _ = zip(
91 | *sorted_by_evalue)
92 | merged_msa = parsers.Msa(sequences=merged_sequences,
93 | deletion_matrix=merged_deletion_matrix,
94 | descriptions=merged_descriptions)
95 | if max_hits is not None:
96 | merged_msa = merged_msa.truncate(max_seqs=max_hits)
97 |
98 | return merged_msa
99 |
100 |
101 | def show_msa_info(
102 | single_chain_msas: Sequence[parsers.Msa],
103 | sequence_index: int):
104 | """Prints info and shows a plot of the deduplicated single chain MSA."""
105 | full_single_chain_msa = []
106 | for single_chain_msa in single_chain_msas:
107 | full_single_chain_msa.extend(single_chain_msa.sequences)
108 |
109 | # Deduplicate but preserve order (hence can't use set).
110 | deduped_full_single_chain_msa = list(dict.fromkeys(full_single_chain_msa))
111 | total_msa_size = len(deduped_full_single_chain_msa)
112 | print(f'\n{total_msa_size} unique sequences found in total for sequence '
113 | f'{sequence_index}\n')
114 |
115 | aa_map = {res: i for i, res in enumerate('ABCDEFGHIJKLMNOPQRSTUVWXYZ-')}
116 | msa_arr = np.array(
117 | [[aa_map[aa] for aa in seq] for seq in deduped_full_single_chain_msa])
118 |
119 | plt.figure(figsize=(12, 3))
120 | plt.title(f'Per-Residue Count of Non-Gap Amino Acids in the MSA for Sequence '
121 | f'{sequence_index}')
122 | plt.plot(np.sum(msa_arr != aa_map['-'], axis=0), color='black')
123 | plt.ylabel('Non-Gap Count')
124 | plt.yticks(range(0, total_msa_size + 1, max(1, int(total_msa_size / 3))))
125 | plt.show()
126 |
127 |
128 | def empty_placeholder_template_features(
129 | num_templates: int, num_res: int) -> Mapping[str, np.ndarray]:
130 | return {
131 | 'template_aatype': np.zeros(
132 | (num_templates, num_res,
133 | len(residue_constants.restypes_with_x_and_gap)), dtype=np.float32),
134 | 'template_all_atom_masks': np.zeros(
135 | (num_templates, num_res, residue_constants.atom_type_num),
136 | dtype=np.float32),
137 | 'template_all_atom_positions': np.zeros(
138 | (num_templates, num_res, residue_constants.atom_type_num, 3),
139 | dtype=np.float32),
140 | 'template_domain_names': np.zeros([num_templates], dtype=np.object),
141 | 'template_sequence': np.zeros([num_templates], dtype=np.object),
142 | 'template_sum_probs': np.zeros([num_templates], dtype=np.float32),
143 | }
144 |
145 |
146 | def get_pae_json(pae: np.ndarray, max_pae: float) -> str:
147 | """Returns the PAE in the same format as is used in the AFDB.
148 |
149 | Note that the values are presented as floats to 1 decimal place,
150 | whereas AFDB returns integer values.
151 |
152 | Args:
153 | pae: The n_res x n_res PAE array.
154 | max_pae: The maximum possible PAE value.
155 | Returns:
156 | PAE output format as a JSON string.
157 | """
158 | # Check the PAE array is the correct shape.
159 | if (pae.ndim != 2 or pae.shape[0] != pae.shape[1]):
160 | raise ValueError(f'PAE must be a square matrix, got {pae.shape}')
161 |
162 | # Round the predicted aligned errors to 1 decimal place.
163 | rounded_errors = np.round(pae.astype(np.float64), decimals=1)
164 | formatted_output = [{
165 | 'predicted_aligned_error': rounded_errors.tolist(),
166 | 'max_predicted_aligned_error': max_pae
167 | }]
168 | return json.dumps(formatted_output, indent=None, separators=(',', ':'))
169 |
--------------------------------------------------------------------------------
/src/alphafold/relax/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Amber relaxation."""
15 |
--------------------------------------------------------------------------------
/src/alphafold/relax/amber_minimize_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for amber_minimize."""
16 | import os
17 |
18 | from absl.testing import absltest
19 | from alphafold.common import protein
20 | from alphafold.relax import amber_minimize
21 | import numpy as np
22 | # Internal import (7716).
23 |
24 | _USE_GPU = False
25 |
26 |
27 | def _load_test_protein(data_path):
28 | pdb_path = os.path.join(absltest.get_default_test_srcdir(), data_path)
29 | with open(pdb_path, 'r') as f:
30 | return protein.from_pdb_string(f.read())
31 |
32 |
33 | class AmberMinimizeTest(absltest.TestCase):
34 |
35 | def test_multiple_disulfides_target(self):
36 | prot = _load_test_protein(
37 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb'
38 | )
39 | ret = amber_minimize.run_pipeline(prot, max_iterations=10, max_attempts=1,
40 | stiffness=10., use_gpu=_USE_GPU)
41 | self.assertIn('opt_time', ret)
42 | self.assertIn('min_attempts', ret)
43 |
44 | def test_raises_invalid_protein_assertion(self):
45 | prot = _load_test_protein(
46 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb'
47 | )
48 | prot.atom_mask[4, :] = 0
49 | with self.assertRaisesRegex(
50 | ValueError,
51 | 'Amber minimization can only be performed on proteins with well-defined'
52 | ' residues. This protein contains at least one residue with no atoms.'):
53 | amber_minimize.run_pipeline(prot, max_iterations=10,
54 | stiffness=1.,
55 | max_attempts=1,
56 | use_gpu=_USE_GPU)
57 |
58 | def test_iterative_relax(self):
59 | prot = _load_test_protein(
60 | 'alphafold/relax/testdata/with_violations.pdb'
61 | )
62 | violations = amber_minimize.get_violation_metrics(prot)
63 | self.assertGreater(violations['num_residue_violations'], 0)
64 | out = amber_minimize.run_pipeline(
65 | prot=prot, max_outer_iterations=10, stiffness=10., use_gpu=_USE_GPU)
66 | self.assertLess(out['efinal'], out['einit'])
67 | self.assertEqual(0, out['num_residue_violations'])
68 |
69 | def test_find_violations(self):
70 | prot = _load_test_protein(
71 | 'alphafold/relax/testdata/multiple_disulfides_target.pdb'
72 | )
73 | viols, _ = amber_minimize.find_violations(prot)
74 |
75 | expected_between_residues_connection_mask = np.zeros((191,), np.float32)
76 | for residue in (42, 43, 59, 60, 135, 136):
77 | expected_between_residues_connection_mask[residue] = 1.0
78 |
79 | expected_clash_indices = np.array([
80 | [8, 4],
81 | [8, 5],
82 | [13, 3],
83 | [14, 1],
84 | [14, 4],
85 | [26, 4],
86 | [26, 5],
87 | [31, 8],
88 | [31, 10],
89 | [39, 0],
90 | [39, 1],
91 | [39, 2],
92 | [39, 3],
93 | [39, 4],
94 | [42, 5],
95 | [42, 6],
96 | [42, 7],
97 | [42, 8],
98 | [47, 7],
99 | [47, 8],
100 | [47, 9],
101 | [47, 10],
102 | [64, 4],
103 | [85, 5],
104 | [102, 4],
105 | [102, 5],
106 | [109, 13],
107 | [111, 5],
108 | [118, 6],
109 | [118, 7],
110 | [118, 8],
111 | [124, 4],
112 | [124, 5],
113 | [131, 5],
114 | [139, 7],
115 | [147, 4],
116 | [152, 7]], dtype=np.int32)
117 | expected_between_residues_clash_mask = np.zeros([191, 14])
118 | expected_between_residues_clash_mask[expected_clash_indices[:, 0],
119 | expected_clash_indices[:, 1]] += 1
120 | expected_per_atom_violations = np.zeros([191, 14])
121 | np.testing.assert_array_equal(
122 | viols['between_residues']['connections_per_residue_violation_mask'],
123 | expected_between_residues_connection_mask)
124 | np.testing.assert_array_equal(
125 | viols['between_residues']['clashes_per_atom_clash_mask'],
126 | expected_between_residues_clash_mask)
127 | np.testing.assert_array_equal(
128 | viols['within_residues']['per_atom_violations'],
129 | expected_per_atom_violations)
130 |
131 |
132 | if __name__ == '__main__':
133 | absltest.main()
134 |
--------------------------------------------------------------------------------
/src/alphafold/relax/cleanup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations.
16 |
17 | fix_pdb uses a third-party tool. We also support fixing some additional edge
18 | cases like removing chains of length one (see clean_structure).
19 | """
20 | import io
21 |
22 | import pdbfixer
23 | from simtk.openmm import app
24 | from simtk.openmm.app import element
25 |
26 |
27 | def fix_pdb(pdbfile, alterations_info):
28 | """Apply pdbfixer to the contents of a PDB file; return a PDB string result.
29 |
30 | 1) Replaces nonstandard residues.
31 | 2) Removes heterogens (non protein residues) including water.
32 | 3) Adds missing residues and missing atoms within existing residues.
33 | 4) Adds hydrogens assuming pH=7.0.
34 | 5) KeepIds is currently true, so the fixer must keep the existing chain and
35 | residue identifiers. This will fail for some files in wider PDB that have
36 | invalid IDs.
37 |
38 | Args:
39 | pdbfile: Input PDB file handle.
40 | alterations_info: A dict that will store details of changes made.
41 |
42 | Returns:
43 | A PDB string representing the fixed structure.
44 | """
45 | fixer = pdbfixer.PDBFixer(pdbfile=pdbfile)
46 | fixer.findNonstandardResidues()
47 | alterations_info['nonstandard_residues'] = fixer.nonstandardResidues
48 | fixer.replaceNonstandardResidues()
49 | _remove_heterogens(fixer, alterations_info, keep_water=False)
50 | fixer.findMissingResidues()
51 | alterations_info['missing_residues'] = fixer.missingResidues
52 | fixer.findMissingAtoms()
53 | alterations_info['missing_heavy_atoms'] = fixer.missingAtoms
54 | alterations_info['missing_terminals'] = fixer.missingTerminals
55 | fixer.addMissingAtoms(seed=0)
56 | fixer.addMissingHydrogens()
57 | out_handle = io.StringIO()
58 | app.PDBFile.writeFile(fixer.topology, fixer.positions, out_handle,
59 | keepIds=True)
60 | return out_handle.getvalue()
61 |
62 |
63 | def clean_structure(pdb_structure, alterations_info):
64 | """Applies additional fixes to an OpenMM structure, to handle edge cases.
65 |
66 | Args:
67 | pdb_structure: An OpenMM structure to modify and fix.
68 | alterations_info: A dict that will store details of changes made.
69 | """
70 | _replace_met_se(pdb_structure, alterations_info)
71 | _remove_chains_of_length_one(pdb_structure, alterations_info)
72 |
73 |
74 | def _remove_heterogens(fixer, alterations_info, keep_water):
75 | """Removes the residues that Pdbfixer considers to be heterogens.
76 |
77 | Args:
78 | fixer: A Pdbfixer instance.
79 | alterations_info: A dict that will store details of changes made.
80 | keep_water: If True, water (HOH) is not considered to be a heterogen.
81 | """
82 | initial_resnames = set()
83 | for chain in fixer.topology.chains():
84 | for residue in chain.residues():
85 | initial_resnames.add(residue.name)
86 | fixer.removeHeterogens(keepWater=keep_water)
87 | final_resnames = set()
88 | for chain in fixer.topology.chains():
89 | for residue in chain.residues():
90 | final_resnames.add(residue.name)
91 | alterations_info['removed_heterogens'] = (
92 | initial_resnames.difference(final_resnames))
93 |
94 |
95 | def _replace_met_se(pdb_structure, alterations_info):
96 | """Replace the Se in any MET residues that were not marked as modified."""
97 | modified_met_residues = []
98 | for res in pdb_structure.iter_residues():
99 | name = res.get_name_with_spaces().strip()
100 | if name == 'MET':
101 | s_atom = res.get_atom('SD')
102 | if s_atom.element_symbol == 'Se':
103 | s_atom.element_symbol = 'S'
104 | s_atom.element = element.get_by_symbol('S')
105 | modified_met_residues.append(s_atom.residue_number)
106 | alterations_info['Se_in_MET'] = modified_met_residues
107 |
108 |
109 | def _remove_chains_of_length_one(pdb_structure, alterations_info):
110 | """Removes chains that correspond to a single amino acid.
111 |
112 | A single amino acid in a chain is both N and C terminus. There is no force
113 | template for this case.
114 |
115 | Args:
116 | pdb_structure: An OpenMM pdb_structure to modify and fix.
117 | alterations_info: A dict that will store details of changes made.
118 | """
119 | removed_chains = {}
120 | for model in pdb_structure.iter_models():
121 | valid_chains = [c for c in model.iter_chains() if len(c) > 1]
122 | invalid_chain_ids = [c.chain_id for c in model.iter_chains() if len(c) <= 1]
123 | model.chains = valid_chains
124 | for chain_id in invalid_chain_ids:
125 | model.chains_by_id.pop(chain_id)
126 | removed_chains[model.number] = invalid_chain_ids
127 | alterations_info['removed_chains'] = removed_chains
128 |
--------------------------------------------------------------------------------
/src/alphafold/relax/cleanup_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for relax.cleanup."""
16 | import io
17 |
18 | from absl.testing import absltest
19 | from alphafold.relax import cleanup
20 | from simtk.openmm.app.internal import pdbstructure
21 |
22 |
23 | def _pdb_to_structure(pdb_str):
24 | handle = io.StringIO(pdb_str)
25 | return pdbstructure.PdbStructure(handle)
26 |
27 |
28 | def _lines_to_structure(pdb_lines):
29 | return _pdb_to_structure('\n'.join(pdb_lines))
30 |
31 |
32 | class CleanupTest(absltest.TestCase):
33 |
34 | def test_missing_residues(self):
35 | pdb_lines = ['SEQRES 1 C 3 CYS GLY LEU',
36 | 'ATOM 1 N CYS C 1 -12.262 20.115 60.959 1.00 '
37 | '19.08 N',
38 | 'ATOM 2 CA CYS C 1 -11.065 20.934 60.773 1.00 '
39 | '17.23 C',
40 | 'ATOM 3 C CYS C 1 -10.002 20.742 61.844 1.00 '
41 | '15.38 C',
42 | 'ATOM 4 O CYS C 1 -10.284 20.225 62.929 1.00 '
43 | '16.04 O',
44 | 'ATOM 5 N LEU C 3 -7.688 18.700 62.045 1.00 '
45 | '14.75 N',
46 | 'ATOM 6 CA LEU C 3 -7.256 17.320 62.234 1.00 '
47 | '16.81 C',
48 | 'ATOM 7 C LEU C 3 -6.380 16.864 61.070 1.00 '
49 | '16.95 C',
50 | 'ATOM 8 O LEU C 3 -6.551 17.332 59.947 1.00 '
51 | '16.97 O']
52 | input_handle = io.StringIO('\n'.join(pdb_lines))
53 | alterations = {}
54 | result = cleanup.fix_pdb(input_handle, alterations)
55 | structure = _pdb_to_structure(result)
56 | residue_names = [r.get_name() for r in structure.iter_residues()]
57 | self.assertCountEqual(residue_names, ['CYS', 'GLY', 'LEU'])
58 | self.assertCountEqual(alterations['missing_residues'].values(), [['GLY']])
59 |
60 | def test_missing_atoms(self):
61 | pdb_lines = ['SEQRES 1 A 1 PRO',
62 | 'ATOM 1 CA PRO A 1 1.000 1.000 1.000 1.00 '
63 | ' 0.00 C']
64 | input_handle = io.StringIO('\n'.join(pdb_lines))
65 | alterations = {}
66 | result = cleanup.fix_pdb(input_handle, alterations)
67 | structure = _pdb_to_structure(result)
68 | atom_names = [a.get_name() for a in structure.iter_atoms()]
69 | self.assertCountEqual(atom_names, ['N', 'CD', 'HD2', 'HD3', 'CG', 'HG2',
70 | 'HG3', 'CB', 'HB2', 'HB3', 'CA', 'HA',
71 | 'C', 'O', 'H2', 'H3', 'OXT'])
72 | missing_atoms_by_residue = list(alterations['missing_heavy_atoms'].values())
73 | self.assertLen(missing_atoms_by_residue, 1)
74 | atoms_added = [a.name for a in missing_atoms_by_residue[0]]
75 | self.assertCountEqual(atoms_added, ['N', 'CD', 'CG', 'CB', 'C', 'O'])
76 | missing_terminals_by_residue = alterations['missing_terminals']
77 | self.assertLen(missing_terminals_by_residue, 1)
78 | has_missing_terminal = [r.name for r in missing_terminals_by_residue.keys()]
79 | self.assertCountEqual(has_missing_terminal, ['PRO'])
80 | self.assertCountEqual([t for t in missing_terminals_by_residue.values()],
81 | [['OXT']])
82 |
83 | def test_remove_heterogens(self):
84 | pdb_lines = ['SEQRES 1 A 1 GLY',
85 | 'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 '
86 | ' 0.00 C',
87 | 'ATOM 2 O HOH A 2 0.000 0.000 0.000 1.00 '
88 | ' 0.00 O']
89 | input_handle = io.StringIO('\n'.join(pdb_lines))
90 | alterations = {}
91 | result = cleanup.fix_pdb(input_handle, alterations)
92 | structure = _pdb_to_structure(result)
93 | self.assertCountEqual([res.get_name() for res in structure.iter_residues()],
94 | ['GLY'])
95 | self.assertEqual(alterations['removed_heterogens'], set(['HOH']))
96 |
97 | def test_fix_nonstandard_residues(self):
98 | pdb_lines = ['SEQRES 1 A 1 DAL',
99 | 'ATOM 1 CA DAL A 1 0.000 0.000 0.000 1.00 '
100 | ' 0.00 C']
101 | input_handle = io.StringIO('\n'.join(pdb_lines))
102 | alterations = {}
103 | result = cleanup.fix_pdb(input_handle, alterations)
104 | structure = _pdb_to_structure(result)
105 | residue_names = [res.get_name() for res in structure.iter_residues()]
106 | self.assertCountEqual(residue_names, ['ALA'])
107 | self.assertLen(alterations['nonstandard_residues'], 1)
108 | original_res, new_name = alterations['nonstandard_residues'][0]
109 | self.assertEqual(original_res.id, '1')
110 | self.assertEqual(new_name, 'ALA')
111 |
112 | def test_replace_met_se(self):
113 | pdb_lines = ['SEQRES 1 A 1 MET',
114 | 'ATOM 1 SD MET A 1 0.000 0.000 0.000 1.00 '
115 | ' 0.00 Se']
116 | structure = _lines_to_structure(pdb_lines)
117 | alterations = {}
118 | cleanup._replace_met_se(structure, alterations)
119 | sd = [a for a in structure.iter_atoms() if a.get_name() == 'SD']
120 | self.assertLen(sd, 1)
121 | self.assertEqual(sd[0].element_symbol, 'S')
122 | self.assertCountEqual(alterations['Se_in_MET'], [sd[0].residue_number])
123 |
124 | def test_remove_chains_of_length_one(self):
125 | pdb_lines = ['SEQRES 1 A 1 GLY',
126 | 'ATOM 1 CA GLY A 1 0.000 0.000 0.000 1.00 '
127 | ' 0.00 C']
128 | structure = _lines_to_structure(pdb_lines)
129 | alterations = {}
130 | cleanup._remove_chains_of_length_one(structure, alterations)
131 | chains = list(structure.iter_chains())
132 | self.assertEmpty(chains)
133 | self.assertCountEqual(alterations['removed_chains'].values(), [['A']])
134 |
135 |
136 | if __name__ == '__main__':
137 | absltest.main()
138 |
--------------------------------------------------------------------------------
/src/alphafold/relax/relax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Amber relaxation."""
16 | from typing import Any, Dict, Sequence, Tuple
17 | from alphafold.common import protein
18 | from alphafold.relax import amber_minimize
19 | from alphafold.relax import utils
20 | import numpy as np
21 |
22 |
23 | class AmberRelaxation(object):
24 | """Amber relaxation."""
25 |
26 | def __init__(self,
27 | *,
28 | max_iterations: int,
29 | tolerance: float,
30 | stiffness: float,
31 | exclude_residues: Sequence[int],
32 | max_outer_iterations: int,
33 | use_gpu: bool):
34 | """Initialize Amber Relaxer.
35 |
36 | Args:
37 | max_iterations: Maximum number of L-BFGS iterations. 0 means no max.
38 | tolerance: kcal/mol, the energy tolerance of L-BFGS.
39 | stiffness: kcal/mol A**2, spring constant of heavy atom restraining
40 | potential.
41 | exclude_residues: Residues to exclude from per-atom restraining.
42 | Zero-indexed.
43 | max_outer_iterations: Maximum number of violation-informed relax
44 | iterations. A value of 1 will run the non-iterative procedure used in
45 | CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes
46 | as soon as there are no violations, hence in most cases this causes no
47 | slowdown. In the worst case we do 20 outer iterations.
48 | use_gpu: Whether to run on GPU.
49 | """
50 |
51 | self._max_iterations = max_iterations
52 | self._tolerance = tolerance
53 | self._stiffness = stiffness
54 | self._exclude_residues = exclude_residues
55 | self._max_outer_iterations = max_outer_iterations
56 | self._use_gpu = use_gpu
57 |
58 | def process(self, *,
59 | prot: protein.Protein
60 | ) -> Tuple[str, Dict[str, Any], Sequence[float]]:
61 | """Runs Amber relax on a prediction, adds hydrogens, returns PDB string."""
62 | out = amber_minimize.run_pipeline(
63 | prot=prot, max_iterations=self._max_iterations,
64 | tolerance=self._tolerance, stiffness=self._stiffness,
65 | exclude_residues=self._exclude_residues,
66 | max_outer_iterations=self._max_outer_iterations,
67 | use_gpu=self._use_gpu)
68 | min_pos = out['pos']
69 | start_pos = out['posinit']
70 | rmsd = np.sqrt(np.sum((start_pos - min_pos)**2) / start_pos.shape[0])
71 | debug_data = {
72 | 'initial_energy': out['einit'],
73 | 'final_energy': out['efinal'],
74 | 'attempts': out['min_attempts'],
75 | 'rmsd': rmsd
76 | }
77 | min_pdb = out['min_pdb']
78 | min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors)
79 | utils.assert_equal_nonterminal_atom_types(
80 | protein.from_pdb_string(min_pdb).atom_mask,
81 | prot.atom_mask)
82 | violations = out['structural_violations'][
83 | 'total_per_residue_violations_mask'].tolist()
84 | return min_pdb, debug_data, violations
85 |
--------------------------------------------------------------------------------
/src/alphafold/relax/relax_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for relax."""
16 | import os
17 |
18 | from absl.testing import absltest
19 | from alphafold.common import protein
20 | from alphafold.relax import relax
21 | import numpy as np
22 | # Internal import (7716).
23 |
24 |
25 | class RunAmberRelaxTest(absltest.TestCase):
26 |
27 | def setUp(self):
28 | super().setUp()
29 | self.test_dir = os.path.join(
30 | absltest.get_default_test_srcdir(),
31 | 'alphafold/relax/testdata/')
32 | self.test_config = {
33 | 'max_iterations': 1,
34 | 'tolerance': 2.39,
35 | 'stiffness': 10.0,
36 | 'exclude_residues': [],
37 | 'max_outer_iterations': 1,
38 | 'use_gpu': False}
39 |
40 | def test_process(self):
41 | amber_relax = relax.AmberRelaxation(**self.test_config)
42 |
43 | with open(os.path.join(self.test_dir, 'model_output.pdb')) as f:
44 | test_prot = protein.from_pdb_string(f.read())
45 | pdb_min, debug_info, num_violations = amber_relax.process(prot=test_prot)
46 |
47 | self.assertCountEqual(debug_info.keys(),
48 | set({'initial_energy', 'final_energy',
49 | 'attempts', 'rmsd'}))
50 | self.assertLess(debug_info['final_energy'], debug_info['initial_energy'])
51 | self.assertGreater(debug_info['rmsd'], 0)
52 |
53 | prot_min = protein.from_pdb_string(pdb_min)
54 | # Most protein properties should be unchanged.
55 | np.testing.assert_almost_equal(test_prot.aatype, prot_min.aatype)
56 | np.testing.assert_almost_equal(test_prot.residue_index,
57 | prot_min.residue_index)
58 | # Atom mask and bfactors identical except for terminal OXT of last residue.
59 | np.testing.assert_almost_equal(test_prot.atom_mask[:-1, :],
60 | prot_min.atom_mask[:-1, :])
61 | np.testing.assert_almost_equal(test_prot.b_factors[:-1, :],
62 | prot_min.b_factors[:-1, :])
63 | np.testing.assert_almost_equal(test_prot.atom_mask[:, :-1],
64 | prot_min.atom_mask[:, :-1])
65 | np.testing.assert_almost_equal(test_prot.b_factors[:, :-1],
66 | prot_min.b_factors[:, :-1])
67 | # There are no residues with violations.
68 | np.testing.assert_equal(num_violations, np.zeros_like(num_violations))
69 |
70 | def test_unresolved_violations(self):
71 | amber_relax = relax.AmberRelaxation(**self.test_config)
72 | with open(os.path.join(self.test_dir,
73 | 'with_violations_casp14.pdb')) as f:
74 | test_prot = protein.from_pdb_string(f.read())
75 | _, _, num_violations = amber_relax.process(prot=test_prot)
76 | exp_num_violations = np.array(
77 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
78 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1,
79 | 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
80 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
81 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
82 | 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
83 | 0, 0, 0, 0])
84 | # Check no violations were added. Can't check exactly due to stochasticity.
85 | self.assertTrue(np.all(np.array(num_violations) <= exp_num_violations))
86 |
87 |
88 | if __name__ == '__main__':
89 | absltest.main()
90 |
--------------------------------------------------------------------------------
/src/alphafold/relax/testdata/model_output.pdb:
--------------------------------------------------------------------------------
1 | ATOM 1 C MET A 1 1.921 -46.152 7.786 1.00 4.39 C
2 | ATOM 2 CA MET A 1 1.631 -46.829 9.131 1.00 4.39 C
3 | ATOM 3 CB MET A 1 2.759 -47.768 9.578 1.00 4.39 C
4 | ATOM 4 CE MET A 1 3.466 -49.770 13.198 1.00 4.39 C
5 | ATOM 5 CG MET A 1 2.581 -48.221 11.034 1.00 4.39 C
6 | ATOM 6 H MET A 1 0.234 -48.249 8.549 1.00 4.39 H
7 | ATOM 7 H2 MET A 1 -0.424 -46.789 8.952 1.00 4.39 H
8 | ATOM 8 H3 MET A 1 0.111 -47.796 10.118 1.00 4.39 H
9 | ATOM 9 HA MET A 1 1.628 -46.009 9.849 1.00 4.39 H
10 | ATOM 10 HB2 MET A 1 3.701 -47.225 9.500 1.00 4.39 H
11 | ATOM 11 HB3 MET A 1 2.807 -48.640 8.926 1.00 4.39 H
12 | ATOM 12 HE1 MET A 1 2.747 -50.537 12.910 1.00 4.39 H
13 | ATOM 13 HE2 MET A 1 4.296 -50.241 13.725 1.00 4.39 H
14 | ATOM 14 HE3 MET A 1 2.988 -49.052 13.864 1.00 4.39 H
15 | ATOM 15 HG2 MET A 1 1.791 -48.971 11.083 1.00 4.39 H
16 | ATOM 16 HG3 MET A 1 2.295 -47.368 11.650 1.00 4.39 H
17 | ATOM 17 N MET A 1 0.291 -47.464 9.182 1.00 4.39 N
18 | ATOM 18 O MET A 1 2.091 -44.945 7.799 1.00 4.39 O
19 | ATOM 19 SD MET A 1 4.096 -48.921 11.725 1.00 4.39 S
20 | ATOM 20 C LYS A 2 1.366 -45.033 4.898 1.00 2.92 C
21 | ATOM 21 CA LYS A 2 2.235 -46.242 5.308 1.00 2.92 C
22 | ATOM 22 CB LYS A 2 2.206 -47.314 4.196 1.00 2.92 C
23 | ATOM 23 CD LYS A 2 3.331 -49.342 3.134 1.00 2.92 C
24 | ATOM 24 CE LYS A 2 4.434 -50.403 3.293 1.00 2.92 C
25 | ATOM 25 CG LYS A 2 3.294 -48.395 4.349 1.00 2.92 C
26 | ATOM 26 H LYS A 2 1.832 -47.853 6.656 1.00 2.92 H
27 | ATOM 27 HA LYS A 2 3.248 -45.841 5.355 1.00 2.92 H
28 | ATOM 28 HB2 LYS A 2 1.223 -47.785 4.167 1.00 2.92 H
29 | ATOM 29 HB3 LYS A 2 2.363 -46.812 3.241 1.00 2.92 H
30 | ATOM 30 HD2 LYS A 2 3.524 -48.754 2.237 1.00 2.92 H
31 | ATOM 31 HD3 LYS A 2 2.364 -49.833 3.031 1.00 2.92 H
32 | ATOM 32 HE2 LYS A 2 5.383 -49.891 3.455 1.00 2.92 H
33 | ATOM 33 HE3 LYS A 2 4.225 -51.000 4.180 1.00 2.92 H
34 | ATOM 34 HG2 LYS A 2 3.102 -48.977 5.250 1.00 2.92 H
35 | ATOM 35 HG3 LYS A 2 4.264 -47.909 4.446 1.00 2.92 H
36 | ATOM 36 HZ1 LYS A 2 4.763 -50.747 1.274 1.00 2.92 H
37 | ATOM 37 HZ2 LYS A 2 3.681 -51.785 1.931 1.00 2.92 H
38 | ATOM 38 HZ3 LYS A 2 5.280 -51.965 2.224 1.00 2.92 H
39 | ATOM 39 N LYS A 2 1.907 -46.846 6.629 1.00 2.92 N
40 | ATOM 40 NZ LYS A 2 4.542 -51.286 2.100 1.00 2.92 N
41 | ATOM 41 O LYS A 2 1.882 -44.093 4.312 1.00 2.92 O
42 | ATOM 42 C PHE A 3 -0.511 -42.597 5.624 1.00 4.39 C
43 | ATOM 43 CA PHE A 3 -0.853 -43.933 4.929 1.00 4.39 C
44 | ATOM 44 CB PHE A 3 -2.271 -44.408 5.285 1.00 4.39 C
45 | ATOM 45 CD1 PHE A 3 -3.760 -43.542 3.432 1.00 4.39 C
46 | ATOM 46 CD2 PHE A 3 -4.050 -42.638 5.675 1.00 4.39 C
47 | ATOM 47 CE1 PHE A 3 -4.797 -42.715 2.965 1.00 4.39 C
48 | ATOM 48 CE2 PHE A 3 -5.091 -41.818 5.207 1.00 4.39 C
49 | ATOM 49 CG PHE A 3 -3.382 -43.505 4.788 1.00 4.39 C
50 | ATOM 50 CZ PHE A 3 -5.463 -41.853 3.853 1.00 4.39 C
51 | ATOM 51 H PHE A 3 -0.311 -45.868 5.655 1.00 4.39 H
52 | ATOM 52 HA PHE A 3 -0.817 -43.746 3.856 1.00 4.39 H
53 | ATOM 53 HB2 PHE A 3 -2.353 -44.512 6.367 1.00 4.39 H
54 | ATOM 54 HB3 PHE A 3 -2.432 -45.393 4.848 1.00 4.39 H
55 | ATOM 55 HD1 PHE A 3 -3.255 -44.198 2.739 1.00 4.39 H
56 | ATOM 56 HD2 PHE A 3 -3.768 -42.590 6.716 1.00 4.39 H
57 | ATOM 57 HE1 PHE A 3 -5.083 -42.735 1.923 1.00 4.39 H
58 | ATOM 58 HE2 PHE A 3 -5.604 -41.151 5.885 1.00 4.39 H
59 | ATOM 59 HZ PHE A 3 -6.257 -41.215 3.493 1.00 4.39 H
60 | ATOM 60 N PHE A 3 0.079 -45.027 5.253 1.00 4.39 N
61 | ATOM 61 O PHE A 3 -0.633 -41.541 5.014 1.00 4.39 O
62 | ATOM 62 C LEU A 4 1.598 -40.732 7.042 1.00 4.39 C
63 | ATOM 63 CA LEU A 4 0.367 -41.437 7.633 1.00 4.39 C
64 | ATOM 64 CB LEU A 4 0.628 -41.823 9.104 1.00 4.39 C
65 | ATOM 65 CD1 LEU A 4 -0.319 -42.778 11.228 1.00 4.39 C
66 | ATOM 66 CD2 LEU A 4 -1.300 -40.694 10.309 1.00 4.39 C
67 | ATOM 67 CG LEU A 4 -0.650 -42.027 9.937 1.00 4.39 C
68 | ATOM 68 H LEU A 4 0.163 -43.538 7.292 1.00 4.39 H
69 | ATOM 69 HA LEU A 4 -0.445 -40.712 7.588 1.00 4.39 H
70 | ATOM 70 HB2 LEU A 4 1.213 -41.034 9.576 1.00 4.39 H
71 | ATOM 71 HB3 LEU A 4 1.235 -42.728 9.127 1.00 4.39 H
72 | ATOM 72 HD11 LEU A 4 0.380 -42.191 11.824 1.00 4.39 H
73 | ATOM 73 HD12 LEU A 4 0.127 -43.747 11.002 1.00 4.39 H
74 | ATOM 74 HD13 LEU A 4 -1.230 -42.927 11.808 1.00 4.39 H
75 | ATOM 75 HD21 LEU A 4 -0.606 -40.080 10.883 1.00 4.39 H
76 | ATOM 76 HD22 LEU A 4 -2.193 -40.869 10.909 1.00 4.39 H
77 | ATOM 77 HD23 LEU A 4 -1.593 -40.147 9.413 1.00 4.39 H
78 | ATOM 78 HG LEU A 4 -1.359 -42.630 9.370 1.00 4.39 H
79 | ATOM 79 N LEU A 4 -0.012 -42.638 6.869 1.00 4.39 N
80 | ATOM 80 O LEU A 4 1.655 -39.508 7.028 1.00 4.39 O
81 | ATOM 81 C VAL A 5 3.372 -40.190 4.573 1.00 4.39 C
82 | ATOM 82 CA VAL A 5 3.752 -40.956 5.845 1.00 4.39 C
83 | ATOM 83 CB VAL A 5 4.757 -42.083 5.528 1.00 4.39 C
84 | ATOM 84 CG1 VAL A 5 6.019 -41.568 4.827 1.00 4.39 C
85 | ATOM 85 CG2 VAL A 5 5.199 -42.807 6.810 1.00 4.39 C
86 | ATOM 86 H VAL A 5 2.440 -42.503 6.548 1.00 4.39 H
87 | ATOM 87 HA VAL A 5 4.234 -40.242 6.512 1.00 4.39 H
88 | ATOM 88 HB VAL A 5 4.279 -42.813 4.875 1.00 4.39 H
89 | ATOM 89 HG11 VAL A 5 6.494 -40.795 5.431 1.00 4.39 H
90 | ATOM 90 HG12 VAL A 5 5.770 -41.145 3.853 1.00 4.39 H
91 | ATOM 91 HG13 VAL A 5 6.725 -42.383 4.670 1.00 4.39 H
92 | ATOM 92 HG21 VAL A 5 4.347 -43.283 7.297 1.00 4.39 H
93 | ATOM 93 HG22 VAL A 5 5.933 -43.575 6.568 1.00 4.39 H
94 | ATOM 94 HG23 VAL A 5 5.651 -42.093 7.498 1.00 4.39 H
95 | ATOM 95 N VAL A 5 2.554 -41.501 6.509 1.00 4.39 N
96 | ATOM 96 O VAL A 5 3.937 -39.138 4.297 1.00 4.39 O
97 | TER 96 VAL A 5
98 | END
99 |
--------------------------------------------------------------------------------
/src/alphafold/relax/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utils for minimization."""
16 | import io
17 | from alphafold.common import residue_constants
18 | from Bio import PDB
19 | import numpy as np
20 |
21 |
22 | def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str:
23 | """Overwrites the B-factors in pdb_str with contents of bfactors array.
24 |
25 | Args:
26 | pdb_str: An input PDB string.
27 | bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the
28 | B-factors are per residue; i.e. that the nonzero entries are identical in
29 | [0, i, :].
30 |
31 | Returns:
32 | A new PDB string with the B-factors replaced.
33 | """
34 | if bfactors.shape[-1] != residue_constants.atom_type_num:
35 | raise ValueError(
36 | f'Invalid final dimension size for bfactors: {bfactors.shape[-1]}.')
37 |
38 | parser = PDB.PDBParser(QUIET=True)
39 | handle = io.StringIO(pdb_str)
40 | structure = parser.get_structure('', handle)
41 |
42 | curr_resid = ('', '', '')
43 | idx = -1
44 | for atom in structure.get_atoms():
45 | atom_resid = atom.parent.get_id()
46 | if atom_resid != curr_resid:
47 | idx += 1
48 | if idx >= bfactors.shape[0]:
49 | raise ValueError('Index into bfactors exceeds number of residues. '
50 | 'B-factors shape: {shape}, idx: {idx}.')
51 | curr_resid = atom_resid
52 | atom.bfactor = bfactors[idx, residue_constants.atom_order['CA']]
53 |
54 | new_pdb = io.StringIO()
55 | pdb_io = PDB.PDBIO()
56 | pdb_io.set_structure(structure)
57 | pdb_io.save(new_pdb)
58 | return new_pdb.getvalue()
59 |
60 |
61 | def assert_equal_nonterminal_atom_types(
62 | atom_mask: np.ndarray, ref_atom_mask: np.ndarray):
63 | """Checks that pre- and post-minimized proteins have same atom set."""
64 | # Ignore any terminal OXT atoms which may have been added by minimization.
65 | oxt = residue_constants.atom_order['OXT']
66 | no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool)
67 | no_oxt_mask[..., oxt] = False
68 | np.testing.assert_almost_equal(ref_atom_mask[no_oxt_mask],
69 | atom_mask[no_oxt_mask])
70 |
--------------------------------------------------------------------------------
/src/alphafold/relax/utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for utils."""
16 |
17 | import os
18 |
19 | from absl.testing import absltest
20 | from alphafold.common import protein
21 | from alphafold.relax import utils
22 | import numpy as np
23 | # Internal import (7716).
24 |
25 |
26 | class UtilsTest(absltest.TestCase):
27 |
28 | def test_overwrite_b_factors(self):
29 | testdir = os.path.join(
30 | absltest.get_default_test_srcdir(),
31 | 'alphafold/relax/testdata/'
32 | 'multiple_disulfides_target.pdb')
33 | with open(testdir) as f:
34 | test_pdb = f.read()
35 | n_residues = 191
36 | bfactors = np.stack([np.arange(0, n_residues)] * 37, axis=-1)
37 |
38 | output_pdb = utils.overwrite_b_factors(test_pdb, bfactors)
39 |
40 | # Check that the atom lines are unchanged apart from the B-factors.
41 | atom_lines_original = [l for l in test_pdb.split('\n') if l[:4] == ('ATOM')]
42 | atom_lines_new = [l for l in output_pdb.split('\n') if l[:4] == ('ATOM')]
43 | for line_original, line_new in zip(atom_lines_original, atom_lines_new):
44 | self.assertEqual(line_original[:60].strip(), line_new[:60].strip())
45 | self.assertEqual(line_original[66:].strip(), line_new[66:].strip())
46 |
47 | # Check B-factors are correctly set for all atoms present.
48 | as_protein = protein.from_pdb_string(output_pdb)
49 | np.testing.assert_almost_equal(
50 | np.where(as_protein.atom_mask > 0, as_protein.b_factors, 0),
51 | np.where(as_protein.atom_mask > 0, bfactors, 0))
52 |
53 |
54 | if __name__ == '__main__':
55 | absltest.main()
56 |
--------------------------------------------------------------------------------
/src/generate_msas.sh:
--------------------------------------------------------------------------------
1 | """
2 | Script for generating input features with AlphaFold-multimer
3 |
4 | Fill in all the variables below to run the feature generation.
5 | The reduced database version is used here.
6 |
7 | This script assumes that all python packages necessary are in the current path.
8 | """
9 |
10 | #Get ID
11 | ID=$1
12 | echo $ID
13 | #Generate input MSAs and templates for AFM
14 | FASTA_DIR=$2
15 | FASTA_PATHS=$FASTA_DIR/$ID'.fasta'
16 | ls $FASTA_PATHS
17 | OUTDIR=$3
18 | #Genetic search
19 | JACKHMMER=./hmmer-3.3.2/src/jackhmmer
20 | HHBLITS=./hh-suite/build/bin/hhblits
21 | HHSEARCH=./hh-suite/build/bin/hhsearch
22 | HMMSEARCH=./hmmer-3.3.2/src/hmmsearch
23 | HMMBUILD=./hmmer-3.3.2/src/hmmbuild
24 | KALIGN=./kalign-3.3.2/src/kalign
25 | #Dbs
26 | DATADIR=$10
27 | UNIREF90=$DATADIR/uniref90/uniref90.fasta
28 | MGNIFY=$DATADIR/mgnify/mgy_clusters_2022_05.fa
29 | #BFD=$DATADIR/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt
30 | SMALL_BFD=$DATADIR/small_bfd/bfd-first_non_consensus_sequences.fasta
31 | UNIREF30=$DATADIR/uniref30/UniRef30_2021_03
32 | UNIPROT=$DATADIR/uniprot/uniprot.fasta
33 | PDB70=$DATADIR/pdb70_from_mmcif_220313
34 | PDBSEQRES=$DATADIR/pdb_seqres/pdb_seqres.txt
35 | MMCIFDIR=$DATADIR/pdb_mmcif/
36 | #Settings
37 | DB_PRESET='reduced_dbs'
38 | MODEL_PRESET='multimer'
39 | MAX_DATE='2034-01-01' #Max template date
40 | OBS_PDBS=./obsolete.txt
41 | AFDIR=./
42 | #Run
43 | python3 $AFDIR/run_alphafold_msa_template_only.py --fasta_paths=$FASTA_PATHS \
44 | --output_dir=$OUTDIR --jackhmmer_binary_path=$JACKHMMER \
45 | --hhblits_binary_path=$HHBLITS --hhsearch_binary_path=$HHSEARCH \
46 | --hmmsearch_binary_path=$HMMSEARCH --hmmbuild_binary_path=$HMMBUILD \
47 | --kalign_binary_path=$KALIGN --uniref90_database_path=$UNIREF90 \
48 | --mgnify_database_path=$MGNIFY --small_bfd_database_path=$SMALL_BFD \
49 | --uniprot_database_path=$UNIPROT \
50 | --pdb_seqres_database_path=$PDBSEQRES \
51 | --template_mmcif_dir=$MMCIFDIR --db_preset=$DB_PRESET --model_preset=$MODEL_PRESET \
52 | --max_template_date=$MAX_DATE --obsolete_pdbs_path=$OBS_PDBS
53 |
--------------------------------------------------------------------------------
/src/obsolete.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/patrickbryant1/AFProfile/a6aaaca40c8598b70c0f8f2768edae7e1847792e/src/obsolete.txt
--------------------------------------------------------------------------------