├── Dockerfile ├── LICENSE ├── README.md ├── alphaflow ├── config.py ├── data │ ├── data_modules.py │ ├── data_pipeline.py │ ├── feature_pipeline.py │ ├── inference.py │ └── input_pipeline.py ├── model │ ├── __init__.py │ ├── alphafold.py │ ├── esmfold.py │ ├── input_stack.py │ ├── layers.py │ ├── tri_self_attn_block.py │ ├── trunk.py │ └── wrapper.py └── utils │ ├── diffusion.py │ ├── logging.py │ ├── loss.py │ ├── misc.py │ ├── parsing.py │ ├── protein.py │ └── tensor_utils.py ├── assets ├── 12l_md_templates.md └── 6uof_A_animation.gif ├── predict.py ├── scripts ├── add_msa_info.py ├── analyze_ensembles.py ├── cluster_chains.py ├── download_atlas.sh ├── mmseqs_query.py ├── mmseqs_search.py ├── mmseqs_search_helper.py ├── prep_atlas.py ├── print_analysis.py └── unpack_mmcif.py ├── splits ├── atlas.csv ├── atlas_test.csv ├── atlas_train.csv ├── atlas_val.csv ├── cameo2022.csv ├── pdb_test.csv └── pdb_test.json └── train.py /Dockerfile: -------------------------------------------------------------------------------- 1 | # Original Copyright 2021 DeepMind Technologies Limited 2 | # Modification Copyright 2022 # Copyright 2021 AlQuraishi Laboratory 3 | # Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | # This container build on the OpenFold container, and installs AlphaFlow. 7 | # At the end, you may wish to download the weights to run ESMFlow, so they are cached in the image. 8 | # It is a large image, about 20GB without weights, 25GB with weights. 9 | # 10 | # OpenFold is quite difficult to get working, as it installs custom torch kernels, so it is used as the base. 11 | # Adapted from https://github.com/aws-solutions-library-samples/aws-batch-arch-for-protein-folding/blob/main/infrastructure/docker/openfold/Dockerfile 12 | # 13 | # To run most recent image (after building), with GPUS, and mounting a directory `outputs` 14 | # docker run --gpus all -v "$(pwd)/outputs:/outputs" -it "$(docker image ls -q | head -n1)" bash 15 | # 16 | # Note that you may need to install nvidia-container-toolkit to run. 17 | # https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/1.15.0/install-guide.html 18 | # 19 | # Test command to output into mounted directory: 20 | # python predict.py --mode esmfold --input_csv splits/atlas_test.csv --pdb 6o2v_A --weights params/esmflow_md_base_202402.pt --samples 5 --outpdb /outputs 21 | 22 | FROM nvcr.io/nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu18.04 23 | 24 | RUN apt-key del 7fa2af80 25 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub 26 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub 27 | 28 | RUN apt-get update \ 29 | && DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ 30 | wget \ 31 | libxml2 \ 32 | cuda-minimal-build-11-3 \ 33 | libcusparse-dev-11-3 \ 34 | libcublas-dev-11-3 \ 35 | libcusolver-dev-11-3 \ 36 | git \ 37 | awscli \ 38 | && rm -rf /var/lib/apt/lists/* \ 39 | && apt-get autoremove -y \ 40 | && apt-get clean 41 | 42 | RUN wget -q -P /tmp -O /tmp/miniconda.sh \ 43 | "https://repo.anaconda.com/miniconda/Miniconda3-py39_23.5.2-0-Linux-$(uname -m).sh" \ 44 | && bash /tmp/miniconda.sh -b -p /opt/conda \ 45 | && rm /tmp/miniconda.sh 46 | 47 | ENV PATH /opt/conda/bin:$PATH 48 | 49 | RUN git clone https://github.com/aqlaboratory/openfold.git /opt/openfold \ 50 | && cd /opt/openfold \ 51 | && git checkout 1d878a1203e6d662a209a95f71b90083d5fc079c 52 | 53 | # installing into the base environment since the docker container wont do anything other than run openfold and alphaflow 54 | # RUN conda install -qy conda==4.13.0 \ 55 | RUN conda env update -n base --file /opt/openfold/environment.yml \ 56 | && conda clean --all --force-pkgs-dirs --yes 57 | 58 | RUN wget -q -P /opt/openfold/openfold/resources \ 59 | https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt 60 | 61 | RUN patch -p0 -d /opt/conda/lib/python3.9/site-packages/ < /opt/openfold/lib/openmm.patch 62 | 63 | # Install OpenFold 64 | RUN cd /opt/openfold \ 65 | && pip3 install --upgrade pip --no-cache-dir \ 66 | && python3 setup.py install 67 | 68 | # Install alphaflow 69 | RUN git clone https://github.com/bjing2016/alphaflow.git /opt/alphaflow 70 | 71 | # Install alphaflow packages ~ as defined in README 72 | # torch CUDA version should match your machine 73 | RUN python -m pip install numpy==1.21.2 pandas==1.5.3 && \ 74 | python -m pip install torch==1.13.1+cu117 -f https://download.pytorch.org/whl/torch_stable.html && \ 75 | python -m pip install biopython==1.79 dm-tree==0.1.6 modelcif==0.7 ml-collections==0.1.0 scipy==1.7.3 absl-py einops && \ 76 | python -m pip install pytorch_lightning==2.0.4 fair-esm mdtraj wandb 77 | 78 | WORKDIR /opt/alphaflow 79 | 80 | # Optionally, download weights as part of the image, so the cached image contains them and we don't re-download each time 81 | # ESMFlow and ESM2 82 | # RUN mkdir params && \ 83 | # aws s3 cp s3://alphaflow/params/esmflow_md_base_202402.pt params/esmflow_pdb_md_202402.pt && \ 84 | # mkdir -p /root/.cache/torch/hub/checkpoints && \ 85 | # wget -q -O /root/.cache/torch/hub/checkpoints/esm2_t36_3B_UR50D.pt https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t36_3B_UR50D.pt && \ 86 | # wget -q -O /root/.cache/torch/hub/checkpoints/esm2_t36_3B_UR50D-contact-regression.pt https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t36_3B_UR50D-contact-regression.pt 87 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Bowen Jing, Bonnie Berger, Tommi Jaakkola 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AlphaFlow 2 | 3 | AlphaFlow is a modified version of AlphaFold, fine-tuned with a flow matching objective, designed for generative modeling of protein conformational ensembles. In particular, AlphaFlow aims to model: 4 | * Experimental ensembles, i.e, potential conformational states as they would be deposited in the PDB 5 | * Molecular dynamics ensembles at physiological temperatures 6 | 7 | We also provide a similarly fine-tuned version of ESMFold called ESMFlow. Technical details and thorough benchmarking results can be found in our paper, [AlphaFold Meets Flow Matching for Generating Protein Ensembles](https://arxiv.org/abs/2402.04845), by Bowen Jing, Bonnie Berger, Tommi Jaakkola. This repository contains all code, instructions and model weights necessary to run the method. If you have any questions, feel free to open an issue or reach out at bjing@mit.edu. 8 | 9 | **June 2024 update:** We have trained a 12-layer version of AlphaFlow-MD+Templates (base and distilled) which runs 2.5x times faster than the 48-layer version at a [small loss in performance](assets/12l_md_templates.md). We recommend considering this model if reference structures (PDB or AlphaFold) are available and runtime is of high priority. 10 | 11 |

12 | 13 |

14 | 15 | ## Table of Contents 16 | 1. [Installation](#Installation) 17 | 2. [Model weights](#Model-weights) 18 | 3. [Running inference](#Running-inference) 19 | 4. [Evaluation scripts](#Evaluation-scripts) 20 | 5. [Training](#Training) 21 | 6. [Ensembles](#Ensembles) 22 | 7. [License](#License) 23 | 8. [Citation](#Citation) 24 | 25 | 26 | ## Installation 27 | In an environment with Python 3.9 (for example, `conda create -n alphaflow python=3.9`), run: 28 | ``` 29 | pip install numpy==1.21.2 pandas==1.5.3 30 | pip install torch==1.12.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html 31 | pip install biopython==1.79 dm-tree==0.1.6 modelcif==0.7 ml-collections==0.1.0 scipy==1.7.1 absl-py einops 32 | pip install pytorch_lightning==2.0.4 fair-esm mdtraj==1.9.9 wandb 33 | pip install 'openfold @ git+https://github.com/aqlaboratory/openfold.git@103d037' 34 | ``` 35 | The OpenFold installation requires CUDA 11. If the system has the wrong version, you can install CUDA 11 in the Conda environment: 36 | ``` 37 | conda install nvidia/label/cuda-11.8.0::cuda 38 | conda install nvidia/label/cuda-11.8.0::cuda-cudart-dev 39 | conda install nvidia/label/cuda-11.8.0::libcusparse-dev 40 | conda install nvidia/label/cuda-11.8.0::libcusolver-dev 41 | conda install nvidia/label/cuda-11.8.0::libcublas-dev 42 | ln -s $CONDA_PREFIX/lib/libcudart_static.a $CONDA_PREFIX/lib/libcudart.a 43 | ``` 44 | Then install OpenFold: 45 | ``` 46 | CUDA_HOME=$CONDA_PREFIX pip install 'openfold @ git+https://github.com/aqlaboratory/openfold.git@103d037' 47 | ``` 48 | 49 | ## Model weights 50 | 51 | We provide several versions of AlphaFlow (and similarly named versions of ESMFlow). 52 | 53 | * **AlphaFlow-PDB**—trained on PDB structures to model experimental ensembles from X-ray crystallography or cryo-EM under different conditions 54 | * **AlphaFlow-MD**—trained on all-atom, explicit solvent MD trajectories at 300K 55 | * **AlphaFlow-MD+Templates**—trained to additionally take a PDB structure as input, and models the corresponding MD ensemble at 300K 56 | 57 | For all models, the **distilled** version runs significantly faster at the cost of some loss of accuracy (benchmarked in the paper). 58 | 59 | For AlphaFlow-MD+Templates, the **12l** versions have 12 instead of 48 Evoformer layers and run 2.5x times faster at a [small loss in performance](assets/12l_md_templates.md). 60 | 61 | ### AlphaFlow models 62 | | Model|Version|Weights| 63 | |:---|:--|:--| 64 | | AlphaFlow-PDB | base | https://storage.googleapis.com/alphaflow/params/alphaflow_pdb_base_202402.pt | 65 | | AlphaFlow-PDB | distilled | https://storage.googleapis.com/alphaflow/params/alphaflow_pdb_distilled_202402.pt | 66 | | AlphaFlow-MD | base | https://storage.googleapis.com/alphaflow/params/alphaflow_md_base_202402.pt | 67 | | AlphaFlow-MD | distilled | https://storage.googleapis.com/alphaflow/params/alphaflow_md_distilled_202402.pt | 68 | | AlphaFlow-MD+Templates | base | https://storage.googleapis.com/alphaflow/params/alphaflow_md_templates_base_202402.pt | 69 | | AlphaFlow-MD+Templates | distilled | https://storage.googleapis.com/alphaflow/params/alphaflow_md_templates_distilled_202402.pt | 70 | | AlphaFlow-MD+Templates | 12l-base | https://storage.googleapis.com/alphaflow/params/alphaflow_12l_md_templates_base_202406.pt | 71 | | AlphaFlow-MD+Templates | 12l-distilled | https://storage.googleapis.com/alphaflow/params/alphaflow_12l_md_templates_distilled_202406.pt | 72 | 73 | 74 | ### ESMFlow models 75 | | Model|Version|Weights| 76 | |:---|:--|:--| 77 | | ESMFlow-PDB | base | https://storage.googleapis.com/alphaflow/params/esmflow_pdb_base_202402.pt | 78 | | ESMFlow-PDB | distilled | https://storage.googleapis.com/alphaflow/params/esmflow_pdb_distilled_202402.pt | 79 | | ESMFlow-MD | base | https://storage.googleapis.com/alphaflow/params/esmflow_md_base_202402.pt | 80 | | ESMFlow-MD | distilled | https://storage.googleapis.com/alphaflow/params/esmflow_md_distilled_202402.pt | 81 | | ESMFlow-MD+Templates | base | https://storage.googleapis.com/alphaflow/params/esmflow_md_templates_base_202402.pt | 82 | | ESMFlow-MD+Templates | distilled | https://storage.googleapis.com/alphaflow/params/esmflow_md_templates_distilled_202402.pt | 83 | 84 | Training checkpoints (from which fine-tuning can be resumed) are available upon request; please reach out if you'd like to collaborate! 85 | 86 | ## Running inference 87 | 88 | ### Preparing input files 89 | 90 | 1. Prepare a input CSV with an `name` and `seqres` entry for each row. See `splits/atlas_test.csv` for examples. 91 | 2. If running an **AlphaFlow** model, prepare an **MSA directory** and place the alignments in `.a3m` format at the following paths: `{alignment_dir}/{name}/a3m/{name}.a3m`. If you don't have the MSAs, there are two ways to generate them: 92 | 1. Query the ColabFold server with `python -m scripts.mmseqs_query --split [PATH] --outdir [DIR]`. 93 | 2. Download UniRef30 and ColabDB according to https://github.com/sokrypton/ColabFold/blob/main/setup_databases.sh and run `python -m scripts.mmseqs_search_helper --split [PATH] --db_dir [DIR] --outdir [DIR]`. 94 | 3. If running an **MD+Templates** model, place the template PDB files into a templates directory with filenames matching the names in the input CSV. The PDB files should include only a single chain with no residue gaps. 95 | 96 | ### Running the model 97 | 98 | The basic command for running inference with **AlphaFlow** is: 99 | ``` 100 | python predict.py --mode alphafold --input_csv [PATH] --msa_dir [DIR] --weights [PATH] --samples [N] --outpdb [DIR] 101 | ``` 102 | If running the **PDB model**, we recommend appending `--self_cond --resample` for improved performance. 103 | 104 | The basic command for running inference with **ESMFlow** is 105 | ``` 106 | python predict.py --mode esmfold --input_csv [PATH] --weights [PATH] --samples [N] --outpdb [DIR] 107 | ``` 108 | Additional command line arguments for either model: 109 | * Use the `--pdb_id` argument to select (one or more) rows in the CSV. If no argument is specified, inference is run on all rows. 110 | * If running the **MD model with templates**, append `--templates_dir [DIR]`. 111 | * If running any **distilled** model, append the arguments `--noisy_first --no_diffusion`. 112 | * To truncate the inference process for increased precision and reduced diversity, append (for example) `--tmax 0.2 --steps 2`. The default inference settings correspond to `--tmax 1.0 --steps 10`. See Appendix B.1 in the paper for more details. 113 | 114 | ## Evaluation scripts 115 | 116 | Our ensemble evaluations may be reproduced via the following steps: 117 | 1. Download the ATLAS dataset by runnig from `bash scripts/download_atlas.sh` from the desired root directory 118 | 2. Prepare the ensemble directory with a PDB file for each ATLAS target, each with 250 structures (see zipped AlphaFlow ensembles below for examples). **Some results are not directly comparable for evaluations with a different number of structures.** 119 | 3. Run `python -m scripts.analyze_ensembles --atlas_dir [DIR] --pdb_dir [DIR] --num_workers [N]`. This will produce an analysis file named `out.pkl` in the `pdb_dir`. 120 | 4. Run `python -m scripts.print_analysis [PATH] [PATH] ...` with an arbitrary number of paths to `out.pkl` files. A formatted comparison table will be printed. 121 | 122 | ## Training 123 | 124 | ### Downloading datasets 125 | 126 | To download and preprocess the PDB, 127 | 1. Run `aws s3 sync --no-sign-request s3://pdbsnapshots/20230102/pub/pdb/data/structures/divided/mmCIF pdb_mmcif` from the desired directory. 128 | 2. Run `find pdb_mmcif -name '*.gz' | xargs gunzip` to extract the MMCIF files. 129 | 3. From the repository root, run `python -m scripts.unpack_mmcif --mmcif_dir [DIR] --outdir [DIR] --num_workers [N]`. This will preprocess all chains into NPZ files and create a `pdb_mmcif.csv` index. 130 | 4. Download OpenProteinSet with `aws s3 sync --no-sign-request s3://openfold/ openfold` from the desired directory. 131 | 5. Run `python -m scripts.add_msa_info --openfold_dir [DIR]` to produce a `pdb_mmcif_msa.csv` index with OpenProteinSet MSA lookup. 132 | 6. Run `python -m scripts.cluster_chains` to produce a `pdb_clusters` file at 40% sequence similarity (Mmseqs installation required). 133 | 7. Create MSAs for the PDB validation split (`splits/cameo2022.csv`) according to the instructions in the previous section. 134 | 135 | To download and preprocess the ATLAS MD trajectory dataset, 136 | 1. Run `bash scripts/download_atlas.sh` from the desired directory. 137 | 2. From the repository root, run `python -m scripts.prep_atlas --atlas_dir [DIR] --outdir [DIR] --num_workers [N]`. This will preprocess the ATLAS trajectories into NPZ files. 138 | 3. Create MSAs for all entries in `splits/atlas.csv` according to the instructions in the previous section. 139 | 140 | ### Running training 141 | 142 | Before running training, download the pretrained AlphaFold and ESMFold weights into the repository root via 143 | ``` 144 | wget https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar 145 | tar -xvf alphafold_params_2022-12-06.tar params_model_1.npz 146 | wget https://dl.fbaipublicfiles.com/fair-esm/models/esmfold_3B_v1.pt 147 | ``` 148 | 149 | The basic command for training AlphaFlow is 150 | ``` 151 | python train.py --lr 5e-4 --noise_prob 0.8 --accumulate_grad 8 --train_epoch_len 80000 --train_cutoff 2018-05-01 --filter_chains \ 152 | --train_data_dir [DIR] \ 153 | --train_msa_dir [DIR] \ 154 | --mmcif_dir [DIR] \ 155 | --val_msa_dir [DIR] \ 156 | --run_name [NAME] [--wandb] 157 | ``` 158 | where the PDB NPZ directory, the OpenProteinSet directory, the PDB mmCIF directory, and the validation MSA directory are specified. This training run produces the AlphaFlow-PDB base version. All other models are built off this checkpoint. 159 | 160 | To continue training on ATLAS, run 161 | ``` 162 | python train.py --normal_validate --sample_train_confs --sample_val_confs --num_val_confs 100 --pdb_chains splits/atlas_train.csv --val_csv splits/atlas_val.csv --self_cond_prob 0.0 --noise_prob 0.9 --val_freq 10 --ckpt_freq 10 \ 163 | --train_data_dir [DIR] \ 164 | --train_msa_dir [DIR] \ 165 | --ckpt [PATH] \ 166 | --run_name [NAME] [--wandb] 167 | ``` 168 | where the ATLAS MSA and NPZ directories and AlphaFlow-PDB checkpoints are specified. 169 | 170 | To instead train on ATLAS with templates, run with the additional arguments `--first_as_template --extra_input --lr 1e-4 --restore_weights_only --extra_input_prob 1.0`. 171 | 172 | **Distillation**: to distill a model, append `--distillation` and supply the `--ckpt [PATH]` of the model to be distilled. For PDB training, we remove `--accumulate_grad 8` and recommend distilling with a shorter `--train_epoch_len 16000`. Note that `--self_cond_prob` and `--noise_prob` will be ignored and can be omitted. 173 | 174 | **ESMFlow**: run the same commands with `--mode esmfold` and `--train_cutoff 2020-05-01`. 175 | 176 | ## Ensembles 177 | 178 | We provide the ensembles sampled from the model which were used for the analyses and results reported in the paper. 179 | 180 | ### AlphaFlow ensembles 181 | | Model|Version|Samples| 182 | |:---|:--|:--| 183 | | AlphaFlow-PDB | base | https://storage.googleapis.com/alphaflow/samples/alphaflow_pdb_base_202402.zip | 184 | | AlphaFlow-PDB | distilled | https://storage.googleapis.com/alphaflow/samples/alphaflow_pdb_distilled_202402.zip | 185 | | AlphaFlow-MD | base | https://storage.googleapis.com/alphaflow/samples/alphaflow_md_base_202402.zip | 186 | | AlphaFlow-MD | distilled | https://storage.googleapis.com/alphaflow/samples/alphaflow_md_distilled_202402.zip | 187 | | AlphaFlow-MD+Templates | base | https://storage.googleapis.com/alphaflow/samples/alphaflow_md_templates_base_202402.zip | 188 | | AlphaFlow-MD+Templates | distilled | https://storage.googleapis.com/alphaflow/samples/alphaflow_md_templates_distilled_202402.zip | 189 | | AlphaFlow-MD+Templates | 12l-base | https://storage.googleapis.com/alphaflow/samples/alphaflow_12l_md_templates_base_202406.zip | 190 | | AlphaFlow-MD+Templates | 12l-distilled | https://storage.googleapis.com/alphaflow/samples/alphaflow_12l_md_templates_distilled_202406.zip | 191 | 192 | 193 | ### ESMFlow ensembles 194 | | Model|Version|Samples| 195 | |:---|:--|:--| 196 | | ESMFlow-PDB | base | https://storage.googleapis.com/alphaflow/samples/esmflow_pdb_base_202402.zip | 197 | | ESMFlow-PDB | distilled | https://storage.googleapis.com/alphaflow/samples/esmflow_pdb_distilled_202402.zip | 198 | | ESMFlow-MD | base | https://storage.googleapis.com/alphaflow/samples/esmflow_md_base_202402.zip | 199 | | ESMFlow-MD | distilled | https://storage.googleapis.com/alphaflow/samples/esmflow_md_distilled_202402.zip | 200 | | ESMFlow-MD+Templates | base | https://storage.googleapis.com/alphaflow/samples/esmflow_md_templates_base_202402.zip | 201 | | ESMFlow-MD+Templates | distilled | https://storage.googleapis.com/alphaflow/samples/esmflow_md_templates_distilled_202402.zip | 202 | 203 | ## License 204 | MIT. Other licenses may apply to third-party source code noted in file headers. 205 | 206 | ## Citation 207 | ``` 208 | @inproceedings{jing2024alphafold, 209 | title={AlphaFold Meets Flow Matching for Generating Protein Ensembles}, 210 | author={Jing, Bowen and Berger, Bonnie and Jaakkola, Tommi}, 211 | year={2024}, 212 | booktitle={Forty-first International Conference on Machine Learning} 213 | } 214 | ``` 215 | -------------------------------------------------------------------------------- /alphaflow/data/feature_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import copy 17 | from typing import Mapping, Tuple, List, Dict, Sequence 18 | 19 | import ml_collections 20 | import numpy as np 21 | import torch 22 | 23 | from . import input_pipeline 24 | 25 | 26 | FeatureDict = Mapping[str, np.ndarray] 27 | TensorDict = Dict[str, torch.Tensor] 28 | 29 | 30 | def np_to_tensor_dict( 31 | np_example: Mapping[str, np.ndarray], 32 | features: Sequence[str], 33 | ) -> TensorDict: 34 | """Creates dict of tensors from a dict of NumPy arrays. 35 | 36 | Args: 37 | np_example: A dict of NumPy feature arrays. 38 | features: A list of strings of feature names to be returned in the dataset. 39 | 40 | Returns: 41 | A dictionary of features mapping feature names to features. Only the given 42 | features are returned, all other ones are filtered out. 43 | """ 44 | tensor_dict = { 45 | k: torch.tensor(v) for k, v in np_example.items() if k in features 46 | } 47 | 48 | return tensor_dict 49 | 50 | 51 | def make_data_config( 52 | config: ml_collections.ConfigDict, 53 | mode: str, 54 | num_res: int, 55 | ) -> Tuple[ml_collections.ConfigDict, List[str]]: 56 | cfg = copy.deepcopy(config) 57 | mode_cfg = cfg[mode] 58 | with cfg.unlocked(): 59 | if mode_cfg.crop_size is None: 60 | mode_cfg.crop_size = num_res 61 | 62 | feature_names = cfg.common.unsupervised_features 63 | 64 | if cfg.common.use_templates: 65 | feature_names += cfg.common.template_features 66 | 67 | if cfg[mode].supervised: 68 | feature_names += cfg.supervised.supervised_features 69 | 70 | return cfg, feature_names 71 | 72 | 73 | def np_example_to_features( 74 | np_example: FeatureDict, 75 | config: ml_collections.ConfigDict, 76 | mode: str, 77 | ): 78 | np_example = dict(np_example) 79 | num_res = int(np_example["seq_length"][0]) 80 | cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) 81 | 82 | if "deletion_matrix_int" in np_example: 83 | np_example["deletion_matrix"] = np_example.pop( 84 | "deletion_matrix_int" 85 | ).astype(np.float32) 86 | 87 | tensor_dict = np_to_tensor_dict( 88 | np_example=np_example, features=feature_names 89 | ) 90 | with torch.no_grad(): 91 | features = input_pipeline.process_tensors_from_config( 92 | tensor_dict, 93 | cfg.common, 94 | cfg[mode], 95 | ) 96 | 97 | if mode == "train": 98 | p = torch.rand(1).item() 99 | use_clamped_fape_value = float(p < cfg.supervised.clamp_prob) 100 | features["use_clamped_fape"] = torch.full( 101 | size=[cfg.common.max_recycling_iters + 1], 102 | fill_value=use_clamped_fape_value, 103 | dtype=torch.float32, 104 | ) 105 | else: 106 | features["use_clamped_fape"] = torch.full( 107 | size=[cfg.common.max_recycling_iters + 1], 108 | fill_value=0.0, 109 | dtype=torch.float32, 110 | ) 111 | 112 | return {k: v for k, v in features.items()} 113 | 114 | 115 | class FeaturePipeline: 116 | def __init__( 117 | self, 118 | config: ml_collections.ConfigDict, 119 | ): 120 | self.config = config 121 | 122 | def process_features( 123 | self, 124 | raw_features: FeatureDict, 125 | mode: str = "train", 126 | ) -> FeatureDict: 127 | return np_example_to_features( 128 | np_example=raw_features, 129 | config=self.config, 130 | mode=mode, 131 | ) 132 | -------------------------------------------------------------------------------- /alphaflow/data/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import numpy as np 4 | from openfold.np import residue_constants 5 | from .data_pipeline import DataPipeline 6 | from .feature_pipeline import FeaturePipeline 7 | from openfold.data.data_transforms import make_atom14_masks 8 | import alphaflow.utils.protein as protein 9 | 10 | def seq_to_tensor(seq): 11 | unk_idx = residue_constants.restype_order_with_x["X"] 12 | encoded = torch.tensor( 13 | [residue_constants.restype_order_with_x.get(aa, unk_idx) for aa in seq] 14 | ) 15 | return encoded 16 | 17 | class AlphaFoldCSVDataset: 18 | def __init__(self, config, path, mmcif_dir=None, msa_dir=None, templates_dir=None): 19 | super().__init__() 20 | self.pdb_chains = pd.read_csv(path, index_col='name') 21 | self.msa_dir = msa_dir 22 | self.mmcif_dir = mmcif_dir 23 | self.data_pipeline = DataPipeline(template_featurizer=None) 24 | self.feature_pipeline = FeaturePipeline(config) 25 | self.templates_dir = templates_dir 26 | 27 | def __len__(self): 28 | return len(self.pdb_chains) 29 | 30 | def __getitem__(self, idx): 31 | 32 | item = self.pdb_chains.iloc[idx] 33 | 34 | mmcif_feats = self.data_pipeline.process_str(item.seqres, item.name) 35 | if self.templates_dir: 36 | path = f"{self.templates_dir}/{item.name}.pdb" 37 | with open(path) as f: 38 | prot = protein.from_pdb_string(f.read()) 39 | extra_all_atom_positions = prot.atom_positions.astype(np.float32) 40 | 41 | 42 | try: msa_id = item.msa_id 43 | except: msa_id = item.name 44 | msa_features = self.data_pipeline._process_msa_feats(f'{self.msa_dir}/{msa_id}', item.seqres, alignment_index=None) 45 | data = {**mmcif_feats, **msa_features} 46 | 47 | feats = self.feature_pipeline.process_features(data, mode='predict') 48 | if self.templates_dir: 49 | feats['extra_all_atom_positions'] = torch.from_numpy(extra_all_atom_positions) 50 | feats['pseudo_beta_mask'] = torch.ones(len(item.seqres)) 51 | feats['name'] = item.name 52 | feats['seqres'] = item.seqres 53 | make_atom14_masks(feats) 54 | 55 | if self.mmcif_dir is not None: 56 | pdb_id, chain = item.name.split('_') 57 | with open(f"{self.mmcif_dir}/{pdb_id[1:3]}/{pdb_id}.cif") as f: 58 | feats['ref_prot'] = protein.from_mmcif_string(f.read(), chain, name=item.name) 59 | 60 | return feats 61 | 62 | class CSVDataset: 63 | def __init__(self, config, path, mmcif_dir=None, msa_dir=None, templates_dir=None): 64 | super().__init__() 65 | self.pdb_chains = pd.read_csv(path, index_col='name') 66 | self.templates_dir = templates_dir 67 | 68 | def __len__(self): 69 | return self.pdb_chains.shape[0] 70 | 71 | def __getitem__(self, idx): 72 | row = self.pdb_chains.iloc[idx] 73 | batch = { 74 | 'name': row.name, 75 | 'seqres': row.seqres, 76 | 'aatype': seq_to_tensor(row.seqres), 77 | 'residue_index': torch.arange(len(row.seqres)), 78 | 'pseudo_beta_mask': torch.ones(len(row.seqres)), 79 | 'seq_mask': torch.ones(len(row.seqres)) 80 | } 81 | make_atom14_masks(batch) 82 | 83 | if self.templates_dir: 84 | path = f"{self.templates_dir}/{row.name}.pdb" 85 | with open(path) as f: 86 | prot = protein.from_pdb_string(f.read()) 87 | extra_all_atom_positions = prot.atom_positions.astype(np.float32) 88 | batch['extra_all_atom_positions'] = torch.from_numpy(extra_all_atom_positions) 89 | 90 | return batch 91 | -------------------------------------------------------------------------------- /alphaflow/data/input_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import itertools 18 | 19 | from openfold.data import data_transforms 20 | NUM_RES = "num residues placeholder" 21 | NUM_MSA_SEQ = "msa placeholder" 22 | NUM_EXTRA_SEQ = "extra msa placeholder" 23 | NUM_TEMPLATES = "num templates placeholder" 24 | 25 | @data_transforms.curry1 26 | def random_crop_to_size( 27 | protein, 28 | crop_size, 29 | max_templates, 30 | shape_schema, 31 | subsample_templates=False, 32 | seed=None, 33 | ): 34 | """Crop randomly to `crop_size`, or keep as is if shorter than that.""" 35 | # We want each ensemble to be cropped the same way 36 | 37 | g = torch.Generator(device=protein["seq_length"].device) 38 | if seed is not None: 39 | g.manual_seed(seed) 40 | 41 | seq_length = protein["seq_length"] 42 | 43 | if "template_mask" in protein: 44 | num_templates = protein["template_mask"].shape[-1] 45 | else: 46 | num_templates = 0 47 | 48 | # No need to subsample templates if there aren't any 49 | subsample_templates = subsample_templates and num_templates 50 | 51 | num_res_crop_size = min(int(seq_length), crop_size) 52 | 53 | def _randint(lower, upper): 54 | return int(torch.randint( 55 | lower, 56 | upper + 1, 57 | (1,), 58 | device=protein["seq_length"].device, 59 | generator=g, 60 | )[0]) 61 | 62 | if subsample_templates: 63 | templates_crop_start = _randint(0, num_templates) 64 | templates_select_indices = torch.randperm( 65 | num_templates, device=protein["seq_length"].device, generator=g 66 | ) 67 | else: 68 | templates_crop_start = 0 69 | 70 | num_templates_crop_size = min( 71 | num_templates - templates_crop_start, max_templates 72 | ) 73 | 74 | n = seq_length - num_res_crop_size 75 | if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.: 76 | right_anchor = n 77 | else: 78 | x = _randint(0, n) 79 | right_anchor = n - x 80 | num_res_crop_start = _randint(0, right_anchor) 81 | 82 | for k, v in protein.items(): 83 | if k not in shape_schema or ( 84 | "template" not in k and NUM_RES not in shape_schema[k] 85 | ): 86 | continue 87 | 88 | # randomly permute the templates before cropping them. 89 | if k.startswith("template") and subsample_templates: 90 | v = v[templates_select_indices] 91 | 92 | slices = [] 93 | for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)): 94 | is_num_res = dim_size == NUM_RES 95 | if i == 0 and k.startswith("template"): 96 | crop_size = num_templates_crop_size 97 | crop_start = templates_crop_start 98 | else: 99 | crop_start = num_res_crop_start if is_num_res else 0 100 | crop_size = num_res_crop_size if is_num_res else dim 101 | slices.append(slice(crop_start, crop_start + crop_size)) 102 | protein[k] = v[slices] #### MODIFIED 103 | 104 | protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size) 105 | 106 | return protein 107 | 108 | 109 | 110 | @data_transforms.curry1 111 | def make_fixed_size( 112 | protein, 113 | shape_schema, 114 | num_res=0, 115 | num_templates=0, 116 | ): 117 | 118 | """Guess at the MSA and sequence dimension to make fixed size.""" 119 | pad_size_map = { 120 | NUM_RES: num_res, 121 | NUM_TEMPLATES: num_templates, 122 | } 123 | 124 | for k, v in protein.items(): 125 | # Don't transfer this to the accelerator. 126 | if k == "extra_cluster_assignment": 127 | continue 128 | shape = list(v.shape) 129 | schema = shape_schema[k] 130 | msg = "Rank mismatch between shape and shape schema for" 131 | assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}" 132 | pad_size = [ 133 | pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema) 134 | ] 135 | 136 | padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)] 137 | padding.reverse() 138 | padding = list(itertools.chain(*padding)) 139 | if padding: 140 | protein[k] = torch.nn.functional.pad(v, padding) 141 | protein[k] = torch.reshape(protein[k], pad_size) 142 | 143 | return protein 144 | 145 | 146 | def nonensembled_transform_fns(common_cfg, mode_cfg): 147 | """Input pipeline data transformers that are not ensembled.""" 148 | transforms = [ 149 | data_transforms.cast_to_64bit_ints, 150 | data_transforms.correct_msa_restypes, 151 | data_transforms.squeeze_features, 152 | data_transforms.randomly_replace_msa_with_unknown(0.0), 153 | data_transforms.make_seq_mask, 154 | data_transforms.make_msa_mask, 155 | data_transforms.make_hhblits_profile, 156 | ] 157 | if common_cfg.use_templates: 158 | transforms.extend( 159 | [ 160 | data_transforms.fix_templates_aatype, 161 | data_transforms.make_template_mask, 162 | data_transforms.make_pseudo_beta("template_"), 163 | ] 164 | ) 165 | if common_cfg.use_template_torsion_angles: 166 | transforms.extend( 167 | [ 168 | data_transforms.atom37_to_torsion_angles("template_"), 169 | ] 170 | ) 171 | 172 | transforms.extend( 173 | [ 174 | data_transforms.make_atom14_masks, 175 | ] 176 | ) 177 | 178 | if mode_cfg.supervised: 179 | transforms.extend( 180 | [ 181 | data_transforms.make_atom14_positions, 182 | data_transforms.atom37_to_frames, 183 | data_transforms.atom37_to_torsion_angles(""), 184 | data_transforms.make_pseudo_beta(""), 185 | data_transforms.get_backbone_frames, 186 | data_transforms.get_chi_angles, 187 | ] 188 | ) 189 | 190 | ############### 191 | 192 | if "max_distillation_msa_clusters" in mode_cfg: 193 | transforms.append( 194 | data_transforms.sample_msa_distillation( 195 | mode_cfg.max_distillation_msa_clusters 196 | ) 197 | ) 198 | 199 | if common_cfg.reduce_msa_clusters_by_max_templates: 200 | pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates 201 | else: 202 | pad_msa_clusters = mode_cfg.max_msa_clusters 203 | 204 | max_msa_clusters = pad_msa_clusters 205 | max_extra_msa = mode_cfg.max_extra_msa 206 | 207 | msa_seed = None 208 | if(not common_cfg.resample_msa_in_recycling): 209 | msa_seed = ensemble_seed 210 | 211 | transforms.append( 212 | data_transforms.sample_msa( 213 | max_msa_clusters, 214 | keep_extra=True, 215 | seed=msa_seed, 216 | ) 217 | ) 218 | 219 | if "masked_msa" in common_cfg: 220 | # Masked MSA should come *before* MSA clustering so that 221 | # the clustering and full MSA profile do not leak information about 222 | # the masked locations and secret corrupted locations. 223 | transforms.append( 224 | data_transforms.make_masked_msa( 225 | common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction 226 | ) 227 | ) 228 | 229 | if common_cfg.msa_cluster_features: 230 | transforms.append(data_transforms.nearest_neighbor_clusters()) 231 | transforms.append(data_transforms.summarize_clusters()) 232 | 233 | # Crop after creating the cluster profiles. 234 | if max_extra_msa: 235 | transforms.append(data_transforms.crop_extra_msa(max_extra_msa)) 236 | else: 237 | transforms.append(data_transforms.delete_extra_msa) 238 | 239 | transforms.append(data_transforms.make_msa_feat()) 240 | 241 | ################## 242 | crop_feats = dict(common_cfg.feat) 243 | # return transforms 244 | 245 | if mode_cfg.fixed_size: 246 | transforms.append(data_transforms.select_feat(list(crop_feats))) 247 | 248 | transforms.append( 249 | data_transforms.random_crop_to_size( 250 | mode_cfg.crop_size, 251 | mode_cfg.max_templates, 252 | crop_feats, 253 | mode_cfg.subsample_templates, 254 | ) 255 | # random_crop_to_size( 256 | # mode_cfg.crop_size, 257 | # mode_cfg.max_templates, 258 | # crop_feats, 259 | # mode_cfg.subsample_templates, 260 | # ) 261 | ) 262 | transforms.append( 263 | data_transforms.make_fixed_size( 264 | crop_feats, 265 | pad_msa_clusters, 266 | mode_cfg.max_extra_msa, 267 | mode_cfg.crop_size, 268 | mode_cfg.max_templates, 269 | ) 270 | # make_fixed_size( 271 | # crop_feats, 272 | # mode_cfg.crop_size, 273 | # mode_cfg.max_templates, 274 | # ) 275 | ) 276 | ''' 277 | else: 278 | transforms.append( 279 | data_transforms.crop_templates(mode_cfg.max_templates) 280 | ) 281 | 282 | ''' 283 | return transforms 284 | 285 | 286 | def process_tensors_from_config(tensors, common_cfg, mode_cfg): 287 | """Based on the config, apply filters and transformations to the data.""" 288 | 289 | no_templates = True 290 | if("template_aatype" in tensors): 291 | no_templates = tensors["template_aatype"].shape[0] == 0 292 | 293 | 294 | nonensembled = nonensembled_transform_fns( 295 | common_cfg, 296 | mode_cfg, 297 | ) 298 | 299 | tensors = compose(nonensembled)(tensors) 300 | 301 | return tensors 302 | 303 | 304 | 305 | @data_transforms.curry1 306 | def compose(x, fs): 307 | for f in fs: 308 | x = f(x) 309 | return x 310 | 311 | 312 | def map_fn(fun, x): 313 | ensembles = [fun(elem) for elem in x] 314 | features = ensembles[0].keys() 315 | ensembled_dict = {} 316 | for feat in features: 317 | ensembled_dict[feat] = torch.stack( 318 | [dict_i[feat] for dict_i in ensembles], dim=-1 319 | ) 320 | return ensembled_dict 321 | -------------------------------------------------------------------------------- /alphaflow/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bjing2016/alphaflow/02dc03763a016949326c2c741e6e33094f9250fd/alphaflow/model/__init__.py -------------------------------------------------------------------------------- /alphaflow/model/alphafold.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | from openfold.model.embedders import ( 20 | InputEmbedder, 21 | RecyclingEmbedder, 22 | ExtraMSAEmbedder, 23 | ) 24 | from openfold.model.evoformer import EvoformerStack, ExtraMSAStack 25 | from openfold.model.heads import AuxiliaryHeads 26 | from openfold.model.structure_module import StructureModule 27 | 28 | import openfold.np.residue_constants as residue_constants 29 | from openfold.utils.feats import ( 30 | pseudo_beta_fn, 31 | build_extra_msa_feat, 32 | atom14_to_atom37, 33 | ) 34 | from openfold.utils.tensor_utils import add 35 | from .input_stack import InputPairStack 36 | from .layers import GaussianFourierProjection 37 | from openfold.model.primitives import Linear 38 | 39 | 40 | class AlphaFold(nn.Module): 41 | """ 42 | Alphafold 2. 43 | 44 | Implements Algorithm 2 (but with training). 45 | """ 46 | 47 | def __init__(self, config, extra_input=False): 48 | """ 49 | Args: 50 | config: 51 | A dict-like config object (like the one in config.py) 52 | """ 53 | super(AlphaFold, self).__init__() 54 | 55 | self.globals = config.globals 56 | self.config = config.model 57 | self.template_config = self.config.template 58 | self.extra_msa_config = self.config.extra_msa 59 | 60 | # Main trunk + structure module 61 | self.input_embedder = InputEmbedder( 62 | **self.config["input_embedder"], 63 | ) 64 | self.recycling_embedder = RecyclingEmbedder( 65 | **self.config["recycling_embedder"], 66 | ) 67 | 68 | 69 | if(self.extra_msa_config.enabled): 70 | self.extra_msa_embedder = ExtraMSAEmbedder( 71 | **self.extra_msa_config["extra_msa_embedder"], 72 | ) 73 | self.extra_msa_stack = ExtraMSAStack( 74 | **self.extra_msa_config["extra_msa_stack"], 75 | ) 76 | 77 | self.evoformer = EvoformerStack( 78 | **self.config["evoformer_stack"], 79 | ) 80 | self.structure_module = StructureModule( 81 | **self.config["structure_module"], 82 | ) 83 | self.aux_heads = AuxiliaryHeads( 84 | self.config["heads"], 85 | ) 86 | 87 | ################ 88 | self.input_pair_embedding = Linear( 89 | self.config.input_pair_embedder.no_bins, 90 | self.config.evoformer_stack.c_z, 91 | init="final", 92 | ) 93 | self.input_time_projection = GaussianFourierProjection( 94 | embedding_size=self.config.input_pair_embedder.time_emb_dim 95 | ) 96 | self.input_time_embedding = Linear( 97 | self.config.input_pair_embedder.time_emb_dim, 98 | self.config.evoformer_stack.c_z, 99 | init="final", 100 | ) 101 | self.input_pair_stack = InputPairStack(**self.config.input_pair_stack) 102 | self.extra_input = extra_input 103 | if extra_input: 104 | self.extra_input_pair_embedding = Linear( 105 | self.config.input_pair_embedder.no_bins, 106 | self.config.evoformer_stack.c_z, 107 | init="final", 108 | ) 109 | self.extra_input_pair_stack = InputPairStack(**self.config.input_pair_stack) 110 | 111 | ################ 112 | 113 | def _get_input_pair_embeddings(self, dists, mask): 114 | 115 | mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) 116 | 117 | lower = torch.linspace( 118 | self.config.input_pair_embedder.min_bin, 119 | self.config.input_pair_embedder.max_bin, 120 | self.config.input_pair_embedder.no_bins, 121 | device=dists.device) 122 | dists = dists.unsqueeze(-1) 123 | inf = self.config.input_pair_embedder.inf 124 | upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) 125 | dgram = ((dists > lower) * (dists < upper)).type(dists.dtype) 126 | 127 | inp_z = self.input_pair_embedding(dgram * mask.unsqueeze(-1)) 128 | inp_z = self.input_pair_stack(inp_z, mask, chunk_size=None) 129 | return inp_z 130 | 131 | def _get_extra_input_pair_embeddings(self, dists, mask): 132 | 133 | mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) 134 | 135 | lower = torch.linspace( 136 | self.config.input_pair_embedder.min_bin, 137 | self.config.input_pair_embedder.max_bin, 138 | self.config.input_pair_embedder.no_bins, 139 | device=dists.device) 140 | dists = dists.unsqueeze(-1) 141 | inf = self.config.input_pair_embedder.inf 142 | upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) 143 | dgram = ((dists > lower) * (dists < upper)).type(dists.dtype) 144 | 145 | inp_z = self.extra_input_pair_embedding(dgram * mask.unsqueeze(-1)) 146 | inp_z = self.extra_input_pair_stack(inp_z, mask, chunk_size=None) 147 | return inp_z 148 | 149 | 150 | def forward(self, batch, prev_outputs=None): 151 | 152 | feats = batch 153 | 154 | # Primary output dictionary 155 | outputs = {} 156 | 157 | # # This needs to be done manually for DeepSpeed's sake 158 | # dtype = next(self.parameters()).dtype 159 | # for k in feats: 160 | # if(feats[k].dtype == torch.float32): 161 | # feats[k] = feats[k].to(dtype=dtype) 162 | 163 | # Grab some data about the input 164 | batch_dims = feats["target_feat"].shape[:-2] 165 | no_batch_dims = len(batch_dims) 166 | n = feats["target_feat"].shape[-2] 167 | n_seq = feats["msa_feat"].shape[-3] 168 | device = feats["target_feat"].device 169 | 170 | # Controls whether the model uses in-place operations throughout 171 | # The dual condition accounts for activation checkpoints 172 | inplace_safe = not (self.training or torch.is_grad_enabled()) 173 | 174 | # Prep some features 175 | seq_mask = feats["seq_mask"] 176 | pair_mask = seq_mask[..., None] * seq_mask[..., None, :] 177 | msa_mask = feats["msa_mask"] 178 | 179 | ## Initialize the MSA and pair representations 180 | 181 | # m: [*, S_c, N, C_m] 182 | # z: [*, N, N, C_z] 183 | m, z = self.input_embedder( 184 | feats["target_feat"], 185 | feats["residue_index"], 186 | feats["msa_feat"], 187 | inplace_safe=inplace_safe, 188 | ) 189 | if prev_outputs is None: 190 | m_1_prev = m.new_zeros((*batch_dims, n, self.config.input_embedder.c_m), requires_grad=False) 191 | # [*, N, N, C_z] 192 | z_prev = z.new_zeros((*batch_dims, n, n, self.config.input_embedder.c_z), requires_grad=False) 193 | # [*, N, 3] 194 | x_prev = z.new_zeros((*batch_dims, n, residue_constants.atom_type_num, 3), requires_grad=False) 195 | 196 | else: 197 | m_1_prev, z_prev, x_prev = prev_outputs['m_1_prev'], prev_outputs['z_prev'], prev_outputs['x_prev'] 198 | 199 | x_prev = pseudo_beta_fn( 200 | feats["aatype"], x_prev, None 201 | ).to(dtype=z.dtype) 202 | 203 | # m_1_prev_emb: [*, N, C_m] 204 | # z_prev_emb: [*, N, N, C_z] 205 | m_1_prev_emb, z_prev_emb = self.recycling_embedder( 206 | m_1_prev, 207 | z_prev, 208 | x_prev, 209 | inplace_safe=inplace_safe, 210 | ) 211 | 212 | # [*, S_c, N, C_m] 213 | m[..., 0, :, :] += m_1_prev_emb 214 | 215 | # [*, N, N, C_z] 216 | z = add(z, z_prev_emb, inplace=inplace_safe) 217 | 218 | 219 | ####################### 220 | if 'noised_pseudo_beta_dists' in batch: 221 | inp_z = self._get_input_pair_embeddings( 222 | batch['noised_pseudo_beta_dists'], 223 | batch['pseudo_beta_mask'], 224 | ) 225 | inp_z = inp_z + self.input_time_embedding(self.input_time_projection(batch['t']))[:,None,None] 226 | 227 | else: # otherwise DDP complains 228 | B, L = batch['aatype'].shape 229 | inp_z = self._get_input_pair_embeddings( 230 | z.new_zeros(B, L, L), 231 | z.new_zeros(B, L), 232 | ) 233 | inp_z = inp_z + self.input_time_embedding(self.input_time_projection(z.new_zeros(B)))[:,None,None] 234 | 235 | z = add(z, inp_z, inplace=inplace_safe) 236 | 237 | ############################# 238 | if self.extra_input: 239 | if 'extra_all_atom_positions' in batch: 240 | extra_pseudo_beta = pseudo_beta_fn(batch['aatype'], batch['extra_all_atom_positions'], None) 241 | extra_pseudo_beta_dists = torch.sum((extra_pseudo_beta.unsqueeze(-2) - extra_pseudo_beta.unsqueeze(-3)) ** 2, dim=-1)**0.5 242 | extra_inp_z = self._get_extra_input_pair_embeddings( 243 | extra_pseudo_beta_dists, 244 | batch['pseudo_beta_mask'], 245 | ) 246 | 247 | else: # otherwise DDP complains 248 | B, L = batch['aatype'].shape 249 | extra_inp_z = self._get_extra_input_pair_embeddings( 250 | z.new_zeros(B, L, L), 251 | z.new_zeros(B, L), 252 | ) * 0.0 253 | 254 | z = add(z, extra_inp_z, inplace=inplace_safe) 255 | ######################## 256 | 257 | # Embed extra MSA features + merge with pairwise embeddings 258 | if self.config.extra_msa.enabled: 259 | # [*, S_e, N, C_e] 260 | a = self.extra_msa_embedder(build_extra_msa_feat(feats)) 261 | 262 | if(self.globals.offload_inference): 263 | # To allow the extra MSA stack (and later the evoformer) to 264 | # offload its inputs, we remove all references to them here 265 | input_tensors = [a, z] 266 | del a, z 267 | 268 | # [*, N, N, C_z] 269 | z = self.extra_msa_stack._forward_offload( 270 | input_tensors, 271 | msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype), 272 | chunk_size=self.globals.chunk_size, 273 | use_lma=self.globals.use_lma, 274 | pair_mask=pair_mask.to(dtype=m.dtype), 275 | _mask_trans=self.config._mask_trans, 276 | ) 277 | 278 | del input_tensors 279 | else: 280 | # [*, N, N, C_z] 281 | 282 | z = self.extra_msa_stack( 283 | a, z, 284 | msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype), 285 | chunk_size=self.globals.chunk_size, 286 | use_lma=self.globals.use_lma, 287 | pair_mask=pair_mask.to(dtype=m.dtype), 288 | inplace_safe=inplace_safe, 289 | _mask_trans=self.config._mask_trans, 290 | ) 291 | 292 | # Run MSA + pair embeddings through the trunk of the network 293 | # m: [*, S, N, C_m] 294 | # z: [*, N, N, C_z] 295 | # s: [*, N, C_s] 296 | if(self.globals.offload_inference): 297 | input_tensors = [m, z] 298 | del m, z 299 | m, z, s = self.evoformer._forward_offload( 300 | input_tensors, 301 | msa_mask=msa_mask.to(dtype=input_tensors[0].dtype), 302 | pair_mask=pair_mask.to(dtype=input_tensors[1].dtype), 303 | chunk_size=self.globals.chunk_size, 304 | use_lma=self.globals.use_lma, 305 | _mask_trans=self.config._mask_trans, 306 | ) 307 | 308 | del input_tensors 309 | else: 310 | m, z, s = self.evoformer( 311 | m, 312 | z, 313 | msa_mask=msa_mask.to(dtype=m.dtype), 314 | pair_mask=pair_mask.to(dtype=z.dtype), 315 | chunk_size=self.globals.chunk_size, 316 | use_lma=self.globals.use_lma, 317 | use_flash=self.globals.use_flash, 318 | inplace_safe=inplace_safe, 319 | _mask_trans=self.config._mask_trans, 320 | ) 321 | 322 | outputs["msa"] = m[..., :n_seq, :, :] 323 | outputs["pair"] = z 324 | outputs["single"] = s 325 | 326 | del z 327 | 328 | # Predict 3D structure 329 | outputs["sm"] = self.structure_module( 330 | outputs, 331 | feats["aatype"], 332 | mask=feats["seq_mask"].to(dtype=s.dtype), 333 | inplace_safe=inplace_safe, 334 | _offload_inference=self.globals.offload_inference, 335 | ) 336 | outputs["final_atom_positions"] = atom14_to_atom37( 337 | outputs["sm"]["positions"][-1], feats 338 | ) 339 | outputs["final_atom_mask"] = feats["atom37_atom_exists"] 340 | outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1] 341 | 342 | 343 | outputs.update(self.aux_heads(outputs)) 344 | 345 | 346 | # [*, N, C_m] 347 | outputs['m_1_prev'] = m[..., 0, :, :] 348 | 349 | # [*, N, N, C_z] 350 | outputs['z_prev'] = outputs["pair"] 351 | 352 | # [*, N, 3] 353 | outputs['x_prev'] = outputs["final_atom_positions"] 354 | 355 | return outputs 356 | 357 | # def forward(self, batch): 358 | # """ 359 | # Args: 360 | # batch: 361 | # Dictionary of arguments outlined in Algorithm 2. Keys must 362 | # include the official names of the features in the 363 | # supplement subsection 1.2.9. 364 | 365 | # The final dimension of each input must have length equal to 366 | # the number of recycling iterations. 367 | 368 | # Features (without the recycling dimension): 369 | 370 | # "aatype" ([*, N_res]): 371 | # Contrary to the supplement, this tensor of residue 372 | # indices is not one-hot. 373 | # "target_feat" ([*, N_res, C_tf]) 374 | # One-hot encoding of the target sequence. C_tf is 375 | # config.model.input_embedder.tf_dim. 376 | # "residue_index" ([*, N_res]) 377 | # Tensor whose final dimension consists of 378 | # consecutive indices from 0 to N_res. 379 | # "msa_feat" ([*, N_seq, N_res, C_msa]) 380 | # MSA features, constructed as in the supplement. 381 | # C_msa is config.model.input_embedder.msa_dim. 382 | # "seq_mask" ([*, N_res]) 383 | # 1-D sequence mask 384 | # "msa_mask" ([*, N_seq, N_res]) 385 | # MSA mask 386 | # "pair_mask" ([*, N_res, N_res]) 387 | # 2-D pair mask 388 | # "extra_msa_mask" ([*, N_extra, N_res]) 389 | # Extra MSA mask 390 | # "template_mask" ([*, N_templ]) 391 | # Template mask (on the level of templates, not 392 | # residues) 393 | # "template_aatype" ([*, N_templ, N_res]) 394 | # Tensor of template residue indices (indices greater 395 | # than 19 are clamped to 20 (Unknown)) 396 | # "template_all_atom_positions" 397 | # ([*, N_templ, N_res, 37, 3]) 398 | # Template atom coordinates in atom37 format 399 | # "template_all_atom_mask" ([*, N_templ, N_res, 37]) 400 | # Template atom coordinate mask 401 | # "template_pseudo_beta" ([*, N_templ, N_res, 3]) 402 | # Positions of template carbon "pseudo-beta" atoms 403 | # (i.e. C_beta for all residues but glycine, for 404 | # for which C_alpha is used instead) 405 | # "template_pseudo_beta_mask" ([*, N_templ, N_res]) 406 | # Pseudo-beta mask 407 | # """ 408 | # # Initialize recycling embeddings 409 | 410 | # m_1_prev, z_prev, x_prev = None, None, None 411 | # prevs = [m_1_prev, z_prev, x_prev] 412 | 413 | # is_grad_enabled = torch.is_grad_enabled() 414 | 415 | # # Main recycling loop 416 | # num_iters = batch["aatype"].shape[-1] 417 | # for cycle_no in range(num_iters): 418 | # # Select the features for the current recycling cycle 419 | # fetch_cur_batch = lambda t: t[..., cycle_no] 420 | # feats = tensor_tree_map(fetch_cur_batch, batch) 421 | 422 | # # Enable grad iff we're training and it's the final recycling layer 423 | # is_final_iter = cycle_no == (num_iters - 1) 424 | # with torch.set_grad_enabled(is_grad_enabled and is_final_iter): 425 | # if is_final_iter: 426 | # # Sidestep AMP bug (PyTorch issue #65766) 427 | # if torch.is_autocast_enabled(): 428 | # torch.clear_autocast_cache() 429 | 430 | # # Run the next iteration of the model 431 | # outputs, m_1_prev, z_prev, x_prev = self.iteration( 432 | # feats, 433 | # prevs, 434 | # _recycle=(num_iters > 1) 435 | # ) 436 | 437 | # if(not is_final_iter): 438 | # del outputs 439 | # prevs = [m_1_prev, z_prev, x_prev] 440 | # del m_1_prev, z_prev, x_prev 441 | 442 | # # Run auxiliary heads 443 | # outputs.update(self.aux_heads(outputs)) 444 | 445 | # return outputs -------------------------------------------------------------------------------- /alphaflow/model/esmfold.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import typing as T 6 | from functools import partial 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import LayerNorm 11 | 12 | import esm 13 | from esm import Alphabet 14 | 15 | from alphaflow.utils.misc import ( 16 | categorical_lddt, 17 | batch_encode_sequences, 18 | collate_dense_tensors, 19 | output_to_pdb, 20 | ) 21 | 22 | from .trunk import FoldingTrunk 23 | from .layers import GaussianFourierProjection 24 | from .input_stack import InputPairStack 25 | 26 | from openfold.data.data_transforms import make_atom14_masks 27 | from openfold.np import residue_constants 28 | from openfold.model.heads import PerResidueLDDTCaPredictor 29 | from openfold.model.primitives import Linear 30 | from openfold.utils.feats import atom14_to_atom37, pseudo_beta_fn 31 | 32 | 33 | load_fn = esm.pretrained.load_model_and_alphabet 34 | esm_registry = { 35 | "esm2_8M": partial(load_fn, "esm2_t6_8M_UR50D_500K"), 36 | "esm2_8M_270K": esm.pretrained.esm2_t6_8M_UR50D, 37 | "esm2_35M": partial(load_fn, "esm2_t12_35M_UR50D_500K"), 38 | "esm2_35M_270K": esm.pretrained.esm2_t12_35M_UR50D, 39 | "esm2_150M": partial(load_fn, "esm2_t30_150M_UR50D_500K"), 40 | "esm2_150M_270K": partial(load_fn, "esm2_t30_150M_UR50D_270K"), 41 | "esm2_650M": esm.pretrained.esm2_t33_650M_UR50D, 42 | "esm2_650M_270K": partial(load_fn, "esm2_t33_650M_270K_UR50D"), 43 | "esm2_3B": esm.pretrained.esm2_t36_3B_UR50D, 44 | "esm2_3B_270K": partial(load_fn, "esm2_t36_3B_UR50D_500K"), 45 | "esm2_15B": esm.pretrained.esm2_t48_15B_UR50D, 46 | } 47 | 48 | 49 | class ESMFold(nn.Module): 50 | def __init__(self, cfg, extra_input=False): 51 | super().__init__() 52 | 53 | self.cfg = cfg 54 | cfg = self.cfg 55 | 56 | self.distogram_bins = 64 57 | self.esm, self.esm_dict = esm_registry.get(cfg.esm_type)() 58 | 59 | self.esm.requires_grad_(False) 60 | self.esm.half() 61 | 62 | self.esm_feats = self.esm.embed_dim 63 | self.esm_attns = self.esm.num_layers * self.esm.attention_heads 64 | self.register_buffer("af2_to_esm", ESMFold._af2_to_esm(self.esm_dict).float()) # hack to get EMA working 65 | self.esm_s_combine = nn.Parameter(torch.zeros(self.esm.num_layers + 1)) 66 | 67 | c_s = cfg.trunk.sequence_state_dim 68 | c_z = cfg.trunk.pairwise_state_dim 69 | 70 | self.esm_s_mlp = nn.Sequential( 71 | LayerNorm(self.esm_feats), 72 | nn.Linear(self.esm_feats, c_s), 73 | nn.ReLU(), 74 | nn.Linear(c_s, c_s), 75 | ) 76 | ###################### 77 | self.input_pair_embedding = Linear( 78 | cfg.input_pair_embedder.no_bins, 79 | cfg.trunk.pairwise_state_dim, 80 | init="final", 81 | ) 82 | self.input_time_projection = GaussianFourierProjection( 83 | embedding_size=cfg.input_pair_embedder.time_emb_dim 84 | ) 85 | self.input_time_embedding = Linear( 86 | cfg.input_pair_embedder.time_emb_dim, 87 | cfg.trunk.pairwise_state_dim, 88 | init="final", 89 | ) 90 | torch.nn.init.zeros_(self.input_pair_embedding.weight) 91 | torch.nn.init.zeros_(self.input_pair_embedding.bias) 92 | self.input_pair_stack = InputPairStack(**cfg.input_pair_stack) 93 | 94 | self.extra_input = extra_input 95 | if extra_input: 96 | self.extra_input_pair_embedding = Linear( 97 | cfg.input_pair_embedder.no_bins, 98 | cfg.evoformer_stack.c_z, 99 | init="final", 100 | ) 101 | self.extra_input_pair_stack = InputPairStack(**cfg.input_pair_stack) 102 | 103 | ####################### 104 | 105 | # 0 is padding, N is unknown residues, N + 1 is mask. 106 | self.n_tokens_embed = residue_constants.restype_num + 3 107 | self.pad_idx = 0 108 | self.unk_idx = self.n_tokens_embed - 2 109 | self.mask_idx = self.n_tokens_embed - 1 110 | self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0) 111 | 112 | self.trunk = FoldingTrunk(cfg.trunk) 113 | 114 | self.distogram_head = nn.Linear(c_z, self.distogram_bins) 115 | # self.ptm_head = nn.Linear(c_z, self.distogram_bins) 116 | # self.lm_head = nn.Linear(c_s, self.n_tokens_embed) 117 | self.lddt_bins = 50 118 | 119 | self.lddt_head = PerResidueLDDTCaPredictor( 120 | no_bins=self.lddt_bins, 121 | c_in=cfg.trunk.structure_module.c_s, 122 | c_hidden=cfg.lddt_head_hid_dim 123 | ) 124 | 125 | @staticmethod 126 | def _af2_to_esm(d: Alphabet): 127 | # Remember that t is shifted from residue_constants by 1 (0 is padding). 128 | esm_reorder = [d.padding_idx] + [ 129 | d.get_idx(v) for v in residue_constants.restypes_with_x 130 | ] 131 | return torch.tensor(esm_reorder) 132 | 133 | def _af2_idx_to_esm_idx(self, aa, mask): 134 | aa = (aa + 1).masked_fill(mask != 1, 0) 135 | return self.af2_to_esm.long()[aa] 136 | 137 | def _compute_language_model_representations( 138 | self, esmaa: torch.Tensor 139 | ) -> torch.Tensor: 140 | """Adds bos/eos tokens for the language model, since the structure module doesn't use these.""" 141 | batch_size = esmaa.size(0) 142 | 143 | bosi, eosi = self.esm_dict.cls_idx, self.esm_dict.eos_idx 144 | bos = esmaa.new_full((batch_size, 1), bosi) 145 | eos = esmaa.new_full((batch_size, 1), self.esm_dict.padding_idx) 146 | esmaa = torch.cat([bos, esmaa, eos], dim=1) 147 | # Use the first padding index as eos during inference. 148 | esmaa[range(batch_size), (esmaa != 1).sum(1)] = eosi 149 | 150 | res = self.esm( 151 | esmaa, 152 | repr_layers=range(self.esm.num_layers + 1), 153 | need_head_weights=self.cfg.use_esm_attn_map, 154 | ) 155 | esm_s = torch.stack( 156 | [v for _, v in sorted(res["representations"].items())], dim=2 157 | ) 158 | esm_s = esm_s[:, 1:-1] # B, L, nLayers, C 159 | esm_z = ( 160 | res["attentions"].permute(0, 4, 3, 1, 2).flatten(3, 4)[:, 1:-1, 1:-1, :] 161 | if self.cfg.use_esm_attn_map 162 | else None 163 | ) 164 | return esm_s, esm_z 165 | 166 | def _mask_inputs_to_esm(self, esmaa, pattern): 167 | new_esmaa = esmaa.clone() 168 | new_esmaa[pattern == 1] = self.esm_dict.mask_idx 169 | return new_esmaa 170 | 171 | def _get_input_pair_embeddings(self, dists, mask): 172 | 173 | mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) 174 | 175 | lower = torch.linspace( 176 | self.cfg.input_pair_embedder.min_bin, 177 | self.cfg.input_pair_embedder.max_bin, 178 | self.cfg.input_pair_embedder.no_bins, 179 | device=dists.device) 180 | dists = dists.unsqueeze(-1) 181 | inf = self.cfg.input_pair_embedder.inf 182 | upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) 183 | dgram = ((dists > lower) * (dists < upper)).type(dists.dtype) 184 | 185 | inp_z = self.input_pair_embedding(dgram * mask.unsqueeze(-1)) 186 | inp_z = self.input_pair_stack(inp_z, mask, chunk_size=None) 187 | return inp_z 188 | 189 | def _get_extra_input_pair_embeddings(self, dists, mask): 190 | 191 | mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) 192 | 193 | lower = torch.linspace( 194 | self.cfg.input_pair_embedder.min_bin, 195 | self.cfg.input_pair_embedder.max_bin, 196 | self.cfg.input_pair_embedder.no_bins, 197 | device=dists.device) 198 | dists = dists.unsqueeze(-1) 199 | inf = self.cfg.input_pair_embedder.inf 200 | upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) 201 | dgram = ((dists > lower) * (dists < upper)).type(dists.dtype) 202 | 203 | inp_z = self.extra_input_pair_embedding(dgram * mask.unsqueeze(-1)) 204 | inp_z = self.extra_input_pair_stack(inp_z, mask, chunk_size=None) 205 | return inp_z 206 | 207 | 208 | def forward( 209 | self, 210 | batch, 211 | prev_outputs=None, 212 | ): 213 | """Runs a forward pass given input tokens. Use `model.infer` to 214 | run inference from a sequence. 215 | 216 | Args: 217 | aa (torch.Tensor): Tensor containing indices corresponding to amino acids. Indices match 218 | openfold.np.residue_constants.restype_order_with_x. 219 | mask (torch.Tensor): Binary tensor with 1 meaning position is unmasked and 0 meaning position is masked. 220 | residx (torch.Tensor): Residue indices of amino acids. Will assume contiguous if not provided. 221 | masking_pattern (torch.Tensor): Optional masking to pass to the input. Binary tensor of the same size 222 | as `aa`. Positions with 1 will be masked. ESMFold sometimes produces different samples when 223 | different masks are provided. 224 | num_recycles (int): How many recycle iterations to perform. If None, defaults to training max 225 | recycles, which is 3. 226 | """ 227 | aa = batch['aatype'] 228 | mask = batch['seq_mask'] 229 | residx = batch['residue_index'] 230 | 231 | # === ESM === 232 | 233 | esmaa = self._af2_idx_to_esm_idx(aa, mask) 234 | esm_s, _ = self._compute_language_model_representations(esmaa) 235 | 236 | # Convert esm_s to the precision used by the trunk and 237 | # the structure module. These tensors may be a lower precision if, for example, 238 | # we're running the language model in fp16 precision. 239 | esm_s = esm_s.to(self.esm_s_combine.dtype) 240 | esm_s = esm_s.detach() 241 | 242 | # === preprocessing === 243 | esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2) 244 | s_s_0 = self.esm_s_mlp(esm_s) 245 | s_s_0 += self.embedding(aa) 246 | ####################### 247 | if 'noised_pseudo_beta_dists' in batch: 248 | inp_z = self._get_input_pair_embeddings( 249 | batch['noised_pseudo_beta_dists'], 250 | batch['pseudo_beta_mask'] 251 | ) 252 | inp_z = inp_z + self.input_time_embedding(self.input_time_projection(batch['t']))[:,None,None] 253 | else: # have to run the module, else DDP wont work 254 | B, L = batch['aatype'].shape 255 | inp_z = self._get_input_pair_embeddings( 256 | s_s_0.new_zeros(B, L, L), 257 | batch['pseudo_beta_mask'] * 0.0 258 | ) 259 | inp_z = inp_z + self.input_time_embedding(self.input_time_projection(inp_z.new_zeros(B)))[:,None,None] 260 | ########################## 261 | ############################# 262 | if self.extra_input: 263 | if 'extra_all_atom_positions' in batch: 264 | extra_pseudo_beta = pseudo_beta_fn(batch['aatype'], batch['extra_all_atom_positions'], None) 265 | extra_pseudo_beta_dists = torch.sum((extra_pseudo_beta.unsqueeze(-2) - extra_pseudo_beta.unsqueeze(-3)) ** 2, dim=-1)**0.5 266 | extra_inp_z = self._get_extra_input_pair_embeddings( 267 | extra_pseudo_beta_dists, 268 | batch['pseudo_beta_mask'], 269 | ) 270 | 271 | else: # otherwise DDP complains 272 | B, L = batch['aatype'].shape 273 | extra_inp_z = self._get_extra_input_pair_embeddings( 274 | inp_z.new_zeros(B, L, L), 275 | inp_z.new_zeros(B, L), 276 | ) * 0.0 277 | 278 | inp_z = inp_z + extra_inp_z 279 | ######################## 280 | 281 | 282 | 283 | s_z_0 = inp_z 284 | if prev_outputs is not None: 285 | s_s_0 = s_s_0 + self.trunk.recycle_s_norm(prev_outputs['s_s']) 286 | s_z_0 = s_z_0 + self.trunk.recycle_z_norm(prev_outputs['s_z']) 287 | s_z_0 = s_z_0 + self.trunk.recycle_disto(FoldingTrunk.distogram( 288 | prev_outputs['sm']["positions"][-1][:, :, :3], 289 | 3.375, 290 | 21.375, 291 | self.trunk.recycle_bins, 292 | )) 293 | 294 | else: 295 | s_s_0 = s_s_0 + self.trunk.recycle_s_norm(torch.zeros_like(s_s_0)) * 0.0 296 | s_z_0 = s_z_0 + self.trunk.recycle_z_norm(torch.zeros_like(s_z_0)) * 0.0 297 | s_z_0 = s_z_0 + self.trunk.recycle_disto(s_z_0.new_zeros(s_z_0.shape[:-2], dtype=torch.long)) * 0.0 298 | 299 | 300 | 301 | structure: dict = self.trunk( 302 | s_s_0, s_z_0, aa, residx, mask, no_recycles=0 # num_recycles 303 | ) 304 | disto_logits = self.distogram_head(structure["s_z"]) 305 | disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2 306 | structure["distogram_logits"] = disto_logits 307 | 308 | ''' 309 | lm_logits = self.lm_head(structure["s_s"]) 310 | structure["lm_logits"] = lm_logits 311 | ''' 312 | 313 | structure["aatype"] = aa 314 | make_atom14_masks(structure) 315 | 316 | for k in [ 317 | "atom14_atom_exists", 318 | "atom37_atom_exists", 319 | ]: 320 | structure[k] *= mask.unsqueeze(-1) 321 | structure["residue_index"] = residx 322 | lddt_head = self.lddt_head(structure['sm']["single"]) 323 | structure["lddt_logits"] = lddt_head 324 | plddt = categorical_lddt(lddt_head, bins=self.lddt_bins) 325 | structure["plddt"] = 100 * plddt 326 | # we predict plDDT between 0 and 1, scale to be between 0 and 100. 327 | 328 | ''' 329 | ptm_logits = self.ptm_head(structure["s_z"]) 330 | seqlen = mask.type(torch.int64).sum(1) 331 | structure["tm_logits"] = ptm_logits 332 | structure["ptm"] = torch.stack( 333 | [ 334 | compute_tm( 335 | batch_ptm_logits[None, :sl, :sl], 336 | max_bins=31, 337 | no_bins=self.distogram_bins, 338 | ) 339 | for batch_ptm_logits, sl in zip(ptm_logits, seqlen) 340 | ] 341 | ) 342 | structure.update( 343 | compute_predicted_aligned_error( 344 | ptm_logits, max_bin=31, no_bins=self.distogram_bins 345 | ) 346 | ) 347 | ''' 348 | 349 | structure["final_atom_positions"] = atom14_to_atom37(structure["sm"]["positions"][-1], batch) 350 | structure["final_affine_tensor"] = structure["sm"]["frames"][-1] 351 | if "name" in batch: structure["name"] = batch["name"] 352 | return structure 353 | 354 | @torch.no_grad() 355 | def infer( 356 | self, 357 | sequences: T.Union[str, T.List[str]], 358 | residx=None, 359 | masking_pattern: T.Optional[torch.Tensor] = None, 360 | num_recycles: T.Optional[int] = None, 361 | residue_index_offset: T.Optional[int] = 512, 362 | chain_linker: T.Optional[str] = "G" * 25, 363 | ): 364 | """Runs a forward pass given input sequences. 365 | 366 | Args: 367 | sequences (Union[str, List[str]]): A list of sequences to make predictions for. Multimers can also be passed in, 368 | each chain should be separated by a ':' token (e.g. "::"). 369 | residx (torch.Tensor): Residue indices of amino acids. Will assume contiguous if not provided. 370 | masking_pattern (torch.Tensor): Optional masking to pass to the input. Binary tensor of the same size 371 | as `aa`. Positions with 1 will be masked. ESMFold sometimes produces different samples when 372 | different masks are provided. 373 | num_recycles (int): How many recycle iterations to perform. If None, defaults to training max 374 | recycles (cfg.trunk.max_recycles), which is 4. 375 | residue_index_offset (int): Residue index separation between chains if predicting a multimer. Has no effect on 376 | single chain predictions. Default: 512. 377 | chain_linker (str): Linker to use between chains if predicting a multimer. Has no effect on single chain 378 | predictions. Default: length-25 poly-G ("G" * 25). 379 | """ 380 | if isinstance(sequences, str): 381 | sequences = [sequences] 382 | 383 | aatype, mask, _residx, linker_mask, chain_index = batch_encode_sequences( 384 | sequences, residue_index_offset, chain_linker 385 | ) 386 | 387 | if residx is None: 388 | residx = _residx 389 | elif not isinstance(residx, torch.Tensor): 390 | residx = collate_dense_tensors(residx) 391 | 392 | aatype, mask, residx, linker_mask = map( 393 | lambda x: x.to(self.device), (aatype, mask, residx, linker_mask) 394 | ) 395 | 396 | output = self.forward( 397 | aatype, 398 | mask=mask, 399 | residx=residx, 400 | masking_pattern=masking_pattern, 401 | num_recycles=num_recycles, 402 | ) 403 | 404 | output["atom37_atom_exists"] = output[ 405 | "atom37_atom_exists" 406 | ] * linker_mask.unsqueeze(2) 407 | 408 | output["mean_plddt"] = (output["plddt"] * output["atom37_atom_exists"]).sum(dim=(1, 2)) / output["atom37_atom_exists"].sum(dim=(1, 2)) 409 | output["chain_index"] = chain_index 410 | 411 | return output 412 | 413 | def output_to_pdb(self, output: T.Dict) -> T.List[str]: 414 | """Returns the pbd (file) string from the model given the model output.""" 415 | return output_to_pdb(output) 416 | 417 | def infer_pdbs(self, seqs: T.List[str], *args, **kwargs) -> T.List[str]: 418 | """Returns list of pdb (files) strings from the model given a list of input sequences.""" 419 | output = self.infer(seqs, *args, **kwargs) 420 | return self.output_to_pdb(output) 421 | 422 | def infer_pdb(self, sequence: str, *args, **kwargs) -> str: 423 | """Returns the pdb (file) string from the model given an input sequence.""" 424 | return self.infer_pdbs([sequence], *args, **kwargs)[0] 425 | 426 | def set_chunk_size(self, chunk_size: T.Optional[int]): 427 | # This parameter means the axial attention will be computed 428 | # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2). 429 | # It's equivalent to running a for loop over chunks of the dimension we're iterative over, 430 | # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-lengthed chunks. 431 | # Setting the value to None will return to default behavior, disable chunking. 432 | self.trunk.set_chunk_size(chunk_size) 433 | 434 | @property 435 | def device(self): 436 | return self.esm_s_combine.device 437 | -------------------------------------------------------------------------------- /alphaflow/model/input_stack.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from functools import partial 16 | from typing import Optional 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from openfold.model.primitives import LayerNorm 22 | from openfold.model.dropout import ( 23 | DropoutRowwise, 24 | DropoutColumnwise, 25 | ) 26 | from openfold.model.pair_transition import PairTransition 27 | from openfold.model.triangular_attention import ( 28 | TriangleAttentionStartingNode, 29 | TriangleAttentionEndingNode, 30 | ) 31 | from openfold.model.triangular_multiplicative_update import ( 32 | TriangleMultiplicationOutgoing, 33 | TriangleMultiplicationIncoming, 34 | ) 35 | from openfold.utils.checkpointing import checkpoint_blocks 36 | from openfold.utils.chunk_utils import ChunkSizeTuner 37 | from openfold.utils.tensor_utils import add 38 | 39 | 40 | class InputPairStackBlock(nn.Module): 41 | def __init__( 42 | self, 43 | c_t: int, 44 | c_hidden_tri_att: int, 45 | c_hidden_tri_mul: int, 46 | no_heads: int, 47 | pair_transition_n: int, 48 | dropout_rate: float, 49 | inf: float, 50 | **kwargs, 51 | ): 52 | super(InputPairStackBlock, self).__init__() 53 | 54 | self.c_t = c_t 55 | self.c_hidden_tri_att = c_hidden_tri_att 56 | self.c_hidden_tri_mul = c_hidden_tri_mul 57 | self.no_heads = no_heads 58 | self.pair_transition_n = pair_transition_n 59 | self.dropout_rate = dropout_rate 60 | self.inf = inf 61 | 62 | self.dropout_row = DropoutRowwise(self.dropout_rate) 63 | self.dropout_col = DropoutColumnwise(self.dropout_rate) 64 | 65 | self.tri_att_start = TriangleAttentionStartingNode( 66 | self.c_t, 67 | self.c_hidden_tri_att, 68 | self.no_heads, 69 | inf=inf, 70 | ) 71 | self.tri_att_end = TriangleAttentionEndingNode( 72 | self.c_t, 73 | self.c_hidden_tri_att, 74 | self.no_heads, 75 | inf=inf, 76 | ) 77 | 78 | self.tri_mul_out = TriangleMultiplicationOutgoing( 79 | self.c_t, 80 | self.c_hidden_tri_mul, 81 | ) 82 | self.tri_mul_in = TriangleMultiplicationIncoming( 83 | self.c_t, 84 | self.c_hidden_tri_mul, 85 | ) 86 | 87 | self.pair_transition = PairTransition( 88 | self.c_t, 89 | self.pair_transition_n, 90 | ) 91 | 92 | def forward(self, 93 | z: torch.Tensor, 94 | mask: torch.Tensor, 95 | chunk_size: Optional[int] = None, 96 | use_lma: bool = False, 97 | inplace_safe: bool = False, 98 | _mask_trans: bool = True, 99 | _attn_chunk_size: Optional[int] = None, 100 | ): 101 | if(_attn_chunk_size is None): 102 | _attn_chunk_size = chunk_size 103 | 104 | single = z # single_templates[i] 105 | single_mask = mask # single_templates_masks[i] 106 | 107 | single = add(single, 108 | self.dropout_row( 109 | self.tri_att_start( 110 | single, 111 | chunk_size=_attn_chunk_size, 112 | mask=single_mask, 113 | use_lma=use_lma, 114 | inplace_safe=inplace_safe, 115 | ) 116 | ), 117 | inplace_safe, 118 | ) 119 | 120 | single = add(single, 121 | self.dropout_col( 122 | self.tri_att_end( 123 | single, 124 | chunk_size=_attn_chunk_size, 125 | mask=single_mask, 126 | use_lma=use_lma, 127 | inplace_safe=inplace_safe, 128 | ) 129 | ), 130 | inplace_safe, 131 | ) 132 | 133 | tmu_update = self.tri_mul_out( 134 | single, 135 | mask=single_mask, 136 | inplace_safe=inplace_safe, 137 | _add_with_inplace=True, 138 | ) 139 | if(not inplace_safe): 140 | single = single + self.dropout_row(tmu_update) 141 | else: 142 | single = tmu_update 143 | 144 | del tmu_update 145 | 146 | tmu_update = self.tri_mul_in( 147 | single, 148 | mask=single_mask, 149 | inplace_safe=inplace_safe, 150 | _add_with_inplace=True, 151 | ) 152 | if(not inplace_safe): 153 | single = single + self.dropout_row(tmu_update) 154 | else: 155 | single = tmu_update 156 | 157 | del tmu_update 158 | 159 | single = add(single, 160 | self.pair_transition( 161 | single, 162 | mask=single_mask if _mask_trans else None, 163 | chunk_size=chunk_size, 164 | ), 165 | inplace_safe, 166 | ) 167 | 168 | return single 169 | 170 | 171 | class InputPairStack(nn.Module): 172 | """ 173 | Implements Algorithm 16. 174 | """ 175 | def __init__( 176 | self, 177 | c_t, 178 | c_hidden_tri_att, 179 | c_hidden_tri_mul, 180 | no_blocks, 181 | no_heads, 182 | pair_transition_n, 183 | dropout_rate, 184 | blocks_per_ckpt, 185 | tune_chunk_size: bool = False, 186 | inf=1e9, 187 | **kwargs, 188 | ): 189 | """ 190 | Args: 191 | c_t: 192 | Template embedding channel dimension 193 | c_hidden_tri_att: 194 | Per-head hidden dimension for triangular attention 195 | c_hidden_tri_att: 196 | Hidden dimension for triangular multiplication 197 | no_blocks: 198 | Number of blocks in the stack 199 | pair_transition_n: 200 | Scale of pair transition (Alg. 15) hidden dimension 201 | dropout_rate: 202 | Dropout rate used throughout the stack 203 | blocks_per_ckpt: 204 | Number of blocks per activation checkpoint. None disables 205 | activation checkpointing 206 | """ 207 | super(InputPairStack, self).__init__() 208 | 209 | self.blocks_per_ckpt = blocks_per_ckpt 210 | 211 | self.blocks = nn.ModuleList() 212 | for _ in range(no_blocks): 213 | block = InputPairStackBlock( 214 | c_t=c_t, 215 | c_hidden_tri_att=c_hidden_tri_att, 216 | c_hidden_tri_mul=c_hidden_tri_mul, 217 | no_heads=no_heads, 218 | pair_transition_n=pair_transition_n, 219 | dropout_rate=dropout_rate, 220 | inf=inf, 221 | ) 222 | self.blocks.append(block) 223 | 224 | self.layer_norm = LayerNorm(c_t) 225 | 226 | self.tune_chunk_size = tune_chunk_size 227 | self.chunk_size_tuner = None 228 | if(tune_chunk_size): 229 | self.chunk_size_tuner = ChunkSizeTuner() 230 | 231 | def forward( 232 | self, 233 | t: torch.tensor, 234 | mask: torch.tensor, 235 | chunk_size: int, 236 | use_lma: bool = False, 237 | inplace_safe: bool = False, 238 | _mask_trans: bool = True, 239 | ): 240 | """ 241 | Args: 242 | t: 243 | [*, N_templ, N_res, N_res, C_t] template embedding 244 | mask: 245 | [*, N_templ, N_res, N_res] mask 246 | Returns: 247 | [*, N_templ, N_res, N_res, C_t] template embedding update 248 | """ 249 | 250 | blocks = [ 251 | partial( 252 | b, 253 | mask=mask, 254 | chunk_size=chunk_size, 255 | use_lma=use_lma, 256 | inplace_safe=inplace_safe, 257 | _mask_trans=_mask_trans, 258 | ) 259 | for b in self.blocks 260 | ] 261 | 262 | if(chunk_size is not None and self.chunk_size_tuner is not None): 263 | assert(not self.training) 264 | tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size( 265 | representative_fn=blocks[0], 266 | args=(t.clone(),), 267 | min_chunk_size=chunk_size, 268 | ) 269 | blocks = [ 270 | partial(b, 271 | chunk_size=tuned_chunk_size, 272 | _attn_chunk_size=max(chunk_size, tuned_chunk_size // 4), 273 | ) for b in blocks 274 | ] 275 | 276 | t, = checkpoint_blocks( 277 | blocks=blocks, 278 | args=(t,), 279 | blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, 280 | ) 281 | 282 | t = self.layer_norm(t) 283 | 284 | return t -------------------------------------------------------------------------------- /alphaflow/model/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import typing as T 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from einops import rearrange, repeat 11 | from torch import nn 12 | 13 | class GaussianFourierProjection(nn.Module): 14 | """Gaussian Fourier embeddings for noise levels. 15 | from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/models/layerspp.py#L32 16 | """ 17 | 18 | def __init__(self, embedding_size=256, scale=1.0): 19 | super().__init__() 20 | self.W = nn.Parameter( 21 | torch.randn(embedding_size // 2) * scale, requires_grad=False 22 | ) 23 | 24 | def forward(self, x): 25 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 26 | emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 27 | return emb 28 | 29 | class Attention(nn.Module): 30 | def __init__(self, embed_dim, num_heads, head_width, gated=False): 31 | super().__init__() 32 | assert embed_dim == num_heads * head_width 33 | 34 | self.embed_dim = embed_dim 35 | self.num_heads = num_heads 36 | self.head_width = head_width 37 | 38 | self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False) 39 | self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True) 40 | self.gated = gated 41 | if gated: 42 | self.g_proj = nn.Linear(embed_dim, embed_dim) 43 | torch.nn.init.zeros_(self.g_proj.weight) 44 | torch.nn.init.ones_(self.g_proj.bias) 45 | 46 | self.rescale_factor = self.head_width**-0.5 47 | 48 | torch.nn.init.zeros_(self.o_proj.bias) 49 | 50 | def forward(self, x, mask=None, bias=None, indices=None): 51 | """ 52 | Basic self attention with optional mask and external pairwise bias. 53 | To handle sequences of different lengths, use mask. 54 | 55 | Inputs: 56 | x: batch of input sequneces (.. x L x C) 57 | mask: batch of boolean masks where 1=valid, 0=padding position (.. x L_k). optional. 58 | bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads). optional. 59 | 60 | Outputs: 61 | sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads) 62 | """ 63 | 64 | t = rearrange(self.proj(x), "... l (h c) -> ... h l c", h=self.num_heads) 65 | q, k, v = t.chunk(3, dim=-1) 66 | 67 | q = self.rescale_factor * q 68 | a = torch.einsum("...qc,...kc->...qk", q, k) 69 | 70 | # Add external attention bias. 71 | if bias is not None: 72 | a = a + rearrange(bias, "... lq lk h -> ... h lq lk") 73 | 74 | # Do not attend to padding tokens. 75 | if mask is not None: 76 | mask = repeat( 77 | mask, "... lk -> ... h lq lk", h=self.num_heads, lq=q.shape[-2] 78 | ) 79 | a = a.masked_fill(mask == False, -np.inf) 80 | 81 | a = F.softmax(a, dim=-1) 82 | 83 | y = torch.einsum("...hqk,...hkc->...qhc", a, v) 84 | y = rearrange(y, "... h c -> ... (h c)", h=self.num_heads) 85 | 86 | if self.gated: 87 | y = self.g_proj(x).sigmoid() * y 88 | y = self.o_proj(y) 89 | 90 | return y, rearrange(a, "... lq lk h -> ... h lq lk") 91 | 92 | 93 | class Dropout(nn.Module): 94 | """ 95 | Implementation of dropout with the ability to share the dropout mask 96 | along a particular dimension. 97 | """ 98 | 99 | def __init__(self, r: float, batch_dim: T.Union[int, T.List[int]]): 100 | super(Dropout, self).__init__() 101 | 102 | self.r = r 103 | if type(batch_dim) == int: 104 | batch_dim = [batch_dim] 105 | self.batch_dim = batch_dim 106 | self.dropout = nn.Dropout(self.r) 107 | 108 | def forward(self, x: torch.Tensor) -> torch.Tensor: 109 | shape = list(x.shape) 110 | if self.batch_dim is not None: 111 | for bd in self.batch_dim: 112 | shape[bd] = 1 113 | return x * self.dropout(x.new_ones(shape)) 114 | 115 | 116 | class SequenceToPair(nn.Module): 117 | def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim): 118 | super().__init__() 119 | 120 | self.layernorm = nn.LayerNorm(sequence_state_dim) 121 | self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True) 122 | self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True) 123 | 124 | torch.nn.init.zeros_(self.proj.bias) 125 | torch.nn.init.zeros_(self.o_proj.bias) 126 | 127 | def forward(self, sequence_state): 128 | """ 129 | Inputs: 130 | sequence_state: B x L x sequence_state_dim 131 | 132 | Output: 133 | pairwise_state: B x L x L x pairwise_state_dim 134 | 135 | Intermediate state: 136 | B x L x L x 2*inner_dim 137 | """ 138 | 139 | assert len(sequence_state.shape) == 3 140 | 141 | s = self.layernorm(sequence_state) 142 | s = self.proj(s) 143 | q, k = s.chunk(2, dim=-1) 144 | 145 | prod = q[:, None, :, :] * k[:, :, None, :] 146 | diff = q[:, None, :, :] - k[:, :, None, :] 147 | 148 | x = torch.cat([prod, diff], dim=-1) 149 | x = self.o_proj(x) 150 | 151 | return x 152 | 153 | 154 | class PairToSequence(nn.Module): 155 | def __init__(self, pairwise_state_dim, num_heads): 156 | super().__init__() 157 | 158 | self.layernorm = nn.LayerNorm(pairwise_state_dim) 159 | self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False) 160 | 161 | def forward(self, pairwise_state): 162 | """ 163 | Inputs: 164 | pairwise_state: B x L x L x pairwise_state_dim 165 | 166 | Output: 167 | pairwise_bias: B x L x L x num_heads 168 | """ 169 | assert len(pairwise_state.shape) == 4 170 | z = self.layernorm(pairwise_state) 171 | pairwise_bias = self.linear(z) 172 | return pairwise_bias 173 | 174 | 175 | class ResidueMLP(nn.Module): 176 | def __init__(self, embed_dim, inner_dim, norm=nn.LayerNorm, dropout=0): 177 | super().__init__() 178 | 179 | self.mlp = nn.Sequential( 180 | norm(embed_dim), 181 | nn.Linear(embed_dim, inner_dim), 182 | nn.ReLU(), 183 | nn.Linear(inner_dim, embed_dim), 184 | nn.Dropout(dropout), 185 | ) 186 | 187 | def forward(self, x): 188 | return x + self.mlp(x) 189 | -------------------------------------------------------------------------------- /alphaflow/model/tri_self_attn_block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import torch 6 | from openfold.model.triangular_attention import ( 7 | TriangleAttentionEndingNode, 8 | TriangleAttentionStartingNode, 9 | ) 10 | from openfold.model.triangular_multiplicative_update import ( 11 | TriangleMultiplicationIncoming, 12 | TriangleMultiplicationOutgoing, 13 | ) 14 | from torch import nn 15 | 16 | from .layers import ( 17 | Attention, 18 | Dropout, 19 | PairToSequence, 20 | ResidueMLP, 21 | SequenceToPair, 22 | ) 23 | 24 | 25 | class TriangularSelfAttentionBlock(nn.Module): 26 | def __init__( 27 | self, 28 | sequence_state_dim, 29 | pairwise_state_dim, 30 | sequence_head_width, 31 | pairwise_head_width, 32 | dropout=0, 33 | **__kwargs, 34 | ): 35 | super().__init__() 36 | 37 | assert sequence_state_dim % sequence_head_width == 0 38 | assert pairwise_state_dim % pairwise_head_width == 0 39 | sequence_num_heads = sequence_state_dim // sequence_head_width 40 | pairwise_num_heads = pairwise_state_dim // pairwise_head_width 41 | assert sequence_state_dim == sequence_num_heads * sequence_head_width 42 | assert pairwise_state_dim == pairwise_num_heads * pairwise_head_width 43 | assert pairwise_state_dim % 2 == 0 44 | 45 | self.sequence_state_dim = sequence_state_dim 46 | self.pairwise_state_dim = pairwise_state_dim 47 | 48 | self.layernorm_1 = nn.LayerNorm(sequence_state_dim) 49 | 50 | self.sequence_to_pair = SequenceToPair( 51 | sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim 52 | ) 53 | self.pair_to_sequence = PairToSequence(pairwise_state_dim, sequence_num_heads) 54 | 55 | self.seq_attention = Attention( 56 | sequence_state_dim, sequence_num_heads, sequence_head_width, gated=True 57 | ) 58 | self.tri_mul_out = TriangleMultiplicationOutgoing( 59 | pairwise_state_dim, 60 | pairwise_state_dim, 61 | ) 62 | self.tri_mul_in = TriangleMultiplicationIncoming( 63 | pairwise_state_dim, 64 | pairwise_state_dim, 65 | ) 66 | self.tri_att_start = TriangleAttentionStartingNode( 67 | pairwise_state_dim, 68 | pairwise_head_width, 69 | pairwise_num_heads, 70 | inf=1e9, 71 | ) # type: ignore 72 | self.tri_att_end = TriangleAttentionEndingNode( 73 | pairwise_state_dim, 74 | pairwise_head_width, 75 | pairwise_num_heads, 76 | inf=1e9, 77 | ) # type: ignore 78 | 79 | self.mlp_seq = ResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=dropout) 80 | self.mlp_pair = ResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=dropout) 81 | 82 | assert dropout < 0.4 83 | self.drop = nn.Dropout(dropout) 84 | self.row_drop = Dropout(dropout * 2, 2) 85 | self.col_drop = Dropout(dropout * 2, 1) 86 | 87 | torch.nn.init.zeros_(self.tri_mul_in.linear_z.weight) 88 | torch.nn.init.zeros_(self.tri_mul_in.linear_z.bias) 89 | torch.nn.init.zeros_(self.tri_mul_out.linear_z.weight) 90 | torch.nn.init.zeros_(self.tri_mul_out.linear_z.bias) 91 | torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.weight) 92 | torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.bias) 93 | torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.weight) 94 | torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.bias) 95 | 96 | torch.nn.init.zeros_(self.sequence_to_pair.o_proj.weight) 97 | torch.nn.init.zeros_(self.sequence_to_pair.o_proj.bias) 98 | torch.nn.init.zeros_(self.pair_to_sequence.linear.weight) 99 | torch.nn.init.zeros_(self.seq_attention.o_proj.weight) 100 | torch.nn.init.zeros_(self.seq_attention.o_proj.bias) 101 | torch.nn.init.zeros_(self.mlp_seq.mlp[-2].weight) 102 | torch.nn.init.zeros_(self.mlp_seq.mlp[-2].bias) 103 | torch.nn.init.zeros_(self.mlp_pair.mlp[-2].weight) 104 | torch.nn.init.zeros_(self.mlp_pair.mlp[-2].bias) 105 | 106 | def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs): 107 | """ 108 | Inputs: 109 | sequence_state: B x L x sequence_state_dim 110 | pairwise_state: B x L x L x pairwise_state_dim 111 | mask: B x L boolean tensor of valid positions 112 | 113 | Output: 114 | sequence_state: B x L x sequence_state_dim 115 | pairwise_state: B x L x L x pairwise_state_dim 116 | """ 117 | assert len(sequence_state.shape) == 3 118 | assert len(pairwise_state.shape) == 4 119 | if mask is not None: 120 | assert len(mask.shape) == 2 121 | 122 | batch_dim, seq_dim, sequence_state_dim = sequence_state.shape 123 | pairwise_state_dim = pairwise_state.shape[3] 124 | assert sequence_state_dim == self.sequence_state_dim 125 | assert pairwise_state_dim == self.pairwise_state_dim 126 | assert batch_dim == pairwise_state.shape[0] 127 | assert seq_dim == pairwise_state.shape[1] 128 | assert seq_dim == pairwise_state.shape[2] 129 | 130 | # Update sequence state 131 | bias = self.pair_to_sequence(pairwise_state) 132 | 133 | # Self attention with bias + mlp. 134 | y = self.layernorm_1(sequence_state) 135 | y, _ = self.seq_attention(y, mask=mask, bias=bias) 136 | sequence_state = sequence_state + self.drop(y) 137 | sequence_state = self.mlp_seq(sequence_state) 138 | 139 | # Update pairwise state 140 | pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state) 141 | 142 | # Axial attention with triangular bias. 143 | tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None 144 | pairwise_state = pairwise_state + self.row_drop( 145 | self.tri_mul_out(pairwise_state, mask=tri_mask) 146 | ) 147 | pairwise_state = pairwise_state + self.col_drop( 148 | self.tri_mul_in(pairwise_state, mask=tri_mask) 149 | ) 150 | pairwise_state = pairwise_state + self.row_drop( 151 | self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size) 152 | ) 153 | pairwise_state = pairwise_state + self.col_drop( 154 | self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size) 155 | ) 156 | 157 | # MLP over pairs. 158 | pairwise_state = self.mlp_pair(pairwise_state) 159 | 160 | return sequence_state, pairwise_state 161 | -------------------------------------------------------------------------------- /alphaflow/model/trunk.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import typing as T 6 | from functools import partial 7 | from openfold.utils.checkpointing import checkpoint_blocks 8 | import torch 9 | import torch.nn as nn 10 | from openfold.model.structure_module import StructureModule 11 | from .tri_self_attn_block import TriangularSelfAttentionBlock 12 | 13 | def get_axial_mask(mask): 14 | """ 15 | Helper to convert B x L mask of valid positions to axial mask used 16 | in row column attentions. 17 | 18 | Input: 19 | mask: B x L tensor of booleans 20 | 21 | Output: 22 | mask: B x L x L tensor of booleans 23 | """ 24 | 25 | if mask is None: 26 | return None 27 | assert len(mask.shape) == 2 28 | batch_dim, seq_dim = mask.shape 29 | m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim) 30 | m = m.reshape(batch_dim * seq_dim, seq_dim) 31 | return m 32 | 33 | 34 | class RelativePosition(nn.Module): 35 | def __init__(self, bins, pairwise_state_dim): 36 | super().__init__() 37 | self.bins = bins 38 | 39 | # Note an additional offset is used so that the 0th position 40 | # is reserved for masked pairs. 41 | self.embedding = torch.nn.Embedding(2 * bins + 2, pairwise_state_dim) 42 | 43 | def forward(self, residue_index, mask=None): 44 | """ 45 | Input: 46 | residue_index: B x L tensor of indices (dytpe=torch.long) 47 | mask: B x L tensor of booleans 48 | 49 | Output: 50 | pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings 51 | """ 52 | 53 | assert residue_index.dtype == torch.long 54 | if mask is not None: 55 | assert residue_index.shape == mask.shape 56 | 57 | diff = residue_index[:, None, :] - residue_index[:, :, None] 58 | diff = diff.clamp(-self.bins, self.bins) 59 | diff = diff + self.bins + 1 # Add 1 to adjust for padding index. 60 | 61 | if mask is not None: 62 | mask = mask[:, None, :] * mask[:, :, None] 63 | diff[mask == False] = 0 64 | 65 | output = self.embedding(diff) 66 | return output 67 | 68 | 69 | class FoldingTrunk(nn.Module): 70 | def __init__(self, cfg): 71 | super().__init__() 72 | self.cfg = cfg 73 | # assert self.cfg.max_recycles > 0 74 | 75 | c_s = self.cfg.sequence_state_dim 76 | c_z = self.cfg.pairwise_state_dim 77 | 78 | assert c_s % self.cfg.sequence_head_width == 0 79 | assert c_z % self.cfg.pairwise_head_width == 0 80 | block = TriangularSelfAttentionBlock 81 | 82 | self.pairwise_positional_embedding = RelativePosition(self.cfg.position_bins, c_z) 83 | 84 | self.blocks = nn.ModuleList( 85 | [ 86 | block( 87 | sequence_state_dim=c_s, 88 | pairwise_state_dim=c_z, 89 | sequence_head_width=self.cfg.sequence_head_width, 90 | pairwise_head_width=self.cfg.pairwise_head_width, 91 | dropout=self.cfg.dropout, 92 | ) 93 | for i in range(self.cfg.num_blocks) 94 | ] 95 | ) 96 | 97 | self.recycle_bins = 15 98 | self.recycle_s_norm = nn.LayerNorm(c_s) 99 | self.recycle_z_norm = nn.LayerNorm(c_z) 100 | self.recycle_disto = nn.Embedding(self.recycle_bins, c_z) 101 | self.recycle_disto.weight[0].detach().zero_() 102 | 103 | 104 | self.structure_module = StructureModule(**self.cfg.structure_module) # type: ignore 105 | self.trunk2sm_s = nn.Linear(c_s, self.structure_module.c_s) 106 | self.trunk2sm_z = nn.Linear(c_z, self.structure_module.c_z) 107 | 108 | self.chunk_size = self.cfg.chunk_size 109 | 110 | def set_chunk_size(self, chunk_size): 111 | # This parameter means the axial attention will be computed 112 | # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2). 113 | # It's equivalent to running a for loop over chunks of the dimension we're iterative over, 114 | # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-lengthed chunks. 115 | self.chunk_size = chunk_size 116 | 117 | 118 | def _prep_blocks(self, mask, residue_index, chunk_size): 119 | blocks = [ 120 | partial( 121 | b, 122 | mask=mask, 123 | residue_index=residue_index, 124 | chunk_size=chunk_size, 125 | ) 126 | for b in self.blocks 127 | ] 128 | return blocks 129 | def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles: T.Optional[int] = None): 130 | """ 131 | Inputs: 132 | seq_feats: B x L x C tensor of sequence features 133 | pair_feats: B x L x L x C tensor of pair features 134 | residx: B x L long tensor giving the position in the sequence 135 | mask: B x L boolean tensor indicating valid residues 136 | 137 | Output: 138 | predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object 139 | """ 140 | 141 | device = seq_feats.device 142 | s_s_0 = seq_feats 143 | s_z_0 = pair_feats 144 | 145 | 146 | def trunk_iter(s, z, residx, mask): 147 | z = z + self.pairwise_positional_embedding(residx, mask=mask) 148 | blocks = self._prep_blocks(mask=mask, residue_index=residx, chunk_size=self.chunk_size) 149 | 150 | s, z = checkpoint_blocks( 151 | blocks, 152 | args=(s, z), 153 | blocks_per_ckpt=1, 154 | ) 155 | return s, z 156 | 157 | 158 | s_s, s_z = trunk_iter(s_s_0, s_z_0, residx, mask) 159 | 160 | # === Structure module === 161 | structure = {} 162 | 163 | structure["sm"] = self.structure_module( 164 | {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)}, 165 | true_aa, 166 | mask.float(), 167 | ) 168 | 169 | 170 | 171 | assert isinstance(structure, dict) # type: ignore 172 | structure["s_s"] = s_s 173 | structure["s_z"] = s_z 174 | 175 | return structure 176 | 177 | @staticmethod 178 | def distogram(coords, min_bin, max_bin, num_bins): 179 | # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates. 180 | boundaries = torch.linspace( 181 | min_bin, 182 | max_bin, 183 | num_bins - 1, 184 | device=coords.device, 185 | ) 186 | boundaries = boundaries**2 187 | N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)] 188 | # Infer CB coordinates. 189 | b = CA - N 190 | c = C - CA 191 | a = b.cross(c, dim=-1) 192 | CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA 193 | dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True) 194 | bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L] 195 | return bins 196 | -------------------------------------------------------------------------------- /alphaflow/utils/diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | #https://github.com/scipy/scipy/blob/main/scipy/spatial/transform/_rotation.pyx 5 | def rmsdalign(a, b, weights=None): # alignes B to A # [*, N, 3] 6 | B = a.shape[:-2] 7 | N = a.shape[-2] 8 | if weights == None: 9 | weights = a.new_ones(*B, N) 10 | weights = weights.unsqueeze(-1) 11 | a_mean = (a * weights).sum(-2, keepdims=True) / weights.sum(-2, keepdims=True) 12 | a = a - a_mean 13 | b_mean = (b * weights).sum(-2, keepdims=True) / weights.sum(-2, keepdims=True) 14 | b = b - b_mean 15 | B = torch.einsum('...ji,...jk->...ik', weights * a, b) 16 | u, s, vh = torch.linalg.svd(B) 17 | 18 | # Correct improper rotation if necessary (as in Kabsch algorithm) 19 | ''' 20 | if torch.linalg.det(u @ vh) < 0: 21 | s[-1] = -s[-1] 22 | u[:, -1] = -u[:, -1] 23 | ''' 24 | sgn = torch.sign(torch.linalg.det(u @ vh)) 25 | s[...,-1] *= sgn 26 | u[...,:,-1] *= sgn.unsqueeze(-1) 27 | C = u @ vh # c rotates B to A 28 | return b @ C.mT + a_mean 29 | 30 | def kabsch_rmsd(a, b, weights=None): 31 | B = a.shape[:-2] 32 | N = a.shape[-2] 33 | if weights == None: 34 | weights = a.new_ones(*B, N) 35 | b_aligned = rmsdalign(a, b, weights) 36 | out = torch.square(b_aligned - a).sum(-1) 37 | out = (out * weights).sum(-1) / weights.sum(-1) 38 | return torch.sqrt(out) 39 | 40 | class HarmonicPrior: 41 | def __init__(self, N = 256, a =3/(3.8**2)): 42 | J = torch.zeros(N, N) 43 | for i, j in zip(np.arange(N-1), np.arange(1, N)): 44 | J[i,i] += a 45 | J[j,j] += a 46 | J[i,j] = J[j,i] = -a 47 | D, P = torch.linalg.eigh(J) 48 | D_inv = 1/D 49 | D_inv[0] = 0 50 | self.P, self.D_inv = P, D_inv 51 | self.N = N 52 | 53 | def to(self, device): 54 | self.P = self.P.to(device) 55 | self.D_inv = self.D_inv.to(device) 56 | 57 | def sample(self, batch_dims=()): 58 | return self.P @ (torch.sqrt(self.D_inv)[:,None] * torch.randn(*batch_dims, self.N, 3, device=self.P.device)) 59 | 60 | 61 | ''' 62 | def transition_matrix(N_bins=1000, X_max=5): 63 | bins = torch.linspace(0, X_max, N_bins+1, dtype=torch.float64) 64 | cbins = (bins[1:] + bins[:-1]) / 2 65 | bw = cbins[1] - cbins[0] 66 | mu = 2 / cbins - cbins 67 | idx = torch.arange(N_bins) 68 | mat = torch.zeros((N_bins, N_bins), dtype=torch.float64) 69 | mat[idx, idx] = -2 / bw**2 70 | 71 | mat[idx[1:], idx[:-1]] = mu[idx[:-1]] / 2 / bw + 1 / bw**2 # M_{i+1,i} = -mu[i]/2 72 | mat[idx[:-1], idx[1:]] = -mu[idx[1:]] / 2 / bw + 1 / bw**2 # M_{i+1,i} = mu[i]/2 73 | mat[idx, idx] -= mat.sum(0) # fix edges 74 | 75 | return mat, bins 76 | 77 | _mat, _bins = transition_matrix() 78 | _D, _Q = torch.linalg.eig(_mat) 79 | _Q_inv = torch.linalg.inv(_Q) 80 | _sigmas = torch.from_numpy(np.load('chain_stats.npy')) 81 | 82 | 83 | def add_noise(dists, residue_index, mask, t, device='cpu'): 84 | 85 | sigmas, Q, D, Q_inv, bins = _sigmas.to(device), _Q.to(device), _D.to(device), _Q_inv.to(device), _bins.to(device) 86 | 87 | mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) 88 | # dists = torch.sum((pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2, dim=-1)**0.5 89 | offsets = torch.abs(residue_index.unsqueeze(-1) - residue_index.unsqueeze(-2)) 90 | sigmas = sigmas[offsets] 91 | ndists = dists / sigmas * mask 92 | 93 | bindists = (ndists.unsqueeze(-1) > bins).sum(-1) 94 | bindists = torch.clamp(bindists, 0, 999) 95 | 96 | P = ((Q*torch.exp(D*t)) @ Q_inv).T # now we have a row stochatic matrix P_ij = P(i -> j) 97 | probs = P.real[bindists] # this is equivalent to left multiplication by basis e_i 98 | 99 | probs = torch.clamp(probs / probs.sum(-1, keepdims=True), 0, 1) 100 | newbindists = Categorical(probs, validate_args=False).sample() 101 | cbins = (bins[1:] + bins[:-1]) / 2 102 | newdists = cbins[newbindists] * mask * sigmas 103 | 104 | return newdists.float() 105 | 106 | def sample_posterior(orig_dists, noisy_dists, residue_index, mask, s, t, device='cpu'): 107 | sigmas, Q, D, Q_inv, bins = _sigmas.to(device), _Q.to(device), _D.to(device), _Q_inv.to(device), _bins.to(device) 108 | mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) 109 | 110 | P_0s = ((Q*torch.exp(D*s)) @ Q_inv).T 111 | P_st = ((Q*torch.exp(D*(t-s))) @ Q_inv).T 112 | 113 | offsets = torch.abs(residue_index.unsqueeze(-1) - residue_index.unsqueeze(-2)) 114 | sigmas = sigmas[offsets] 115 | 116 | orig_ndists = orig_dists / sigmas * mask 117 | orig_bindists = (orig_ndists.unsqueeze(-1) > bins).sum(-1) 118 | orig_bindists = torch.clamp(orig_bindists, 0, 999) 119 | 120 | noisy_ndists = noisy_dists / sigmas * mask 121 | noisy_bindists = (noisy_ndists.unsqueeze(-1) > bins).sum(-1) 122 | noisy_bindists = torch.clamp(noisy_bindists, 0, 999) 123 | 124 | probs = P_0s.real[orig_bindists] * P_st.T.real[noisy_bindists] 125 | probs = torch.clamp(probs / probs.sum(-1, keepdims=True), 0, 1) 126 | newbindists = Categorical(probs, validate_args=False).sample() 127 | cbins = (bins[1:] + bins[:-1]) / 2 128 | newdists = cbins[newbindists] * mask * sigmas 129 | 130 | return newdists.float() 131 | 132 | def sample_prior(residue_index, device='cpu'): 133 | 134 | sigmas, Q, D, Q_inv, bins = _sigmas.to(device), _Q.to(device), _D.to(device), _Q_inv.to(device), _bins.to(device) 135 | B, L = residue_index.shape 136 | probs = Q[:,D.real.argmax()].real 137 | probs = torch.clamp(probs / probs.sum(-1, keepdims=True), 0, 1).broadcast_to(B, L, L, 1000) 138 | 139 | offsets = torch.abs(residue_index.unsqueeze(-1) - residue_index.unsqueeze(-2)) 140 | sigmas = sigmas[offsets] 141 | 142 | newbindists = Categorical(probs, validate_args=False).sample() 143 | 144 | cbins = (bins[1:] + bins[:-1]) / 2 145 | newdists = cbins[newbindists] * sigmas 146 | return newdists.float() 147 | ''' 148 | -------------------------------------------------------------------------------- /alphaflow/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging, socket, os 2 | 3 | model_dir = os.environ.get("MODEL_DIR", "./workdir/default") 4 | 5 | class Rank(logging.Filter): 6 | def filter(self, record): 7 | record.global_rank = os.environ.get("GLOBAL_RANK", 0) 8 | record.local_rank = os.environ.get("LOCAL_RANK", 0) 9 | return True 10 | 11 | 12 | def get_logger(name): 13 | logger = logging.Logger(name) 14 | # logger.addFilter(Rank()) 15 | level = {"crititical": 50, "error": 40, "warning": 30, "info": 20, "debug": 10}[ 16 | os.environ.get("LOGGER_LEVEL", "info") 17 | ] 18 | logger.setLevel(level) 19 | 20 | ch = logging.StreamHandler() 21 | ch.setLevel(logging.INFO) 22 | os.makedirs(model_dir, exist_ok=True) 23 | fh = logging.FileHandler(os.path.join(model_dir, "log.out")) 24 | fh.setLevel(logging.DEBUG) 25 | # formatter = logging.Formatter(f'%(asctime)s [{socket.gethostname()}:%(process)d:%(global_rank)s:%(local_rank)s] 26 | # [%(levelname)s] %(message)s') # (%(name)s) 27 | formatter = logging.Formatter( 28 | f"%(asctime)s [{socket.gethostname()}:%(process)d] [%(levelname)s] %(message)s" 29 | ) 30 | ch.setFormatter(formatter) 31 | fh.setFormatter(formatter) 32 | logger.addHandler(ch) 33 | logger.addHandler(fh) 34 | return logger 35 | 36 | def init(): 37 | if bool(int(os.environ.get("WANDB_LOGGING", "0"))): 38 | os.makedirs(model_dir, exist_ok=True) 39 | out_file = open(os.path.join(model_dir, "std.out"), 'ab') 40 | os.dup2(out_file.fileno(), 1) 41 | os.dup2(out_file.fileno(), 2) 42 | 43 | init() -------------------------------------------------------------------------------- /alphaflow/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | import typing as T 6 | import torch 7 | from openfold.np import residue_constants, protein 8 | from openfold.utils.feats import atom14_to_atom37 9 | 10 | def encode_sequence( 11 | seq: str, 12 | residue_index_offset: T.Optional[int] = 512, 13 | chain_linker: T.Optional[str] = "G" * 25, 14 | ) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 15 | if chain_linker is None: 16 | chain_linker = "" 17 | if residue_index_offset is None: 18 | residue_index_offset = 0 19 | 20 | chains = seq.split(":") 21 | seq = chain_linker.join(chains) 22 | 23 | unk_idx = residue_constants.restype_order_with_x["X"] 24 | encoded = torch.tensor( 25 | [residue_constants.restype_order_with_x.get(aa, unk_idx) for aa in seq] 26 | ) 27 | residx = torch.arange(len(encoded)) 28 | 29 | if residue_index_offset > 0: 30 | start = 0 31 | for i, chain in enumerate(chains): 32 | residx[start : start + len(chain) + len(chain_linker)] += ( 33 | i * residue_index_offset 34 | ) 35 | start += len(chain) + len(chain_linker) 36 | 37 | linker_mask = torch.ones_like(encoded, dtype=torch.float32) 38 | chain_index = [] 39 | offset = 0 40 | for i, chain in enumerate(chains): 41 | if i > 0: 42 | chain_index.extend([i - 1] * len(chain_linker)) 43 | chain_index.extend([i] * len(chain)) 44 | offset += len(chain) 45 | linker_mask[offset : offset + len(chain_linker)] = 0 46 | offset += len(chain_linker) 47 | 48 | chain_index = torch.tensor(chain_index, dtype=torch.int64) 49 | 50 | return encoded, residx, linker_mask, chain_index 51 | 52 | 53 | def batch_encode_sequences( 54 | sequences: T.Sequence[str], 55 | residue_index_offset: T.Optional[int] = 512, 56 | chain_linker: T.Optional[str] = "G" * 25, 57 | ) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 58 | 59 | aatype_list = [] 60 | residx_list = [] 61 | linker_mask_list = [] 62 | chain_index_list = [] 63 | for seq in sequences: 64 | aatype_seq, residx_seq, linker_mask_seq, chain_index_seq = encode_sequence( 65 | seq, 66 | residue_index_offset=residue_index_offset, 67 | chain_linker=chain_linker, 68 | ) 69 | aatype_list.append(aatype_seq) 70 | residx_list.append(residx_seq) 71 | linker_mask_list.append(linker_mask_seq) 72 | chain_index_list.append(chain_index_seq) 73 | 74 | aatype = collate_dense_tensors(aatype_list) 75 | mask = collate_dense_tensors( 76 | [aatype.new_ones(len(aatype_seq)) for aatype_seq in aatype_list] 77 | ) 78 | residx = collate_dense_tensors(residx_list) 79 | linker_mask = collate_dense_tensors(linker_mask_list) 80 | chain_index_list = collate_dense_tensors(chain_index_list, -1) 81 | 82 | return aatype, mask, residx, linker_mask, chain_index_list 83 | 84 | 85 | def output_to_pdb(output: T.Dict) -> T.List[str]: 86 | """Returns the pbd (file) string from the model given the model output.""" 87 | # atom14_to_atom37 must be called first, as it fails on latest numpy if the 88 | # input is a numpy array. It will work if the input is a torch tensor. 89 | final_atom_positions = atom14_to_atom37(output["positions"][-1], output) 90 | output = {k: v.to("cpu").numpy() for k, v in output.items()} 91 | final_atom_positions = final_atom_positions.cpu().numpy() 92 | final_atom_mask = output["atom37_atom_exists"] 93 | pdbs = [] 94 | for i in range(output["aatype"].shape[0]): 95 | aa = output["aatype"][i] 96 | pred_pos = final_atom_positions[i] 97 | mask = final_atom_mask[i] 98 | resid = output["residue_index"][i] + 1 99 | pred = protein.Protein( 100 | aatype=aa, 101 | atom_positions=pred_pos, 102 | atom_mask=mask, 103 | residue_index=resid, 104 | b_factors=output["plddt"][i], 105 | chain_index=output["chain_index"][i] if "chain_index" in output else None, 106 | ) 107 | pdbs.append(pred) 108 | return pdbs 109 | 110 | 111 | def collate_dense_tensors( 112 | samples: T.List[torch.Tensor], pad_v: float = 0 113 | ) -> torch.Tensor: 114 | """ 115 | Takes a list of tensors with the following dimensions: 116 | [(d_11, ..., d_1K), 117 | (d_21, ..., d_2K), 118 | ..., 119 | (d_N1, ..., d_NK)] 120 | and stack + pads them into a single tensor of: 121 | (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK}) 122 | """ 123 | if len(samples) == 0: 124 | return torch.Tensor() 125 | if len(set(x.dim() for x in samples)) != 1: 126 | raise RuntimeError( 127 | f"Samples has varying dimensions: {[x.dim() for x in samples]}" 128 | ) 129 | (device,) = tuple(set(x.device for x in samples)) # assumes all on same device 130 | max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])] 131 | result = torch.empty( 132 | len(samples), *max_shape, dtype=samples[0].dtype, device=device 133 | ) 134 | result.fill_(pad_v) 135 | for i in range(len(samples)): 136 | result_i = result[i] 137 | t = samples[i] 138 | result_i[tuple(slice(0, k) for k in t.shape)] = t 139 | return result 140 | 141 | 142 | 143 | class CategoricalMixture: 144 | def __init__(self, param, bins=50, start=0, end=1): 145 | # All tensors are of shape ..., bins. 146 | self.logits = param 147 | bins = torch.linspace( 148 | start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype 149 | ) 150 | self.v_bins = (bins[:-1] + bins[1:]) / 2 151 | 152 | def log_prob(self, true): 153 | # Shapes are: 154 | # self.probs: ... x bins 155 | # true : ... 156 | true_index = ( 157 | ( 158 | true.unsqueeze(-1) 159 | - self.v_bins[ 160 | [ 161 | None, 162 | ] 163 | * true.ndim 164 | ] 165 | ) 166 | .abs() 167 | .argmin(-1) 168 | ) 169 | nll = self.logits.log_softmax(-1) 170 | return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1) 171 | 172 | def mean(self): 173 | return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1) 174 | 175 | 176 | def categorical_lddt(logits, bins=50): 177 | # Logits are ..., 37, bins. 178 | return CategoricalMixture(logits, bins=bins).mean() 179 | 180 | 181 | 182 | 183 | -------------------------------------------------------------------------------- /alphaflow/utils/parsing.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import subprocess, os 3 | 4 | 5 | def parse_train_args(): 6 | parser = ArgumentParser() 7 | 8 | parser.add_argument("--mode", choices=['esmfold', 'alphafold'], default='alphafold') 9 | 10 | ## Trainer settings 11 | parser.add_argument("--ckpt", type=str, default=None) 12 | parser.add_argument("--restore_weights_only", action='store_true') 13 | parser.add_argument("--validate", action='store_true', default=False) 14 | 15 | ## Epoch settings 16 | parser.add_argument("--epochs", type=int, default=100) 17 | parser.add_argument("--train_epoch_len", type=int, default=40000) 18 | parser.add_argument("--limit_batches", type=int, default=None) 19 | parser.add_argument("--batch_size", type=int, default=1) 20 | 21 | ## Optimization settings 22 | parser.add_argument("--num_workers", type=int, default=8) 23 | parser.add_argument("--check_grad", action="store_true") 24 | parser.add_argument("--accumulate_grad", type=int, default=1) 25 | parser.add_argument("--grad_clip", type=float, default=1.) 26 | parser.add_argument("--lr", type=float, default=1e-3) 27 | parser.add_argument("--no_ema", action='store_true') 28 | 29 | ## Training data 30 | parser.add_argument("--train_data_dir", type=str, default='./data') 31 | parser.add_argument("--pdb_chains", type=str, default='./pdb_chains_msa.csv') 32 | parser.add_argument("--train_msa_dir", type=str, default='./msa_dir') 33 | parser.add_argument("--pdb_clusters", type=str, default='./pdb_clusters') 34 | parser.add_argument("--train_cutoff", type=str, default='2021-10-01') 35 | parser.add_argument("--mmcif_dir", type=str, default='./mmcif_dir') 36 | parser.add_argument("--filter_chains", action='store_true') 37 | parser.add_argument("--sample_train_confs", action='store_true') 38 | 39 | ## Validation data 40 | parser.add_argument("--val_csv", type=str, default='splits/cameo2022.csv') 41 | parser.add_argument("--val_samples", type=int, default=5) 42 | parser.add_argument("--val_msa_dir", type=str, default='./alignment_dir') 43 | parser.add_argument("--sample_val_confs", action='store_true') 44 | parser.add_argument("--num_val_confs", type=int, default=None) 45 | parser.add_argument("--normal_validate", action='store_true') 46 | 47 | ## Flow matching 48 | parser.add_argument("--noise_prob", type=float, default=0.5) 49 | parser.add_argument("--self_cond_prob", type=float, default=0.5) 50 | parser.add_argument("--extra_input", action='store_true') 51 | parser.add_argument("--extra_input_prob", type=float, default=0.5) 52 | parser.add_argument("--first_as_template", action='store_true') 53 | parser.add_argument("--distillation", action='store_true') 54 | parser.add_argument("--distill_self_cond", action='store_true') 55 | 56 | ## Logging args 57 | parser.add_argument("--print_freq", type=int, default=100) 58 | parser.add_argument("--val_freq", type=int, default=1) 59 | parser.add_argument("--ckpt_freq", type=int, default=1) 60 | parser.add_argument("--wandb", action="store_true") 61 | parser.add_argument("--run_name", type=str, default="default") 62 | 63 | args = parser.parse_args() 64 | os.environ["MODEL_DIR"] = os.path.join("workdir", args.run_name) 65 | os.environ["WANDB_LOGGING"] = str(int(args.wandb)) 66 | if args.wandb: 67 | if subprocess.check_output(["git", "status", "-s"]): 68 | exit() 69 | args.commit = ( 70 | subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() 71 | ) 72 | 73 | return args 74 | -------------------------------------------------------------------------------- /alphaflow/utils/protein.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Optional 2 | from openfold.np.protein import to_pdb 3 | from openfold.np.protein import from_pdb_string as _from_pdb_string 4 | from openfold.data import mmcif_parsing 5 | from openfold.np import residue_constants 6 | from alphaflow.utils.tensor_utils import tensor_tree_map 7 | import subprocess, tempfile, os, dataclasses 8 | import numpy as np 9 | from Bio import pairwise2 10 | 11 | @dataclasses.dataclass(repr=False) 12 | class Protein: 13 | """Protein structure representation.""" 14 | 15 | # Cartesian coordinates of atoms in angstroms. The atom types correspond to 16 | # residue_constants.atom_types, i.e. the first three are N, CA, CB. 17 | atom_positions: np.ndarray # [num_res, num_atom_type, 3] 18 | 19 | aatype: np.ndarray # [num_res] 20 | seqres: str 21 | name: str 22 | 23 | # Binary float mask to indicate presence of a particular atom. 1.0 if an atom 24 | # is present and 0.0 if not. This should be used for loss masking. 25 | atom_mask: np.ndarray # [num_res, num_atom_type] 26 | 27 | # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. 28 | residue_index: np.ndarray # [num_res] 29 | 30 | # B-factors, or temperature factors, of each residue (in sq. angstroms units), 31 | # representing the displacement of the residue from its ground truth mean 32 | # value. 33 | b_factors: np.ndarray # [num_res, num_atom_type] 34 | 35 | # Chain indices for multi-chain predictions 36 | chain_index: Optional[np.ndarray] = None 37 | 38 | # Optional remark about the protein. Included as a comment in output PDB 39 | # files 40 | remark: Optional[str] = None 41 | 42 | # Templates used to generate this protein (prediction-only) 43 | parents: Optional[Sequence[str]] = None 44 | 45 | # Chain corresponding to each parent 46 | parents_chain_index: Optional[Sequence[int]] = None 47 | 48 | def __repr__(self): 49 | ca_pos = residue_constants.atom_order["CA"] 50 | present = int(self.atom_mask[..., ca_pos].sum()) 51 | total = self.atom_mask.shape[0] 52 | return f"Protein(name={self.name} seqres={self.seqres} residues={present}/{total} b_mean={self.b_factors[...,ca_pos].mean()})" 53 | 54 | def present(self): 55 | ca_pos = residue_constants.atom_order["CA"] 56 | return int(self.atom_mask[..., ca_pos].sum()) 57 | 58 | def total(self): 59 | return self.atom_mask.shape[0] 60 | 61 | def output_to_protein(output): 62 | """Returns the pbd (file) string from the model given the model output.""" 63 | output = tensor_tree_map(lambda x: x.cpu().numpy(), output) 64 | final_atom_positions = output['final_atom_positions'] 65 | final_atom_mask = output["atom37_atom_exists"] 66 | pdbs = [] 67 | for i in range(output["aatype"].shape[0]): 68 | unk_idx = residue_constants.restype_order_with_x["X"] 69 | seqres = ''.join( 70 | [residue_constants.restypes[idx] if idx != unk_idx else "X" for idx in output["aatype"][i]] 71 | ) 72 | pred = Protein( 73 | name=output['name'][i], 74 | aatype=output["aatype"][i], 75 | seqres=seqres, 76 | atom_positions=final_atom_positions[i], 77 | atom_mask=final_atom_mask[i], 78 | residue_index=output["residue_index"][i] + 1, 79 | b_factors=np.repeat(output["plddt"][i][...,None], residue_constants.atom_type_num, axis=-1), 80 | chain_index=output["chain_index"][i] if "chain_index" in output else None, 81 | ) 82 | pdbs.append(pred) 83 | return pdbs 84 | 85 | def from_dict(prot): 86 | name = prot['domain_name'].item().decode(encoding='utf-8') 87 | seq = prot['sequence'].item().decode(encoding='utf-8') 88 | return Protein( 89 | name=name, 90 | aatype=np.nonzero(prot["aatype"])[1], 91 | atom_positions=prot["all_atom_positions"], 92 | seqres=seq, 93 | atom_mask=prot["all_atom_mask"], 94 | residue_index=prot['residue_index'] + 1, 95 | b_factors=np.zeros((len(seq), 37)) 96 | ) 97 | 98 | def from_pdb_string(pdb_string, name=''): 99 | prot = _from_pdb_string(pdb_string) 100 | 101 | unk_idx = residue_constants.restype_order_with_x["X"] 102 | seqres = ''.join( 103 | [residue_constants.restypes[idx] if idx != unk_idx else "X" for idx in prot.aatype] 104 | ) 105 | prot = Protein( 106 | **prot.__dict__, 107 | seqres=seqres, 108 | name=name, 109 | ) 110 | return prot 111 | 112 | def from_mmcif_string(mmcif_string, chain, name='', is_author_chain=False): 113 | mmcif_object = mmcif_parsing.parse(file_id = '', mmcif_string=mmcif_string) 114 | 115 | if(mmcif_object.mmcif_object is None): 116 | raise list(mmcif_object.errors.values())[0] 117 | 118 | mmcif_object = mmcif_object.mmcif_object 119 | 120 | atom_coords, atom_mask = mmcif_parsing.get_atom_coords(mmcif_object, chain) 121 | L = atom_coords.shape[0] 122 | seq = mmcif_object.chain_to_seqres[chain] 123 | 124 | unk_idx = residue_constants.restype_order_with_x["X"] 125 | aatype = np.array( 126 | [residue_constants.restype_order_with_x.get(aa, unk_idx) for aa in seq] 127 | ) 128 | prot = Protein( 129 | aatype=aatype, 130 | name=name, 131 | seqres=seq, 132 | atom_positions=atom_coords, 133 | atom_mask=atom_mask, 134 | residue_index=np.arange(L) + 1, 135 | b_factors=np.zeros((L, 37)), # maybe replace 37 later 136 | ) 137 | return prot 138 | 139 | 140 | def global_metrics(ref_prot, pred_prot, lddt=False, symmetric=False): 141 | if lddt or symmetric: 142 | ref_prot, pred_prot = align_residue_numbering(ref_prot, pred_prot, mask=symmetric) 143 | 144 | f, ref_path = tempfile.mkstemp(); os.close(f) 145 | f, pred_path = tempfile.mkstemp(); os.close(f) 146 | with open(ref_path, 'w') as f: 147 | f.write(to_pdb(ref_prot)) 148 | with open(pred_path, 'w') as f: 149 | f.write(to_pdb(pred_prot)) 150 | 151 | out = tmscore(ref_path, pred_path) 152 | if lddt: 153 | out['lddt'] = my_lddt_func(ref_path, pred_path) 154 | 155 | os.unlink(ref_path) 156 | os.unlink(pred_path) 157 | return out 158 | 159 | def prots_to_pdb(prots): 160 | ss = '' 161 | for i, prot in enumerate(prots): 162 | ss += f'MODEL {i}\n' 163 | prot = to_pdb(prot) 164 | ss += '\n'.join(prot.split('\n')[1:-2]) 165 | ss += '\nENDMDL\n' 166 | return ss 167 | 168 | def align_residue_numbering(prot1, prot2, mask=False): 169 | prot1 = Protein(**prot1.__dict__) 170 | prot2 = Protein(**prot2.__dict__) 171 | 172 | alignment = pairwise2.align.globalxx(prot1.seqres, prot2.seqres)[0] 173 | prot1.residue_index = np.array([i for i, c in enumerate(alignment.seqA) if c != '-']) 174 | prot2.residue_index = np.array([i for i, c in enumerate(alignment.seqB) if c != '-']) 175 | 176 | if mask: 177 | ca_pos = residue_constants.atom_order["CA"] 178 | mask1 = np.zeros(len(alignment.seqA)) 179 | mask1[prot1.residue_index[prot1.atom_mask[..., ca_pos] == 1]] = 1 180 | mask2 = np.zeros(len(alignment.seqA)) 181 | mask2[prot2.residue_index[prot2.atom_mask[..., ca_pos] == 1]] = 1 182 | 183 | mask = (mask1 == 1) & (mask2 == 1) 184 | 185 | prot1.atom_mask = prot1.atom_mask * mask[prot1.residue_index].reshape(-1, 1) 186 | prot2.atom_mask = prot2.atom_mask * mask[prot2.residue_index].reshape(-1, 1) 187 | 188 | return prot1, prot2 189 | # ca_pos = residue_constants.atom_order["CA"] 190 | # ref_ca = ref_prot.atom_positions[..., ca_pos, :] 191 | # pred_ca = pred_prot.atom_positions[...,ca_pos, :] 192 | # mask = ref_prot.atom_mask[..., ca_pos].astype(bool) 193 | # trans_ca, rms = superimposition._superimpose_np(ref_ca[mask], pred_ca[mask]) 194 | 195 | def tmscore(ref_path, pred_path): 196 | 197 | 198 | out = subprocess.check_output(['TMscore', '-seq', pred_path, ref_path], 199 | stderr=open('/dev/null', 'w')) 200 | 201 | start = out.find(b'RMSD') 202 | end = out.find(b'rotation') 203 | out = out[start:end] 204 | 205 | rmsd, _, tm, _, gdt_ts, gdt_ha, _, _ = out.split(b'\n') 206 | 207 | result = { 208 | 'rmsd': float(rmsd.split(b'=')[-1]), 209 | 'tm': float(tm.split(b'=')[1].split()[0]), 210 | 'gdt_ts': float(gdt_ts.split(b'=')[1].split()[0]), 211 | 'gdt_ha': float(gdt_ha.split(b'=')[1].split()[0]), 212 | } 213 | return result 214 | 215 | def drmsd(prot1, prot2, align=False, eps=1e-10): 216 | ca_pos = residue_constants.atom_order["CA"] 217 | if align: 218 | prot1, prot2 = align_residue_numbering(prot1, prot2) 219 | N = max(prot1.residue_index.max(), prot2.residue_index.max()) 220 | mask1, mask2 = np.zeros(N), np.zeros(N) 221 | mask1[prot1.residue_index - 1] = prot1.atom_mask[:,ca_pos] 222 | mask2[prot2.residue_index - 1] = prot2.atom_mask[:,ca_pos] 223 | pos1, pos2 = np.zeros((N,3)), np.zeros((N,3)) 224 | 225 | pos1[prot1.residue_index - 1] = prot1.atom_positions[:,ca_pos] 226 | pos2[prot2.residue_index - 1] = prot2.atom_positions[:,ca_pos] 227 | 228 | dmat1 = np.sqrt(eps + np.sum((pos1[..., None, :] - pos1[..., None, :, :]) ** 2, axis=-1)) 229 | dmat2 = np.sqrt(eps + np.sum((pos2[..., None, :] - pos2[..., None, :, :]) ** 2, axis=-1)) 230 | 231 | dists_to_score = mask1 * mask1[:,None] * mask2 * mask2[:,None] * (1.0 - np.eye(N)) 232 | score = np.square(dmat1 - dmat2) 233 | 234 | return np.sqrt((score * dists_to_score).sum() / dists_to_score.sum()) 235 | 236 | def lddt_ca(prot1, prot2, cutoff=15.0, eps=1e-10, align=False, per_residue=False, symmetric=False): 237 | 238 | ca_pos = residue_constants.atom_order["CA"] 239 | 240 | if align: 241 | prot1, prot2 = align_residue_numbering(prot1, prot2) 242 | N = max(prot1.residue_index.max(), prot2.residue_index.max()) 243 | mask1, mask2 = np.zeros(N), np.zeros(N) 244 | mask1[prot1.residue_index - 1] = prot1.atom_mask[:,ca_pos] 245 | mask2[prot2.residue_index - 1] = prot2.atom_mask[:,ca_pos] 246 | pos1, pos2 = np.zeros((N,3)), np.zeros((N,3)) 247 | 248 | pos1[prot1.residue_index - 1] = prot1.atom_positions[:,ca_pos] 249 | pos2[prot2.residue_index - 1] = prot2.atom_positions[:,ca_pos] 250 | 251 | # mask1, mask2 = prot1.atom_mask[:,ca_pos], prot2.atom_mask[:,ca_pos] 252 | # pos1, pos2 = prot1.atom_positions[:,ca_pos], prot2.atom_positions[:,ca_pos] 253 | 254 | dmat1 = np.sqrt(eps + np.sum((pos1[..., None, :] - pos1[..., None, :, :]) ** 2, axis=-1)) 255 | dmat2 = np.sqrt(eps + np.sum((pos2[..., None, :] - pos2[..., None, :, :]) ** 2, axis=-1)) 256 | 257 | if symmetric: 258 | dists_to_score = (dmat1 < cutoff) | (dmat2 < cutoff) 259 | else: 260 | dists_to_score = (dmat1 < cutoff) 261 | dists_to_score = dists_to_score * mask1 * mask1[:,None] * mask2 * mask2[:,None] * (1.0 - np.eye(N)) 262 | dist_l1 = np.abs(dmat1 - dmat2) 263 | score = (dist_l1[...,None] < np.array([0.5, 1.0, 2.0, 4.0])).mean(-1) 264 | 265 | if per_residue: 266 | score = (dists_to_score * score).sum(-1) / dists_to_score.sum(-1) 267 | return score[prot1.residue_index - 1] 268 | else: 269 | score = (dists_to_score * score).sum() / dists_to_score.sum() 270 | 271 | return score 272 | def my_lddt_func(ref_path, pred_path): 273 | 274 | out = subprocess.check_output(['lddt', '-c', '-x', pred_path, ref_path], 275 | stderr=open('/dev/null', 'w')) 276 | 277 | result = None 278 | for line in out.split(b'\n'): 279 | if b'Global LDDT score' in line: 280 | result = float(line.split(b':')[-1].strip()) 281 | 282 | return result 283 | -------------------------------------------------------------------------------- /alphaflow/utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from functools import partial 17 | from typing import List 18 | 19 | import torch 20 | import torch.nn as nn 21 | 22 | 23 | def add(m1, m2, inplace): 24 | # The first operation in a checkpoint can't be in-place, but it's 25 | # nice to have in-place addition during inference. Thus... 26 | if(not inplace): 27 | m1 = m1 + m2 28 | else: 29 | m1 += m2 30 | 31 | return m1 32 | 33 | 34 | def permute_final_dims(tensor: torch.Tensor, inds: List[int]): 35 | zero_index = -1 * len(inds) 36 | first_inds = list(range(len(tensor.shape[:zero_index]))) 37 | return tensor.permute(first_inds + [zero_index + i for i in inds]) 38 | 39 | 40 | def flatten_final_dims(t: torch.Tensor, no_dims: int): 41 | return t.reshape(t.shape[:-no_dims] + (-1,)) 42 | 43 | 44 | def masked_mean(mask, value, dim, eps=1e-4): 45 | mask = mask.expand(*value.shape) 46 | return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) 47 | 48 | 49 | def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): 50 | boundaries = torch.linspace( 51 | min_bin, max_bin, no_bins - 1, device=pts.device 52 | ) 53 | dists = torch.sqrt( 54 | torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) 55 | ) 56 | return torch.bucketize(dists, boundaries) 57 | 58 | 59 | def dict_multimap(fn, dicts): 60 | first = dicts[0] 61 | new_dict = {} 62 | for k, v in first.items(): 63 | all_v = [d[k] for d in dicts] 64 | if type(v) is dict: 65 | new_dict[k] = dict_multimap(fn, all_v) 66 | else: 67 | new_dict[k] = fn(all_v) 68 | 69 | return new_dict 70 | 71 | 72 | def one_hot(x, v_bins): 73 | reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) 74 | diffs = x[..., None] - reshaped_bins 75 | am = torch.argmin(torch.abs(diffs), dim=-1) 76 | return nn.functional.one_hot(am, num_classes=len(v_bins)).float() 77 | 78 | 79 | def batched_gather(data, inds, dim=0, no_batch_dims=0): 80 | ranges = [] 81 | for i, s in enumerate(data.shape[:no_batch_dims]): 82 | r = torch.arange(s) 83 | r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) 84 | ranges.append(r) 85 | 86 | remaining_dims = [ 87 | slice(None) for _ in range(len(data.shape) - no_batch_dims) 88 | ] 89 | remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds 90 | ranges.extend(remaining_dims) 91 | return data[ranges] 92 | 93 | 94 | # With tree_map, a poor man's JAX tree_map 95 | def dict_map(fn, dic, leaf_type): 96 | new_dict = {} 97 | for k, v in dic.items(): 98 | if type(v) is dict: 99 | new_dict[k] = dict_map(fn, v, leaf_type) 100 | else: 101 | new_dict[k] = tree_map(fn, v, leaf_type) 102 | 103 | return new_dict 104 | 105 | 106 | def tree_map(fn, tree, leaf_type): 107 | if isinstance(tree, dict): 108 | return dict_map(fn, tree, leaf_type) 109 | elif isinstance(tree, list): 110 | return [tree_map(fn, x, leaf_type) for x in tree] 111 | elif isinstance(tree, tuple): 112 | return tuple([tree_map(fn, x, leaf_type) for x in tree]) 113 | elif isinstance(tree, leaf_type): 114 | return fn(tree) 115 | else: 116 | return tree 117 | 118 | tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) 119 | -------------------------------------------------------------------------------- /assets/12l_md_templates.md: -------------------------------------------------------------------------------- 1 | ### Comparison of 48-layer and 12-layer AlphaFlow-MD+Templates 2 | | |48l (base)|48l (distilled)|12l (base)|12l (distilled)| 3 | |:-|:-:|:-:|:-:|:-:| 4 | | Pariwise RMSD | 2.18 | 1.73 | 1.94 | 1.40 5 | | Pairwise RMSD $r$ | 0.94 | 0.92 | 0.81 | 0.76 6 | | All-atom RMSF | 1.31 | 1.00 | 1.01 | 0.76 7 | | Global RMSF $r$ | 0.91 | 0.89 | 0.78 | 0.74 8 | | Per-target RMSF $r$ | 0.90 | 0.88 | 0.89 | 0.86 9 | | Root mean $\mathcal{W}_2$-dist | 1.95 | 2.18 | 2.26 | 2.43 10 | | MD PCA $\mathcal{W}_2$-dist | 1.25 | 1.41 | 1.40 | 1.56 11 | | Joint PCA $\mathcal{W}_2$-dist | 1.58 | 1.68 | 1.78 | 1.90 12 | | % PC-sim > 0.5 | 44 | 43 | 46 | 39 13 | | Weak contacts $J$ | 0.62 | 0.51 | 0.60 | 0.56 14 | | Transient contacts $J$ | 0.47 | 0.42 | 0.36 | 0.24 15 | | Exposed residue $J$ | 0.50 | 0.47 | 0.47 | 0.44 16 | | Exposed MI matrix $\rho$ | 0.25 | 0.18 | 0.21 | 0.13 17 | | **Runtime (s)** | 38 | 3.8 | 15.2 | 1.56 -------------------------------------------------------------------------------- /assets/6uof_A_animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bjing2016/alphaflow/02dc03763a016949326c2c741e6e33094f9250fd/assets/6uof_A_animation.gif -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser() 3 | parser.add_argument('--input_csv', type=str, default='splits/transporters_only.csv') 4 | parser.add_argument('--templates_dir', type=str, default=None) 5 | parser.add_argument('--msa_dir', type=str, default='./alignment_dir') 6 | parser.add_argument('--mode', choices=['alphafold', 'esmfold'], default='alphafold') 7 | parser.add_argument('--samples', type=int, default=10) 8 | parser.add_argument('--steps', type=int, default=10) 9 | parser.add_argument('--outpdb', type=str, default='./outpdb/default') 10 | parser.add_argument('--weights', type=str, default=None) 11 | parser.add_argument('--ckpt', type=str, default=None) 12 | parser.add_argument('--original_weights', action='store_true') 13 | parser.add_argument('--pdb_id', nargs='*', default=[]) 14 | parser.add_argument('--subsample', type=int, default=None) 15 | parser.add_argument('--resample', action='store_true') 16 | parser.add_argument('--tmax', type=float, default=1.0) 17 | parser.add_argument('--no_diffusion', action='store_true', default=False) 18 | parser.add_argument('--self_cond', action='store_true', default=False) 19 | parser.add_argument('--noisy_first', action='store_true', default=False) 20 | parser.add_argument('--runtime_json', type=str, default=None) 21 | parser.add_argument('--no_overwrite', action='store_true', default=False) 22 | args = parser.parse_args() 23 | 24 | import torch, tqdm, os, wandb, json, time 25 | import pandas as pd 26 | import pytorch_lightning as pl 27 | import numpy as np 28 | from collections import defaultdict 29 | from alphaflow.data.data_modules import collate_fn 30 | from alphaflow.model.wrapper import AlphaFoldWrapper, ESMFoldWrapper 31 | from alphaflow.utils.tensor_utils import tensor_tree_map 32 | import alphaflow.utils.protein as protein 33 | from alphaflow.data.inference import AlphaFoldCSVDataset, CSVDataset 34 | from collections import defaultdict 35 | from openfold.utils.import_weights import import_jax_weights_ 36 | from alphaflow.config import model_config 37 | 38 | from alphaflow.utils.logging import get_logger 39 | logger = get_logger(__name__) 40 | torch.set_float32_matmul_precision("high") 41 | 42 | config = model_config( 43 | 'initial_training', 44 | train=True, 45 | low_prec=True 46 | ) 47 | schedule = np.linspace(args.tmax, 0, args.steps+1) 48 | if args.tmax != 1.0: 49 | schedule = np.array([1.0] + list(schedule)) 50 | loss_cfg = config.loss 51 | data_cfg = config.data 52 | data_cfg.common.use_templates = False 53 | data_cfg.common.max_recycling_iters = 0 54 | 55 | if args.subsample: # https://elifesciences.org/articles/75751#s3 56 | data_cfg.predict.max_msa_clusters = args.subsample // 2 57 | data_cfg.predict.max_extra_msa = args.subsample 58 | 59 | @torch.no_grad() 60 | def main(): 61 | 62 | valset = { 63 | 'alphafold': AlphaFoldCSVDataset, 64 | 'esmfold': CSVDataset, 65 | }[args.mode]( 66 | data_cfg, 67 | args.input_csv, 68 | msa_dir=args.msa_dir, 69 | templates_dir=args.templates_dir, 70 | ) 71 | # valset[0] 72 | logger.info("Loading the model") 73 | model_class = {'alphafold': AlphaFoldWrapper, 'esmfold': ESMFoldWrapper}[args.mode] 74 | 75 | if args.weights: 76 | ckpt = torch.load(args.weights, map_location='cpu') 77 | model = model_class(**ckpt['hyper_parameters'], training=False) 78 | model.model.load_state_dict(ckpt['params'], strict=False) 79 | model = model.cuda() 80 | 81 | 82 | elif args.original_weights: 83 | model = model_class(config, None, training=False) 84 | if args.mode == 'esmfold': 85 | path = "esmfold_3B_v1.pt" 86 | model_data = torch.load(path, map_location='cpu') 87 | model_state = model_data["model"] 88 | model.model.load_state_dict(model_state, strict=False) 89 | model = model.to(torch.float).cuda() 90 | 91 | elif args.mode == 'alphafold': 92 | import_jax_weights_(model.model, 'params_model_1.npz', version='model_3') 93 | model = model.cuda() 94 | 95 | else: 96 | model = model_class.load_from_checkpoint(args.ckpt, map_location='cpu') 97 | model.load_ema_weights() 98 | model = model.cuda() 99 | model.eval() 100 | 101 | logger.info("Model has been loaded") 102 | 103 | results = defaultdict(list) 104 | os.makedirs(args.outpdb, exist_ok=True) 105 | runtime = defaultdict(list) 106 | for i, item in enumerate(valset): 107 | if args.pdb_id and item['name'] not in args.pdb_id: 108 | continue 109 | if args.no_overwrite and os.path.exists(f'{args.outpdb}/{item["name"]}.pdb'): 110 | continue 111 | result = [] 112 | for j in tqdm.trange(args.samples): 113 | if args.subsample or args.resample: 114 | item = valset[i] # resample MSA 115 | 116 | batch = collate_fn([item]) 117 | batch = tensor_tree_map(lambda x: x.cuda(), batch) 118 | start = time.time() 119 | prots = model.inference(batch, as_protein=True, noisy_first=args.noisy_first, 120 | no_diffusion=args.no_diffusion, schedule=schedule, self_cond=args.self_cond) 121 | runtime[item['name']].append(time.time() - start) 122 | result.append(prots[-1]) 123 | 124 | 125 | 126 | with open(f'{args.outpdb}/{item["name"]}.pdb', 'w') as f: 127 | f.write(protein.prots_to_pdb(result)) 128 | 129 | if args.runtime_json: 130 | with open(args.runtime_json, 'w') as f: 131 | f.write(json.dumps(dict(runtime))) 132 | if __name__ == "__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /scripts/add_msa_info.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os, tqdm 3 | from collections import defaultdict 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--openfold_dir', type=str, required=True) 7 | parser.add_argument('--incsv', type=str, default='pdb_chains.csv') 8 | parser.add_argument('--outcsv', type=str, default='pdb_chains_msa.csv') 9 | args = parser.parse_args() 10 | 11 | msas = os.listdir(f'{args.openfold_dir}/pdb') 12 | 13 | msa_queries = {} 14 | for pdb_id in tqdm.tqdm(msas): 15 | a3m_paths = [ 16 | f'{args.openfold_dir}/pdb/{pdb_id}/a3m/bfd_uniclust_hits.a3m', 17 | f'{args.openfold_dir}/pdb/{pdb_id}/a3m/mgnify_hits.a3m', 18 | f'{args.openfold_dir}/pdb/{pdb_id}/a3m/uniref90_hits.a3m', 19 | ] 20 | for a3m_path in a3m_paths: 21 | if os.path.exists(a3m_path): 22 | break 23 | with open(a3m_path) as f: 24 | _ = next(f) 25 | msa_queries[pdb_id] = next(f).strip() 26 | 27 | msa_queries = pd.Series(msa_queries) 28 | 29 | df = pd.read_csv(args.incsv) 30 | msa_id = [None]*len(df) 31 | 32 | freqs = defaultdict(int) 33 | done, skipped = 0, 0 34 | for seqres, sub_df in tqdm.tqdm(df.groupby('seqres')): 35 | freqs[len(sub_df)] += 1 36 | found = list(filter(lambda n: n in msa_queries.index, sub_df.name)) 37 | if not found: 38 | skipped += 1 39 | continue 40 | done += 1 41 | if len(found) == 1: 42 | if seqres == msa_queries.loc[found[0]]: 43 | for idx in sub_df.index: msa_id[idx] = found[0] 44 | else: 45 | print('Mismatch', found[0]) 46 | if len(found) != 1: 47 | print('Found multiple', found) 48 | for pdb_id in found: 49 | if seqres == msa_queries.loc[pdb_id]: 50 | for idx in sub_df.index: msa_id[idx] = found[0] 51 | print('Match', pdb_id) 52 | else: 53 | print('Mismatch', pdb_id) 54 | 55 | df['msa_id'] = msa_id 56 | df[~df.msa_id.isnull()].set_index('name').to_csv(args.outcsv) 57 | 58 | -------------------------------------------------------------------------------- /scripts/analyze_ensembles.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser() 3 | parser.add_argument('--atlas_dir', type=str, required=True) 4 | parser.add_argument('--pdbdir', type=str, required=True) 5 | parser.add_argument('--pdb_id', nargs='*', default=[]) 6 | parser.add_argument('--bb_only', action='store_true') 7 | parser.add_argument('--ca_only', action='store_true') 8 | parser.add_argument('--num_workers', type=int, default=1) 9 | 10 | args = parser.parse_args() 11 | from sklearn.decomposition import PCA 12 | import mdtraj, pickle, tqdm, os 13 | import pandas as pd 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | from multiprocessing import Pool 17 | from scipy.optimize import linear_sum_assignment 18 | 19 | def get_pca(xyz): 20 | traj_reshaped = xyz.reshape(xyz.shape[0], -1) 21 | pca = PCA(n_components=min(traj_reshaped.shape)) 22 | coords = pca.fit_transform(traj_reshaped) 23 | return pca, coords 24 | 25 | def get_rmsds(traj1, traj2, broadcast=False): 26 | n_atoms = traj1.shape[1] 27 | traj1 = traj1.reshape(traj1.shape[0], n_atoms * 3) 28 | traj2 = traj2.reshape(traj2.shape[0], n_atoms * 3) 29 | if broadcast: 30 | traj1, traj2 = traj1[:,None], traj2[None] 31 | distmat = np.square(traj1 - traj2).sum(-1)**0.5 / n_atoms**0.5 * 10 32 | return distmat 33 | 34 | def condense_sidechain_sasas(sasas, top): 35 | assert top.n_residues > 1 36 | 37 | if top.n_atoms != sasas.shape[1]: 38 | raise exception.DataInvalid( 39 | f"The number of atoms in top ({top.n_atoms}) didn't match the " 40 | f"number of SASAs provided ({sasas.shape[1]}). Make sure you " 41 | f"computed atom-level SASAs (mode='atom') and that you've passed " 42 | "the correct topology file and array of SASAs" 43 | ) 44 | 45 | sc_mask = np.array([a.name not in ['CA', 'C', 'N', 'O', 'OXT'] for a in top.atoms]) 46 | res_id = np.array([a.residue.index for a in top.atoms]) 47 | 48 | rsd_sasas = np.zeros((sasas.shape[0], top.n_residues), dtype='float32') 49 | 50 | for i in range(top.n_residues): 51 | rsd_sasas[:, i] = sasas[:, sc_mask & (res_id == i)].sum(1) 52 | return rsd_sasas 53 | 54 | def sasa_mi(sasa): 55 | N, L = sasa.shape 56 | joint_probs = np.zeros((L, L, 2, 2)) 57 | 58 | joint_probs[:,:,1,1] = (sasa[:,:,None] & sasa[:,None,:]).mean(0) 59 | joint_probs[:,:,1,0] = (sasa[:,:,None] & ~sasa[:,None,:]).mean(0) 60 | joint_probs[:,:,0,1] = (~sasa[:,:,None] & sasa[:,None,:]).mean(0) 61 | joint_probs[:,:,0,0] = (~sasa[:,:,None] & ~sasa[:,None,:]).mean(0) 62 | 63 | marginal_probs = np.stack([1-sasa.mean(0), sasa.mean(0)], -1) 64 | indep_probs = marginal_probs[None,:,None,:] * marginal_probs[:,None,:,None] 65 | mi = np.nansum(joint_probs * np.log(joint_probs / indep_probs), (-1, -2)) 66 | mi[np.arange(L), np.arange(L)] = 0 67 | return mi 68 | 69 | 70 | def get_mean_covar(xyz): 71 | mean = xyz.mean(0) 72 | xyz = xyz - mean 73 | covar = (xyz[...,None] * xyz[...,None,:]).mean(0) 74 | return mean, covar 75 | 76 | 77 | def sqrtm(M): 78 | D, P = np.linalg.eig(M) 79 | out = (P * np.sqrt(D[:,None])) @ np.linalg.inv(P) 80 | return out 81 | 82 | 83 | def get_wasserstein(distmat, p=2): 84 | assert distmat.shape[0] == distmat.shape[1] 85 | distmat = distmat ** p 86 | row_ind, col_ind = linear_sum_assignment(distmat) 87 | return distmat[row_ind, col_ind].mean() ** (1/p) 88 | 89 | def align_tops(top1, top2): 90 | names1 = [repr(a) for a in top1.atoms] 91 | names2 = [repr(a) for a in top2.atoms] 92 | 93 | intersection = [nam for nam in names1 if nam in names2] 94 | 95 | mask1 = [names1.index(nam) for nam in intersection] 96 | mask2 = [names2.index(nam) for nam in intersection] 97 | return mask1, mask2 98 | 99 | def main(name): 100 | print('Analyzing', name) 101 | out = {} 102 | ref_aa = mdtraj.load(f'{args.atlas_dir}/{name}/{name}.pdb') 103 | topfile = f'{args.atlas_dir}/{name}/{name}.pdb' 104 | print('Loading reference trajectory') 105 | traj_aa = mdtraj.load(f'{args.atlas_dir}/{name}/{name}_prod_R1_fit.xtc', top=topfile) \ 106 | + mdtraj.load(f'{args.atlas_dir}/{name}/{name}_prod_R2_fit.xtc', top=topfile) \ 107 | + mdtraj.load(f'{args.atlas_dir}/{name}/{name}_prod_R3_fit.xtc', top=topfile) 108 | print(f'Loaded {traj_aa.n_frames} reference frames') 109 | 110 | print('Loading AF2 conformers') 111 | aftraj_aa = mdtraj.load(f'{args.pdbdir}/{name}.pdb') 112 | 113 | print(f'Loaded {aftraj_aa.n_frames} AF2 conformers') 114 | print(f'Reference has {traj_aa.n_atoms} atoms') 115 | print(f'Crystal has {ref_aa.n_atoms} atoms') 116 | print(f'AF has {aftraj_aa.n_atoms} atoms') 117 | 118 | print('Removing hydrogens') 119 | 120 | traj_aa.atom_slice([a.index for a in traj_aa.top.atoms if a.element.symbol != 'H'], True) 121 | ref_aa.atom_slice([a.index for a in ref_aa.top.atoms if a.element.symbol != 'H'], True) 122 | aftraj_aa.atom_slice([a.index for a in aftraj_aa.top.atoms if a.element.symbol != 'H'], True) 123 | 124 | print(f'Reference has {traj_aa.n_atoms} atoms') 125 | print(f'Crystal has {ref_aa.n_atoms} atoms') 126 | print(f'AF has {aftraj_aa.n_atoms} atoms') 127 | 128 | if args.bb_only: 129 | print('Removing sidechains') 130 | aftraj_aa.atom_slice([a.index for a in aftraj_aa.top.atoms if a.name in ['CA', 'C', 'N', 'O', 'OXT']], True) 131 | print(f'AF has {aftraj_aa.n_atoms} atoms') 132 | 133 | elif args.ca_only: 134 | print('Removing sidechains') 135 | aftraj_aa.atom_slice([a.index for a in aftraj_aa.top.atoms if a.name == 'CA'], True) 136 | print(f'AF has {aftraj_aa.n_atoms} atoms') 137 | 138 | 139 | refmask, afmask = align_tops(traj_aa.top, aftraj_aa.top) 140 | traj_aa.atom_slice(refmask, True) 141 | ref_aa.atom_slice(refmask, True) 142 | aftraj_aa.atom_slice(afmask, True) 143 | 144 | print(f'Aligned on {aftraj_aa.n_atoms} atoms') 145 | 146 | np.random.seed(137) 147 | RAND1 = np.random.randint(0, traj_aa.n_frames, aftraj_aa.n_frames) 148 | RAND2 = np.random.randint(0, traj_aa.n_frames, aftraj_aa.n_frames) 149 | RAND1K = np.random.randint(0, traj_aa.n_frames, 1000) 150 | 151 | traj_aa.superpose(ref_aa) 152 | aftraj_aa.superpose(ref_aa) 153 | 154 | out['ca_mask'] = ca_mask = [a.index for a in traj_aa.top.atoms if a.name == 'CA'] 155 | traj = traj_aa.atom_slice(ca_mask, False) 156 | ref = ref_aa.atom_slice(ca_mask, False) 157 | aftraj = aftraj_aa.atom_slice(ca_mask, False) 158 | print(f'Sliced {aftraj.n_atoms} C-alphas') 159 | 160 | traj.superpose(ref) 161 | aftraj.superpose(ref) 162 | 163 | 164 | n_atoms = aftraj.n_atoms 165 | 166 | print(f'Doing PCA') 167 | 168 | ref_pca, ref_coords = get_pca(traj.xyz) 169 | af_coords_ref_pca = ref_pca.transform(aftraj.xyz.reshape(aftraj.n_frames, -1)) 170 | seed_coords_ref_pca = ref_pca.transform(ref.xyz.reshape(1, -1)) 171 | 172 | af_pca, af_coords = get_pca(aftraj.xyz) 173 | ref_coords_af_pca = af_pca.transform(traj.xyz.reshape(traj.n_frames, -1)) 174 | seed_coords_af_pca = af_pca.transform(ref.xyz.reshape(1, -1)) 175 | 176 | joint_pca, _ = get_pca(np.concatenate([traj[RAND1].xyz, aftraj.xyz])) 177 | af_coords_joint_pca = joint_pca.transform(aftraj.xyz.reshape(aftraj.n_frames, -1)) 178 | ref_coords_joint_pca = joint_pca.transform(traj.xyz.reshape(traj.n_frames, -1)) 179 | seed_coords_joint_pca = joint_pca.transform(ref.xyz.reshape(1, -1)) 180 | 181 | out['ref_variance'] = ref_pca.explained_variance_ / n_atoms * 100 182 | out['af_variance'] = af_pca.explained_variance_ / n_atoms * 100 183 | out['joint_variance'] = joint_pca.explained_variance_ / n_atoms * 100 184 | 185 | out['af_rmsf'] = mdtraj.rmsf(aftraj_aa, ref_aa) * 10 186 | out['ref_rmsf'] = mdtraj.rmsf(traj_aa, ref_aa) * 10 187 | 188 | print(f'Computing atomic EMD') 189 | ref_mean, ref_covar = get_mean_covar(traj_aa[RAND1K].xyz) 190 | af_mean, af_covar = get_mean_covar(aftraj_aa.xyz) 191 | out['emd_mean'] = (np.square(ref_mean - af_mean).sum(-1) ** 0.5) * 10 192 | try: 193 | out['emd_var'] = (np.trace(ref_covar + af_covar - 2*sqrtm(ref_covar @ af_covar), axis1=1,axis2=2) ** 0.5) * 10 194 | except: 195 | out['emd_var'] = np.trace(ref_covar) ** 0.5 * 10 196 | 197 | 198 | print(f'Analyzing SASA') 199 | sasa_thresh = 0.02 200 | af_sasa = mdtraj.shrake_rupley(aftraj_aa, probe_radius=0.28) 201 | af_sasa = condense_sidechain_sasas(af_sasa, aftraj_aa.top) 202 | ref_sasa = mdtraj.shrake_rupley(traj_aa[RAND1K], probe_radius=0.28) 203 | ref_sasa = condense_sidechain_sasas(ref_sasa, traj_aa.top) 204 | crystal_sasa = mdtraj.shrake_rupley(ref_aa, probe_radius=0.28) 205 | out['crystal_sasa'] = condense_sidechain_sasas(crystal_sasa, ref_aa.top) 206 | 207 | out['ref_sa_prob'] = (ref_sasa > sasa_thresh).mean(0) 208 | out['af_sa_prob'] = (af_sasa > sasa_thresh).mean(0) 209 | out['ref_mi_mat'] = sasa_mi(ref_sasa > sasa_thresh) 210 | out['af_mi_mat'] = sasa_mi(af_sasa > sasa_thresh) 211 | 212 | ref_distmat = np.linalg.norm(traj[RAND1].xyz[:,None,:] - traj[RAND1].xyz[:,:,None], axis=-1) 213 | af_distmat = np.linalg.norm(aftraj.xyz[:,None,:] - aftraj.xyz[:,:,None], axis=-1) 214 | 215 | out['ref_contact_prob'] = (ref_distmat < 0.8).mean(0) 216 | out['af_contact_prob'] = (af_distmat < 0.8).mean(0) 217 | out['crystal_distmat'] = np.linalg.norm(ref.xyz[0,None,:] - ref.xyz[0,:,None], axis=-1) 218 | 219 | out['ref_mean_pairwise_rmsd'] = get_rmsds(traj[RAND1].xyz, traj[RAND2].xyz, broadcast=True).mean() 220 | out['af_mean_pairwise_rmsd'] = get_rmsds(aftraj.xyz, aftraj.xyz, broadcast=True).mean() 221 | 222 | out['ref_rms_pairwise_rmsd'] = np.square(get_rmsds(traj[RAND1].xyz, traj[RAND2].xyz, broadcast=True)).mean() ** 0.5 223 | out['af_rms_pairwise_rmsd'] = np.square(get_rmsds(aftraj.xyz, aftraj.xyz, broadcast=True)).mean() ** 0.5 224 | 225 | out['ref_self_mean_pairwise_rmsd'] = get_rmsds(traj[RAND1].xyz, traj[RAND1].xyz, broadcast=True).mean() 226 | out['ref_self_rms_pairwise_rmsd'] = np.square(get_rmsds(traj[RAND1].xyz, traj[RAND1].xyz, broadcast=True)).mean() ** 0.5 227 | 228 | out['cosine_sim'] = (ref_pca.components_[0] * af_pca.components_[0]).sum() 229 | 230 | 231 | def get_emd(ref_coords1, ref_coords2, af_coords, seed_coords, K=None): 232 | if len(ref_coords1.shape) == 3: 233 | ref_coords1 = ref_coords1.reshape(ref_coords1.shape[0], -1) 234 | ref_coords2 = ref_coords2.reshape(ref_coords2.shape[0], -1) 235 | af_coords = af_coords.reshape(af_coords.shape[0], -1) 236 | seed_coords = seed_coords.reshape(seed_coords.shape[0], -1) 237 | if K is not None: 238 | ref_coords1 = ref_coords1[:,:K] 239 | ref_coords2 = ref_coords2[:,:K] 240 | af_coords = af_coords[:,:K] 241 | seed_coords = seed_coords[:,:K] 242 | emd = {} 243 | emd['ref|ref mean'] = (np.square(ref_coords1 - ref_coords1.mean(0)).sum(-1)).mean()**0.5 / n_atoms ** 0.5 * 10 244 | 245 | distmat = np.square(ref_coords1[:,None] - ref_coords2[None]).sum(-1) 246 | distmat = distmat ** 0.5 / n_atoms ** 0.5 * 10 247 | emd['ref|ref2'] = get_wasserstein(distmat) 248 | emd['ref mean|ref2 mean'] = np.square(ref_coords1.mean(0) - ref_coords2.mean(0)).sum() ** 0.5 / n_atoms ** 0.5 * 10 249 | 250 | distmat = np.square(ref_coords1[:,None] - af_coords[None]).sum(-1) 251 | distmat = distmat ** 0.5 / n_atoms ** 0.5 * 10 252 | emd['ref|af'] = get_wasserstein(distmat) 253 | emd['ref mean|af mean'] = np.square(ref_coords1.mean(0) - af_coords.mean(0)).sum() ** 0.5 / n_atoms ** 0.5 * 10 254 | 255 | emd['ref|seed'] = (np.square(ref_coords1 - seed_coords).sum(-1)).mean()**0.5 / n_atoms ** 0.5 * 10 256 | emd['ref mean|seed'] = (np.square(ref_coords1.mean(0) - seed_coords).sum(-1)).mean()**0.5 / n_atoms ** 0.5 * 10 257 | 258 | emd['af|seed'] = (np.square(af_coords - seed_coords).sum(-1)).mean()**0.5 / n_atoms ** 0.5 * 10 259 | emd['af|af mean'] = (np.square(af_coords - af_coords.mean(0)).sum(-1)).mean()**0.5 / n_atoms ** 0.5 * 10 260 | emd['af mean|seed'] = (np.square(af_coords.mean(0) - seed_coords).sum(-1)).mean()**0.5 / n_atoms ** 0.5 * 10 261 | return emd 262 | 263 | K=2 264 | out[f'EMD,ref'] = get_emd(ref_coords[RAND1], ref_coords[RAND2], af_coords_ref_pca, seed_coords_ref_pca, K=K) 265 | out[f'EMD,af2'] = get_emd(ref_coords_af_pca[RAND1], ref_coords_af_pca[RAND2], af_coords, seed_coords_af_pca, K=K) 266 | out[f'EMD,joint'] = get_emd(ref_coords_joint_pca[RAND1], ref_coords_joint_pca[RAND2], af_coords_joint_pca, seed_coords_joint_pca, K=K) 267 | return name, out 268 | 269 | 270 | if args.pdb_id: 271 | pdb_id = args.pdb_id 272 | else: 273 | pdb_id = [nam.split('.')[0] for nam in os.listdir(args.pdbdir) if '.pdb' in nam] 274 | 275 | if args.num_workers > 1: 276 | p = Pool(args.num_workers) 277 | p.__enter__() 278 | __map__ = p.imap 279 | else: 280 | __map__ = map 281 | out = dict(tqdm.tqdm(__map__(main, pdb_id), total=len(pdb_id))) 282 | if args.num_workers > 1: 283 | p.__exit__(None, None, None) 284 | 285 | with open(f"{args.pdbdir}/out.pkl", 'wb') as f: 286 | f.write(pickle.dumps(out)) -------------------------------------------------------------------------------- /scripts/cluster_chains.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--chains', type=str, default='pdb_mmcif.csv') 5 | parser.add_argument('--out', type=str, default='pdb_clusters') 6 | parser.add_argument('--thresh', type=float, default=0.4) 7 | parser.add_argument('--mmseqs_path', type=str, default='mmseqs') 8 | args = parser.parse_args() 9 | 10 | import pandas as pd 11 | import os, json, tqdm, random, subprocess, pickle 12 | from collections import defaultdict 13 | from functools import partial 14 | import numpy as np 15 | from Bio.Seq import Seq 16 | from Bio.SeqRecord import SeqRecord 17 | from Bio import SeqIO 18 | from collections import defaultdict 19 | 20 | def main(): 21 | df = pd.read_csv(args.chains, index_col='name') 22 | 23 | sequences = [SeqRecord(Seq(row.seqres), id=name) for name, row in tqdm.tqdm(df.iterrows())] 24 | SeqIO.write(sequences, ".in.fasta", "fasta") 25 | cmd = [args.mmseqs_path, 'easy-cluster', '.in.fasta', '.out', '.tmp', '--min-seq-id', str(args.thresh), '--alignment-mode', '1'] 26 | subprocess.run(cmd)#, stdout=open('/dev/null', 'w')) 27 | f = open('.out_cluster.tsv') 28 | clusters = [] 29 | for line in f: 30 | a, b = line.strip().split() 31 | if a == b: 32 | clusters.append([]) 33 | clusters[-1].append(b) 34 | subprocess.run(['rm', '-r', '.in.fasta', '.tmp', '.out_all_seqs.fasta', '.out_rep_seq.fasta', '.out_cluster.tsv']) 35 | 36 | with open(args.out, 'w') as f: 37 | for clus in clusters: 38 | f.write(' '.join(clus)) 39 | f.write('\n') 40 | 41 | 42 | if __name__ == "__main__": 43 | main() -------------------------------------------------------------------------------- /scripts/download_atlas.sh: -------------------------------------------------------------------------------- 1 | for name in $(cat `dirname $0`/../splits/atlas.csv | grep -v name | awk -F ',' {'print $1'}); do 2 | wget https://www.dsimb.inserm.fr/ATLAS/database/ATLAS/${name}/${name}_protein.zip 3 | mkdir ${name} 4 | unzip ${name}_protein.zip -d ${name} 5 | rm ${name}_protein.zip 6 | done -------------------------------------------------------------------------------- /scripts/mmseqs_query.py: -------------------------------------------------------------------------------- 1 | # https://github.com/sokrypton/ColabFold/blob/main/colabfold/colabfold.py 2 | 3 | import argparse 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--split', type=str, required=True) 6 | parser.add_argument('--outdir', type=str, default='./alignment_dir') 7 | args = parser.parse_args() 8 | 9 | from typing import Tuple, List 10 | import os 11 | import requests 12 | import random 13 | import time 14 | import logging 15 | import tarfile 16 | import pandas as pd 17 | from tqdm import tqdm 18 | logger = logging.getLogger(__name__) 19 | 20 | TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]' 21 | def run_mmseqs2(x, prefix, use_env=True, use_filter=True, 22 | use_templates=False, filter=None, use_pairing=False, pairing_strategy="greedy", 23 | host_url="https://api.colabfold.com", 24 | user_agent: str = "") -> Tuple[List[str], List[str]]: 25 | submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa" 26 | 27 | headers = {} 28 | if user_agent != "": 29 | headers['User-Agent'] = user_agent 30 | else: 31 | logger.warning("No user agent specified. Please set a user agent (e.g., 'toolname/version contact@email') to help us debug in case of problems. This warning will become an error in the future.") 32 | 33 | def submit(seqs, mode, N=101): 34 | n, query = N, "" 35 | for seq in seqs: 36 | query += f">{n}\n{seq}\n" 37 | n += 1 38 | 39 | while True: 40 | error_count = 0 41 | try: 42 | # https://requests.readthedocs.io/en/latest/user/advanced/#advanced 43 | # "good practice to set connect timeouts to slightly larger than a multiple of 3" 44 | res = requests.post(f'{host_url}/{submission_endpoint}', data={ 'q': query, 'mode': mode }, timeout=6.02, headers=headers) 45 | except requests.exceptions.Timeout: 46 | logger.warning("Timeout while submitting to MSA server. Retrying...") 47 | continue 48 | except Exception as e: 49 | error_count += 1 50 | logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)") 51 | logger.warning(f"Error: {e}") 52 | time.sleep(5) 53 | if error_count > 5: 54 | raise 55 | continue 56 | break 57 | 58 | try: 59 | out = res.json() 60 | except ValueError: 61 | logger.error(f"Server didn't reply with json: {res.text}") 62 | out = {"status":"ERROR"} 63 | return out 64 | 65 | def status(ID): 66 | while True: 67 | error_count = 0 68 | try: 69 | res = requests.get(f'{host_url}/ticket/{ID}', timeout=6.02, headers=headers) 70 | except requests.exceptions.Timeout: 71 | logger.warning("Timeout while fetching status from MSA server. Retrying...") 72 | continue 73 | except Exception as e: 74 | error_count += 1 75 | logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)") 76 | logger.warning(f"Error: {e}") 77 | time.sleep(5) 78 | if error_count > 5: 79 | raise 80 | continue 81 | break 82 | try: 83 | out = res.json() 84 | except ValueError: 85 | logger.error(f"Server didn't reply with json: {res.text}") 86 | out = {"status":"ERROR"} 87 | return out 88 | 89 | def download(ID, path): 90 | error_count = 0 91 | while True: 92 | try: 93 | res = requests.get(f'{host_url}/result/download/{ID}', timeout=6.02, headers=headers) 94 | except requests.exceptions.Timeout: 95 | logger.warning("Timeout while fetching result from MSA server. Retrying...") 96 | continue 97 | except Exception as e: 98 | error_count += 1 99 | logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)") 100 | logger.warning(f"Error: {e}") 101 | time.sleep(5) 102 | if error_count > 5: 103 | raise 104 | continue 105 | break 106 | with open(path,"wb") as out: out.write(res.content) 107 | 108 | # process input x 109 | seqs = [x] if isinstance(x, str) else x 110 | 111 | # compatibility to old option 112 | if filter is not None: 113 | use_filter = filter 114 | 115 | # setup mode 116 | if use_filter: 117 | mode = "env" if use_env else "all" 118 | else: 119 | mode = "env-nofilter" if use_env else "nofilter" 120 | 121 | if use_pairing: 122 | use_templates = False 123 | use_env = False 124 | mode = "" 125 | # greedy is default, complete was the previous behavior 126 | if pairing_strategy == "greedy": 127 | mode = "pairgreedy" 128 | elif pairing_strategy == "complete": 129 | mode = "paircomplete" 130 | 131 | # define path 132 | path = f"{prefix}_{mode}" 133 | if not os.path.isdir(path): os.mkdir(path) 134 | 135 | # call mmseqs2 api 136 | tar_gz_file = f'{path}/out.tar.gz' 137 | N,REDO = 101,True 138 | 139 | # deduplicate and keep track of order 140 | seqs_unique = [] 141 | #TODO this might be slow for large sets 142 | [seqs_unique.append(x) for x in seqs if x not in seqs_unique] 143 | Ms = [N + seqs_unique.index(seq) for seq in seqs] 144 | # lets do it! 145 | if not os.path.isfile(tar_gz_file): 146 | TIME_ESTIMATE = 150 * len(seqs_unique) 147 | with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar: 148 | while REDO: 149 | pbar.set_description("SUBMIT") 150 | 151 | # Resubmit job until it goes through 152 | out = submit(seqs_unique, mode, N) 153 | while out["status"] in ["UNKNOWN", "RATELIMIT"]: 154 | sleep_time = 5 + random.randint(0, 5) 155 | logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}") 156 | # resubmit 157 | time.sleep(sleep_time) 158 | out = submit(seqs_unique, mode, N) 159 | 160 | if out["status"] == "ERROR": 161 | raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.') 162 | 163 | if out["status"] == "MAINTENANCE": 164 | raise Exception(f'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.') 165 | 166 | # wait for job to finish 167 | ID,TIME = out["id"],0 168 | pbar.set_description(out["status"]) 169 | while out["status"] in ["UNKNOWN","RUNNING","PENDING"]: 170 | t = 5 + random.randint(0,5) 171 | logger.error(f"Sleeping for {t}s. Reason: {out['status']}") 172 | time.sleep(t) 173 | out = status(ID) 174 | pbar.set_description(out["status"]) 175 | if out["status"] == "RUNNING": 176 | TIME += t 177 | pbar.update(n=t) 178 | #if TIME > 900 and out["status"] != "COMPLETE": 179 | # # something failed on the server side, need to resubmit 180 | # N += 1 181 | # break 182 | 183 | if out["status"] == "COMPLETE": 184 | if TIME < TIME_ESTIMATE: 185 | pbar.update(n=(TIME_ESTIMATE-TIME)) 186 | REDO = False 187 | 188 | if out["status"] == "ERROR": 189 | REDO = False 190 | raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.') 191 | 192 | # Download results 193 | download(ID, tar_gz_file) 194 | 195 | # prep list of a3m files 196 | if use_pairing: 197 | a3m_files = [f"{path}/pair.a3m"] 198 | else: 199 | a3m_files = [f"{path}/uniref.a3m"] 200 | if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m") 201 | 202 | # extract a3m files 203 | if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files): 204 | with tarfile.open(tar_gz_file) as tar_gz: 205 | tar_gz.extractall(path) 206 | 207 | # templates 208 | if use_templates: 209 | templates = {} 210 | #print("seq\tpdb\tcid\tevalue") 211 | for line in open(f"{path}/pdb70.m8","r"): 212 | p = line.rstrip().split() 213 | M,pdb,qid,e_value = p[0],p[1],p[2],p[10] 214 | M = int(M) 215 | if M not in templates: templates[M] = [] 216 | templates[M].append(pdb) 217 | #if len(templates[M]) <= 20: 218 | # print(f"{int(M)-N}\t{pdb}\t{qid}\t{e_value}") 219 | 220 | template_paths = {} 221 | for k,TMPL in templates.items(): 222 | TMPL_PATH = f"{prefix}_{mode}/templates_{k}" 223 | if not os.path.isdir(TMPL_PATH): 224 | os.mkdir(TMPL_PATH) 225 | TMPL_LINE = ",".join(TMPL[:20]) 226 | response = None 227 | while True: 228 | error_count = 0 229 | try: 230 | # https://requests.readthedocs.io/en/latest/user/advanced/#advanced 231 | # "good practice to set connect timeouts to slightly larger than a multiple of 3" 232 | response = requests.get(f"{host_url}/template/{TMPL_LINE}", stream=True, timeout=6.02, headers=headers) 233 | except requests.exceptions.Timeout: 234 | logger.warning("Timeout while submitting to template server. Retrying...") 235 | continue 236 | except Exception as e: 237 | error_count += 1 238 | logger.warning(f"Error while fetching result from template server. Retrying... ({error_count}/5)") 239 | logger.warning(f"Error: {e}") 240 | time.sleep(5) 241 | if error_count > 5: 242 | raise 243 | continue 244 | break 245 | with tarfile.open(fileobj=response.raw, mode="r|gz") as tar: 246 | tar.extractall(path=TMPL_PATH) 247 | os.symlink("pdb70_a3m.ffindex", f"{TMPL_PATH}/pdb70_cs219.ffindex") 248 | with open(f"{TMPL_PATH}/pdb70_cs219.ffdata", "w") as f: 249 | f.write("") 250 | template_paths[k] = TMPL_PATH 251 | 252 | # gather a3m lines 253 | a3m_lines = {} 254 | for a3m_file in a3m_files: 255 | update_M,M = True,None 256 | for line in open(a3m_file,"r"): 257 | if len(line) > 0: 258 | if "\x00" in line: 259 | line = line.replace("\x00","") 260 | update_M = True 261 | if line.startswith(">") and update_M: 262 | M = int(line[1:].rstrip()) 263 | update_M = False 264 | if M not in a3m_lines: a3m_lines[M] = [] 265 | a3m_lines[M].append(line) 266 | 267 | # return results 268 | 269 | a3m_lines = ["".join(a3m_lines[n]) for n in Ms] 270 | 271 | if use_templates: 272 | template_paths_ = [] 273 | for n in Ms: 274 | if n not in template_paths: 275 | template_paths_.append(None) 276 | #print(f"{n-N}\tno_templates_found") 277 | else: 278 | template_paths_.append(template_paths[n]) 279 | template_paths = template_paths_ 280 | 281 | 282 | return (a3m_lines, template_paths) if use_templates else a3m_lines 283 | 284 | df = pd.read_csv(args.split, index_col='name') 285 | os.makedirs(args.outdir, exist_ok=True) 286 | 287 | msas = run_mmseqs2(list(df.seqres), prefix='/tmp/', user_agent='bjing2016/alphaflow') 288 | os.system('rm -r /tmp/_env') 289 | 290 | for name, msa in zip(df.index, msas): 291 | os.makedirs(f'{args.outdir}/{name}/a3m/', exist_ok=True) 292 | with open(f'{args.outdir}/{name}/a3m/{name}.a3m', 'w') as f: 293 | f.write(msa) 294 | 295 | -------------------------------------------------------------------------------- /scripts/mmseqs_search_helper.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser() 3 | parser.add_argument('--split', type=str, required=True) 4 | parser.add_argument('--db_dir', type=str, default='./dbbase') 5 | parser.add_argument('--outdir', type=str, default='./alignment_dir') 6 | args = parser.parse_args() 7 | 8 | import pandas as pd 9 | import subprocess, os 10 | 11 | df = pd.read_csv(args.split) 12 | os.makedirs(args.outdir, exist_ok=True) 13 | with open('/tmp/tmp.fasta', 'w') as f: 14 | for _, row in df.iterrows(): 15 | f.write(f'>{row["name"]}\n{row.seqres}\n') 16 | 17 | cmd = f'python -m scripts.mmseqs_search /tmp/tmp.fasta {args.db_dir} {args.outdir}' 18 | os.system(cmd) 19 | 20 | for name in os.listdir(args.outdir): 21 | if '.a3m' not in name: 22 | continue 23 | with open(f'{args.outdir}/{name}') as f: 24 | pdb_id = next(f).strip()[1:] 25 | cmd = f'mkdir -p {args.outdir}/{pdb_id}/a3m' 26 | print(cmd) 27 | os.system(cmd) 28 | cmd = f'mv {args.outdir}/{name} {args.outdir}/{pdb_id}/a3m/{pdb_id}.a3m' 29 | os.system(cmd) 30 | print(cmd) -------------------------------------------------------------------------------- /scripts/prep_atlas.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--split', type=str, default='splits/atlas.csv') 5 | parser.add_argument('--atlas_dir', type=str, required=True) 6 | parser.add_argument('--outdir', type=str, default='./data_atlas') 7 | parser.add_argument('--num_workers', type=int, default=1) 8 | args = parser.parse_args() 9 | 10 | import mdtraj, os, tempfile, tqdm 11 | from alphaflow.utils import protein 12 | from openfold.data.data_pipeline import make_protein_features 13 | import pandas as pd 14 | from multiprocessing import Pool 15 | import numpy as np 16 | 17 | os.makedirs(args.outdir, exist_ok=True) 18 | 19 | df = pd.read_csv(args.split, index_col='name') 20 | 21 | def main(): 22 | jobs = [] 23 | for name in df.index: 24 | #if os.path.exists(f'{args.outdir}/{name}.npz'): continue 25 | jobs.append(name) 26 | 27 | if args.num_workers > 1: 28 | p = Pool(args.num_workers) 29 | p.__enter__() 30 | __map__ = p.imap 31 | else: 32 | __map__ = map 33 | for _ in tqdm.tqdm(__map__(do_job, jobs), total=len(jobs)): 34 | pass 35 | if args.num_workers > 1: 36 | p.__exit__(None, None, None) 37 | 38 | def do_job(name): 39 | traj = mdtraj.load(f'{args.atlas_dir}/{name}/{name}_prod_R1_fit.xtc', top=f'{args.atlas_dir}/{name}/{name}.pdb') \ 40 | + mdtraj.load(f'{args.atlas_dir}/{name}/{name}_prod_R2_fit.xtc', top=f'{args.atlas_dir}/{name}/{name}.pdb') \ 41 | + mdtraj.load(f'{args.atlas_dir}/{name}/{name}_prod_R3_fit.xtc', top=f'{args.atlas_dir}/{name}/{name}.pdb') 42 | ref = mdtraj.load(f'{args.atlas_dir}/{name}/{name}.pdb') 43 | traj = ref + traj 44 | f, temp_path = tempfile.mkstemp(); os.close(f) 45 | positions_stacked = [] 46 | for i in tqdm.trange(0, len(traj), 100): 47 | traj[i].save_pdb(temp_path) 48 | 49 | with open(temp_path) as f: 50 | prot = protein.from_pdb_string(f.read()) 51 | pdb_feats = make_protein_features(prot, name) 52 | positions_stacked.append(pdb_feats['all_atom_positions']) 53 | 54 | 55 | pdb_feats['all_atom_positions'] = np.stack(positions_stacked) 56 | print({key: pdb_feats[key].shape for key in pdb_feats}) 57 | np.savez(f"{args.outdir}/{name}.npz", **pdb_feats) 58 | os.unlink(temp_path) 59 | 60 | main() 61 | -------------------------------------------------------------------------------- /scripts/print_analysis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tqdm, sys, pickle, warnings 3 | import pandas as pd 4 | import scipy.stats 5 | 6 | paths = sys.argv[1:] 7 | 8 | def correlations(a, b, prefix=''): 9 | return { 10 | prefix + 'pearson': scipy.stats.pearsonr(a, b)[0], 11 | prefix + 'spearman': scipy.stats.spearmanr(a, b)[0], 12 | prefix + 'kendall': scipy.stats.kendalltau(a, b)[0], 13 | } 14 | 15 | 16 | def analyze_data(data): 17 | mi_mats = {} 18 | df = [] 19 | for name, out in data.items(): 20 | item = { 21 | 'name': name, 22 | 'md_pairwise': out['ref_mean_pairwise_rmsd'], 23 | 'af_pairwise': out['af_mean_pairwise_rmsd'], 24 | 'cosine_sim': abs(out['cosine_sim']), 25 | 'emd_mean': np.square(out['emd_mean']).mean() ** 0.5, 26 | 'emd_var': np.square(out['emd_var']).mean() ** 0.5, 27 | } | correlations(out['af_rmsf'], out['ref_rmsf'], prefix='rmsf_') 28 | if 'EMD,ref' not in out: 29 | out['EMD,ref'] = out['EMD-2,ref'] 30 | out['EMD,af2'] = out['EMD-2,af2'] 31 | out['EMD,joint'] = out['EMD-2,joint'] 32 | for emd_dict, emd_key in [ 33 | (out['EMD,ref'], 'ref'), 34 | (out['EMD,joint'], 'joint') 35 | ]: 36 | item.update({ 37 | emd_key + 'emd': emd_dict['ref|af'], 38 | emd_key + 'emd_tr': emd_dict['ref mean|af mean'], 39 | emd_key + 'emd_int': (emd_dict['ref|af']**2 - emd_dict['ref mean|af mean']**2)**0.5, 40 | }) 41 | 42 | try: 43 | crystal_contact_mask = out['crystal_distmat'] < 0.8 44 | ref_transient_mask = (~crystal_contact_mask) & (out['ref_contact_prob'] > 0.1) 45 | af_transient_mask = (~crystal_contact_mask) & (out['af_contact_prob'] > 0.1) 46 | ref_weak_mask = crystal_contact_mask & (out['ref_contact_prob'] < 0.9) 47 | af_weak_mask = crystal_contact_mask & (out['af_contact_prob'] < 0.9) 48 | item.update({ 49 | 'weak_contacts_iou': (ref_weak_mask & af_weak_mask).sum() / (ref_weak_mask | af_weak_mask).sum(), 50 | 'transient_contacts_iou': (ref_transient_mask & af_transient_mask).sum() / (ref_transient_mask | af_transient_mask).sum() 51 | }) 52 | except: 53 | item.update({ 54 | 'weak_contacts_iou': np.nan, 55 | 'transient_contacts_iou': np.nan, 56 | }) 57 | sasa_thresh = 0.02 58 | buried_mask = out['crystal_sasa'][0] < sasa_thresh 59 | ref_sa_mask = (out['ref_sa_prob'] > 0.1) & buried_mask 60 | af_sa_mask = (out['af_sa_prob'] > 0.1) & buried_mask 61 | 62 | item.update({ 63 | 'num_sasa': ref_sa_mask.sum(), 64 | 'sasa_iou': (ref_sa_mask & af_sa_mask).sum() / (ref_sa_mask | af_sa_mask).sum(), 65 | }) 66 | item.update(correlations(out['ref_mi_mat'].flatten(), out['af_mi_mat'].flatten(), prefix='exposon_mi_')) 67 | 68 | df.append(item) 69 | df = pd.DataFrame(df).set_index('name')#.join(val_df) 70 | all_ref_rmsf = np.concatenate([data[name]['ref_rmsf'] for name in df.index]) 71 | all_af_rmsf = np.concatenate([data[name]['af_rmsf'] for name in df.index]) 72 | return all_ref_rmsf, all_af_rmsf, df, data 73 | 74 | datas = {} 75 | 76 | for path in tqdm.tqdm(paths): 77 | with open(path, 'rb') as f: 78 | data = pickle.load(f) 79 | with warnings.catch_warnings(): 80 | warnings.simplefilter("ignore") 81 | datas[path] = analyze_data(data) 82 | 83 | new_df = [] 84 | for key in datas: 85 | ref_rmsf, af_rmsf, df, data = datas[key] 86 | new_df.append({ 87 | 'path': key, 88 | 'count': len(df), 89 | 'MD pairwise RMSD': df.md_pairwise.median(), 90 | 'Pairwise RMSD': df.af_pairwise.median(), 91 | 'Pairwise RMSD r': scipy.stats.pearsonr(df.md_pairwise, df.af_pairwise)[0], 92 | 'MD RMSF': np.median(ref_rmsf), 93 | 'RMSF': np.median(af_rmsf), 94 | 'Global RMSF r': scipy.stats.pearsonr(ref_rmsf, af_rmsf)[0], 95 | 'Per target RMSF r': df.rmsf_pearson.median(), 96 | 'RMWD': np.sqrt(df.emd_mean**2 + df.emd_var**2).median(), 97 | 'RMWD trans': df.emd_mean.median(), 98 | 'RMWD var': df.emd_var.median(), 99 | 'MD PCA W2': df.refemd.median(), 100 | 'Joint PCA W2': df.jointemd.median(), 101 | 'PC sim > 0.5 %': (df.cosine_sim > 0.5).mean() * 100, 102 | 'Weak contacts J': df.weak_contacts_iou.median(), 103 | 'Weak contacts nans': df.weak_contacts_iou.isna().mean(), 104 | 'Transient contacts J': df.transient_contacts_iou.median(), 105 | 'Transient contacts nans': df.transient_contacts_iou.isna().mean(), 106 | 'Exposed residue J': df.sasa_iou.median(), 107 | 'Exposed MI matrix rho': df.exposon_mi_spearman.median(), 108 | }) 109 | 110 | new_df = pd.DataFrame(new_df).set_index('path') 111 | print(new_df.round(2).T) -------------------------------------------------------------------------------- /scripts/unpack_mmcif.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--mmcif_dir', type=str, required=True) 5 | parser.add_argument('--outdir', type=str, default='./data') 6 | parser.add_argument('--outcsv', type=str, default='./pdb_mmcif.csv') 7 | parser.add_argument('--num_workers', type=int, default=15) 8 | args = parser.parse_args() 9 | 10 | import warnings, tqdm, os, io, logging 11 | import pandas as pd 12 | import numpy as np 13 | from multiprocessing import Pool 14 | from alphaflow.data.data_pipeline import DataPipeline 15 | from openfold.data import mmcif_parsing 16 | 17 | pipeline = DataPipeline(template_featurizer=None) 18 | 19 | def main(): 20 | dirs = os.listdir(args.mmcif_dir) 21 | files = [os.listdir(f"{args.mmcif_dir}/{dir}") for dir in dirs] 22 | files = sum(files, []) 23 | if args.num_workers > 1: 24 | p = Pool(args.num_workers) 25 | p.__enter__() 26 | __map__ = p.imap 27 | else: 28 | __map__ = map 29 | infos = list(tqdm.tqdm(__map__(unpack_mmcif, files), total=len(files))) 30 | if args.num_workers > 1: 31 | p.__exit__(None, None, None) 32 | info = [] 33 | for inf in infos: 34 | info.extend(inf) 35 | df = pd.DataFrame(info).set_index('name') 36 | df.to_csv(args.outcsv) 37 | 38 | def unpack_mmcif(name): 39 | path = f"{args.mmcif_dir}/{name[1:3]}/{name}" 40 | 41 | with open(path, 'r') as f: 42 | mmcif_string = f.read() 43 | 44 | 45 | mmcif = mmcif_parsing.parse( 46 | file_id=name[:-4], mmcif_string=mmcif_string 47 | ) 48 | if mmcif.mmcif_object is None: 49 | logging.info(f"Could not parse {name}. Skipping...") 50 | return [] 51 | else: 52 | mmcif = mmcif.mmcif_object 53 | 54 | out = [] 55 | for chain, seq in mmcif.chain_to_seqres.items(): 56 | out.append({ 57 | "name": f"{name[:-4]}_{chain}", 58 | "release_date": mmcif.header["release_date"], 59 | "seqres": seq, 60 | "resolution": mmcif.header["resolution"], 61 | }) 62 | 63 | data = pipeline.process_mmcif(mmcif=mmcif, chain_id=chain) 64 | out_dir = f"{args.outdir}/{name[1:3]}" 65 | os.makedirs(out_dir, exist_ok=True) 66 | out_path = f"{out_dir}/{name[:-4]}_{chain}.npz" 67 | np.savez(out_path, **data) 68 | 69 | 70 | return out 71 | 72 | if __name__ == "__main__": 73 | main() -------------------------------------------------------------------------------- /splits/atlas_val.csv: -------------------------------------------------------------------------------- 1 | name,seqres,release_date,msa_id 2 | 6irx_A,VYWDLDIQTNAVIRERAPADHLPPHPEIELQRAQLTTKLRQHYHELCSQREGIEPPRESFNRWLLERKVVDKGLDPLLPSECDPVISPSMFREIMNDIPIRLSRIKYKEEARKLLFKYAEAAKKMIDSRNVTPESRKVVKWNVEDTMNWLRRDHSASKEDYMDRLENLRKQCGPHVASVAKDSVEGICSKIYHISAEYVRRIRQAHLTLLKECNISVDGTESAEVQDRLVYCYPVRLSIPAPPQTRVELHFENDIACLRFKGEMVKVSRGHFNKLELLYRYSCIDDPRFEKFLSRVWCLIKRYQVMFGSGVNEGSGLQGSLPVPVFEALNKQFGVTFECFASPLNCYFKQFCSAFPDIDGFFGSRGPFLSFSPASGSFEANPPFCEELMDAMVTHFEDLLGRSSEPLSFIIFVPEWRDPPTPALTRMEASRFRRHQMTVPAFEHEYRSGSQHICKREEIYYKAIHGTAVIFLQNNAGFAKWEPTTERIQELLAAYK,2018-12-05,6irx_A 3 | 6cka_B,MLYIDEFKEAIDKGYILGDTVAIVRKNGKIFDYVLPHEKVRDDEVVTVERVEEVMVELDKLEHHHHHH,2018-11-14,6cka_B 4 | 6hj6_A,ETGMPLSCSKNNSHHYIFVGNNSGLELTLTNTSLLNHKFCNLSDAHKRAQYDMALMSIVSTFHLSIPNFNQYEAMSCDFNGDNITVQYNLSHASAVDAANHCGTVANGILETFHKFFWSNNIKDAYQLPNQGKLAHCYSTSYQFLIIQNTTWEDHCTFSRPTGTKHHHHHH,2018-10-10,6hj6_A 5 | 6dgk_B,ASRYLTDMTLEEMSRDWFMLMPKQKVAGSLCIRMDQAIMDKNIILKANFSVIFDRLETLILLRAFTEEGAIVGEISPLPSLPGHTDEDVKNAVGVLIGGLEANDNTVRVSETLQRFAWRS,2018-08-15,6dgk_B 6 | 6fc0_B,GPHMQPSAALQSLRSARFLPGIVQDIYPPGIKSPNPALNEAVQKKGRIFKYDVQFLLQFQNVFTEKPSPDFDQQVKALIGD,2018-06-27,6fc0_B 7 | 6dlm_A,GSPRSYLLKELADLSQHLVRLLERLVRESERVVEVLERGEVDEEELKRLEDLHRELEKAVREVRETHREIRERSR,2018-12-19,6dlm_A 8 | 6mdw_A,GPGRPQESLSLVDASWELVDPTPDLQALFVQFNDQFFWGQLEAVEVKWSVRMTLCAGICSYEGKGGMCSIRLSEPLLKLRPRKDLVETLLHEMIHAYLFVTNNDKDREGHGPEFCKHMHRINSLTGANITVYHTFHDEVDEYRRHWWRCNGPCQHRPPYYGYVKRATNREPSAHDYWWAEHQKTCGGTYIKIKE,2019-04-10,6mdw_A 9 | 6bwq_A,TADAVLMIEANLDDQTGEGLGYVMNQLLTAGAYDVFFTPIQMKKDRPATKLTVLGNVNDKDLLTKLILQETTTIGVRYQTWQRTIMQRHFLTVATPYGDVQVKVATYQDIEKKMPEYADCAQLAQQFHIPFRTVYQAALVAVDQLDEEA,2018-06-20,6bwq_A 10 | 5ydn_A,MNYATVNDLCARYTRTRLDILTRPKTADGQPDDAVAEQALADASAFIDGYLAARFVLPLTVVPSLLKRQCCVVAWFYLNESQPTEQITATYRDTVRWLEQVRDGKTDPG,2018-07-11,5ydn_A 11 | 5w82_E,HMMPSEDYAIWYARATIAALQAAEYRLAMPSASYTAWFTDAVSDKLDKISESLNTLVECVIDKRLAVSVPEPLPVRVENKVQVEVEDEVRVRVENKVDVEVKN,2018-12-26,5w82_E 12 | 6cb7_A,HMDKLRVLYDEFVTISKDNLERETGLSASDVDMDFDLNIFMTLVPVLAAAVCAITPTIEDDKIVTMMKYCSYQSFSFWFLKSGAVVKSVYNKLDYVKKEKFVATFRDMLLNVQTLISLNSMY,2018-12-12,6cb7_A 13 | 5yrv_I,MARVSDYPLANKHPEWVKTATNKTLDDFTLENVLSNKVTAQDMRITPETLRLQASIAKDAGRDRLAMNFERAAELTAVPDDRILEIYNALRPYRSTKEELLAIADDLESRYQAKICAAFVREAATLYVERKKLKGDD,2018-09-19,5yrv_I 14 | 5z51_A,PTLWPQREALKSALQYPALAGPVFDALTVEGFTHPEYAAVRAAIDTAGGTSAGLSGAQWLDMVRQQTTSTVTSALISELGVEAIQVDDDKLPRYIAGVLARLQEVWLGRQIAEVKSKLQRMSPIEQGDEYHALFGDLVAMEAYRRSLLEQASGDDLHHHHHH,2018-11-28,5z51_A 15 | 6a9a_A,MISDSMTVEEIRLHLGLALKEKDFVVDKTGVKTIEIIGASFVADEPFIFGALNDEYIQRELEWYKSKSLFVKDIPGETPKIWQQVASSKGEINSNYGWAIWSEDNYAQYDMCLAELGQNPDSRRGIMIYTRPSMQFDYNKDGMSDFMSTNTVQYLIRDKKINAVVNMRSNDVVFGFRNNYAWQKYVLDKLVSDLNAGDSTRQYKAGSIIWNVGSLHVYSRHFYLVDHWWKTGETHISKKDYVGKYA,2019-01-02,6a9a_A 16 | 6atg_C,GHMNQRNINELKIFVEKAKYYSIKLDAIYNECTGAYNDIMTYSEGTFSDQSKVNQAISIFKKDNKIVNKFKELEKIIEEYKPMFLSKLIDDFAIELDQAVDNDVSNARHVADSYKKLRKSVVLAYIESFDVISSKFVDSKFVEASKKFVNKAKEFVEENDLIALECIVKTIGDMVNDREINSRSRYNNFYKKEADFLGAAVELEGAYKAIKQTL,2018-09-12,6atg_C 17 | 5w4a_B,MDTNKREIVEFLGIRTYFFPNLALYAVNNDELLVSDPNKANSFAAYVFGASDKKPSVDDIVQILFPSGSDSGTILTSMDTLLALGPDFLTEFKKRNQDLARFNLTHDLSILAQGDEDAAKKKLNLMGRKAKLQKTEAAKILAILIKTINSEENYEKFTELSELCGLDLDFDAYVFTKILGLEDEDTADEVEVIRDNFLNRLDQTKPKLADIIRNGP,2018-06-13,5w4a_B 18 | 5zmo_A,GSREAPKTFHRRVGDVRPARRAMGPALHRPVLLLWAIGQAVARAPRLQPWSTTRDAVAPLMEKYGQVEDGVDGVRYPFWALVRDDLWCVEQAEELTLTSRGRRPTLESLNAVDPSAGLREDDYNLLRSQPEAAASAAAGLIARYFHLLPAGLLEDFGLHELLAGRWPDALRP,2018-09-26,5zmo_A 19 | 6bk4_A,ARKDNLAKTIAETAKIREVLIIQNVLNCFNDDQVRSDFLNGENGAKKLENTELELLEKFFIETQTRRPETADDVSFIATAQKSAELFYSTINARPKSFGEVSFEKLRSLFQQIQDSGYLDKYYL,2018-11-07,6bk4_A 20 | 6eu8_A,MDKKDDVKETIKKSKTIDATKYEMWTYVNLETGQTETHRDFSEWHVMKNGKLLETIPAKGSEADIKIKWHIAIHRFDIRTNEGEAIATKETEFSKVTGLPAGDYKKDVEIKDKMLVGFNMADMMKSKFTVAGMAKVNPVLKTWIVENPMGKAPVLSKSVFVVKFKDGSYAKIKFTDATNDKQEKGHVSFNYEFQPK,2018-10-24,6eu8_A 21 | 6mbg_A,SNATGERFHPLLGRRVLDPAPGARCFTGTVSSDRPAYLGEHWVYDAIVVLGVTYLEMALAAASRLLPGNAADTLIVEDVTVWSPLVLRAGAPARLRLRSEDERFEIHSAEPERSDDESAWTRHATGRIARRQLSPDATGQLPRLDGEAVELDAYYERMRIYYGPRLRNIRHLERRGREAIGHVCLQGEEAQESASYELHPALLDACFQCVFALIYAHESHREPFVPLGCARIELRARGVREVRVHLRLHPPRSTDHNQTHTADLRLFDMEGRLVASVDALQLKRA,2018-09-19,6mbg_A 22 | 6bm5_A,GSKKPRTPPVTLEAARDNDFAFDWQAYTPPVAHRLGVQEVEASIETLRNYIDWTPFFMTWSLAGKYPRILEDEVVGVEAQRLFKDANDMLDKLSAEKTLNPRGVVGLFPANRVGDDIEIYRDETRTHVINVSHHLRQQTEKTGFANYCLADFVAPKLSGKADYIGAFAVTGGLEEDALADAFEAQHDDYNKIMVKALADRLAEAFAEYLHERVRKVYWGYAPNENLSNEELIRENYQGIRPAPGYPACPEHTEKATIWELLEVEKHTGMKLTESFAMWPGASVSGWYFSHPDSKYYAVAQIQRDQVEDYARRKGMSVTEVERWLAPNLGYDAD,2018-05-23,6bm5_A 23 | 6f45_D,MAVQGPWVGSSYVAETGQNWASLAANELRVTERPFWISSFIGRSKEEIWEWTGENHSFNKDWLIGELRNRGGTPVVINIRAHQVSYTPGAPLFEFPGDLPNAYITLNIYADIYGRGGTGGVAYLGGNPGGDCIHNWIGNRLRINNQGWICGGGGGGGGFRVGHTEAGGGGGRPLGAGGVSSLNLNGDNATLGAPGRGYQLGNDYAGNGGDVGNPGSASSAEMGGGAAGRAVVGTSPQWINVGNIAGSWL,2018-08-01,6f45_D 24 | 6fub_B,METGNKYIEKRAIDLSRERDPNFFDNPGIPVPECFWFMFKNNVRQDAGTCYSSWKMDMKVGPNWVHIKSDDNCNLSGDFPPGWIVLGKKRPGF,2018-06-13,6fub_B 25 | 6gfx_C,MEAGRPRPVLRSVNSREPSQVIFCNRSPRVVLPVWLNFDGEPQPYPTLPPGTGRRIHSYRGHLWLFRDAGTHDGLLVNQTELFVPSLNVDGQPIFANITLPVYTLKERCLQVVRSLVKPENYRRLDIVRSLYEDLEDHPNVQKDLERLTQERIAHQRMGD,2018-07-11,6gfx_C 26 | 6a02_A,MAATNMTDNVTLNNDKISGQAWQAMRDIGMSRFELFNGRTQKAEQLAAQAEKLLNDDSTDWNLYVKSDKKAPVEGDHYIRINSSITVAEDYLPAGQKNDAINKANQKMKEGDKKGTIEALKLAGVSVIENQELIPLQQTRKDVTTALSLMNEGKYYQAGLLLKSAQDGIVVDSQSVQLEHHHHHH,2019-02-06,6a02_A 27 | 6c0h_A,SGMKSAKEPTIYQDVDIIRRIQELMVLCSLLPPDGKLREALELALALHEEPALARITPLTNLHPFATKAWLETLWLGEGVSSEEKELVAWQNKSENMGPAIRELKNAEQQSGITLVARLTS,2018-09-05,6c0h_A 28 | 6dnm_A,SHMVVLGGDRDFWLQVGIDPIQIMTGTATFYTLRCYLDDRPIFLGRNGRISVFGSERALARYLADEHDHDLSDLSTYDDIRTAATDGSLAVAVTDDNVYVLSGLVDDFADGPDAVDREQLDLAVELLRDIGDYSEDSAVDKALETTRPLGQLVAYVLDPHSVGKPTAPYAAAVREWEKLERFVESRLRRE,2019-01-23,6dnm_A 29 | 6as3_A,HHHHHHMEKKLSDAQVALVAAWRKYPDLRESLEEAASILSLIVFQAETLSDQANELANYIRRQGLEEAEGACRNIDIMRAKWVEVCGEVNQYGIRVYGDAIDRDVD,2018-08-29,6as3_A 30 | 6bn0_A,QPYNPCKPQEVIDTKCMGPKDCLYPNPDSCTTYIQCVPLDEVGNAKPVVKPCPKGLQWNDNVGKKWCDYPNLSTCPVKT,2018-08-22,6bn0_A 31 | 5wfy_A,MILKILNEIASIGSTKQKQAILEKNKDNELLKRVYRLTYSRGLQYYIKKWPKPGIATQSFGMLTLTDMLDFIEFTLATRKLTGNAAIEELTGYITDGKKDDVEVLRRVMMRDLECGASVSIANKVWPGLEHHHHHH,2018-09-26,5wfy_A 32 | 6hem_A,GPDFSKHLKEETIQIITKASHEHEDKSPETVLQSAIKLEYARLVKLAQEDTPPETDYRLHHVVVYFIQNQAPKKIIEKTLLEQFGDRNLSFDERCHNIMKVAQAKLEMIKPEEVNLEEYEEWHQDYRKFRETTMYLIIGLENFQRESYIDSLLFLICAYQNNKELLSKGLYRGHDEELISHYRRECLLKLNEQAAELFESGEDREVNNGLIIMNEFIVPFLPLLLVDEMEEKDILAVEDMRNRWCSYLGQEMEPHLQEKLTDFLPKLLDCSMEIKSFHEPPKLPSYSTHELCERFARIMLSLS,2019-03-27,6hem_A 33 | 5z1n_A,GPLSGLKKLIPEEGRELIGSVKKIIKRVSNEEKANEMEKNILKILIKVFFYIDSKAIQIGDLAKVDRALRDGFNHLDRAFRYYGVKKAADLVVILEKASTALKEAEQETVTLLTPFFRPHNIQLIRNTFAFLGSLDFFTKVWDDLEIEDDLFLLISALNKYTQIELIY,2018-10-17,5z1n_A 34 | 5zlq_A,GPMGKIALQLKATLENITNLRPVGEDFRWYLKMKCGNCGEISDKWQYIRLMDSVALKGGRGSASMVQKCKLCARENSIEILSSTIKPYNAEDNENFKTIVEFECRGLEPVDFQPQAGFAAEGVESGTAFSDINLQEKDWTDYDEKAQESVGIYEVTHQFVKC,2018-10-10,5zlq_A 35 | 5naz_A,SVAHGFLITRHSQTTDAPQCPQGTLQVYEGFSLLYVQGNKRAHGQDLGTAGSCLRRFSTMPFMFCNINNVCNFASRNDYSYWLSTPEPMPMSMQPLKGQSIQPFISRCAVCEAPAVVIAVHSQTIQIPHCPQGWDSLWIGYSFMMHTSAGAEGSGQALASPGSCLEEFRSAPFIECHGRGTCNYYANSYSFWLATVDVSDMFSKPQSETLKAGDLRTRISRCQVCMKRT,2018-09-12,5naz_A 36 | 6e33_A,GKVKKRLPQAKRACAKCQKDNKKCDDARPCQRCIKAKTDCIDLPRKKRPTGVRRGPYKKLS,2018-10-24,6e33_A 37 | 5ok6_A,MDREPQHEELPGLDSQWRQIENGESGRERPLRAGESWFLVEKHWYKQWEAYVQGGDQDSSTFPGCINNATLFQDEINWRLKEGLVEGEDYVLLPAAAWHYLVSWYGLEHGQPPIERKVIELPNIQKVEVYPVELLLVRHNDLGKSHTVQFSHTDSIGLVLRTARERFLVEPQEDTRLWAKNSEGSLDRLYDTHITVLDAALETGQLIIMETRKKDGTWPSAQLEHHHHHH,2018-08-08,5ok6_A 38 | 6idx_A,GPGSMPPPSDIVKVAIEWPGANAQLLEIDQKRPLASIIKEVCDGWSLPNPEYYTLRYADGPQLYITEQTRSDIKNGTILQLAISPSRAARQLMERTQSSNMETRLDAMKELAKLSADVTFATEFINMDGIIVLTRLVESGTKLLSHYSEMLAFTLTAFLELMDHGIVSWDMVSITFIKQIAGYVSQPMVDVSILQRSLAILESMVLNSQSLYQKIAEEITVGQLISHLQVSNQEIQTYAIALINALFLKAPEDKRQDMANAFAQKHLRSIILNHVIRGNRPIKTEMAHQLYVLQVLTFNLLEERMMTKMDPNDQAQRDIIFELRRIAFDAESDPSNAPGSGTEKRKAMYTKDYKMLGFTNHINPAMDFTQTPPGMLALDNMLYLAKVHQDTYIRIVLENSSREDKHECPFGRSAIELTKMLCEILQVGELPNEGRNDYHPMFFTHDRAFEELFGICIQLLNKTWKEMRATAEDFNKVMQVVREQITRALPSKPNSLDQFKSKLRSLSYSEILRLRQSERMSQDD,2019-01-23,6idx_A 39 | 6e5y_A,SNAMTTILPNLPTGQKVGIAFSGGLDTSAALLWMRQKGAVPYAYTANLGQPDEPDYDEIPRRAMQYGAEAARLVDCRAQLVAEGIAALQAGAFHISTAGLTYFNTTPIGRAVTGTMLVAAMKEDGVNIWGDGSTFKGNDIERFYRYGLLTNPDLKIYKPWLDQTFIDELGGRAEMSEYMRQAGFDYKMSAEKAYSTDSNMLGATHEAKDLELLSAGIRIVQPIMGVAFWQDSVQIKAEEVTVRFEEGQPVALNGVEYADPVELLLEANRIGGRHGLGMSDQIENRIIEAKSRGIYEAPGLALLFIAYERLVTGIHNEDTIEQYRENGRKLGRLLYQGRWFDPQAIMLRETAQRWVARAITGEVTLELRRGNDYSLLNTESANLTYAPERLSMEKVENAPFTPADRIGQLTMRNLDIVDTREKLFTYVKTGLLAPSAGSALPQIKDGKK,2018-08-01,6e5y_A 40 | 6crk_G,MASNNTASIAQARKLVEQLKMEANIDRIKVSKAAADLMAYCEAHAKEDPLLTPVPASENPFREKKFFSAIL,2018-10-24,6crk_G 41 | -------------------------------------------------------------------------------- /splits/pdb_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "Q15758_1": [ 3 | "6gct_A", 4 | "6mp6_A", 5 | "6mpb_A", 6 | "6rvx_A", 7 | "7bcq_A", 8 | "7bcs_A", 9 | "7bct_A" 10 | ], 11 | "P41440_3": [ 12 | "7xpz_A", 13 | "7xq0_A", 14 | "7xq1_A", 15 | "7xq2_A", 16 | "8goe_A", 17 | "8gof_A", 18 | "8hii_A", 19 | "8hij_A", 20 | "8hik_A" 21 | ], 22 | "Q16602_1": [ 23 | "6e3y_R", 24 | "6uun_R", 25 | "6uus_R", 26 | "6uva_R", 27 | "7knt_R", 28 | "7knu_R" 29 | ], 30 | "Q01650_1": [ 31 | "6irs_B", 32 | "6irt_B", 33 | "6jmq_A", 34 | "7dsk_B", 35 | "7dsl_B", 36 | "7dsn_B", 37 | "7dsq_B" 38 | ], 39 | "Q9SVY5_1": [ 40 | "6j5t_B", 41 | "6j5u_B", 42 | "6j5w_B", 43 | "6j6i_B" 44 | ], 45 | "Q8H9R8_1": [ 46 | "7sxk_a", 47 | "7sya_a" 48 | ], 49 | "Q24560_1": [ 50 | "6tis_B", 51 | "6tiu_B", 52 | "6tiy_B", 53 | "6tiz_B", 54 | "7quc_B", 55 | "7qud_B", 56 | "7qup_10B" 57 | ], 58 | "Q8U338_1": [ 59 | "7tr6_C", 60 | "7tr8_C", 61 | "7tr9_C", 62 | "7tra_B" 63 | ], 64 | "O92284_1": [ 65 | "6lgw_E", 66 | "6lgx_A" 67 | ], 68 | "P38786_1": [ 69 | "6agb_I", 70 | "6ah3_I", 71 | "6w6v_I", 72 | "7c79_I", 73 | "7c7a_I" 74 | ], 75 | "O43172_1": [ 76 | "6ahd_K", 77 | "6qw6_4B", 78 | "6qx9_4B" 79 | ], 80 | "B1H241_1": [ 81 | "6nmg_A", 82 | "6nmj_A", 83 | "6tyl_A", 84 | "6ukt_A" 85 | ], 86 | "A0QUE0_1": [ 87 | "6yxu_H", 88 | "6yys_H", 89 | "6z11_H" 90 | ], 91 | "Q18BV5_1": [ 92 | "6t9m_AAA", 93 | "6tsb_AAA" 94 | ], 95 | "G1SLW8_1": [ 96 | "6w2s_8", 97 | "6w2t_8", 98 | "6yam_u" 99 | ], 100 | "P68040_1": [ 101 | "7cpu_Sg", 102 | "7cpv_Sg", 103 | "7ls1_I3", 104 | "7ls2_I3" 105 | ], 106 | "Q9BU64_1": [ 107 | "7pb8_O", 108 | "7pkn_O", 109 | "7r5s_O", 110 | "7xho_O" 111 | ], 112 | "Q02256_1": [ 113 | "6n8m_Y", 114 | "6n8n_Y", 115 | "6n8o_Y", 116 | "6rzz_s", 117 | "6s05_s" 118 | ], 119 | "Q1D6H4_1": [ 120 | "7s20_A", 121 | "7s21_A", 122 | "7s2t_A", 123 | "7s4q_A" 124 | ], 125 | "A0A7J7G2E3_1": [ 126 | "7v4i_A", 127 | "7v4j_A", 128 | "7v4k_A", 129 | "7v4l_A" 130 | ], 131 | "A0A498U580_1": [ 132 | "6xgp_A", 133 | "6xgq_A" 134 | ], 135 | "Q5FVI6_1": [ 136 | "6vq6_G", 137 | "6vq7_G", 138 | "6vq8_G" 139 | ], 140 | "P73953_1": [ 141 | "7cye_A", 142 | "7cyf_A", 143 | "7egk_A", 144 | "7egl_A" 145 | ], 146 | "P46976_2": [ 147 | "7q0b_E", 148 | "7q0s_E", 149 | "7q12_E", 150 | "7q13_E", 151 | "7zbn_E", 152 | "8cvx_E", 153 | "8cvy_E", 154 | "8cvz_E" 155 | ], 156 | "Q02939_1": [ 157 | "7k01_2", 158 | "7ml0_2", 159 | "7ml1_2", 160 | "7ml2_2", 161 | "7ml4_2", 162 | "7o4i_2", 163 | "7o4j_2", 164 | "7o4k_2", 165 | "7o4l_2", 166 | "7o72_2", 167 | "7o73_2", 168 | "7o75_2", 169 | "7zs9_2", 170 | "7zsa_2" 171 | ], 172 | "Q8DL32_1": [ 173 | "6hum_A", 174 | "6khi_A", 175 | "6khj_A", 176 | "6l7o_A", 177 | "6l7p_A", 178 | "6nbx_A", 179 | "6nby_A", 180 | "6tjv_A" 181 | ], 182 | "A8IVW2_1": [ 183 | "7n6g_3M", 184 | "7sqc_1F" 185 | ], 186 | "F2NWD3_1": [ 187 | "6wxw_A", 188 | "6wxx_A", 189 | "6wxy_B", 190 | "6xl1_A" 191 | ], 192 | "P03887_1": [ 193 | "7dgz_1", 194 | "7qsd_H", 195 | "7qsk_H", 196 | "7qsl_H", 197 | "7qsm_H", 198 | "7qsn_H", 199 | "7qso_H" 200 | ], 201 | "Q9VJY9_3": [ 202 | "7w0a_B", 203 | "7w0b_B", 204 | "7w0c_B" 205 | ], 206 | "P53044_1": [ 207 | "8dar_H", 208 | "8das_H", 209 | "8dat_H", 210 | "8dau_H", 211 | "8dav_H", 212 | "8daw_H" 213 | ], 214 | "P82251_1": [ 215 | "6li9_B", 216 | "6lid_B", 217 | "6yup_D", 218 | "6yv1_A" 219 | ], 220 | "Q9KN86_1": [ 221 | "6u2a_A", 222 | "6ue4_A" 223 | ], 224 | "P20429_1": [ 225 | "6wvj_A", 226 | "6wvk_A", 227 | "6zfb_U" 228 | ], 229 | "P13866_1": [ 230 | "7sl8_A", 231 | "7sla_A", 232 | "7wmv_A" 233 | ], 234 | "Q5Y388_3": [ 235 | "7fd2_B", 236 | "7wc2_B", 237 | "7wco_K" 238 | ], 239 | "P78371_1": [ 240 | "6qb8_B", 241 | "7nvl_B", 242 | "7nvm_B", 243 | "7nvn_B", 244 | "7nvo_B", 245 | "7trg_E", 246 | "7ttn_E", 247 | "7ttt_E", 248 | "7tub_E", 249 | "7wu7_B" 250 | ], 251 | "P16140_1": [ 252 | "7fde_B", 253 | "7tmm_B", 254 | "7tmo_B", 255 | "7tmp_B", 256 | "7tmq_B", 257 | "7tmr_B" 258 | ], 259 | "Q9H1D9_1": [ 260 | "7a6h_P", 261 | "7ae1_P", 262 | "7ae3_P", 263 | "7aea_P", 264 | "7ast_Z", 265 | "7d58_P", 266 | "7d59_P", 267 | "7dn3_P", 268 | "7du2_P", 269 | "7fji_P", 270 | "7fjj_P" 271 | ], 272 | "P41586_2": [ 273 | "6lpb_R", 274 | "6m1h_A", 275 | "6m1i_A", 276 | "6p9y_R", 277 | "8e3x_R" 278 | ], 279 | "Q83VS8_1": [ 280 | "7e8r_A", 281 | "7edb_A" 282 | ], 283 | "P20042_1": [ 284 | "6ybv_s", 285 | "6zmw_s", 286 | "6zp4_4", 287 | "7a09_4" 288 | ], 289 | "P03910_1": [ 290 | "7dgz_4", 291 | "7qsd_M", 292 | "7qsk_M", 293 | "7qsl_M", 294 | "7qsm_M", 295 | "7qsn_M", 296 | "7qso_M" 297 | ], 298 | "Q9F5Q9_1": [ 299 | "6p7r_A", 300 | "6p7t_A", 301 | "6pb9_A" 302 | ], 303 | "Q8DKY0_1": [ 304 | "6hum_D", 305 | "6khi_D", 306 | "6khj_D", 307 | "6l7o_D", 308 | "6l7p_D", 309 | "6nbq_D", 310 | "6nbx_D", 311 | "6nby_D" 312 | ], 313 | "P0CB42_1": [ 314 | "6ima_A", 315 | "6imc_A", 316 | "6ksf_A", 317 | "6l94_A" 318 | ], 319 | "A0A172JI16_1": [ 320 | "7s00_c", 321 | "7s01_c", 322 | "7um0_c" 323 | ], 324 | "Q7K740_1": [ 325 | "6mb3_E", 326 | "6mhg_E", 327 | "7v05_X" 328 | ], 329 | "F0NH89_1": [ 330 | "7pq2_AAA", 331 | "7pq3_AAA", 332 | "7pq6_AAA", 333 | "7pqa_AAA" 334 | ], 335 | "P47871_3": [ 336 | "6lmk_R", 337 | "6lml_R", 338 | "6whc_R", 339 | "6wpw_R", 340 | "7v35_R" 341 | ], 342 | "Q9UU79_1": [ 343 | "8esq_r", 344 | "8esr_r", 345 | "8etc_r", 346 | "8etg_r", 347 | "8eth_r", 348 | "8eti_r", 349 | "8etj_r", 350 | "8eup_r", 351 | "8euy_r", 352 | "8ev3_r" 353 | ], 354 | "G0S654_1": [ 355 | "6pdw_B", 356 | "6pdy_A", 357 | "6pe0_A" 358 | ], 359 | "Q92847_3": [ 360 | "6ko5_A", 361 | "7f9y_R", 362 | "7f9z_R", 363 | "7na7_R", 364 | "7na8_R", 365 | "7w2z_R" 366 | ], 367 | "P33993_1": [ 368 | "6xtx_7", 369 | "7pfo_7", 370 | "7plo_7" 371 | ], 372 | "Q4GZ99_1": [ 373 | "6hiw_CO", 374 | "6hiy_CO", 375 | "6sga_CO", 376 | "6sgb_CO", 377 | "7pua_CO", 378 | "7pub_CO" 379 | ], 380 | "A0A1D5AKC8_2": [ 381 | "7xxa_B", 382 | "7xxg_B", 383 | "7xxj_B" 384 | ], 385 | "P01871_3": [ 386 | "6kxs_A", 387 | "7k0c_A", 388 | "7qdo_A", 389 | "7xt6_B" 390 | ], 391 | "A0A3A8DDU9_1": [ 392 | "7ecw_A", 393 | "7elm_A", 394 | "7eln_A", 395 | "7eqg_B", 396 | "7we6_A" 397 | ], 398 | "Q8DJD9_1": [ 399 | "6hum_H", 400 | "6khi_H", 401 | "6khj_H", 402 | "6l7o_H", 403 | "6l7p_H", 404 | "6nbq_H", 405 | "6nbx_H", 406 | "6nby_H", 407 | "6tjv_H" 408 | ], 409 | "P53985_1": [ 410 | "6lyy_A", 411 | "6lz0_A", 412 | "7cko_A", 413 | "7ckr_A", 414 | "7da5_A", 415 | "7yr5_A" 416 | ], 417 | "P0A1J1_2": [ 418 | "6k3i_AA", 419 | "6k9q_A", 420 | "7cbm_DA", 421 | "7cgb_DL", 422 | "7cgo_DA", 423 | "7e80_DA", 424 | "7e82_DA" 425 | ], 426 | "Q09E27_1": [ 427 | "6ptq_A", 428 | "6ptx_A", 429 | "6pu2_A", 430 | "7jr5_A", 431 | "7jri_A" 432 | ], 433 | "O43614_3": [ 434 | "7l1u_R", 435 | "7l1v_R", 436 | "7sqo_R", 437 | "7sr8_R", 438 | "7xrr_A" 439 | ], 440 | "O18640_1": [ 441 | "6xu6_Ag", 442 | "6xu8_Ag" 443 | ], 444 | "Q4Q703_1": [ 445 | "7aih_BM", 446 | "7am2_BM", 447 | "7ane_BM" 448 | ], 449 | "P61221_1": [ 450 | "6zme_CI", 451 | "6zvj_1", 452 | "7a09_J" 453 | ], 454 | "P32893_1": [ 455 | "7e2d_K", 456 | "7e8t_K", 457 | "7u05_C" 458 | ], 459 | "P00960_1": [ 460 | "7eiv_A", 461 | "7qcf_A" 462 | ], 463 | "P15056_1": [ 464 | "6uan_B", 465 | "7mfd_A", 466 | "7mff_A", 467 | "7zr0_K", 468 | "7zr5_K" 469 | ], 470 | "P46957_1": [ 471 | "6p1h_B", 472 | "6v93_F", 473 | "7kc0_B", 474 | "7s0t_F" 475 | ], 476 | "Q9LVM1_1": [ 477 | "7n58_A", 478 | "7n59_A", 479 | "7n5a_A", 480 | "7n5b_A" 481 | ], 482 | "Q12874_1": [ 483 | "6ahd_w", 484 | "6qx9_A3", 485 | "7evo_C", 486 | "7onb_N", 487 | "7q3l_9", 488 | "7q4o_9", 489 | "7q4p_9" 490 | ], 491 | "G1SVB0_1": [ 492 | "7o7y_Ag", 493 | "7o7z_Ag", 494 | "7o81_Ag", 495 | "7oyd_HH", 496 | "7syn_I", 497 | "7syp_I", 498 | "7syq_I", 499 | "7syr_I", 500 | "7sys_I", 501 | "7syv_I", 502 | "7syw_I", 503 | "7syx_I", 504 | "7zjw_SS", 505 | "7zjx_SS" 506 | ], 507 | "O05000_1": [ 508 | "7a23_I", 509 | "7a24_I", 510 | "7aqq_N", 511 | "7ar7_N", 512 | "7ar8_N", 513 | "7arb_N" 514 | ], 515 | "P32325_2": [ 516 | "7p5z_G", 517 | "7pt6_9", 518 | "7pt7_9", 519 | "7v3v_I" 520 | ], 521 | "Q06631_1": [ 522 | "6lqp_RY", 523 | "6lqu_RY", 524 | "6zqb_JK", 525 | "6zqc_JK", 526 | "7suk_5" 527 | ], 528 | "B6VNP3_1": [ 529 | "6j0b_A", 530 | "6j0c_A", 531 | "6j0n_h" 532 | ], 533 | "O34693_1": [ 534 | "7aqc_R", 535 | "7aqd_R", 536 | "7as8_0", 537 | "7as9_0", 538 | "7asa_0", 539 | "7ope_0" 540 | ], 541 | "Q57YI7_1": [ 542 | "6hix_AR", 543 | "6yxx_AR", 544 | "6yxy_AR" 545 | ], 546 | "P41895_2": [ 547 | "7mei_Q", 548 | "7mk9_Q", 549 | "7ml0_Q", 550 | "7ml1_Q", 551 | "7ml2_Q", 552 | "7ml4_Q", 553 | "7o4i_Q", 554 | "7o4j_Q", 555 | "7o72_Q", 556 | "7o73_Q", 557 | "7o75_Q", 558 | "7zs9_Q", 559 | "7zsa_Q" 560 | ], 561 | "P37468_2": [ 562 | "6ifs_A", 563 | "6ift_A", 564 | "6ifw_A", 565 | "7v2l_W", 566 | "7v2m_U", 567 | "7v2n_U", 568 | "7v2o_U", 569 | "7v2p_U", 570 | "7v2q_U" 571 | ], 572 | "A0A1D5AKC8_3": [ 573 | "7xxa_A", 574 | "7xxg_A", 575 | "7xxj_A" 576 | ], 577 | "A1L314_2": [ 578 | "6sb3_A", 579 | "8a1d_A", 580 | "8a1s_A" 581 | ], 582 | "A0A8C7BTF2_1": [ 583 | "7f5r_A", 584 | "7w8s_A", 585 | "7wa1_A", 586 | "7wa3_A" 587 | ], 588 | "P47110_1": [ 589 | "6p1h_C", 590 | "6v93_G", 591 | "7kc0_C", 592 | "7s0t_G" 593 | ], 594 | "Q8DMR6_1": [ 595 | "6hum_B", 596 | "6khi_B", 597 | "6khj_B", 598 | "6l7o_B", 599 | "6l7p_B", 600 | "6nbq_B", 601 | "6nbx_B", 602 | "6nby_B", 603 | "6tjv_B" 604 | ], 605 | "A0QT50_1": [ 606 | "6hb0_A", 607 | "6hbd_A", 608 | "6hbm_A", 609 | "6hyh_A" 610 | ], 611 | "Q9JJ59_1": [ 612 | "7v5c_A", 613 | "7v5d_A", 614 | "7vfi_A" 615 | ], 616 | "P41091_1": [ 617 | "6o81_S", 618 | "6o85_S", 619 | "6ybv_t", 620 | "6zmw_t", 621 | "6zp4_Y", 622 | "7a09_Y", 623 | "7f66_S", 624 | "7f67_S", 625 | "7qp7_t" 626 | ], 627 | "P51787_3": [ 628 | "6uzz_A", 629 | "6v00_A", 630 | "6v01_A", 631 | "7xni_A", 632 | "7xnk_A", 633 | "7xnl_A", 634 | "7xnn_B" 635 | ], 636 | "Q9LT15_1": [ 637 | "6h7d_A", 638 | "7aaq_A", 639 | "7aar_A" 640 | ], 641 | "P32241_2": [ 642 | "6vn7_R", 643 | "8e3y_R", 644 | "8e3z_R" 645 | ], 646 | "P69801_1": [ 647 | "6k1h_B", 648 | "7dyr_B" 649 | ], 650 | "Q38854_1": [ 651 | "7bzx_A", 652 | "7bzy_11" 653 | ], 654 | "P17694_1": [ 655 | "7dgz_B", 656 | "7qsd_D", 657 | "7qsk_D", 658 | "7qsl_D", 659 | "7qsm_D", 660 | "7qsn_D", 661 | "7qso_D" 662 | ], 663 | "Q586R9_1": [ 664 | "6hix_AK", 665 | "6yxx_AK", 666 | "6yxy_AK" 667 | ], 668 | "P18708_3": [ 669 | "6ip2_A", 670 | "6mdo_A", 671 | "6mdp_A" 672 | ], 673 | "Q8N0V3_1": [ 674 | "7pnx_a", 675 | "7pny_a", 676 | "7pnz_a", 677 | "7po0_a", 678 | "8csr_6", 679 | "8css_6", 680 | "8cst_6", 681 | "8csu_6" 682 | ], 683 | "P29992_1": [ 684 | "6oij_A", 685 | "7rkf_A", 686 | "7try_A" 687 | ], 688 | "Q03431_2": [ 689 | "6fj3_A", 690 | "6nbf_R", 691 | "6nbh_R", 692 | "6nbi_R", 693 | "7vvj_R", 694 | "7vvk_R", 695 | "7vvl_R", 696 | "7vvm_R", 697 | "7vvn_R", 698 | "8ha0_R", 699 | "8haf_R", 700 | "8hao_I" 701 | ] 702 | } -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from alphaflow.utils.parsing import parse_train_args 2 | args = parse_train_args() 3 | 4 | from alphaflow.utils.logging import get_logger 5 | logger = get_logger(__name__) 6 | import torch, tqdm, os, wandb 7 | import pandas as pd 8 | 9 | from functools import partial 10 | import pytorch_lightning as pl 11 | from pytorch_lightning.callbacks import ModelCheckpoint 12 | from openfold.utils.exponential_moving_average import ExponentialMovingAverage 13 | from alphaflow.model.wrapper import ESMFoldWrapper, AlphaFoldWrapper 14 | from openfold.utils.import_weights import import_jax_weights_ 15 | 16 | torch.set_float32_matmul_precision("high") 17 | from alphaflow.config import model_config 18 | from alphaflow.data.data_modules import OpenFoldSingleDataset, OpenFoldBatchCollator, OpenFoldDataset 19 | from alphaflow.data.inference import CSVDataset, AlphaFoldCSVDataset 20 | 21 | config = model_config( 22 | 'initial_training', 23 | train=True, 24 | low_prec=True 25 | ) 26 | 27 | loss_cfg = config.loss 28 | data_cfg = config.data 29 | data_cfg.common.use_templates = False 30 | data_cfg.common.max_recycling_iters = 0 31 | 32 | def load_clusters(path): 33 | cluster_size = [] 34 | with open(args.pdb_clusters) as f: 35 | for line in f: 36 | names = line.split() 37 | for name in names: 38 | cluster_size.append({'name': name, 'cluster_size': len(names)}) 39 | return pd.DataFrame(cluster_size).set_index('name') 40 | 41 | def main(): 42 | 43 | if args.wandb: 44 | wandb.init( 45 | entity=os.environ["WANDB_ENTITY"], 46 | settings=wandb.Settings(start_method="fork"), 47 | project="alphaflow", 48 | name=args.run_name, 49 | config=args, 50 | ) 51 | 52 | logger.info("Loading the chains dataframe") 53 | pdb_chains = pd.read_csv(args.pdb_chains, index_col='name') 54 | 55 | if args.filter_chains: 56 | clusters = load_clusters(args.pdb_clusters) 57 | pdb_chains = pdb_chains.join(clusters) 58 | pdb_chains = pdb_chains[pdb_chains.release_date < args.train_cutoff] 59 | 60 | trainset = OpenFoldSingleDataset( 61 | data_dir = args.train_data_dir, 62 | alignment_dir = args.train_msa_dir, 63 | pdb_chains = pdb_chains, 64 | config = data_cfg, 65 | mode = 'train', 66 | subsample_pos = args.sample_train_confs, 67 | first_as_template = args.first_as_template, 68 | ) 69 | if args.normal_validate: 70 | val_pdb_chains = pd.read_csv(args.val_csv, index_col='name') 71 | valset = OpenFoldSingleDataset( 72 | data_dir = args.train_data_dir, 73 | alignment_dir = args.train_msa_dir, 74 | pdb_chains = val_pdb_chains, 75 | config = data_cfg, 76 | mode = 'train', 77 | subsample_pos = args.sample_val_confs, 78 | num_confs = args.num_val_confs, 79 | first_as_template = args.first_as_template, 80 | ) 81 | else: 82 | valset = AlphaFoldCSVDataset( 83 | data_cfg, 84 | args.val_csv, 85 | mmcif_dir=args.mmcif_dir, 86 | data_dir=args.train_data_dir, 87 | msa_dir=args.val_msa_dir, 88 | ) 89 | if args.filter_chains: 90 | trainset = OpenFoldDataset([trainset], [1.0], args.train_epoch_len) 91 | 92 | val_loader = torch.utils.data.DataLoader( 93 | valset, 94 | batch_size=args.batch_size, 95 | collate_fn=OpenFoldBatchCollator(), 96 | num_workers=args.num_workers, 97 | ) 98 | train_loader = torch.utils.data.DataLoader( 99 | trainset, 100 | batch_size=args.batch_size, 101 | collate_fn=OpenFoldBatchCollator(), 102 | num_workers=args.num_workers, 103 | shuffle=not args.filter_chains, 104 | ) 105 | 106 | 107 | trainer = pl.Trainer( 108 | accelerator="gpu", 109 | max_epochs=args.epochs, 110 | limit_train_batches=args.limit_batches or 1.0, 111 | limit_val_batches=args.limit_batches or 1.0, 112 | num_sanity_val_steps=0, 113 | enable_progress_bar=not args.wandb, 114 | gradient_clip_val=args.grad_clip, 115 | callbacks=[ModelCheckpoint( 116 | dirpath=os.environ["MODEL_DIR"], 117 | save_top_k=-1, 118 | every_n_epochs=args.ckpt_freq, 119 | )], 120 | accumulate_grad_batches=args.accumulate_grad, 121 | check_val_every_n_epoch=args.val_freq, 122 | logger=False, 123 | ) 124 | if args.mode == 'esmfold': 125 | model = ESMFoldWrapper(config, args) 126 | if args.ckpt is None: 127 | logger.info("Loading the model") 128 | path = "esmfold_3B_v1.pt" 129 | model_data = torch.load(path) 130 | model_state = model_data["model"] 131 | model.esmfold.load_state_dict(model_state, strict=False) 132 | logger.info("Model has been loaded") 133 | 134 | if not args.no_ema: 135 | model.ema = ExponentialMovingAverage( 136 | model=model.esmfold, decay=config.ema.decay 137 | ) # need to initialize EMA this way at the beginning 138 | elif args.mode == 'alphafold': 139 | model = AlphaFoldWrapper(config, args) 140 | if args.ckpt is None: 141 | logger.info("Loading the model") 142 | import_jax_weights_(model.esmfold, 'params_model_1.npz', version='model_3') 143 | if not args.no_ema: 144 | model.ema = ExponentialMovingAverage( 145 | model=model.model, decay=config.ema.decay 146 | ) # need to initialize EMA this way at the beginning 147 | 148 | if args.restore_weights_only: 149 | model.load_state_dict(torch.load(args.ckpt, map_location='cpu')['state_dict'], strict=False) 150 | args.ckpt = None 151 | if not args.no_ema: 152 | model.ema = ExponentialMovingAverage( 153 | model=model.model, decay=config.ema.decay 154 | ) # need to initialize EMA this way at the beginning 155 | 156 | 157 | if args.validate: 158 | trainer.validate(model, val_loader, ckpt_path=args.ckpt) 159 | else: 160 | trainer.fit(model, train_loader, val_loader, ckpt_path=args.ckpt) 161 | 162 | if __name__ == "__main__": 163 | main() --------------------------------------------------------------------------------