├── Dockerfile
├── LICENSE
├── README.md
├── alphaflow
├── config.py
├── data
│ ├── data_modules.py
│ ├── data_pipeline.py
│ ├── feature_pipeline.py
│ ├── inference.py
│ └── input_pipeline.py
├── model
│ ├── __init__.py
│ ├── alphafold.py
│ ├── esmfold.py
│ ├── input_stack.py
│ ├── layers.py
│ ├── tri_self_attn_block.py
│ ├── trunk.py
│ └── wrapper.py
└── utils
│ ├── diffusion.py
│ ├── logging.py
│ ├── loss.py
│ ├── misc.py
│ ├── parsing.py
│ ├── protein.py
│ └── tensor_utils.py
├── assets
├── 12l_md_templates.md
└── 6uof_A_animation.gif
├── predict.py
├── scripts
├── add_msa_info.py
├── analyze_ensembles.py
├── cluster_chains.py
├── download_atlas.sh
├── mmseqs_query.py
├── mmseqs_search.py
├── mmseqs_search_helper.py
├── prep_atlas.py
├── print_analysis.py
└── unpack_mmcif.py
├── splits
├── atlas.csv
├── atlas_test.csv
├── atlas_train.csv
├── atlas_val.csv
├── cameo2022.csv
├── pdb_test.csv
└── pdb_test.json
└── train.py
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Original Copyright 2021 DeepMind Technologies Limited
2 | # Modification Copyright 2022 # Copyright 2021 AlQuraishi Laboratory
3 | # Modifications Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 | # SPDX-License-Identifier: Apache-2.0
5 |
6 | # This container build on the OpenFold container, and installs AlphaFlow.
7 | # At the end, you may wish to download the weights to run ESMFlow, so they are cached in the image.
8 | # It is a large image, about 20GB without weights, 25GB with weights.
9 | #
10 | # OpenFold is quite difficult to get working, as it installs custom torch kernels, so it is used as the base.
11 | # Adapted from https://github.com/aws-solutions-library-samples/aws-batch-arch-for-protein-folding/blob/main/infrastructure/docker/openfold/Dockerfile
12 | #
13 | # To run most recent image (after building), with GPUS, and mounting a directory `outputs`
14 | # docker run --gpus all -v "$(pwd)/outputs:/outputs" -it "$(docker image ls -q | head -n1)" bash
15 | #
16 | # Note that you may need to install nvidia-container-toolkit to run.
17 | # https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/1.15.0/install-guide.html
18 | #
19 | # Test command to output into mounted directory:
20 | # python predict.py --mode esmfold --input_csv splits/atlas_test.csv --pdb 6o2v_A --weights params/esmflow_md_base_202402.pt --samples 5 --outpdb /outputs
21 |
22 | FROM nvcr.io/nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu18.04
23 |
24 | RUN apt-key del 7fa2af80
25 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
26 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
27 |
28 | RUN apt-get update \
29 | && DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
30 | wget \
31 | libxml2 \
32 | cuda-minimal-build-11-3 \
33 | libcusparse-dev-11-3 \
34 | libcublas-dev-11-3 \
35 | libcusolver-dev-11-3 \
36 | git \
37 | awscli \
38 | && rm -rf /var/lib/apt/lists/* \
39 | && apt-get autoremove -y \
40 | && apt-get clean
41 |
42 | RUN wget -q -P /tmp -O /tmp/miniconda.sh \
43 | "https://repo.anaconda.com/miniconda/Miniconda3-py39_23.5.2-0-Linux-$(uname -m).sh" \
44 | && bash /tmp/miniconda.sh -b -p /opt/conda \
45 | && rm /tmp/miniconda.sh
46 |
47 | ENV PATH /opt/conda/bin:$PATH
48 |
49 | RUN git clone https://github.com/aqlaboratory/openfold.git /opt/openfold \
50 | && cd /opt/openfold \
51 | && git checkout 1d878a1203e6d662a209a95f71b90083d5fc079c
52 |
53 | # installing into the base environment since the docker container wont do anything other than run openfold and alphaflow
54 | # RUN conda install -qy conda==4.13.0 \
55 | RUN conda env update -n base --file /opt/openfold/environment.yml \
56 | && conda clean --all --force-pkgs-dirs --yes
57 |
58 | RUN wget -q -P /opt/openfold/openfold/resources \
59 | https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt
60 |
61 | RUN patch -p0 -d /opt/conda/lib/python3.9/site-packages/ < /opt/openfold/lib/openmm.patch
62 |
63 | # Install OpenFold
64 | RUN cd /opt/openfold \
65 | && pip3 install --upgrade pip --no-cache-dir \
66 | && python3 setup.py install
67 |
68 | # Install alphaflow
69 | RUN git clone https://github.com/bjing2016/alphaflow.git /opt/alphaflow
70 |
71 | # Install alphaflow packages ~ as defined in README
72 | # torch CUDA version should match your machine
73 | RUN python -m pip install numpy==1.21.2 pandas==1.5.3 && \
74 | python -m pip install torch==1.13.1+cu117 -f https://download.pytorch.org/whl/torch_stable.html && \
75 | python -m pip install biopython==1.79 dm-tree==0.1.6 modelcif==0.7 ml-collections==0.1.0 scipy==1.7.3 absl-py einops && \
76 | python -m pip install pytorch_lightning==2.0.4 fair-esm mdtraj wandb
77 |
78 | WORKDIR /opt/alphaflow
79 |
80 | # Optionally, download weights as part of the image, so the cached image contains them and we don't re-download each time
81 | # ESMFlow and ESM2
82 | # RUN mkdir params && \
83 | # aws s3 cp s3://alphaflow/params/esmflow_md_base_202402.pt params/esmflow_pdb_md_202402.pt && \
84 | # mkdir -p /root/.cache/torch/hub/checkpoints && \
85 | # wget -q -O /root/.cache/torch/hub/checkpoints/esm2_t36_3B_UR50D.pt https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t36_3B_UR50D.pt && \
86 | # wget -q -O /root/.cache/torch/hub/checkpoints/esm2_t36_3B_UR50D-contact-regression.pt https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t36_3B_UR50D-contact-regression.pt
87 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Bowen Jing, Bonnie Berger, Tommi Jaakkola
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AlphaFlow
2 |
3 | AlphaFlow is a modified version of AlphaFold, fine-tuned with a flow matching objective, designed for generative modeling of protein conformational ensembles. In particular, AlphaFlow aims to model:
4 | * Experimental ensembles, i.e, potential conformational states as they would be deposited in the PDB
5 | * Molecular dynamics ensembles at physiological temperatures
6 |
7 | We also provide a similarly fine-tuned version of ESMFold called ESMFlow. Technical details and thorough benchmarking results can be found in our paper, [AlphaFold Meets Flow Matching for Generating Protein Ensembles](https://arxiv.org/abs/2402.04845), by Bowen Jing, Bonnie Berger, Tommi Jaakkola. This repository contains all code, instructions and model weights necessary to run the method. If you have any questions, feel free to open an issue or reach out at bjing@mit.edu.
8 |
9 | **June 2024 update:** We have trained a 12-layer version of AlphaFlow-MD+Templates (base and distilled) which runs 2.5x times faster than the 48-layer version at a [small loss in performance](assets/12l_md_templates.md). We recommend considering this model if reference structures (PDB or AlphaFold) are available and runtime is of high priority.
10 |
11 |
12 |
13 |
14 |
15 | ## Table of Contents
16 | 1. [Installation](#Installation)
17 | 2. [Model weights](#Model-weights)
18 | 3. [Running inference](#Running-inference)
19 | 4. [Evaluation scripts](#Evaluation-scripts)
20 | 5. [Training](#Training)
21 | 6. [Ensembles](#Ensembles)
22 | 7. [License](#License)
23 | 8. [Citation](#Citation)
24 |
25 |
26 | ## Installation
27 | In an environment with Python 3.9 (for example, `conda create -n alphaflow python=3.9`), run:
28 | ```
29 | pip install numpy==1.21.2 pandas==1.5.3
30 | pip install torch==1.12.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html
31 | pip install biopython==1.79 dm-tree==0.1.6 modelcif==0.7 ml-collections==0.1.0 scipy==1.7.1 absl-py einops
32 | pip install pytorch_lightning==2.0.4 fair-esm mdtraj==1.9.9 wandb
33 | pip install 'openfold @ git+https://github.com/aqlaboratory/openfold.git@103d037'
34 | ```
35 | The OpenFold installation requires CUDA 11. If the system has the wrong version, you can install CUDA 11 in the Conda environment:
36 | ```
37 | conda install nvidia/label/cuda-11.8.0::cuda
38 | conda install nvidia/label/cuda-11.8.0::cuda-cudart-dev
39 | conda install nvidia/label/cuda-11.8.0::libcusparse-dev
40 | conda install nvidia/label/cuda-11.8.0::libcusolver-dev
41 | conda install nvidia/label/cuda-11.8.0::libcublas-dev
42 | ln -s $CONDA_PREFIX/lib/libcudart_static.a $CONDA_PREFIX/lib/libcudart.a
43 | ```
44 | Then install OpenFold:
45 | ```
46 | CUDA_HOME=$CONDA_PREFIX pip install 'openfold @ git+https://github.com/aqlaboratory/openfold.git@103d037'
47 | ```
48 |
49 | ## Model weights
50 |
51 | We provide several versions of AlphaFlow (and similarly named versions of ESMFlow).
52 |
53 | * **AlphaFlow-PDB**—trained on PDB structures to model experimental ensembles from X-ray crystallography or cryo-EM under different conditions
54 | * **AlphaFlow-MD**—trained on all-atom, explicit solvent MD trajectories at 300K
55 | * **AlphaFlow-MD+Templates**—trained to additionally take a PDB structure as input, and models the corresponding MD ensemble at 300K
56 |
57 | For all models, the **distilled** version runs significantly faster at the cost of some loss of accuracy (benchmarked in the paper).
58 |
59 | For AlphaFlow-MD+Templates, the **12l** versions have 12 instead of 48 Evoformer layers and run 2.5x times faster at a [small loss in performance](assets/12l_md_templates.md).
60 |
61 | ### AlphaFlow models
62 | | Model|Version|Weights|
63 | |:---|:--|:--|
64 | | AlphaFlow-PDB | base | https://storage.googleapis.com/alphaflow/params/alphaflow_pdb_base_202402.pt |
65 | | AlphaFlow-PDB | distilled | https://storage.googleapis.com/alphaflow/params/alphaflow_pdb_distilled_202402.pt |
66 | | AlphaFlow-MD | base | https://storage.googleapis.com/alphaflow/params/alphaflow_md_base_202402.pt |
67 | | AlphaFlow-MD | distilled | https://storage.googleapis.com/alphaflow/params/alphaflow_md_distilled_202402.pt |
68 | | AlphaFlow-MD+Templates | base | https://storage.googleapis.com/alphaflow/params/alphaflow_md_templates_base_202402.pt |
69 | | AlphaFlow-MD+Templates | distilled | https://storage.googleapis.com/alphaflow/params/alphaflow_md_templates_distilled_202402.pt |
70 | | AlphaFlow-MD+Templates | 12l-base | https://storage.googleapis.com/alphaflow/params/alphaflow_12l_md_templates_base_202406.pt |
71 | | AlphaFlow-MD+Templates | 12l-distilled | https://storage.googleapis.com/alphaflow/params/alphaflow_12l_md_templates_distilled_202406.pt |
72 |
73 |
74 | ### ESMFlow models
75 | | Model|Version|Weights|
76 | |:---|:--|:--|
77 | | ESMFlow-PDB | base | https://storage.googleapis.com/alphaflow/params/esmflow_pdb_base_202402.pt |
78 | | ESMFlow-PDB | distilled | https://storage.googleapis.com/alphaflow/params/esmflow_pdb_distilled_202402.pt |
79 | | ESMFlow-MD | base | https://storage.googleapis.com/alphaflow/params/esmflow_md_base_202402.pt |
80 | | ESMFlow-MD | distilled | https://storage.googleapis.com/alphaflow/params/esmflow_md_distilled_202402.pt |
81 | | ESMFlow-MD+Templates | base | https://storage.googleapis.com/alphaflow/params/esmflow_md_templates_base_202402.pt |
82 | | ESMFlow-MD+Templates | distilled | https://storage.googleapis.com/alphaflow/params/esmflow_md_templates_distilled_202402.pt |
83 |
84 | Training checkpoints (from which fine-tuning can be resumed) are available upon request; please reach out if you'd like to collaborate!
85 |
86 | ## Running inference
87 |
88 | ### Preparing input files
89 |
90 | 1. Prepare a input CSV with an `name` and `seqres` entry for each row. See `splits/atlas_test.csv` for examples.
91 | 2. If running an **AlphaFlow** model, prepare an **MSA directory** and place the alignments in `.a3m` format at the following paths: `{alignment_dir}/{name}/a3m/{name}.a3m`. If you don't have the MSAs, there are two ways to generate them:
92 | 1. Query the ColabFold server with `python -m scripts.mmseqs_query --split [PATH] --outdir [DIR]`.
93 | 2. Download UniRef30 and ColabDB according to https://github.com/sokrypton/ColabFold/blob/main/setup_databases.sh and run `python -m scripts.mmseqs_search_helper --split [PATH] --db_dir [DIR] --outdir [DIR]`.
94 | 3. If running an **MD+Templates** model, place the template PDB files into a templates directory with filenames matching the names in the input CSV. The PDB files should include only a single chain with no residue gaps.
95 |
96 | ### Running the model
97 |
98 | The basic command for running inference with **AlphaFlow** is:
99 | ```
100 | python predict.py --mode alphafold --input_csv [PATH] --msa_dir [DIR] --weights [PATH] --samples [N] --outpdb [DIR]
101 | ```
102 | If running the **PDB model**, we recommend appending `--self_cond --resample` for improved performance.
103 |
104 | The basic command for running inference with **ESMFlow** is
105 | ```
106 | python predict.py --mode esmfold --input_csv [PATH] --weights [PATH] --samples [N] --outpdb [DIR]
107 | ```
108 | Additional command line arguments for either model:
109 | * Use the `--pdb_id` argument to select (one or more) rows in the CSV. If no argument is specified, inference is run on all rows.
110 | * If running the **MD model with templates**, append `--templates_dir [DIR]`.
111 | * If running any **distilled** model, append the arguments `--noisy_first --no_diffusion`.
112 | * To truncate the inference process for increased precision and reduced diversity, append (for example) `--tmax 0.2 --steps 2`. The default inference settings correspond to `--tmax 1.0 --steps 10`. See Appendix B.1 in the paper for more details.
113 |
114 | ## Evaluation scripts
115 |
116 | Our ensemble evaluations may be reproduced via the following steps:
117 | 1. Download the ATLAS dataset by runnig from `bash scripts/download_atlas.sh` from the desired root directory
118 | 2. Prepare the ensemble directory with a PDB file for each ATLAS target, each with 250 structures (see zipped AlphaFlow ensembles below for examples). **Some results are not directly comparable for evaluations with a different number of structures.**
119 | 3. Run `python -m scripts.analyze_ensembles --atlas_dir [DIR] --pdb_dir [DIR] --num_workers [N]`. This will produce an analysis file named `out.pkl` in the `pdb_dir`.
120 | 4. Run `python -m scripts.print_analysis [PATH] [PATH] ...` with an arbitrary number of paths to `out.pkl` files. A formatted comparison table will be printed.
121 |
122 | ## Training
123 |
124 | ### Downloading datasets
125 |
126 | To download and preprocess the PDB,
127 | 1. Run `aws s3 sync --no-sign-request s3://pdbsnapshots/20230102/pub/pdb/data/structures/divided/mmCIF pdb_mmcif` from the desired directory.
128 | 2. Run `find pdb_mmcif -name '*.gz' | xargs gunzip` to extract the MMCIF files.
129 | 3. From the repository root, run `python -m scripts.unpack_mmcif --mmcif_dir [DIR] --outdir [DIR] --num_workers [N]`. This will preprocess all chains into NPZ files and create a `pdb_mmcif.csv` index.
130 | 4. Download OpenProteinSet with `aws s3 sync --no-sign-request s3://openfold/ openfold` from the desired directory.
131 | 5. Run `python -m scripts.add_msa_info --openfold_dir [DIR]` to produce a `pdb_mmcif_msa.csv` index with OpenProteinSet MSA lookup.
132 | 6. Run `python -m scripts.cluster_chains` to produce a `pdb_clusters` file at 40% sequence similarity (Mmseqs installation required).
133 | 7. Create MSAs for the PDB validation split (`splits/cameo2022.csv`) according to the instructions in the previous section.
134 |
135 | To download and preprocess the ATLAS MD trajectory dataset,
136 | 1. Run `bash scripts/download_atlas.sh` from the desired directory.
137 | 2. From the repository root, run `python -m scripts.prep_atlas --atlas_dir [DIR] --outdir [DIR] --num_workers [N]`. This will preprocess the ATLAS trajectories into NPZ files.
138 | 3. Create MSAs for all entries in `splits/atlas.csv` according to the instructions in the previous section.
139 |
140 | ### Running training
141 |
142 | Before running training, download the pretrained AlphaFold and ESMFold weights into the repository root via
143 | ```
144 | wget https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar
145 | tar -xvf alphafold_params_2022-12-06.tar params_model_1.npz
146 | wget https://dl.fbaipublicfiles.com/fair-esm/models/esmfold_3B_v1.pt
147 | ```
148 |
149 | The basic command for training AlphaFlow is
150 | ```
151 | python train.py --lr 5e-4 --noise_prob 0.8 --accumulate_grad 8 --train_epoch_len 80000 --train_cutoff 2018-05-01 --filter_chains \
152 | --train_data_dir [DIR] \
153 | --train_msa_dir [DIR] \
154 | --mmcif_dir [DIR] \
155 | --val_msa_dir [DIR] \
156 | --run_name [NAME] [--wandb]
157 | ```
158 | where the PDB NPZ directory, the OpenProteinSet directory, the PDB mmCIF directory, and the validation MSA directory are specified. This training run produces the AlphaFlow-PDB base version. All other models are built off this checkpoint.
159 |
160 | To continue training on ATLAS, run
161 | ```
162 | python train.py --normal_validate --sample_train_confs --sample_val_confs --num_val_confs 100 --pdb_chains splits/atlas_train.csv --val_csv splits/atlas_val.csv --self_cond_prob 0.0 --noise_prob 0.9 --val_freq 10 --ckpt_freq 10 \
163 | --train_data_dir [DIR] \
164 | --train_msa_dir [DIR] \
165 | --ckpt [PATH] \
166 | --run_name [NAME] [--wandb]
167 | ```
168 | where the ATLAS MSA and NPZ directories and AlphaFlow-PDB checkpoints are specified.
169 |
170 | To instead train on ATLAS with templates, run with the additional arguments `--first_as_template --extra_input --lr 1e-4 --restore_weights_only --extra_input_prob 1.0`.
171 |
172 | **Distillation**: to distill a model, append `--distillation` and supply the `--ckpt [PATH]` of the model to be distilled. For PDB training, we remove `--accumulate_grad 8` and recommend distilling with a shorter `--train_epoch_len 16000`. Note that `--self_cond_prob` and `--noise_prob` will be ignored and can be omitted.
173 |
174 | **ESMFlow**: run the same commands with `--mode esmfold` and `--train_cutoff 2020-05-01`.
175 |
176 | ## Ensembles
177 |
178 | We provide the ensembles sampled from the model which were used for the analyses and results reported in the paper.
179 |
180 | ### AlphaFlow ensembles
181 | | Model|Version|Samples|
182 | |:---|:--|:--|
183 | | AlphaFlow-PDB | base | https://storage.googleapis.com/alphaflow/samples/alphaflow_pdb_base_202402.zip |
184 | | AlphaFlow-PDB | distilled | https://storage.googleapis.com/alphaflow/samples/alphaflow_pdb_distilled_202402.zip |
185 | | AlphaFlow-MD | base | https://storage.googleapis.com/alphaflow/samples/alphaflow_md_base_202402.zip |
186 | | AlphaFlow-MD | distilled | https://storage.googleapis.com/alphaflow/samples/alphaflow_md_distilled_202402.zip |
187 | | AlphaFlow-MD+Templates | base | https://storage.googleapis.com/alphaflow/samples/alphaflow_md_templates_base_202402.zip |
188 | | AlphaFlow-MD+Templates | distilled | https://storage.googleapis.com/alphaflow/samples/alphaflow_md_templates_distilled_202402.zip |
189 | | AlphaFlow-MD+Templates | 12l-base | https://storage.googleapis.com/alphaflow/samples/alphaflow_12l_md_templates_base_202406.zip |
190 | | AlphaFlow-MD+Templates | 12l-distilled | https://storage.googleapis.com/alphaflow/samples/alphaflow_12l_md_templates_distilled_202406.zip |
191 |
192 |
193 | ### ESMFlow ensembles
194 | | Model|Version|Samples|
195 | |:---|:--|:--|
196 | | ESMFlow-PDB | base | https://storage.googleapis.com/alphaflow/samples/esmflow_pdb_base_202402.zip |
197 | | ESMFlow-PDB | distilled | https://storage.googleapis.com/alphaflow/samples/esmflow_pdb_distilled_202402.zip |
198 | | ESMFlow-MD | base | https://storage.googleapis.com/alphaflow/samples/esmflow_md_base_202402.zip |
199 | | ESMFlow-MD | distilled | https://storage.googleapis.com/alphaflow/samples/esmflow_md_distilled_202402.zip |
200 | | ESMFlow-MD+Templates | base | https://storage.googleapis.com/alphaflow/samples/esmflow_md_templates_base_202402.zip |
201 | | ESMFlow-MD+Templates | distilled | https://storage.googleapis.com/alphaflow/samples/esmflow_md_templates_distilled_202402.zip |
202 |
203 | ## License
204 | MIT. Other licenses may apply to third-party source code noted in file headers.
205 |
206 | ## Citation
207 | ```
208 | @inproceedings{jing2024alphafold,
209 | title={AlphaFold Meets Flow Matching for Generating Protein Ensembles},
210 | author={Jing, Bowen and Berger, Bonnie and Jaakkola, Tommi},
211 | year={2024},
212 | booktitle={Forty-first International Conference on Machine Learning}
213 | }
214 | ```
215 |
--------------------------------------------------------------------------------
/alphaflow/data/feature_pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 AlQuraishi Laboratory
2 | # Copyright 2021 DeepMind Technologies Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import copy
17 | from typing import Mapping, Tuple, List, Dict, Sequence
18 |
19 | import ml_collections
20 | import numpy as np
21 | import torch
22 |
23 | from . import input_pipeline
24 |
25 |
26 | FeatureDict = Mapping[str, np.ndarray]
27 | TensorDict = Dict[str, torch.Tensor]
28 |
29 |
30 | def np_to_tensor_dict(
31 | np_example: Mapping[str, np.ndarray],
32 | features: Sequence[str],
33 | ) -> TensorDict:
34 | """Creates dict of tensors from a dict of NumPy arrays.
35 |
36 | Args:
37 | np_example: A dict of NumPy feature arrays.
38 | features: A list of strings of feature names to be returned in the dataset.
39 |
40 | Returns:
41 | A dictionary of features mapping feature names to features. Only the given
42 | features are returned, all other ones are filtered out.
43 | """
44 | tensor_dict = {
45 | k: torch.tensor(v) for k, v in np_example.items() if k in features
46 | }
47 |
48 | return tensor_dict
49 |
50 |
51 | def make_data_config(
52 | config: ml_collections.ConfigDict,
53 | mode: str,
54 | num_res: int,
55 | ) -> Tuple[ml_collections.ConfigDict, List[str]]:
56 | cfg = copy.deepcopy(config)
57 | mode_cfg = cfg[mode]
58 | with cfg.unlocked():
59 | if mode_cfg.crop_size is None:
60 | mode_cfg.crop_size = num_res
61 |
62 | feature_names = cfg.common.unsupervised_features
63 |
64 | if cfg.common.use_templates:
65 | feature_names += cfg.common.template_features
66 |
67 | if cfg[mode].supervised:
68 | feature_names += cfg.supervised.supervised_features
69 |
70 | return cfg, feature_names
71 |
72 |
73 | def np_example_to_features(
74 | np_example: FeatureDict,
75 | config: ml_collections.ConfigDict,
76 | mode: str,
77 | ):
78 | np_example = dict(np_example)
79 | num_res = int(np_example["seq_length"][0])
80 | cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
81 |
82 | if "deletion_matrix_int" in np_example:
83 | np_example["deletion_matrix"] = np_example.pop(
84 | "deletion_matrix_int"
85 | ).astype(np.float32)
86 |
87 | tensor_dict = np_to_tensor_dict(
88 | np_example=np_example, features=feature_names
89 | )
90 | with torch.no_grad():
91 | features = input_pipeline.process_tensors_from_config(
92 | tensor_dict,
93 | cfg.common,
94 | cfg[mode],
95 | )
96 |
97 | if mode == "train":
98 | p = torch.rand(1).item()
99 | use_clamped_fape_value = float(p < cfg.supervised.clamp_prob)
100 | features["use_clamped_fape"] = torch.full(
101 | size=[cfg.common.max_recycling_iters + 1],
102 | fill_value=use_clamped_fape_value,
103 | dtype=torch.float32,
104 | )
105 | else:
106 | features["use_clamped_fape"] = torch.full(
107 | size=[cfg.common.max_recycling_iters + 1],
108 | fill_value=0.0,
109 | dtype=torch.float32,
110 | )
111 |
112 | return {k: v for k, v in features.items()}
113 |
114 |
115 | class FeaturePipeline:
116 | def __init__(
117 | self,
118 | config: ml_collections.ConfigDict,
119 | ):
120 | self.config = config
121 |
122 | def process_features(
123 | self,
124 | raw_features: FeatureDict,
125 | mode: str = "train",
126 | ) -> FeatureDict:
127 | return np_example_to_features(
128 | np_example=raw_features,
129 | config=self.config,
130 | mode=mode,
131 | )
132 |
--------------------------------------------------------------------------------
/alphaflow/data/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 | import numpy as np
4 | from openfold.np import residue_constants
5 | from .data_pipeline import DataPipeline
6 | from .feature_pipeline import FeaturePipeline
7 | from openfold.data.data_transforms import make_atom14_masks
8 | import alphaflow.utils.protein as protein
9 |
10 | def seq_to_tensor(seq):
11 | unk_idx = residue_constants.restype_order_with_x["X"]
12 | encoded = torch.tensor(
13 | [residue_constants.restype_order_with_x.get(aa, unk_idx) for aa in seq]
14 | )
15 | return encoded
16 |
17 | class AlphaFoldCSVDataset:
18 | def __init__(self, config, path, mmcif_dir=None, msa_dir=None, templates_dir=None):
19 | super().__init__()
20 | self.pdb_chains = pd.read_csv(path, index_col='name')
21 | self.msa_dir = msa_dir
22 | self.mmcif_dir = mmcif_dir
23 | self.data_pipeline = DataPipeline(template_featurizer=None)
24 | self.feature_pipeline = FeaturePipeline(config)
25 | self.templates_dir = templates_dir
26 |
27 | def __len__(self):
28 | return len(self.pdb_chains)
29 |
30 | def __getitem__(self, idx):
31 |
32 | item = self.pdb_chains.iloc[idx]
33 |
34 | mmcif_feats = self.data_pipeline.process_str(item.seqres, item.name)
35 | if self.templates_dir:
36 | path = f"{self.templates_dir}/{item.name}.pdb"
37 | with open(path) as f:
38 | prot = protein.from_pdb_string(f.read())
39 | extra_all_atom_positions = prot.atom_positions.astype(np.float32)
40 |
41 |
42 | try: msa_id = item.msa_id
43 | except: msa_id = item.name
44 | msa_features = self.data_pipeline._process_msa_feats(f'{self.msa_dir}/{msa_id}', item.seqres, alignment_index=None)
45 | data = {**mmcif_feats, **msa_features}
46 |
47 | feats = self.feature_pipeline.process_features(data, mode='predict')
48 | if self.templates_dir:
49 | feats['extra_all_atom_positions'] = torch.from_numpy(extra_all_atom_positions)
50 | feats['pseudo_beta_mask'] = torch.ones(len(item.seqres))
51 | feats['name'] = item.name
52 | feats['seqres'] = item.seqres
53 | make_atom14_masks(feats)
54 |
55 | if self.mmcif_dir is not None:
56 | pdb_id, chain = item.name.split('_')
57 | with open(f"{self.mmcif_dir}/{pdb_id[1:3]}/{pdb_id}.cif") as f:
58 | feats['ref_prot'] = protein.from_mmcif_string(f.read(), chain, name=item.name)
59 |
60 | return feats
61 |
62 | class CSVDataset:
63 | def __init__(self, config, path, mmcif_dir=None, msa_dir=None, templates_dir=None):
64 | super().__init__()
65 | self.pdb_chains = pd.read_csv(path, index_col='name')
66 | self.templates_dir = templates_dir
67 |
68 | def __len__(self):
69 | return self.pdb_chains.shape[0]
70 |
71 | def __getitem__(self, idx):
72 | row = self.pdb_chains.iloc[idx]
73 | batch = {
74 | 'name': row.name,
75 | 'seqres': row.seqres,
76 | 'aatype': seq_to_tensor(row.seqres),
77 | 'residue_index': torch.arange(len(row.seqres)),
78 | 'pseudo_beta_mask': torch.ones(len(row.seqres)),
79 | 'seq_mask': torch.ones(len(row.seqres))
80 | }
81 | make_atom14_masks(batch)
82 |
83 | if self.templates_dir:
84 | path = f"{self.templates_dir}/{row.name}.pdb"
85 | with open(path) as f:
86 | prot = protein.from_pdb_string(f.read())
87 | extra_all_atom_positions = prot.atom_positions.astype(np.float32)
88 | batch['extra_all_atom_positions'] = torch.from_numpy(extra_all_atom_positions)
89 |
90 | return batch
91 |
--------------------------------------------------------------------------------
/alphaflow/data/input_pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 AlQuraishi Laboratory
2 | # Copyright 2021 DeepMind Technologies Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 | import itertools
18 |
19 | from openfold.data import data_transforms
20 | NUM_RES = "num residues placeholder"
21 | NUM_MSA_SEQ = "msa placeholder"
22 | NUM_EXTRA_SEQ = "extra msa placeholder"
23 | NUM_TEMPLATES = "num templates placeholder"
24 |
25 | @data_transforms.curry1
26 | def random_crop_to_size(
27 | protein,
28 | crop_size,
29 | max_templates,
30 | shape_schema,
31 | subsample_templates=False,
32 | seed=None,
33 | ):
34 | """Crop randomly to `crop_size`, or keep as is if shorter than that."""
35 | # We want each ensemble to be cropped the same way
36 |
37 | g = torch.Generator(device=protein["seq_length"].device)
38 | if seed is not None:
39 | g.manual_seed(seed)
40 |
41 | seq_length = protein["seq_length"]
42 |
43 | if "template_mask" in protein:
44 | num_templates = protein["template_mask"].shape[-1]
45 | else:
46 | num_templates = 0
47 |
48 | # No need to subsample templates if there aren't any
49 | subsample_templates = subsample_templates and num_templates
50 |
51 | num_res_crop_size = min(int(seq_length), crop_size)
52 |
53 | def _randint(lower, upper):
54 | return int(torch.randint(
55 | lower,
56 | upper + 1,
57 | (1,),
58 | device=protein["seq_length"].device,
59 | generator=g,
60 | )[0])
61 |
62 | if subsample_templates:
63 | templates_crop_start = _randint(0, num_templates)
64 | templates_select_indices = torch.randperm(
65 | num_templates, device=protein["seq_length"].device, generator=g
66 | )
67 | else:
68 | templates_crop_start = 0
69 |
70 | num_templates_crop_size = min(
71 | num_templates - templates_crop_start, max_templates
72 | )
73 |
74 | n = seq_length - num_res_crop_size
75 | if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.:
76 | right_anchor = n
77 | else:
78 | x = _randint(0, n)
79 | right_anchor = n - x
80 | num_res_crop_start = _randint(0, right_anchor)
81 |
82 | for k, v in protein.items():
83 | if k not in shape_schema or (
84 | "template" not in k and NUM_RES not in shape_schema[k]
85 | ):
86 | continue
87 |
88 | # randomly permute the templates before cropping them.
89 | if k.startswith("template") and subsample_templates:
90 | v = v[templates_select_indices]
91 |
92 | slices = []
93 | for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
94 | is_num_res = dim_size == NUM_RES
95 | if i == 0 and k.startswith("template"):
96 | crop_size = num_templates_crop_size
97 | crop_start = templates_crop_start
98 | else:
99 | crop_start = num_res_crop_start if is_num_res else 0
100 | crop_size = num_res_crop_size if is_num_res else dim
101 | slices.append(slice(crop_start, crop_start + crop_size))
102 | protein[k] = v[slices] #### MODIFIED
103 |
104 | protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
105 |
106 | return protein
107 |
108 |
109 |
110 | @data_transforms.curry1
111 | def make_fixed_size(
112 | protein,
113 | shape_schema,
114 | num_res=0,
115 | num_templates=0,
116 | ):
117 |
118 | """Guess at the MSA and sequence dimension to make fixed size."""
119 | pad_size_map = {
120 | NUM_RES: num_res,
121 | NUM_TEMPLATES: num_templates,
122 | }
123 |
124 | for k, v in protein.items():
125 | # Don't transfer this to the accelerator.
126 | if k == "extra_cluster_assignment":
127 | continue
128 | shape = list(v.shape)
129 | schema = shape_schema[k]
130 | msg = "Rank mismatch between shape and shape schema for"
131 | assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
132 | pad_size = [
133 | pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
134 | ]
135 |
136 | padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
137 | padding.reverse()
138 | padding = list(itertools.chain(*padding))
139 | if padding:
140 | protein[k] = torch.nn.functional.pad(v, padding)
141 | protein[k] = torch.reshape(protein[k], pad_size)
142 |
143 | return protein
144 |
145 |
146 | def nonensembled_transform_fns(common_cfg, mode_cfg):
147 | """Input pipeline data transformers that are not ensembled."""
148 | transforms = [
149 | data_transforms.cast_to_64bit_ints,
150 | data_transforms.correct_msa_restypes,
151 | data_transforms.squeeze_features,
152 | data_transforms.randomly_replace_msa_with_unknown(0.0),
153 | data_transforms.make_seq_mask,
154 | data_transforms.make_msa_mask,
155 | data_transforms.make_hhblits_profile,
156 | ]
157 | if common_cfg.use_templates:
158 | transforms.extend(
159 | [
160 | data_transforms.fix_templates_aatype,
161 | data_transforms.make_template_mask,
162 | data_transforms.make_pseudo_beta("template_"),
163 | ]
164 | )
165 | if common_cfg.use_template_torsion_angles:
166 | transforms.extend(
167 | [
168 | data_transforms.atom37_to_torsion_angles("template_"),
169 | ]
170 | )
171 |
172 | transforms.extend(
173 | [
174 | data_transforms.make_atom14_masks,
175 | ]
176 | )
177 |
178 | if mode_cfg.supervised:
179 | transforms.extend(
180 | [
181 | data_transforms.make_atom14_positions,
182 | data_transforms.atom37_to_frames,
183 | data_transforms.atom37_to_torsion_angles(""),
184 | data_transforms.make_pseudo_beta(""),
185 | data_transforms.get_backbone_frames,
186 | data_transforms.get_chi_angles,
187 | ]
188 | )
189 |
190 | ###############
191 |
192 | if "max_distillation_msa_clusters" in mode_cfg:
193 | transforms.append(
194 | data_transforms.sample_msa_distillation(
195 | mode_cfg.max_distillation_msa_clusters
196 | )
197 | )
198 |
199 | if common_cfg.reduce_msa_clusters_by_max_templates:
200 | pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
201 | else:
202 | pad_msa_clusters = mode_cfg.max_msa_clusters
203 |
204 | max_msa_clusters = pad_msa_clusters
205 | max_extra_msa = mode_cfg.max_extra_msa
206 |
207 | msa_seed = None
208 | if(not common_cfg.resample_msa_in_recycling):
209 | msa_seed = ensemble_seed
210 |
211 | transforms.append(
212 | data_transforms.sample_msa(
213 | max_msa_clusters,
214 | keep_extra=True,
215 | seed=msa_seed,
216 | )
217 | )
218 |
219 | if "masked_msa" in common_cfg:
220 | # Masked MSA should come *before* MSA clustering so that
221 | # the clustering and full MSA profile do not leak information about
222 | # the masked locations and secret corrupted locations.
223 | transforms.append(
224 | data_transforms.make_masked_msa(
225 | common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction
226 | )
227 | )
228 |
229 | if common_cfg.msa_cluster_features:
230 | transforms.append(data_transforms.nearest_neighbor_clusters())
231 | transforms.append(data_transforms.summarize_clusters())
232 |
233 | # Crop after creating the cluster profiles.
234 | if max_extra_msa:
235 | transforms.append(data_transforms.crop_extra_msa(max_extra_msa))
236 | else:
237 | transforms.append(data_transforms.delete_extra_msa)
238 |
239 | transforms.append(data_transforms.make_msa_feat())
240 |
241 | ##################
242 | crop_feats = dict(common_cfg.feat)
243 | # return transforms
244 |
245 | if mode_cfg.fixed_size:
246 | transforms.append(data_transforms.select_feat(list(crop_feats)))
247 |
248 | transforms.append(
249 | data_transforms.random_crop_to_size(
250 | mode_cfg.crop_size,
251 | mode_cfg.max_templates,
252 | crop_feats,
253 | mode_cfg.subsample_templates,
254 | )
255 | # random_crop_to_size(
256 | # mode_cfg.crop_size,
257 | # mode_cfg.max_templates,
258 | # crop_feats,
259 | # mode_cfg.subsample_templates,
260 | # )
261 | )
262 | transforms.append(
263 | data_transforms.make_fixed_size(
264 | crop_feats,
265 | pad_msa_clusters,
266 | mode_cfg.max_extra_msa,
267 | mode_cfg.crop_size,
268 | mode_cfg.max_templates,
269 | )
270 | # make_fixed_size(
271 | # crop_feats,
272 | # mode_cfg.crop_size,
273 | # mode_cfg.max_templates,
274 | # )
275 | )
276 | '''
277 | else:
278 | transforms.append(
279 | data_transforms.crop_templates(mode_cfg.max_templates)
280 | )
281 |
282 | '''
283 | return transforms
284 |
285 |
286 | def process_tensors_from_config(tensors, common_cfg, mode_cfg):
287 | """Based on the config, apply filters and transformations to the data."""
288 |
289 | no_templates = True
290 | if("template_aatype" in tensors):
291 | no_templates = tensors["template_aatype"].shape[0] == 0
292 |
293 |
294 | nonensembled = nonensembled_transform_fns(
295 | common_cfg,
296 | mode_cfg,
297 | )
298 |
299 | tensors = compose(nonensembled)(tensors)
300 |
301 | return tensors
302 |
303 |
304 |
305 | @data_transforms.curry1
306 | def compose(x, fs):
307 | for f in fs:
308 | x = f(x)
309 | return x
310 |
311 |
312 | def map_fn(fun, x):
313 | ensembles = [fun(elem) for elem in x]
314 | features = ensembles[0].keys()
315 | ensembled_dict = {}
316 | for feat in features:
317 | ensembled_dict[feat] = torch.stack(
318 | [dict_i[feat] for dict_i in ensembles], dim=-1
319 | )
320 | return ensembled_dict
321 |
--------------------------------------------------------------------------------
/alphaflow/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bjing2016/alphaflow/02dc03763a016949326c2c741e6e33094f9250fd/alphaflow/model/__init__.py
--------------------------------------------------------------------------------
/alphaflow/model/alphafold.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 AlQuraishi Laboratory
2 | # Copyright 2021 DeepMind Technologies Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 | import torch.nn as nn
18 |
19 | from openfold.model.embedders import (
20 | InputEmbedder,
21 | RecyclingEmbedder,
22 | ExtraMSAEmbedder,
23 | )
24 | from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
25 | from openfold.model.heads import AuxiliaryHeads
26 | from openfold.model.structure_module import StructureModule
27 |
28 | import openfold.np.residue_constants as residue_constants
29 | from openfold.utils.feats import (
30 | pseudo_beta_fn,
31 | build_extra_msa_feat,
32 | atom14_to_atom37,
33 | )
34 | from openfold.utils.tensor_utils import add
35 | from .input_stack import InputPairStack
36 | from .layers import GaussianFourierProjection
37 | from openfold.model.primitives import Linear
38 |
39 |
40 | class AlphaFold(nn.Module):
41 | """
42 | Alphafold 2.
43 |
44 | Implements Algorithm 2 (but with training).
45 | """
46 |
47 | def __init__(self, config, extra_input=False):
48 | """
49 | Args:
50 | config:
51 | A dict-like config object (like the one in config.py)
52 | """
53 | super(AlphaFold, self).__init__()
54 |
55 | self.globals = config.globals
56 | self.config = config.model
57 | self.template_config = self.config.template
58 | self.extra_msa_config = self.config.extra_msa
59 |
60 | # Main trunk + structure module
61 | self.input_embedder = InputEmbedder(
62 | **self.config["input_embedder"],
63 | )
64 | self.recycling_embedder = RecyclingEmbedder(
65 | **self.config["recycling_embedder"],
66 | )
67 |
68 |
69 | if(self.extra_msa_config.enabled):
70 | self.extra_msa_embedder = ExtraMSAEmbedder(
71 | **self.extra_msa_config["extra_msa_embedder"],
72 | )
73 | self.extra_msa_stack = ExtraMSAStack(
74 | **self.extra_msa_config["extra_msa_stack"],
75 | )
76 |
77 | self.evoformer = EvoformerStack(
78 | **self.config["evoformer_stack"],
79 | )
80 | self.structure_module = StructureModule(
81 | **self.config["structure_module"],
82 | )
83 | self.aux_heads = AuxiliaryHeads(
84 | self.config["heads"],
85 | )
86 |
87 | ################
88 | self.input_pair_embedding = Linear(
89 | self.config.input_pair_embedder.no_bins,
90 | self.config.evoformer_stack.c_z,
91 | init="final",
92 | )
93 | self.input_time_projection = GaussianFourierProjection(
94 | embedding_size=self.config.input_pair_embedder.time_emb_dim
95 | )
96 | self.input_time_embedding = Linear(
97 | self.config.input_pair_embedder.time_emb_dim,
98 | self.config.evoformer_stack.c_z,
99 | init="final",
100 | )
101 | self.input_pair_stack = InputPairStack(**self.config.input_pair_stack)
102 | self.extra_input = extra_input
103 | if extra_input:
104 | self.extra_input_pair_embedding = Linear(
105 | self.config.input_pair_embedder.no_bins,
106 | self.config.evoformer_stack.c_z,
107 | init="final",
108 | )
109 | self.extra_input_pair_stack = InputPairStack(**self.config.input_pair_stack)
110 |
111 | ################
112 |
113 | def _get_input_pair_embeddings(self, dists, mask):
114 |
115 | mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
116 |
117 | lower = torch.linspace(
118 | self.config.input_pair_embedder.min_bin,
119 | self.config.input_pair_embedder.max_bin,
120 | self.config.input_pair_embedder.no_bins,
121 | device=dists.device)
122 | dists = dists.unsqueeze(-1)
123 | inf = self.config.input_pair_embedder.inf
124 | upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
125 | dgram = ((dists > lower) * (dists < upper)).type(dists.dtype)
126 |
127 | inp_z = self.input_pair_embedding(dgram * mask.unsqueeze(-1))
128 | inp_z = self.input_pair_stack(inp_z, mask, chunk_size=None)
129 | return inp_z
130 |
131 | def _get_extra_input_pair_embeddings(self, dists, mask):
132 |
133 | mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
134 |
135 | lower = torch.linspace(
136 | self.config.input_pair_embedder.min_bin,
137 | self.config.input_pair_embedder.max_bin,
138 | self.config.input_pair_embedder.no_bins,
139 | device=dists.device)
140 | dists = dists.unsqueeze(-1)
141 | inf = self.config.input_pair_embedder.inf
142 | upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
143 | dgram = ((dists > lower) * (dists < upper)).type(dists.dtype)
144 |
145 | inp_z = self.extra_input_pair_embedding(dgram * mask.unsqueeze(-1))
146 | inp_z = self.extra_input_pair_stack(inp_z, mask, chunk_size=None)
147 | return inp_z
148 |
149 |
150 | def forward(self, batch, prev_outputs=None):
151 |
152 | feats = batch
153 |
154 | # Primary output dictionary
155 | outputs = {}
156 |
157 | # # This needs to be done manually for DeepSpeed's sake
158 | # dtype = next(self.parameters()).dtype
159 | # for k in feats:
160 | # if(feats[k].dtype == torch.float32):
161 | # feats[k] = feats[k].to(dtype=dtype)
162 |
163 | # Grab some data about the input
164 | batch_dims = feats["target_feat"].shape[:-2]
165 | no_batch_dims = len(batch_dims)
166 | n = feats["target_feat"].shape[-2]
167 | n_seq = feats["msa_feat"].shape[-3]
168 | device = feats["target_feat"].device
169 |
170 | # Controls whether the model uses in-place operations throughout
171 | # The dual condition accounts for activation checkpoints
172 | inplace_safe = not (self.training or torch.is_grad_enabled())
173 |
174 | # Prep some features
175 | seq_mask = feats["seq_mask"]
176 | pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
177 | msa_mask = feats["msa_mask"]
178 |
179 | ## Initialize the MSA and pair representations
180 |
181 | # m: [*, S_c, N, C_m]
182 | # z: [*, N, N, C_z]
183 | m, z = self.input_embedder(
184 | feats["target_feat"],
185 | feats["residue_index"],
186 | feats["msa_feat"],
187 | inplace_safe=inplace_safe,
188 | )
189 | if prev_outputs is None:
190 | m_1_prev = m.new_zeros((*batch_dims, n, self.config.input_embedder.c_m), requires_grad=False)
191 | # [*, N, N, C_z]
192 | z_prev = z.new_zeros((*batch_dims, n, n, self.config.input_embedder.c_z), requires_grad=False)
193 | # [*, N, 3]
194 | x_prev = z.new_zeros((*batch_dims, n, residue_constants.atom_type_num, 3), requires_grad=False)
195 |
196 | else:
197 | m_1_prev, z_prev, x_prev = prev_outputs['m_1_prev'], prev_outputs['z_prev'], prev_outputs['x_prev']
198 |
199 | x_prev = pseudo_beta_fn(
200 | feats["aatype"], x_prev, None
201 | ).to(dtype=z.dtype)
202 |
203 | # m_1_prev_emb: [*, N, C_m]
204 | # z_prev_emb: [*, N, N, C_z]
205 | m_1_prev_emb, z_prev_emb = self.recycling_embedder(
206 | m_1_prev,
207 | z_prev,
208 | x_prev,
209 | inplace_safe=inplace_safe,
210 | )
211 |
212 | # [*, S_c, N, C_m]
213 | m[..., 0, :, :] += m_1_prev_emb
214 |
215 | # [*, N, N, C_z]
216 | z = add(z, z_prev_emb, inplace=inplace_safe)
217 |
218 |
219 | #######################
220 | if 'noised_pseudo_beta_dists' in batch:
221 | inp_z = self._get_input_pair_embeddings(
222 | batch['noised_pseudo_beta_dists'],
223 | batch['pseudo_beta_mask'],
224 | )
225 | inp_z = inp_z + self.input_time_embedding(self.input_time_projection(batch['t']))[:,None,None]
226 |
227 | else: # otherwise DDP complains
228 | B, L = batch['aatype'].shape
229 | inp_z = self._get_input_pair_embeddings(
230 | z.new_zeros(B, L, L),
231 | z.new_zeros(B, L),
232 | )
233 | inp_z = inp_z + self.input_time_embedding(self.input_time_projection(z.new_zeros(B)))[:,None,None]
234 |
235 | z = add(z, inp_z, inplace=inplace_safe)
236 |
237 | #############################
238 | if self.extra_input:
239 | if 'extra_all_atom_positions' in batch:
240 | extra_pseudo_beta = pseudo_beta_fn(batch['aatype'], batch['extra_all_atom_positions'], None)
241 | extra_pseudo_beta_dists = torch.sum((extra_pseudo_beta.unsqueeze(-2) - extra_pseudo_beta.unsqueeze(-3)) ** 2, dim=-1)**0.5
242 | extra_inp_z = self._get_extra_input_pair_embeddings(
243 | extra_pseudo_beta_dists,
244 | batch['pseudo_beta_mask'],
245 | )
246 |
247 | else: # otherwise DDP complains
248 | B, L = batch['aatype'].shape
249 | extra_inp_z = self._get_extra_input_pair_embeddings(
250 | z.new_zeros(B, L, L),
251 | z.new_zeros(B, L),
252 | ) * 0.0
253 |
254 | z = add(z, extra_inp_z, inplace=inplace_safe)
255 | ########################
256 |
257 | # Embed extra MSA features + merge with pairwise embeddings
258 | if self.config.extra_msa.enabled:
259 | # [*, S_e, N, C_e]
260 | a = self.extra_msa_embedder(build_extra_msa_feat(feats))
261 |
262 | if(self.globals.offload_inference):
263 | # To allow the extra MSA stack (and later the evoformer) to
264 | # offload its inputs, we remove all references to them here
265 | input_tensors = [a, z]
266 | del a, z
267 |
268 | # [*, N, N, C_z]
269 | z = self.extra_msa_stack._forward_offload(
270 | input_tensors,
271 | msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
272 | chunk_size=self.globals.chunk_size,
273 | use_lma=self.globals.use_lma,
274 | pair_mask=pair_mask.to(dtype=m.dtype),
275 | _mask_trans=self.config._mask_trans,
276 | )
277 |
278 | del input_tensors
279 | else:
280 | # [*, N, N, C_z]
281 |
282 | z = self.extra_msa_stack(
283 | a, z,
284 | msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
285 | chunk_size=self.globals.chunk_size,
286 | use_lma=self.globals.use_lma,
287 | pair_mask=pair_mask.to(dtype=m.dtype),
288 | inplace_safe=inplace_safe,
289 | _mask_trans=self.config._mask_trans,
290 | )
291 |
292 | # Run MSA + pair embeddings through the trunk of the network
293 | # m: [*, S, N, C_m]
294 | # z: [*, N, N, C_z]
295 | # s: [*, N, C_s]
296 | if(self.globals.offload_inference):
297 | input_tensors = [m, z]
298 | del m, z
299 | m, z, s = self.evoformer._forward_offload(
300 | input_tensors,
301 | msa_mask=msa_mask.to(dtype=input_tensors[0].dtype),
302 | pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
303 | chunk_size=self.globals.chunk_size,
304 | use_lma=self.globals.use_lma,
305 | _mask_trans=self.config._mask_trans,
306 | )
307 |
308 | del input_tensors
309 | else:
310 | m, z, s = self.evoformer(
311 | m,
312 | z,
313 | msa_mask=msa_mask.to(dtype=m.dtype),
314 | pair_mask=pair_mask.to(dtype=z.dtype),
315 | chunk_size=self.globals.chunk_size,
316 | use_lma=self.globals.use_lma,
317 | use_flash=self.globals.use_flash,
318 | inplace_safe=inplace_safe,
319 | _mask_trans=self.config._mask_trans,
320 | )
321 |
322 | outputs["msa"] = m[..., :n_seq, :, :]
323 | outputs["pair"] = z
324 | outputs["single"] = s
325 |
326 | del z
327 |
328 | # Predict 3D structure
329 | outputs["sm"] = self.structure_module(
330 | outputs,
331 | feats["aatype"],
332 | mask=feats["seq_mask"].to(dtype=s.dtype),
333 | inplace_safe=inplace_safe,
334 | _offload_inference=self.globals.offload_inference,
335 | )
336 | outputs["final_atom_positions"] = atom14_to_atom37(
337 | outputs["sm"]["positions"][-1], feats
338 | )
339 | outputs["final_atom_mask"] = feats["atom37_atom_exists"]
340 | outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
341 |
342 |
343 | outputs.update(self.aux_heads(outputs))
344 |
345 |
346 | # [*, N, C_m]
347 | outputs['m_1_prev'] = m[..., 0, :, :]
348 |
349 | # [*, N, N, C_z]
350 | outputs['z_prev'] = outputs["pair"]
351 |
352 | # [*, N, 3]
353 | outputs['x_prev'] = outputs["final_atom_positions"]
354 |
355 | return outputs
356 |
357 | # def forward(self, batch):
358 | # """
359 | # Args:
360 | # batch:
361 | # Dictionary of arguments outlined in Algorithm 2. Keys must
362 | # include the official names of the features in the
363 | # supplement subsection 1.2.9.
364 |
365 | # The final dimension of each input must have length equal to
366 | # the number of recycling iterations.
367 |
368 | # Features (without the recycling dimension):
369 |
370 | # "aatype" ([*, N_res]):
371 | # Contrary to the supplement, this tensor of residue
372 | # indices is not one-hot.
373 | # "target_feat" ([*, N_res, C_tf])
374 | # One-hot encoding of the target sequence. C_tf is
375 | # config.model.input_embedder.tf_dim.
376 | # "residue_index" ([*, N_res])
377 | # Tensor whose final dimension consists of
378 | # consecutive indices from 0 to N_res.
379 | # "msa_feat" ([*, N_seq, N_res, C_msa])
380 | # MSA features, constructed as in the supplement.
381 | # C_msa is config.model.input_embedder.msa_dim.
382 | # "seq_mask" ([*, N_res])
383 | # 1-D sequence mask
384 | # "msa_mask" ([*, N_seq, N_res])
385 | # MSA mask
386 | # "pair_mask" ([*, N_res, N_res])
387 | # 2-D pair mask
388 | # "extra_msa_mask" ([*, N_extra, N_res])
389 | # Extra MSA mask
390 | # "template_mask" ([*, N_templ])
391 | # Template mask (on the level of templates, not
392 | # residues)
393 | # "template_aatype" ([*, N_templ, N_res])
394 | # Tensor of template residue indices (indices greater
395 | # than 19 are clamped to 20 (Unknown))
396 | # "template_all_atom_positions"
397 | # ([*, N_templ, N_res, 37, 3])
398 | # Template atom coordinates in atom37 format
399 | # "template_all_atom_mask" ([*, N_templ, N_res, 37])
400 | # Template atom coordinate mask
401 | # "template_pseudo_beta" ([*, N_templ, N_res, 3])
402 | # Positions of template carbon "pseudo-beta" atoms
403 | # (i.e. C_beta for all residues but glycine, for
404 | # for which C_alpha is used instead)
405 | # "template_pseudo_beta_mask" ([*, N_templ, N_res])
406 | # Pseudo-beta mask
407 | # """
408 | # # Initialize recycling embeddings
409 |
410 | # m_1_prev, z_prev, x_prev = None, None, None
411 | # prevs = [m_1_prev, z_prev, x_prev]
412 |
413 | # is_grad_enabled = torch.is_grad_enabled()
414 |
415 | # # Main recycling loop
416 | # num_iters = batch["aatype"].shape[-1]
417 | # for cycle_no in range(num_iters):
418 | # # Select the features for the current recycling cycle
419 | # fetch_cur_batch = lambda t: t[..., cycle_no]
420 | # feats = tensor_tree_map(fetch_cur_batch, batch)
421 |
422 | # # Enable grad iff we're training and it's the final recycling layer
423 | # is_final_iter = cycle_no == (num_iters - 1)
424 | # with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
425 | # if is_final_iter:
426 | # # Sidestep AMP bug (PyTorch issue #65766)
427 | # if torch.is_autocast_enabled():
428 | # torch.clear_autocast_cache()
429 |
430 | # # Run the next iteration of the model
431 | # outputs, m_1_prev, z_prev, x_prev = self.iteration(
432 | # feats,
433 | # prevs,
434 | # _recycle=(num_iters > 1)
435 | # )
436 |
437 | # if(not is_final_iter):
438 | # del outputs
439 | # prevs = [m_1_prev, z_prev, x_prev]
440 | # del m_1_prev, z_prev, x_prev
441 |
442 | # # Run auxiliary heads
443 | # outputs.update(self.aux_heads(outputs))
444 |
445 | # return outputs
--------------------------------------------------------------------------------
/alphaflow/model/esmfold.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | import typing as T
6 | from functools import partial
7 |
8 | import torch
9 | from torch import nn
10 | from torch.nn import LayerNorm
11 |
12 | import esm
13 | from esm import Alphabet
14 |
15 | from alphaflow.utils.misc import (
16 | categorical_lddt,
17 | batch_encode_sequences,
18 | collate_dense_tensors,
19 | output_to_pdb,
20 | )
21 |
22 | from .trunk import FoldingTrunk
23 | from .layers import GaussianFourierProjection
24 | from .input_stack import InputPairStack
25 |
26 | from openfold.data.data_transforms import make_atom14_masks
27 | from openfold.np import residue_constants
28 | from openfold.model.heads import PerResidueLDDTCaPredictor
29 | from openfold.model.primitives import Linear
30 | from openfold.utils.feats import atom14_to_atom37, pseudo_beta_fn
31 |
32 |
33 | load_fn = esm.pretrained.load_model_and_alphabet
34 | esm_registry = {
35 | "esm2_8M": partial(load_fn, "esm2_t6_8M_UR50D_500K"),
36 | "esm2_8M_270K": esm.pretrained.esm2_t6_8M_UR50D,
37 | "esm2_35M": partial(load_fn, "esm2_t12_35M_UR50D_500K"),
38 | "esm2_35M_270K": esm.pretrained.esm2_t12_35M_UR50D,
39 | "esm2_150M": partial(load_fn, "esm2_t30_150M_UR50D_500K"),
40 | "esm2_150M_270K": partial(load_fn, "esm2_t30_150M_UR50D_270K"),
41 | "esm2_650M": esm.pretrained.esm2_t33_650M_UR50D,
42 | "esm2_650M_270K": partial(load_fn, "esm2_t33_650M_270K_UR50D"),
43 | "esm2_3B": esm.pretrained.esm2_t36_3B_UR50D,
44 | "esm2_3B_270K": partial(load_fn, "esm2_t36_3B_UR50D_500K"),
45 | "esm2_15B": esm.pretrained.esm2_t48_15B_UR50D,
46 | }
47 |
48 |
49 | class ESMFold(nn.Module):
50 | def __init__(self, cfg, extra_input=False):
51 | super().__init__()
52 |
53 | self.cfg = cfg
54 | cfg = self.cfg
55 |
56 | self.distogram_bins = 64
57 | self.esm, self.esm_dict = esm_registry.get(cfg.esm_type)()
58 |
59 | self.esm.requires_grad_(False)
60 | self.esm.half()
61 |
62 | self.esm_feats = self.esm.embed_dim
63 | self.esm_attns = self.esm.num_layers * self.esm.attention_heads
64 | self.register_buffer("af2_to_esm", ESMFold._af2_to_esm(self.esm_dict).float()) # hack to get EMA working
65 | self.esm_s_combine = nn.Parameter(torch.zeros(self.esm.num_layers + 1))
66 |
67 | c_s = cfg.trunk.sequence_state_dim
68 | c_z = cfg.trunk.pairwise_state_dim
69 |
70 | self.esm_s_mlp = nn.Sequential(
71 | LayerNorm(self.esm_feats),
72 | nn.Linear(self.esm_feats, c_s),
73 | nn.ReLU(),
74 | nn.Linear(c_s, c_s),
75 | )
76 | ######################
77 | self.input_pair_embedding = Linear(
78 | cfg.input_pair_embedder.no_bins,
79 | cfg.trunk.pairwise_state_dim,
80 | init="final",
81 | )
82 | self.input_time_projection = GaussianFourierProjection(
83 | embedding_size=cfg.input_pair_embedder.time_emb_dim
84 | )
85 | self.input_time_embedding = Linear(
86 | cfg.input_pair_embedder.time_emb_dim,
87 | cfg.trunk.pairwise_state_dim,
88 | init="final",
89 | )
90 | torch.nn.init.zeros_(self.input_pair_embedding.weight)
91 | torch.nn.init.zeros_(self.input_pair_embedding.bias)
92 | self.input_pair_stack = InputPairStack(**cfg.input_pair_stack)
93 |
94 | self.extra_input = extra_input
95 | if extra_input:
96 | self.extra_input_pair_embedding = Linear(
97 | cfg.input_pair_embedder.no_bins,
98 | cfg.evoformer_stack.c_z,
99 | init="final",
100 | )
101 | self.extra_input_pair_stack = InputPairStack(**cfg.input_pair_stack)
102 |
103 | #######################
104 |
105 | # 0 is padding, N is unknown residues, N + 1 is mask.
106 | self.n_tokens_embed = residue_constants.restype_num + 3
107 | self.pad_idx = 0
108 | self.unk_idx = self.n_tokens_embed - 2
109 | self.mask_idx = self.n_tokens_embed - 1
110 | self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)
111 |
112 | self.trunk = FoldingTrunk(cfg.trunk)
113 |
114 | self.distogram_head = nn.Linear(c_z, self.distogram_bins)
115 | # self.ptm_head = nn.Linear(c_z, self.distogram_bins)
116 | # self.lm_head = nn.Linear(c_s, self.n_tokens_embed)
117 | self.lddt_bins = 50
118 |
119 | self.lddt_head = PerResidueLDDTCaPredictor(
120 | no_bins=self.lddt_bins,
121 | c_in=cfg.trunk.structure_module.c_s,
122 | c_hidden=cfg.lddt_head_hid_dim
123 | )
124 |
125 | @staticmethod
126 | def _af2_to_esm(d: Alphabet):
127 | # Remember that t is shifted from residue_constants by 1 (0 is padding).
128 | esm_reorder = [d.padding_idx] + [
129 | d.get_idx(v) for v in residue_constants.restypes_with_x
130 | ]
131 | return torch.tensor(esm_reorder)
132 |
133 | def _af2_idx_to_esm_idx(self, aa, mask):
134 | aa = (aa + 1).masked_fill(mask != 1, 0)
135 | return self.af2_to_esm.long()[aa]
136 |
137 | def _compute_language_model_representations(
138 | self, esmaa: torch.Tensor
139 | ) -> torch.Tensor:
140 | """Adds bos/eos tokens for the language model, since the structure module doesn't use these."""
141 | batch_size = esmaa.size(0)
142 |
143 | bosi, eosi = self.esm_dict.cls_idx, self.esm_dict.eos_idx
144 | bos = esmaa.new_full((batch_size, 1), bosi)
145 | eos = esmaa.new_full((batch_size, 1), self.esm_dict.padding_idx)
146 | esmaa = torch.cat([bos, esmaa, eos], dim=1)
147 | # Use the first padding index as eos during inference.
148 | esmaa[range(batch_size), (esmaa != 1).sum(1)] = eosi
149 |
150 | res = self.esm(
151 | esmaa,
152 | repr_layers=range(self.esm.num_layers + 1),
153 | need_head_weights=self.cfg.use_esm_attn_map,
154 | )
155 | esm_s = torch.stack(
156 | [v for _, v in sorted(res["representations"].items())], dim=2
157 | )
158 | esm_s = esm_s[:, 1:-1] # B, L, nLayers, C
159 | esm_z = (
160 | res["attentions"].permute(0, 4, 3, 1, 2).flatten(3, 4)[:, 1:-1, 1:-1, :]
161 | if self.cfg.use_esm_attn_map
162 | else None
163 | )
164 | return esm_s, esm_z
165 |
166 | def _mask_inputs_to_esm(self, esmaa, pattern):
167 | new_esmaa = esmaa.clone()
168 | new_esmaa[pattern == 1] = self.esm_dict.mask_idx
169 | return new_esmaa
170 |
171 | def _get_input_pair_embeddings(self, dists, mask):
172 |
173 | mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
174 |
175 | lower = torch.linspace(
176 | self.cfg.input_pair_embedder.min_bin,
177 | self.cfg.input_pair_embedder.max_bin,
178 | self.cfg.input_pair_embedder.no_bins,
179 | device=dists.device)
180 | dists = dists.unsqueeze(-1)
181 | inf = self.cfg.input_pair_embedder.inf
182 | upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
183 | dgram = ((dists > lower) * (dists < upper)).type(dists.dtype)
184 |
185 | inp_z = self.input_pair_embedding(dgram * mask.unsqueeze(-1))
186 | inp_z = self.input_pair_stack(inp_z, mask, chunk_size=None)
187 | return inp_z
188 |
189 | def _get_extra_input_pair_embeddings(self, dists, mask):
190 |
191 | mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
192 |
193 | lower = torch.linspace(
194 | self.cfg.input_pair_embedder.min_bin,
195 | self.cfg.input_pair_embedder.max_bin,
196 | self.cfg.input_pair_embedder.no_bins,
197 | device=dists.device)
198 | dists = dists.unsqueeze(-1)
199 | inf = self.cfg.input_pair_embedder.inf
200 | upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
201 | dgram = ((dists > lower) * (dists < upper)).type(dists.dtype)
202 |
203 | inp_z = self.extra_input_pair_embedding(dgram * mask.unsqueeze(-1))
204 | inp_z = self.extra_input_pair_stack(inp_z, mask, chunk_size=None)
205 | return inp_z
206 |
207 |
208 | def forward(
209 | self,
210 | batch,
211 | prev_outputs=None,
212 | ):
213 | """Runs a forward pass given input tokens. Use `model.infer` to
214 | run inference from a sequence.
215 |
216 | Args:
217 | aa (torch.Tensor): Tensor containing indices corresponding to amino acids. Indices match
218 | openfold.np.residue_constants.restype_order_with_x.
219 | mask (torch.Tensor): Binary tensor with 1 meaning position is unmasked and 0 meaning position is masked.
220 | residx (torch.Tensor): Residue indices of amino acids. Will assume contiguous if not provided.
221 | masking_pattern (torch.Tensor): Optional masking to pass to the input. Binary tensor of the same size
222 | as `aa`. Positions with 1 will be masked. ESMFold sometimes produces different samples when
223 | different masks are provided.
224 | num_recycles (int): How many recycle iterations to perform. If None, defaults to training max
225 | recycles, which is 3.
226 | """
227 | aa = batch['aatype']
228 | mask = batch['seq_mask']
229 | residx = batch['residue_index']
230 |
231 | # === ESM ===
232 |
233 | esmaa = self._af2_idx_to_esm_idx(aa, mask)
234 | esm_s, _ = self._compute_language_model_representations(esmaa)
235 |
236 | # Convert esm_s to the precision used by the trunk and
237 | # the structure module. These tensors may be a lower precision if, for example,
238 | # we're running the language model in fp16 precision.
239 | esm_s = esm_s.to(self.esm_s_combine.dtype)
240 | esm_s = esm_s.detach()
241 |
242 | # === preprocessing ===
243 | esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)
244 | s_s_0 = self.esm_s_mlp(esm_s)
245 | s_s_0 += self.embedding(aa)
246 | #######################
247 | if 'noised_pseudo_beta_dists' in batch:
248 | inp_z = self._get_input_pair_embeddings(
249 | batch['noised_pseudo_beta_dists'],
250 | batch['pseudo_beta_mask']
251 | )
252 | inp_z = inp_z + self.input_time_embedding(self.input_time_projection(batch['t']))[:,None,None]
253 | else: # have to run the module, else DDP wont work
254 | B, L = batch['aatype'].shape
255 | inp_z = self._get_input_pair_embeddings(
256 | s_s_0.new_zeros(B, L, L),
257 | batch['pseudo_beta_mask'] * 0.0
258 | )
259 | inp_z = inp_z + self.input_time_embedding(self.input_time_projection(inp_z.new_zeros(B)))[:,None,None]
260 | ##########################
261 | #############################
262 | if self.extra_input:
263 | if 'extra_all_atom_positions' in batch:
264 | extra_pseudo_beta = pseudo_beta_fn(batch['aatype'], batch['extra_all_atom_positions'], None)
265 | extra_pseudo_beta_dists = torch.sum((extra_pseudo_beta.unsqueeze(-2) - extra_pseudo_beta.unsqueeze(-3)) ** 2, dim=-1)**0.5
266 | extra_inp_z = self._get_extra_input_pair_embeddings(
267 | extra_pseudo_beta_dists,
268 | batch['pseudo_beta_mask'],
269 | )
270 |
271 | else: # otherwise DDP complains
272 | B, L = batch['aatype'].shape
273 | extra_inp_z = self._get_extra_input_pair_embeddings(
274 | inp_z.new_zeros(B, L, L),
275 | inp_z.new_zeros(B, L),
276 | ) * 0.0
277 |
278 | inp_z = inp_z + extra_inp_z
279 | ########################
280 |
281 |
282 |
283 | s_z_0 = inp_z
284 | if prev_outputs is not None:
285 | s_s_0 = s_s_0 + self.trunk.recycle_s_norm(prev_outputs['s_s'])
286 | s_z_0 = s_z_0 + self.trunk.recycle_z_norm(prev_outputs['s_z'])
287 | s_z_0 = s_z_0 + self.trunk.recycle_disto(FoldingTrunk.distogram(
288 | prev_outputs['sm']["positions"][-1][:, :, :3],
289 | 3.375,
290 | 21.375,
291 | self.trunk.recycle_bins,
292 | ))
293 |
294 | else:
295 | s_s_0 = s_s_0 + self.trunk.recycle_s_norm(torch.zeros_like(s_s_0)) * 0.0
296 | s_z_0 = s_z_0 + self.trunk.recycle_z_norm(torch.zeros_like(s_z_0)) * 0.0
297 | s_z_0 = s_z_0 + self.trunk.recycle_disto(s_z_0.new_zeros(s_z_0.shape[:-2], dtype=torch.long)) * 0.0
298 |
299 |
300 |
301 | structure: dict = self.trunk(
302 | s_s_0, s_z_0, aa, residx, mask, no_recycles=0 # num_recycles
303 | )
304 | disto_logits = self.distogram_head(structure["s_z"])
305 | disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
306 | structure["distogram_logits"] = disto_logits
307 |
308 | '''
309 | lm_logits = self.lm_head(structure["s_s"])
310 | structure["lm_logits"] = lm_logits
311 | '''
312 |
313 | structure["aatype"] = aa
314 | make_atom14_masks(structure)
315 |
316 | for k in [
317 | "atom14_atom_exists",
318 | "atom37_atom_exists",
319 | ]:
320 | structure[k] *= mask.unsqueeze(-1)
321 | structure["residue_index"] = residx
322 | lddt_head = self.lddt_head(structure['sm']["single"])
323 | structure["lddt_logits"] = lddt_head
324 | plddt = categorical_lddt(lddt_head, bins=self.lddt_bins)
325 | structure["plddt"] = 100 * plddt
326 | # we predict plDDT between 0 and 1, scale to be between 0 and 100.
327 |
328 | '''
329 | ptm_logits = self.ptm_head(structure["s_z"])
330 | seqlen = mask.type(torch.int64).sum(1)
331 | structure["tm_logits"] = ptm_logits
332 | structure["ptm"] = torch.stack(
333 | [
334 | compute_tm(
335 | batch_ptm_logits[None, :sl, :sl],
336 | max_bins=31,
337 | no_bins=self.distogram_bins,
338 | )
339 | for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
340 | ]
341 | )
342 | structure.update(
343 | compute_predicted_aligned_error(
344 | ptm_logits, max_bin=31, no_bins=self.distogram_bins
345 | )
346 | )
347 | '''
348 |
349 | structure["final_atom_positions"] = atom14_to_atom37(structure["sm"]["positions"][-1], batch)
350 | structure["final_affine_tensor"] = structure["sm"]["frames"][-1]
351 | if "name" in batch: structure["name"] = batch["name"]
352 | return structure
353 |
354 | @torch.no_grad()
355 | def infer(
356 | self,
357 | sequences: T.Union[str, T.List[str]],
358 | residx=None,
359 | masking_pattern: T.Optional[torch.Tensor] = None,
360 | num_recycles: T.Optional[int] = None,
361 | residue_index_offset: T.Optional[int] = 512,
362 | chain_linker: T.Optional[str] = "G" * 25,
363 | ):
364 | """Runs a forward pass given input sequences.
365 |
366 | Args:
367 | sequences (Union[str, List[str]]): A list of sequences to make predictions for. Multimers can also be passed in,
368 | each chain should be separated by a ':' token (e.g. "::").
369 | residx (torch.Tensor): Residue indices of amino acids. Will assume contiguous if not provided.
370 | masking_pattern (torch.Tensor): Optional masking to pass to the input. Binary tensor of the same size
371 | as `aa`. Positions with 1 will be masked. ESMFold sometimes produces different samples when
372 | different masks are provided.
373 | num_recycles (int): How many recycle iterations to perform. If None, defaults to training max
374 | recycles (cfg.trunk.max_recycles), which is 4.
375 | residue_index_offset (int): Residue index separation between chains if predicting a multimer. Has no effect on
376 | single chain predictions. Default: 512.
377 | chain_linker (str): Linker to use between chains if predicting a multimer. Has no effect on single chain
378 | predictions. Default: length-25 poly-G ("G" * 25).
379 | """
380 | if isinstance(sequences, str):
381 | sequences = [sequences]
382 |
383 | aatype, mask, _residx, linker_mask, chain_index = batch_encode_sequences(
384 | sequences, residue_index_offset, chain_linker
385 | )
386 |
387 | if residx is None:
388 | residx = _residx
389 | elif not isinstance(residx, torch.Tensor):
390 | residx = collate_dense_tensors(residx)
391 |
392 | aatype, mask, residx, linker_mask = map(
393 | lambda x: x.to(self.device), (aatype, mask, residx, linker_mask)
394 | )
395 |
396 | output = self.forward(
397 | aatype,
398 | mask=mask,
399 | residx=residx,
400 | masking_pattern=masking_pattern,
401 | num_recycles=num_recycles,
402 | )
403 |
404 | output["atom37_atom_exists"] = output[
405 | "atom37_atom_exists"
406 | ] * linker_mask.unsqueeze(2)
407 |
408 | output["mean_plddt"] = (output["plddt"] * output["atom37_atom_exists"]).sum(dim=(1, 2)) / output["atom37_atom_exists"].sum(dim=(1, 2))
409 | output["chain_index"] = chain_index
410 |
411 | return output
412 |
413 | def output_to_pdb(self, output: T.Dict) -> T.List[str]:
414 | """Returns the pbd (file) string from the model given the model output."""
415 | return output_to_pdb(output)
416 |
417 | def infer_pdbs(self, seqs: T.List[str], *args, **kwargs) -> T.List[str]:
418 | """Returns list of pdb (files) strings from the model given a list of input sequences."""
419 | output = self.infer(seqs, *args, **kwargs)
420 | return self.output_to_pdb(output)
421 |
422 | def infer_pdb(self, sequence: str, *args, **kwargs) -> str:
423 | """Returns the pdb (file) string from the model given an input sequence."""
424 | return self.infer_pdbs([sequence], *args, **kwargs)[0]
425 |
426 | def set_chunk_size(self, chunk_size: T.Optional[int]):
427 | # This parameter means the axial attention will be computed
428 | # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
429 | # It's equivalent to running a for loop over chunks of the dimension we're iterative over,
430 | # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-lengthed chunks.
431 | # Setting the value to None will return to default behavior, disable chunking.
432 | self.trunk.set_chunk_size(chunk_size)
433 |
434 | @property
435 | def device(self):
436 | return self.esm_s_combine.device
437 |
--------------------------------------------------------------------------------
/alphaflow/model/input_stack.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 AlQuraishi Laboratory
2 | # Copyright 2021 DeepMind Technologies Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from functools import partial
16 | from typing import Optional
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from openfold.model.primitives import LayerNorm
22 | from openfold.model.dropout import (
23 | DropoutRowwise,
24 | DropoutColumnwise,
25 | )
26 | from openfold.model.pair_transition import PairTransition
27 | from openfold.model.triangular_attention import (
28 | TriangleAttentionStartingNode,
29 | TriangleAttentionEndingNode,
30 | )
31 | from openfold.model.triangular_multiplicative_update import (
32 | TriangleMultiplicationOutgoing,
33 | TriangleMultiplicationIncoming,
34 | )
35 | from openfold.utils.checkpointing import checkpoint_blocks
36 | from openfold.utils.chunk_utils import ChunkSizeTuner
37 | from openfold.utils.tensor_utils import add
38 |
39 |
40 | class InputPairStackBlock(nn.Module):
41 | def __init__(
42 | self,
43 | c_t: int,
44 | c_hidden_tri_att: int,
45 | c_hidden_tri_mul: int,
46 | no_heads: int,
47 | pair_transition_n: int,
48 | dropout_rate: float,
49 | inf: float,
50 | **kwargs,
51 | ):
52 | super(InputPairStackBlock, self).__init__()
53 |
54 | self.c_t = c_t
55 | self.c_hidden_tri_att = c_hidden_tri_att
56 | self.c_hidden_tri_mul = c_hidden_tri_mul
57 | self.no_heads = no_heads
58 | self.pair_transition_n = pair_transition_n
59 | self.dropout_rate = dropout_rate
60 | self.inf = inf
61 |
62 | self.dropout_row = DropoutRowwise(self.dropout_rate)
63 | self.dropout_col = DropoutColumnwise(self.dropout_rate)
64 |
65 | self.tri_att_start = TriangleAttentionStartingNode(
66 | self.c_t,
67 | self.c_hidden_tri_att,
68 | self.no_heads,
69 | inf=inf,
70 | )
71 | self.tri_att_end = TriangleAttentionEndingNode(
72 | self.c_t,
73 | self.c_hidden_tri_att,
74 | self.no_heads,
75 | inf=inf,
76 | )
77 |
78 | self.tri_mul_out = TriangleMultiplicationOutgoing(
79 | self.c_t,
80 | self.c_hidden_tri_mul,
81 | )
82 | self.tri_mul_in = TriangleMultiplicationIncoming(
83 | self.c_t,
84 | self.c_hidden_tri_mul,
85 | )
86 |
87 | self.pair_transition = PairTransition(
88 | self.c_t,
89 | self.pair_transition_n,
90 | )
91 |
92 | def forward(self,
93 | z: torch.Tensor,
94 | mask: torch.Tensor,
95 | chunk_size: Optional[int] = None,
96 | use_lma: bool = False,
97 | inplace_safe: bool = False,
98 | _mask_trans: bool = True,
99 | _attn_chunk_size: Optional[int] = None,
100 | ):
101 | if(_attn_chunk_size is None):
102 | _attn_chunk_size = chunk_size
103 |
104 | single = z # single_templates[i]
105 | single_mask = mask # single_templates_masks[i]
106 |
107 | single = add(single,
108 | self.dropout_row(
109 | self.tri_att_start(
110 | single,
111 | chunk_size=_attn_chunk_size,
112 | mask=single_mask,
113 | use_lma=use_lma,
114 | inplace_safe=inplace_safe,
115 | )
116 | ),
117 | inplace_safe,
118 | )
119 |
120 | single = add(single,
121 | self.dropout_col(
122 | self.tri_att_end(
123 | single,
124 | chunk_size=_attn_chunk_size,
125 | mask=single_mask,
126 | use_lma=use_lma,
127 | inplace_safe=inplace_safe,
128 | )
129 | ),
130 | inplace_safe,
131 | )
132 |
133 | tmu_update = self.tri_mul_out(
134 | single,
135 | mask=single_mask,
136 | inplace_safe=inplace_safe,
137 | _add_with_inplace=True,
138 | )
139 | if(not inplace_safe):
140 | single = single + self.dropout_row(tmu_update)
141 | else:
142 | single = tmu_update
143 |
144 | del tmu_update
145 |
146 | tmu_update = self.tri_mul_in(
147 | single,
148 | mask=single_mask,
149 | inplace_safe=inplace_safe,
150 | _add_with_inplace=True,
151 | )
152 | if(not inplace_safe):
153 | single = single + self.dropout_row(tmu_update)
154 | else:
155 | single = tmu_update
156 |
157 | del tmu_update
158 |
159 | single = add(single,
160 | self.pair_transition(
161 | single,
162 | mask=single_mask if _mask_trans else None,
163 | chunk_size=chunk_size,
164 | ),
165 | inplace_safe,
166 | )
167 |
168 | return single
169 |
170 |
171 | class InputPairStack(nn.Module):
172 | """
173 | Implements Algorithm 16.
174 | """
175 | def __init__(
176 | self,
177 | c_t,
178 | c_hidden_tri_att,
179 | c_hidden_tri_mul,
180 | no_blocks,
181 | no_heads,
182 | pair_transition_n,
183 | dropout_rate,
184 | blocks_per_ckpt,
185 | tune_chunk_size: bool = False,
186 | inf=1e9,
187 | **kwargs,
188 | ):
189 | """
190 | Args:
191 | c_t:
192 | Template embedding channel dimension
193 | c_hidden_tri_att:
194 | Per-head hidden dimension for triangular attention
195 | c_hidden_tri_att:
196 | Hidden dimension for triangular multiplication
197 | no_blocks:
198 | Number of blocks in the stack
199 | pair_transition_n:
200 | Scale of pair transition (Alg. 15) hidden dimension
201 | dropout_rate:
202 | Dropout rate used throughout the stack
203 | blocks_per_ckpt:
204 | Number of blocks per activation checkpoint. None disables
205 | activation checkpointing
206 | """
207 | super(InputPairStack, self).__init__()
208 |
209 | self.blocks_per_ckpt = blocks_per_ckpt
210 |
211 | self.blocks = nn.ModuleList()
212 | for _ in range(no_blocks):
213 | block = InputPairStackBlock(
214 | c_t=c_t,
215 | c_hidden_tri_att=c_hidden_tri_att,
216 | c_hidden_tri_mul=c_hidden_tri_mul,
217 | no_heads=no_heads,
218 | pair_transition_n=pair_transition_n,
219 | dropout_rate=dropout_rate,
220 | inf=inf,
221 | )
222 | self.blocks.append(block)
223 |
224 | self.layer_norm = LayerNorm(c_t)
225 |
226 | self.tune_chunk_size = tune_chunk_size
227 | self.chunk_size_tuner = None
228 | if(tune_chunk_size):
229 | self.chunk_size_tuner = ChunkSizeTuner()
230 |
231 | def forward(
232 | self,
233 | t: torch.tensor,
234 | mask: torch.tensor,
235 | chunk_size: int,
236 | use_lma: bool = False,
237 | inplace_safe: bool = False,
238 | _mask_trans: bool = True,
239 | ):
240 | """
241 | Args:
242 | t:
243 | [*, N_templ, N_res, N_res, C_t] template embedding
244 | mask:
245 | [*, N_templ, N_res, N_res] mask
246 | Returns:
247 | [*, N_templ, N_res, N_res, C_t] template embedding update
248 | """
249 |
250 | blocks = [
251 | partial(
252 | b,
253 | mask=mask,
254 | chunk_size=chunk_size,
255 | use_lma=use_lma,
256 | inplace_safe=inplace_safe,
257 | _mask_trans=_mask_trans,
258 | )
259 | for b in self.blocks
260 | ]
261 |
262 | if(chunk_size is not None and self.chunk_size_tuner is not None):
263 | assert(not self.training)
264 | tuned_chunk_size = self.chunk_size_tuner.tune_chunk_size(
265 | representative_fn=blocks[0],
266 | args=(t.clone(),),
267 | min_chunk_size=chunk_size,
268 | )
269 | blocks = [
270 | partial(b,
271 | chunk_size=tuned_chunk_size,
272 | _attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
273 | ) for b in blocks
274 | ]
275 |
276 | t, = checkpoint_blocks(
277 | blocks=blocks,
278 | args=(t,),
279 | blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
280 | )
281 |
282 | t = self.layer_norm(t)
283 |
284 | return t
--------------------------------------------------------------------------------
/alphaflow/model/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | import typing as T
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn.functional as F
10 | from einops import rearrange, repeat
11 | from torch import nn
12 |
13 | class GaussianFourierProjection(nn.Module):
14 | """Gaussian Fourier embeddings for noise levels.
15 | from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/models/layerspp.py#L32
16 | """
17 |
18 | def __init__(self, embedding_size=256, scale=1.0):
19 | super().__init__()
20 | self.W = nn.Parameter(
21 | torch.randn(embedding_size // 2) * scale, requires_grad=False
22 | )
23 |
24 | def forward(self, x):
25 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
26 | emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
27 | return emb
28 |
29 | class Attention(nn.Module):
30 | def __init__(self, embed_dim, num_heads, head_width, gated=False):
31 | super().__init__()
32 | assert embed_dim == num_heads * head_width
33 |
34 | self.embed_dim = embed_dim
35 | self.num_heads = num_heads
36 | self.head_width = head_width
37 |
38 | self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
39 | self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
40 | self.gated = gated
41 | if gated:
42 | self.g_proj = nn.Linear(embed_dim, embed_dim)
43 | torch.nn.init.zeros_(self.g_proj.weight)
44 | torch.nn.init.ones_(self.g_proj.bias)
45 |
46 | self.rescale_factor = self.head_width**-0.5
47 |
48 | torch.nn.init.zeros_(self.o_proj.bias)
49 |
50 | def forward(self, x, mask=None, bias=None, indices=None):
51 | """
52 | Basic self attention with optional mask and external pairwise bias.
53 | To handle sequences of different lengths, use mask.
54 |
55 | Inputs:
56 | x: batch of input sequneces (.. x L x C)
57 | mask: batch of boolean masks where 1=valid, 0=padding position (.. x L_k). optional.
58 | bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads). optional.
59 |
60 | Outputs:
61 | sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads)
62 | """
63 |
64 | t = rearrange(self.proj(x), "... l (h c) -> ... h l c", h=self.num_heads)
65 | q, k, v = t.chunk(3, dim=-1)
66 |
67 | q = self.rescale_factor * q
68 | a = torch.einsum("...qc,...kc->...qk", q, k)
69 |
70 | # Add external attention bias.
71 | if bias is not None:
72 | a = a + rearrange(bias, "... lq lk h -> ... h lq lk")
73 |
74 | # Do not attend to padding tokens.
75 | if mask is not None:
76 | mask = repeat(
77 | mask, "... lk -> ... h lq lk", h=self.num_heads, lq=q.shape[-2]
78 | )
79 | a = a.masked_fill(mask == False, -np.inf)
80 |
81 | a = F.softmax(a, dim=-1)
82 |
83 | y = torch.einsum("...hqk,...hkc->...qhc", a, v)
84 | y = rearrange(y, "... h c -> ... (h c)", h=self.num_heads)
85 |
86 | if self.gated:
87 | y = self.g_proj(x).sigmoid() * y
88 | y = self.o_proj(y)
89 |
90 | return y, rearrange(a, "... lq lk h -> ... h lq lk")
91 |
92 |
93 | class Dropout(nn.Module):
94 | """
95 | Implementation of dropout with the ability to share the dropout mask
96 | along a particular dimension.
97 | """
98 |
99 | def __init__(self, r: float, batch_dim: T.Union[int, T.List[int]]):
100 | super(Dropout, self).__init__()
101 |
102 | self.r = r
103 | if type(batch_dim) == int:
104 | batch_dim = [batch_dim]
105 | self.batch_dim = batch_dim
106 | self.dropout = nn.Dropout(self.r)
107 |
108 | def forward(self, x: torch.Tensor) -> torch.Tensor:
109 | shape = list(x.shape)
110 | if self.batch_dim is not None:
111 | for bd in self.batch_dim:
112 | shape[bd] = 1
113 | return x * self.dropout(x.new_ones(shape))
114 |
115 |
116 | class SequenceToPair(nn.Module):
117 | def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
118 | super().__init__()
119 |
120 | self.layernorm = nn.LayerNorm(sequence_state_dim)
121 | self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)
122 | self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)
123 |
124 | torch.nn.init.zeros_(self.proj.bias)
125 | torch.nn.init.zeros_(self.o_proj.bias)
126 |
127 | def forward(self, sequence_state):
128 | """
129 | Inputs:
130 | sequence_state: B x L x sequence_state_dim
131 |
132 | Output:
133 | pairwise_state: B x L x L x pairwise_state_dim
134 |
135 | Intermediate state:
136 | B x L x L x 2*inner_dim
137 | """
138 |
139 | assert len(sequence_state.shape) == 3
140 |
141 | s = self.layernorm(sequence_state)
142 | s = self.proj(s)
143 | q, k = s.chunk(2, dim=-1)
144 |
145 | prod = q[:, None, :, :] * k[:, :, None, :]
146 | diff = q[:, None, :, :] - k[:, :, None, :]
147 |
148 | x = torch.cat([prod, diff], dim=-1)
149 | x = self.o_proj(x)
150 |
151 | return x
152 |
153 |
154 | class PairToSequence(nn.Module):
155 | def __init__(self, pairwise_state_dim, num_heads):
156 | super().__init__()
157 |
158 | self.layernorm = nn.LayerNorm(pairwise_state_dim)
159 | self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)
160 |
161 | def forward(self, pairwise_state):
162 | """
163 | Inputs:
164 | pairwise_state: B x L x L x pairwise_state_dim
165 |
166 | Output:
167 | pairwise_bias: B x L x L x num_heads
168 | """
169 | assert len(pairwise_state.shape) == 4
170 | z = self.layernorm(pairwise_state)
171 | pairwise_bias = self.linear(z)
172 | return pairwise_bias
173 |
174 |
175 | class ResidueMLP(nn.Module):
176 | def __init__(self, embed_dim, inner_dim, norm=nn.LayerNorm, dropout=0):
177 | super().__init__()
178 |
179 | self.mlp = nn.Sequential(
180 | norm(embed_dim),
181 | nn.Linear(embed_dim, inner_dim),
182 | nn.ReLU(),
183 | nn.Linear(inner_dim, embed_dim),
184 | nn.Dropout(dropout),
185 | )
186 |
187 | def forward(self, x):
188 | return x + self.mlp(x)
189 |
--------------------------------------------------------------------------------
/alphaflow/model/tri_self_attn_block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | import torch
6 | from openfold.model.triangular_attention import (
7 | TriangleAttentionEndingNode,
8 | TriangleAttentionStartingNode,
9 | )
10 | from openfold.model.triangular_multiplicative_update import (
11 | TriangleMultiplicationIncoming,
12 | TriangleMultiplicationOutgoing,
13 | )
14 | from torch import nn
15 |
16 | from .layers import (
17 | Attention,
18 | Dropout,
19 | PairToSequence,
20 | ResidueMLP,
21 | SequenceToPair,
22 | )
23 |
24 |
25 | class TriangularSelfAttentionBlock(nn.Module):
26 | def __init__(
27 | self,
28 | sequence_state_dim,
29 | pairwise_state_dim,
30 | sequence_head_width,
31 | pairwise_head_width,
32 | dropout=0,
33 | **__kwargs,
34 | ):
35 | super().__init__()
36 |
37 | assert sequence_state_dim % sequence_head_width == 0
38 | assert pairwise_state_dim % pairwise_head_width == 0
39 | sequence_num_heads = sequence_state_dim // sequence_head_width
40 | pairwise_num_heads = pairwise_state_dim // pairwise_head_width
41 | assert sequence_state_dim == sequence_num_heads * sequence_head_width
42 | assert pairwise_state_dim == pairwise_num_heads * pairwise_head_width
43 | assert pairwise_state_dim % 2 == 0
44 |
45 | self.sequence_state_dim = sequence_state_dim
46 | self.pairwise_state_dim = pairwise_state_dim
47 |
48 | self.layernorm_1 = nn.LayerNorm(sequence_state_dim)
49 |
50 | self.sequence_to_pair = SequenceToPair(
51 | sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim
52 | )
53 | self.pair_to_sequence = PairToSequence(pairwise_state_dim, sequence_num_heads)
54 |
55 | self.seq_attention = Attention(
56 | sequence_state_dim, sequence_num_heads, sequence_head_width, gated=True
57 | )
58 | self.tri_mul_out = TriangleMultiplicationOutgoing(
59 | pairwise_state_dim,
60 | pairwise_state_dim,
61 | )
62 | self.tri_mul_in = TriangleMultiplicationIncoming(
63 | pairwise_state_dim,
64 | pairwise_state_dim,
65 | )
66 | self.tri_att_start = TriangleAttentionStartingNode(
67 | pairwise_state_dim,
68 | pairwise_head_width,
69 | pairwise_num_heads,
70 | inf=1e9,
71 | ) # type: ignore
72 | self.tri_att_end = TriangleAttentionEndingNode(
73 | pairwise_state_dim,
74 | pairwise_head_width,
75 | pairwise_num_heads,
76 | inf=1e9,
77 | ) # type: ignore
78 |
79 | self.mlp_seq = ResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=dropout)
80 | self.mlp_pair = ResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=dropout)
81 |
82 | assert dropout < 0.4
83 | self.drop = nn.Dropout(dropout)
84 | self.row_drop = Dropout(dropout * 2, 2)
85 | self.col_drop = Dropout(dropout * 2, 1)
86 |
87 | torch.nn.init.zeros_(self.tri_mul_in.linear_z.weight)
88 | torch.nn.init.zeros_(self.tri_mul_in.linear_z.bias)
89 | torch.nn.init.zeros_(self.tri_mul_out.linear_z.weight)
90 | torch.nn.init.zeros_(self.tri_mul_out.linear_z.bias)
91 | torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.weight)
92 | torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.bias)
93 | torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.weight)
94 | torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.bias)
95 |
96 | torch.nn.init.zeros_(self.sequence_to_pair.o_proj.weight)
97 | torch.nn.init.zeros_(self.sequence_to_pair.o_proj.bias)
98 | torch.nn.init.zeros_(self.pair_to_sequence.linear.weight)
99 | torch.nn.init.zeros_(self.seq_attention.o_proj.weight)
100 | torch.nn.init.zeros_(self.seq_attention.o_proj.bias)
101 | torch.nn.init.zeros_(self.mlp_seq.mlp[-2].weight)
102 | torch.nn.init.zeros_(self.mlp_seq.mlp[-2].bias)
103 | torch.nn.init.zeros_(self.mlp_pair.mlp[-2].weight)
104 | torch.nn.init.zeros_(self.mlp_pair.mlp[-2].bias)
105 |
106 | def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs):
107 | """
108 | Inputs:
109 | sequence_state: B x L x sequence_state_dim
110 | pairwise_state: B x L x L x pairwise_state_dim
111 | mask: B x L boolean tensor of valid positions
112 |
113 | Output:
114 | sequence_state: B x L x sequence_state_dim
115 | pairwise_state: B x L x L x pairwise_state_dim
116 | """
117 | assert len(sequence_state.shape) == 3
118 | assert len(pairwise_state.shape) == 4
119 | if mask is not None:
120 | assert len(mask.shape) == 2
121 |
122 | batch_dim, seq_dim, sequence_state_dim = sequence_state.shape
123 | pairwise_state_dim = pairwise_state.shape[3]
124 | assert sequence_state_dim == self.sequence_state_dim
125 | assert pairwise_state_dim == self.pairwise_state_dim
126 | assert batch_dim == pairwise_state.shape[0]
127 | assert seq_dim == pairwise_state.shape[1]
128 | assert seq_dim == pairwise_state.shape[2]
129 |
130 | # Update sequence state
131 | bias = self.pair_to_sequence(pairwise_state)
132 |
133 | # Self attention with bias + mlp.
134 | y = self.layernorm_1(sequence_state)
135 | y, _ = self.seq_attention(y, mask=mask, bias=bias)
136 | sequence_state = sequence_state + self.drop(y)
137 | sequence_state = self.mlp_seq(sequence_state)
138 |
139 | # Update pairwise state
140 | pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state)
141 |
142 | # Axial attention with triangular bias.
143 | tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None
144 | pairwise_state = pairwise_state + self.row_drop(
145 | self.tri_mul_out(pairwise_state, mask=tri_mask)
146 | )
147 | pairwise_state = pairwise_state + self.col_drop(
148 | self.tri_mul_in(pairwise_state, mask=tri_mask)
149 | )
150 | pairwise_state = pairwise_state + self.row_drop(
151 | self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
152 | )
153 | pairwise_state = pairwise_state + self.col_drop(
154 | self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
155 | )
156 |
157 | # MLP over pairs.
158 | pairwise_state = self.mlp_pair(pairwise_state)
159 |
160 | return sequence_state, pairwise_state
161 |
--------------------------------------------------------------------------------
/alphaflow/model/trunk.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | import typing as T
6 | from functools import partial
7 | from openfold.utils.checkpointing import checkpoint_blocks
8 | import torch
9 | import torch.nn as nn
10 | from openfold.model.structure_module import StructureModule
11 | from .tri_self_attn_block import TriangularSelfAttentionBlock
12 |
13 | def get_axial_mask(mask):
14 | """
15 | Helper to convert B x L mask of valid positions to axial mask used
16 | in row column attentions.
17 |
18 | Input:
19 | mask: B x L tensor of booleans
20 |
21 | Output:
22 | mask: B x L x L tensor of booleans
23 | """
24 |
25 | if mask is None:
26 | return None
27 | assert len(mask.shape) == 2
28 | batch_dim, seq_dim = mask.shape
29 | m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim)
30 | m = m.reshape(batch_dim * seq_dim, seq_dim)
31 | return m
32 |
33 |
34 | class RelativePosition(nn.Module):
35 | def __init__(self, bins, pairwise_state_dim):
36 | super().__init__()
37 | self.bins = bins
38 |
39 | # Note an additional offset is used so that the 0th position
40 | # is reserved for masked pairs.
41 | self.embedding = torch.nn.Embedding(2 * bins + 2, pairwise_state_dim)
42 |
43 | def forward(self, residue_index, mask=None):
44 | """
45 | Input:
46 | residue_index: B x L tensor of indices (dytpe=torch.long)
47 | mask: B x L tensor of booleans
48 |
49 | Output:
50 | pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings
51 | """
52 |
53 | assert residue_index.dtype == torch.long
54 | if mask is not None:
55 | assert residue_index.shape == mask.shape
56 |
57 | diff = residue_index[:, None, :] - residue_index[:, :, None]
58 | diff = diff.clamp(-self.bins, self.bins)
59 | diff = diff + self.bins + 1 # Add 1 to adjust for padding index.
60 |
61 | if mask is not None:
62 | mask = mask[:, None, :] * mask[:, :, None]
63 | diff[mask == False] = 0
64 |
65 | output = self.embedding(diff)
66 | return output
67 |
68 |
69 | class FoldingTrunk(nn.Module):
70 | def __init__(self, cfg):
71 | super().__init__()
72 | self.cfg = cfg
73 | # assert self.cfg.max_recycles > 0
74 |
75 | c_s = self.cfg.sequence_state_dim
76 | c_z = self.cfg.pairwise_state_dim
77 |
78 | assert c_s % self.cfg.sequence_head_width == 0
79 | assert c_z % self.cfg.pairwise_head_width == 0
80 | block = TriangularSelfAttentionBlock
81 |
82 | self.pairwise_positional_embedding = RelativePosition(self.cfg.position_bins, c_z)
83 |
84 | self.blocks = nn.ModuleList(
85 | [
86 | block(
87 | sequence_state_dim=c_s,
88 | pairwise_state_dim=c_z,
89 | sequence_head_width=self.cfg.sequence_head_width,
90 | pairwise_head_width=self.cfg.pairwise_head_width,
91 | dropout=self.cfg.dropout,
92 | )
93 | for i in range(self.cfg.num_blocks)
94 | ]
95 | )
96 |
97 | self.recycle_bins = 15
98 | self.recycle_s_norm = nn.LayerNorm(c_s)
99 | self.recycle_z_norm = nn.LayerNorm(c_z)
100 | self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
101 | self.recycle_disto.weight[0].detach().zero_()
102 |
103 |
104 | self.structure_module = StructureModule(**self.cfg.structure_module) # type: ignore
105 | self.trunk2sm_s = nn.Linear(c_s, self.structure_module.c_s)
106 | self.trunk2sm_z = nn.Linear(c_z, self.structure_module.c_z)
107 |
108 | self.chunk_size = self.cfg.chunk_size
109 |
110 | def set_chunk_size(self, chunk_size):
111 | # This parameter means the axial attention will be computed
112 | # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
113 | # It's equivalent to running a for loop over chunks of the dimension we're iterative over,
114 | # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-lengthed chunks.
115 | self.chunk_size = chunk_size
116 |
117 |
118 | def _prep_blocks(self, mask, residue_index, chunk_size):
119 | blocks = [
120 | partial(
121 | b,
122 | mask=mask,
123 | residue_index=residue_index,
124 | chunk_size=chunk_size,
125 | )
126 | for b in self.blocks
127 | ]
128 | return blocks
129 | def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles: T.Optional[int] = None):
130 | """
131 | Inputs:
132 | seq_feats: B x L x C tensor of sequence features
133 | pair_feats: B x L x L x C tensor of pair features
134 | residx: B x L long tensor giving the position in the sequence
135 | mask: B x L boolean tensor indicating valid residues
136 |
137 | Output:
138 | predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
139 | """
140 |
141 | device = seq_feats.device
142 | s_s_0 = seq_feats
143 | s_z_0 = pair_feats
144 |
145 |
146 | def trunk_iter(s, z, residx, mask):
147 | z = z + self.pairwise_positional_embedding(residx, mask=mask)
148 | blocks = self._prep_blocks(mask=mask, residue_index=residx, chunk_size=self.chunk_size)
149 |
150 | s, z = checkpoint_blocks(
151 | blocks,
152 | args=(s, z),
153 | blocks_per_ckpt=1,
154 | )
155 | return s, z
156 |
157 |
158 | s_s, s_z = trunk_iter(s_s_0, s_z_0, residx, mask)
159 |
160 | # === Structure module ===
161 | structure = {}
162 |
163 | structure["sm"] = self.structure_module(
164 | {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
165 | true_aa,
166 | mask.float(),
167 | )
168 |
169 |
170 |
171 | assert isinstance(structure, dict) # type: ignore
172 | structure["s_s"] = s_s
173 | structure["s_z"] = s_z
174 |
175 | return structure
176 |
177 | @staticmethod
178 | def distogram(coords, min_bin, max_bin, num_bins):
179 | # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates.
180 | boundaries = torch.linspace(
181 | min_bin,
182 | max_bin,
183 | num_bins - 1,
184 | device=coords.device,
185 | )
186 | boundaries = boundaries**2
187 | N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
188 | # Infer CB coordinates.
189 | b = CA - N
190 | c = C - CA
191 | a = b.cross(c, dim=-1)
192 | CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
193 | dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
194 | bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L]
195 | return bins
196 |
--------------------------------------------------------------------------------
/alphaflow/utils/diffusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | #https://github.com/scipy/scipy/blob/main/scipy/spatial/transform/_rotation.pyx
5 | def rmsdalign(a, b, weights=None): # alignes B to A # [*, N, 3]
6 | B = a.shape[:-2]
7 | N = a.shape[-2]
8 | if weights == None:
9 | weights = a.new_ones(*B, N)
10 | weights = weights.unsqueeze(-1)
11 | a_mean = (a * weights).sum(-2, keepdims=True) / weights.sum(-2, keepdims=True)
12 | a = a - a_mean
13 | b_mean = (b * weights).sum(-2, keepdims=True) / weights.sum(-2, keepdims=True)
14 | b = b - b_mean
15 | B = torch.einsum('...ji,...jk->...ik', weights * a, b)
16 | u, s, vh = torch.linalg.svd(B)
17 |
18 | # Correct improper rotation if necessary (as in Kabsch algorithm)
19 | '''
20 | if torch.linalg.det(u @ vh) < 0:
21 | s[-1] = -s[-1]
22 | u[:, -1] = -u[:, -1]
23 | '''
24 | sgn = torch.sign(torch.linalg.det(u @ vh))
25 | s[...,-1] *= sgn
26 | u[...,:,-1] *= sgn.unsqueeze(-1)
27 | C = u @ vh # c rotates B to A
28 | return b @ C.mT + a_mean
29 |
30 | def kabsch_rmsd(a, b, weights=None):
31 | B = a.shape[:-2]
32 | N = a.shape[-2]
33 | if weights == None:
34 | weights = a.new_ones(*B, N)
35 | b_aligned = rmsdalign(a, b, weights)
36 | out = torch.square(b_aligned - a).sum(-1)
37 | out = (out * weights).sum(-1) / weights.sum(-1)
38 | return torch.sqrt(out)
39 |
40 | class HarmonicPrior:
41 | def __init__(self, N = 256, a =3/(3.8**2)):
42 | J = torch.zeros(N, N)
43 | for i, j in zip(np.arange(N-1), np.arange(1, N)):
44 | J[i,i] += a
45 | J[j,j] += a
46 | J[i,j] = J[j,i] = -a
47 | D, P = torch.linalg.eigh(J)
48 | D_inv = 1/D
49 | D_inv[0] = 0
50 | self.P, self.D_inv = P, D_inv
51 | self.N = N
52 |
53 | def to(self, device):
54 | self.P = self.P.to(device)
55 | self.D_inv = self.D_inv.to(device)
56 |
57 | def sample(self, batch_dims=()):
58 | return self.P @ (torch.sqrt(self.D_inv)[:,None] * torch.randn(*batch_dims, self.N, 3, device=self.P.device))
59 |
60 |
61 | '''
62 | def transition_matrix(N_bins=1000, X_max=5):
63 | bins = torch.linspace(0, X_max, N_bins+1, dtype=torch.float64)
64 | cbins = (bins[1:] + bins[:-1]) / 2
65 | bw = cbins[1] - cbins[0]
66 | mu = 2 / cbins - cbins
67 | idx = torch.arange(N_bins)
68 | mat = torch.zeros((N_bins, N_bins), dtype=torch.float64)
69 | mat[idx, idx] = -2 / bw**2
70 |
71 | mat[idx[1:], idx[:-1]] = mu[idx[:-1]] / 2 / bw + 1 / bw**2 # M_{i+1,i} = -mu[i]/2
72 | mat[idx[:-1], idx[1:]] = -mu[idx[1:]] / 2 / bw + 1 / bw**2 # M_{i+1,i} = mu[i]/2
73 | mat[idx, idx] -= mat.sum(0) # fix edges
74 |
75 | return mat, bins
76 |
77 | _mat, _bins = transition_matrix()
78 | _D, _Q = torch.linalg.eig(_mat)
79 | _Q_inv = torch.linalg.inv(_Q)
80 | _sigmas = torch.from_numpy(np.load('chain_stats.npy'))
81 |
82 |
83 | def add_noise(dists, residue_index, mask, t, device='cpu'):
84 |
85 | sigmas, Q, D, Q_inv, bins = _sigmas.to(device), _Q.to(device), _D.to(device), _Q_inv.to(device), _bins.to(device)
86 |
87 | mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
88 | # dists = torch.sum((pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2, dim=-1)**0.5
89 | offsets = torch.abs(residue_index.unsqueeze(-1) - residue_index.unsqueeze(-2))
90 | sigmas = sigmas[offsets]
91 | ndists = dists / sigmas * mask
92 |
93 | bindists = (ndists.unsqueeze(-1) > bins).sum(-1)
94 | bindists = torch.clamp(bindists, 0, 999)
95 |
96 | P = ((Q*torch.exp(D*t)) @ Q_inv).T # now we have a row stochatic matrix P_ij = P(i -> j)
97 | probs = P.real[bindists] # this is equivalent to left multiplication by basis e_i
98 |
99 | probs = torch.clamp(probs / probs.sum(-1, keepdims=True), 0, 1)
100 | newbindists = Categorical(probs, validate_args=False).sample()
101 | cbins = (bins[1:] + bins[:-1]) / 2
102 | newdists = cbins[newbindists] * mask * sigmas
103 |
104 | return newdists.float()
105 |
106 | def sample_posterior(orig_dists, noisy_dists, residue_index, mask, s, t, device='cpu'):
107 | sigmas, Q, D, Q_inv, bins = _sigmas.to(device), _Q.to(device), _D.to(device), _Q_inv.to(device), _bins.to(device)
108 | mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
109 |
110 | P_0s = ((Q*torch.exp(D*s)) @ Q_inv).T
111 | P_st = ((Q*torch.exp(D*(t-s))) @ Q_inv).T
112 |
113 | offsets = torch.abs(residue_index.unsqueeze(-1) - residue_index.unsqueeze(-2))
114 | sigmas = sigmas[offsets]
115 |
116 | orig_ndists = orig_dists / sigmas * mask
117 | orig_bindists = (orig_ndists.unsqueeze(-1) > bins).sum(-1)
118 | orig_bindists = torch.clamp(orig_bindists, 0, 999)
119 |
120 | noisy_ndists = noisy_dists / sigmas * mask
121 | noisy_bindists = (noisy_ndists.unsqueeze(-1) > bins).sum(-1)
122 | noisy_bindists = torch.clamp(noisy_bindists, 0, 999)
123 |
124 | probs = P_0s.real[orig_bindists] * P_st.T.real[noisy_bindists]
125 | probs = torch.clamp(probs / probs.sum(-1, keepdims=True), 0, 1)
126 | newbindists = Categorical(probs, validate_args=False).sample()
127 | cbins = (bins[1:] + bins[:-1]) / 2
128 | newdists = cbins[newbindists] * mask * sigmas
129 |
130 | return newdists.float()
131 |
132 | def sample_prior(residue_index, device='cpu'):
133 |
134 | sigmas, Q, D, Q_inv, bins = _sigmas.to(device), _Q.to(device), _D.to(device), _Q_inv.to(device), _bins.to(device)
135 | B, L = residue_index.shape
136 | probs = Q[:,D.real.argmax()].real
137 | probs = torch.clamp(probs / probs.sum(-1, keepdims=True), 0, 1).broadcast_to(B, L, L, 1000)
138 |
139 | offsets = torch.abs(residue_index.unsqueeze(-1) - residue_index.unsqueeze(-2))
140 | sigmas = sigmas[offsets]
141 |
142 | newbindists = Categorical(probs, validate_args=False).sample()
143 |
144 | cbins = (bins[1:] + bins[:-1]) / 2
145 | newdists = cbins[newbindists] * sigmas
146 | return newdists.float()
147 | '''
148 |
--------------------------------------------------------------------------------
/alphaflow/utils/logging.py:
--------------------------------------------------------------------------------
1 | import logging, socket, os
2 |
3 | model_dir = os.environ.get("MODEL_DIR", "./workdir/default")
4 |
5 | class Rank(logging.Filter):
6 | def filter(self, record):
7 | record.global_rank = os.environ.get("GLOBAL_RANK", 0)
8 | record.local_rank = os.environ.get("LOCAL_RANK", 0)
9 | return True
10 |
11 |
12 | def get_logger(name):
13 | logger = logging.Logger(name)
14 | # logger.addFilter(Rank())
15 | level = {"crititical": 50, "error": 40, "warning": 30, "info": 20, "debug": 10}[
16 | os.environ.get("LOGGER_LEVEL", "info")
17 | ]
18 | logger.setLevel(level)
19 |
20 | ch = logging.StreamHandler()
21 | ch.setLevel(logging.INFO)
22 | os.makedirs(model_dir, exist_ok=True)
23 | fh = logging.FileHandler(os.path.join(model_dir, "log.out"))
24 | fh.setLevel(logging.DEBUG)
25 | # formatter = logging.Formatter(f'%(asctime)s [{socket.gethostname()}:%(process)d:%(global_rank)s:%(local_rank)s]
26 | # [%(levelname)s] %(message)s') # (%(name)s)
27 | formatter = logging.Formatter(
28 | f"%(asctime)s [{socket.gethostname()}:%(process)d] [%(levelname)s] %(message)s"
29 | )
30 | ch.setFormatter(formatter)
31 | fh.setFormatter(formatter)
32 | logger.addHandler(ch)
33 | logger.addHandler(fh)
34 | return logger
35 |
36 | def init():
37 | if bool(int(os.environ.get("WANDB_LOGGING", "0"))):
38 | os.makedirs(model_dir, exist_ok=True)
39 | out_file = open(os.path.join(model_dir, "std.out"), 'ab')
40 | os.dup2(out_file.fileno(), 1)
41 | os.dup2(out_file.fileno(), 2)
42 |
43 | init()
--------------------------------------------------------------------------------
/alphaflow/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | import typing as T
6 | import torch
7 | from openfold.np import residue_constants, protein
8 | from openfold.utils.feats import atom14_to_atom37
9 |
10 | def encode_sequence(
11 | seq: str,
12 | residue_index_offset: T.Optional[int] = 512,
13 | chain_linker: T.Optional[str] = "G" * 25,
14 | ) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
15 | if chain_linker is None:
16 | chain_linker = ""
17 | if residue_index_offset is None:
18 | residue_index_offset = 0
19 |
20 | chains = seq.split(":")
21 | seq = chain_linker.join(chains)
22 |
23 | unk_idx = residue_constants.restype_order_with_x["X"]
24 | encoded = torch.tensor(
25 | [residue_constants.restype_order_with_x.get(aa, unk_idx) for aa in seq]
26 | )
27 | residx = torch.arange(len(encoded))
28 |
29 | if residue_index_offset > 0:
30 | start = 0
31 | for i, chain in enumerate(chains):
32 | residx[start : start + len(chain) + len(chain_linker)] += (
33 | i * residue_index_offset
34 | )
35 | start += len(chain) + len(chain_linker)
36 |
37 | linker_mask = torch.ones_like(encoded, dtype=torch.float32)
38 | chain_index = []
39 | offset = 0
40 | for i, chain in enumerate(chains):
41 | if i > 0:
42 | chain_index.extend([i - 1] * len(chain_linker))
43 | chain_index.extend([i] * len(chain))
44 | offset += len(chain)
45 | linker_mask[offset : offset + len(chain_linker)] = 0
46 | offset += len(chain_linker)
47 |
48 | chain_index = torch.tensor(chain_index, dtype=torch.int64)
49 |
50 | return encoded, residx, linker_mask, chain_index
51 |
52 |
53 | def batch_encode_sequences(
54 | sequences: T.Sequence[str],
55 | residue_index_offset: T.Optional[int] = 512,
56 | chain_linker: T.Optional[str] = "G" * 25,
57 | ) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
58 |
59 | aatype_list = []
60 | residx_list = []
61 | linker_mask_list = []
62 | chain_index_list = []
63 | for seq in sequences:
64 | aatype_seq, residx_seq, linker_mask_seq, chain_index_seq = encode_sequence(
65 | seq,
66 | residue_index_offset=residue_index_offset,
67 | chain_linker=chain_linker,
68 | )
69 | aatype_list.append(aatype_seq)
70 | residx_list.append(residx_seq)
71 | linker_mask_list.append(linker_mask_seq)
72 | chain_index_list.append(chain_index_seq)
73 |
74 | aatype = collate_dense_tensors(aatype_list)
75 | mask = collate_dense_tensors(
76 | [aatype.new_ones(len(aatype_seq)) for aatype_seq in aatype_list]
77 | )
78 | residx = collate_dense_tensors(residx_list)
79 | linker_mask = collate_dense_tensors(linker_mask_list)
80 | chain_index_list = collate_dense_tensors(chain_index_list, -1)
81 |
82 | return aatype, mask, residx, linker_mask, chain_index_list
83 |
84 |
85 | def output_to_pdb(output: T.Dict) -> T.List[str]:
86 | """Returns the pbd (file) string from the model given the model output."""
87 | # atom14_to_atom37 must be called first, as it fails on latest numpy if the
88 | # input is a numpy array. It will work if the input is a torch tensor.
89 | final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
90 | output = {k: v.to("cpu").numpy() for k, v in output.items()}
91 | final_atom_positions = final_atom_positions.cpu().numpy()
92 | final_atom_mask = output["atom37_atom_exists"]
93 | pdbs = []
94 | for i in range(output["aatype"].shape[0]):
95 | aa = output["aatype"][i]
96 | pred_pos = final_atom_positions[i]
97 | mask = final_atom_mask[i]
98 | resid = output["residue_index"][i] + 1
99 | pred = protein.Protein(
100 | aatype=aa,
101 | atom_positions=pred_pos,
102 | atom_mask=mask,
103 | residue_index=resid,
104 | b_factors=output["plddt"][i],
105 | chain_index=output["chain_index"][i] if "chain_index" in output else None,
106 | )
107 | pdbs.append(pred)
108 | return pdbs
109 |
110 |
111 | def collate_dense_tensors(
112 | samples: T.List[torch.Tensor], pad_v: float = 0
113 | ) -> torch.Tensor:
114 | """
115 | Takes a list of tensors with the following dimensions:
116 | [(d_11, ..., d_1K),
117 | (d_21, ..., d_2K),
118 | ...,
119 | (d_N1, ..., d_NK)]
120 | and stack + pads them into a single tensor of:
121 | (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
122 | """
123 | if len(samples) == 0:
124 | return torch.Tensor()
125 | if len(set(x.dim() for x in samples)) != 1:
126 | raise RuntimeError(
127 | f"Samples has varying dimensions: {[x.dim() for x in samples]}"
128 | )
129 | (device,) = tuple(set(x.device for x in samples)) # assumes all on same device
130 | max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
131 | result = torch.empty(
132 | len(samples), *max_shape, dtype=samples[0].dtype, device=device
133 | )
134 | result.fill_(pad_v)
135 | for i in range(len(samples)):
136 | result_i = result[i]
137 | t = samples[i]
138 | result_i[tuple(slice(0, k) for k in t.shape)] = t
139 | return result
140 |
141 |
142 |
143 | class CategoricalMixture:
144 | def __init__(self, param, bins=50, start=0, end=1):
145 | # All tensors are of shape ..., bins.
146 | self.logits = param
147 | bins = torch.linspace(
148 | start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype
149 | )
150 | self.v_bins = (bins[:-1] + bins[1:]) / 2
151 |
152 | def log_prob(self, true):
153 | # Shapes are:
154 | # self.probs: ... x bins
155 | # true : ...
156 | true_index = (
157 | (
158 | true.unsqueeze(-1)
159 | - self.v_bins[
160 | [
161 | None,
162 | ]
163 | * true.ndim
164 | ]
165 | )
166 | .abs()
167 | .argmin(-1)
168 | )
169 | nll = self.logits.log_softmax(-1)
170 | return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
171 |
172 | def mean(self):
173 | return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)
174 |
175 |
176 | def categorical_lddt(logits, bins=50):
177 | # Logits are ..., 37, bins.
178 | return CategoricalMixture(logits, bins=bins).mean()
179 |
180 |
181 |
182 |
183 |
--------------------------------------------------------------------------------
/alphaflow/utils/parsing.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import subprocess, os
3 |
4 |
5 | def parse_train_args():
6 | parser = ArgumentParser()
7 |
8 | parser.add_argument("--mode", choices=['esmfold', 'alphafold'], default='alphafold')
9 |
10 | ## Trainer settings
11 | parser.add_argument("--ckpt", type=str, default=None)
12 | parser.add_argument("--restore_weights_only", action='store_true')
13 | parser.add_argument("--validate", action='store_true', default=False)
14 |
15 | ## Epoch settings
16 | parser.add_argument("--epochs", type=int, default=100)
17 | parser.add_argument("--train_epoch_len", type=int, default=40000)
18 | parser.add_argument("--limit_batches", type=int, default=None)
19 | parser.add_argument("--batch_size", type=int, default=1)
20 |
21 | ## Optimization settings
22 | parser.add_argument("--num_workers", type=int, default=8)
23 | parser.add_argument("--check_grad", action="store_true")
24 | parser.add_argument("--accumulate_grad", type=int, default=1)
25 | parser.add_argument("--grad_clip", type=float, default=1.)
26 | parser.add_argument("--lr", type=float, default=1e-3)
27 | parser.add_argument("--no_ema", action='store_true')
28 |
29 | ## Training data
30 | parser.add_argument("--train_data_dir", type=str, default='./data')
31 | parser.add_argument("--pdb_chains", type=str, default='./pdb_chains_msa.csv')
32 | parser.add_argument("--train_msa_dir", type=str, default='./msa_dir')
33 | parser.add_argument("--pdb_clusters", type=str, default='./pdb_clusters')
34 | parser.add_argument("--train_cutoff", type=str, default='2021-10-01')
35 | parser.add_argument("--mmcif_dir", type=str, default='./mmcif_dir')
36 | parser.add_argument("--filter_chains", action='store_true')
37 | parser.add_argument("--sample_train_confs", action='store_true')
38 |
39 | ## Validation data
40 | parser.add_argument("--val_csv", type=str, default='splits/cameo2022.csv')
41 | parser.add_argument("--val_samples", type=int, default=5)
42 | parser.add_argument("--val_msa_dir", type=str, default='./alignment_dir')
43 | parser.add_argument("--sample_val_confs", action='store_true')
44 | parser.add_argument("--num_val_confs", type=int, default=None)
45 | parser.add_argument("--normal_validate", action='store_true')
46 |
47 | ## Flow matching
48 | parser.add_argument("--noise_prob", type=float, default=0.5)
49 | parser.add_argument("--self_cond_prob", type=float, default=0.5)
50 | parser.add_argument("--extra_input", action='store_true')
51 | parser.add_argument("--extra_input_prob", type=float, default=0.5)
52 | parser.add_argument("--first_as_template", action='store_true')
53 | parser.add_argument("--distillation", action='store_true')
54 | parser.add_argument("--distill_self_cond", action='store_true')
55 |
56 | ## Logging args
57 | parser.add_argument("--print_freq", type=int, default=100)
58 | parser.add_argument("--val_freq", type=int, default=1)
59 | parser.add_argument("--ckpt_freq", type=int, default=1)
60 | parser.add_argument("--wandb", action="store_true")
61 | parser.add_argument("--run_name", type=str, default="default")
62 |
63 | args = parser.parse_args()
64 | os.environ["MODEL_DIR"] = os.path.join("workdir", args.run_name)
65 | os.environ["WANDB_LOGGING"] = str(int(args.wandb))
66 | if args.wandb:
67 | if subprocess.check_output(["git", "status", "-s"]):
68 | exit()
69 | args.commit = (
70 | subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
71 | )
72 |
73 | return args
74 |
--------------------------------------------------------------------------------
/alphaflow/utils/protein.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence, Optional
2 | from openfold.np.protein import to_pdb
3 | from openfold.np.protein import from_pdb_string as _from_pdb_string
4 | from openfold.data import mmcif_parsing
5 | from openfold.np import residue_constants
6 | from alphaflow.utils.tensor_utils import tensor_tree_map
7 | import subprocess, tempfile, os, dataclasses
8 | import numpy as np
9 | from Bio import pairwise2
10 |
11 | @dataclasses.dataclass(repr=False)
12 | class Protein:
13 | """Protein structure representation."""
14 |
15 | # Cartesian coordinates of atoms in angstroms. The atom types correspond to
16 | # residue_constants.atom_types, i.e. the first three are N, CA, CB.
17 | atom_positions: np.ndarray # [num_res, num_atom_type, 3]
18 |
19 | aatype: np.ndarray # [num_res]
20 | seqres: str
21 | name: str
22 |
23 | # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
24 | # is present and 0.0 if not. This should be used for loss masking.
25 | atom_mask: np.ndarray # [num_res, num_atom_type]
26 |
27 | # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
28 | residue_index: np.ndarray # [num_res]
29 |
30 | # B-factors, or temperature factors, of each residue (in sq. angstroms units),
31 | # representing the displacement of the residue from its ground truth mean
32 | # value.
33 | b_factors: np.ndarray # [num_res, num_atom_type]
34 |
35 | # Chain indices for multi-chain predictions
36 | chain_index: Optional[np.ndarray] = None
37 |
38 | # Optional remark about the protein. Included as a comment in output PDB
39 | # files
40 | remark: Optional[str] = None
41 |
42 | # Templates used to generate this protein (prediction-only)
43 | parents: Optional[Sequence[str]] = None
44 |
45 | # Chain corresponding to each parent
46 | parents_chain_index: Optional[Sequence[int]] = None
47 |
48 | def __repr__(self):
49 | ca_pos = residue_constants.atom_order["CA"]
50 | present = int(self.atom_mask[..., ca_pos].sum())
51 | total = self.atom_mask.shape[0]
52 | return f"Protein(name={self.name} seqres={self.seqres} residues={present}/{total} b_mean={self.b_factors[...,ca_pos].mean()})"
53 |
54 | def present(self):
55 | ca_pos = residue_constants.atom_order["CA"]
56 | return int(self.atom_mask[..., ca_pos].sum())
57 |
58 | def total(self):
59 | return self.atom_mask.shape[0]
60 |
61 | def output_to_protein(output):
62 | """Returns the pbd (file) string from the model given the model output."""
63 | output = tensor_tree_map(lambda x: x.cpu().numpy(), output)
64 | final_atom_positions = output['final_atom_positions']
65 | final_atom_mask = output["atom37_atom_exists"]
66 | pdbs = []
67 | for i in range(output["aatype"].shape[0]):
68 | unk_idx = residue_constants.restype_order_with_x["X"]
69 | seqres = ''.join(
70 | [residue_constants.restypes[idx] if idx != unk_idx else "X" for idx in output["aatype"][i]]
71 | )
72 | pred = Protein(
73 | name=output['name'][i],
74 | aatype=output["aatype"][i],
75 | seqres=seqres,
76 | atom_positions=final_atom_positions[i],
77 | atom_mask=final_atom_mask[i],
78 | residue_index=output["residue_index"][i] + 1,
79 | b_factors=np.repeat(output["plddt"][i][...,None], residue_constants.atom_type_num, axis=-1),
80 | chain_index=output["chain_index"][i] if "chain_index" in output else None,
81 | )
82 | pdbs.append(pred)
83 | return pdbs
84 |
85 | def from_dict(prot):
86 | name = prot['domain_name'].item().decode(encoding='utf-8')
87 | seq = prot['sequence'].item().decode(encoding='utf-8')
88 | return Protein(
89 | name=name,
90 | aatype=np.nonzero(prot["aatype"])[1],
91 | atom_positions=prot["all_atom_positions"],
92 | seqres=seq,
93 | atom_mask=prot["all_atom_mask"],
94 | residue_index=prot['residue_index'] + 1,
95 | b_factors=np.zeros((len(seq), 37))
96 | )
97 |
98 | def from_pdb_string(pdb_string, name=''):
99 | prot = _from_pdb_string(pdb_string)
100 |
101 | unk_idx = residue_constants.restype_order_with_x["X"]
102 | seqres = ''.join(
103 | [residue_constants.restypes[idx] if idx != unk_idx else "X" for idx in prot.aatype]
104 | )
105 | prot = Protein(
106 | **prot.__dict__,
107 | seqres=seqres,
108 | name=name,
109 | )
110 | return prot
111 |
112 | def from_mmcif_string(mmcif_string, chain, name='', is_author_chain=False):
113 | mmcif_object = mmcif_parsing.parse(file_id = '', mmcif_string=mmcif_string)
114 |
115 | if(mmcif_object.mmcif_object is None):
116 | raise list(mmcif_object.errors.values())[0]
117 |
118 | mmcif_object = mmcif_object.mmcif_object
119 |
120 | atom_coords, atom_mask = mmcif_parsing.get_atom_coords(mmcif_object, chain)
121 | L = atom_coords.shape[0]
122 | seq = mmcif_object.chain_to_seqres[chain]
123 |
124 | unk_idx = residue_constants.restype_order_with_x["X"]
125 | aatype = np.array(
126 | [residue_constants.restype_order_with_x.get(aa, unk_idx) for aa in seq]
127 | )
128 | prot = Protein(
129 | aatype=aatype,
130 | name=name,
131 | seqres=seq,
132 | atom_positions=atom_coords,
133 | atom_mask=atom_mask,
134 | residue_index=np.arange(L) + 1,
135 | b_factors=np.zeros((L, 37)), # maybe replace 37 later
136 | )
137 | return prot
138 |
139 |
140 | def global_metrics(ref_prot, pred_prot, lddt=False, symmetric=False):
141 | if lddt or symmetric:
142 | ref_prot, pred_prot = align_residue_numbering(ref_prot, pred_prot, mask=symmetric)
143 |
144 | f, ref_path = tempfile.mkstemp(); os.close(f)
145 | f, pred_path = tempfile.mkstemp(); os.close(f)
146 | with open(ref_path, 'w') as f:
147 | f.write(to_pdb(ref_prot))
148 | with open(pred_path, 'w') as f:
149 | f.write(to_pdb(pred_prot))
150 |
151 | out = tmscore(ref_path, pred_path)
152 | if lddt:
153 | out['lddt'] = my_lddt_func(ref_path, pred_path)
154 |
155 | os.unlink(ref_path)
156 | os.unlink(pred_path)
157 | return out
158 |
159 | def prots_to_pdb(prots):
160 | ss = ''
161 | for i, prot in enumerate(prots):
162 | ss += f'MODEL {i}\n'
163 | prot = to_pdb(prot)
164 | ss += '\n'.join(prot.split('\n')[1:-2])
165 | ss += '\nENDMDL\n'
166 | return ss
167 |
168 | def align_residue_numbering(prot1, prot2, mask=False):
169 | prot1 = Protein(**prot1.__dict__)
170 | prot2 = Protein(**prot2.__dict__)
171 |
172 | alignment = pairwise2.align.globalxx(prot1.seqres, prot2.seqres)[0]
173 | prot1.residue_index = np.array([i for i, c in enumerate(alignment.seqA) if c != '-'])
174 | prot2.residue_index = np.array([i for i, c in enumerate(alignment.seqB) if c != '-'])
175 |
176 | if mask:
177 | ca_pos = residue_constants.atom_order["CA"]
178 | mask1 = np.zeros(len(alignment.seqA))
179 | mask1[prot1.residue_index[prot1.atom_mask[..., ca_pos] == 1]] = 1
180 | mask2 = np.zeros(len(alignment.seqA))
181 | mask2[prot2.residue_index[prot2.atom_mask[..., ca_pos] == 1]] = 1
182 |
183 | mask = (mask1 == 1) & (mask2 == 1)
184 |
185 | prot1.atom_mask = prot1.atom_mask * mask[prot1.residue_index].reshape(-1, 1)
186 | prot2.atom_mask = prot2.atom_mask * mask[prot2.residue_index].reshape(-1, 1)
187 |
188 | return prot1, prot2
189 | # ca_pos = residue_constants.atom_order["CA"]
190 | # ref_ca = ref_prot.atom_positions[..., ca_pos, :]
191 | # pred_ca = pred_prot.atom_positions[...,ca_pos, :]
192 | # mask = ref_prot.atom_mask[..., ca_pos].astype(bool)
193 | # trans_ca, rms = superimposition._superimpose_np(ref_ca[mask], pred_ca[mask])
194 |
195 | def tmscore(ref_path, pred_path):
196 |
197 |
198 | out = subprocess.check_output(['TMscore', '-seq', pred_path, ref_path],
199 | stderr=open('/dev/null', 'w'))
200 |
201 | start = out.find(b'RMSD')
202 | end = out.find(b'rotation')
203 | out = out[start:end]
204 |
205 | rmsd, _, tm, _, gdt_ts, gdt_ha, _, _ = out.split(b'\n')
206 |
207 | result = {
208 | 'rmsd': float(rmsd.split(b'=')[-1]),
209 | 'tm': float(tm.split(b'=')[1].split()[0]),
210 | 'gdt_ts': float(gdt_ts.split(b'=')[1].split()[0]),
211 | 'gdt_ha': float(gdt_ha.split(b'=')[1].split()[0]),
212 | }
213 | return result
214 |
215 | def drmsd(prot1, prot2, align=False, eps=1e-10):
216 | ca_pos = residue_constants.atom_order["CA"]
217 | if align:
218 | prot1, prot2 = align_residue_numbering(prot1, prot2)
219 | N = max(prot1.residue_index.max(), prot2.residue_index.max())
220 | mask1, mask2 = np.zeros(N), np.zeros(N)
221 | mask1[prot1.residue_index - 1] = prot1.atom_mask[:,ca_pos]
222 | mask2[prot2.residue_index - 1] = prot2.atom_mask[:,ca_pos]
223 | pos1, pos2 = np.zeros((N,3)), np.zeros((N,3))
224 |
225 | pos1[prot1.residue_index - 1] = prot1.atom_positions[:,ca_pos]
226 | pos2[prot2.residue_index - 1] = prot2.atom_positions[:,ca_pos]
227 |
228 | dmat1 = np.sqrt(eps + np.sum((pos1[..., None, :] - pos1[..., None, :, :]) ** 2, axis=-1))
229 | dmat2 = np.sqrt(eps + np.sum((pos2[..., None, :] - pos2[..., None, :, :]) ** 2, axis=-1))
230 |
231 | dists_to_score = mask1 * mask1[:,None] * mask2 * mask2[:,None] * (1.0 - np.eye(N))
232 | score = np.square(dmat1 - dmat2)
233 |
234 | return np.sqrt((score * dists_to_score).sum() / dists_to_score.sum())
235 |
236 | def lddt_ca(prot1, prot2, cutoff=15.0, eps=1e-10, align=False, per_residue=False, symmetric=False):
237 |
238 | ca_pos = residue_constants.atom_order["CA"]
239 |
240 | if align:
241 | prot1, prot2 = align_residue_numbering(prot1, prot2)
242 | N = max(prot1.residue_index.max(), prot2.residue_index.max())
243 | mask1, mask2 = np.zeros(N), np.zeros(N)
244 | mask1[prot1.residue_index - 1] = prot1.atom_mask[:,ca_pos]
245 | mask2[prot2.residue_index - 1] = prot2.atom_mask[:,ca_pos]
246 | pos1, pos2 = np.zeros((N,3)), np.zeros((N,3))
247 |
248 | pos1[prot1.residue_index - 1] = prot1.atom_positions[:,ca_pos]
249 | pos2[prot2.residue_index - 1] = prot2.atom_positions[:,ca_pos]
250 |
251 | # mask1, mask2 = prot1.atom_mask[:,ca_pos], prot2.atom_mask[:,ca_pos]
252 | # pos1, pos2 = prot1.atom_positions[:,ca_pos], prot2.atom_positions[:,ca_pos]
253 |
254 | dmat1 = np.sqrt(eps + np.sum((pos1[..., None, :] - pos1[..., None, :, :]) ** 2, axis=-1))
255 | dmat2 = np.sqrt(eps + np.sum((pos2[..., None, :] - pos2[..., None, :, :]) ** 2, axis=-1))
256 |
257 | if symmetric:
258 | dists_to_score = (dmat1 < cutoff) | (dmat2 < cutoff)
259 | else:
260 | dists_to_score = (dmat1 < cutoff)
261 | dists_to_score = dists_to_score * mask1 * mask1[:,None] * mask2 * mask2[:,None] * (1.0 - np.eye(N))
262 | dist_l1 = np.abs(dmat1 - dmat2)
263 | score = (dist_l1[...,None] < np.array([0.5, 1.0, 2.0, 4.0])).mean(-1)
264 |
265 | if per_residue:
266 | score = (dists_to_score * score).sum(-1) / dists_to_score.sum(-1)
267 | return score[prot1.residue_index - 1]
268 | else:
269 | score = (dists_to_score * score).sum() / dists_to_score.sum()
270 |
271 | return score
272 | def my_lddt_func(ref_path, pred_path):
273 |
274 | out = subprocess.check_output(['lddt', '-c', '-x', pred_path, ref_path],
275 | stderr=open('/dev/null', 'w'))
276 |
277 | result = None
278 | for line in out.split(b'\n'):
279 | if b'Global LDDT score' in line:
280 | result = float(line.split(b':')[-1].strip())
281 |
282 | return result
283 |
--------------------------------------------------------------------------------
/alphaflow/utils/tensor_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 AlQuraishi Laboratory
2 | # Copyright 2021 DeepMind Technologies Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from functools import partial
17 | from typing import List
18 |
19 | import torch
20 | import torch.nn as nn
21 |
22 |
23 | def add(m1, m2, inplace):
24 | # The first operation in a checkpoint can't be in-place, but it's
25 | # nice to have in-place addition during inference. Thus...
26 | if(not inplace):
27 | m1 = m1 + m2
28 | else:
29 | m1 += m2
30 |
31 | return m1
32 |
33 |
34 | def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
35 | zero_index = -1 * len(inds)
36 | first_inds = list(range(len(tensor.shape[:zero_index])))
37 | return tensor.permute(first_inds + [zero_index + i for i in inds])
38 |
39 |
40 | def flatten_final_dims(t: torch.Tensor, no_dims: int):
41 | return t.reshape(t.shape[:-no_dims] + (-1,))
42 |
43 |
44 | def masked_mean(mask, value, dim, eps=1e-4):
45 | mask = mask.expand(*value.shape)
46 | return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
47 |
48 |
49 | def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
50 | boundaries = torch.linspace(
51 | min_bin, max_bin, no_bins - 1, device=pts.device
52 | )
53 | dists = torch.sqrt(
54 | torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)
55 | )
56 | return torch.bucketize(dists, boundaries)
57 |
58 |
59 | def dict_multimap(fn, dicts):
60 | first = dicts[0]
61 | new_dict = {}
62 | for k, v in first.items():
63 | all_v = [d[k] for d in dicts]
64 | if type(v) is dict:
65 | new_dict[k] = dict_multimap(fn, all_v)
66 | else:
67 | new_dict[k] = fn(all_v)
68 |
69 | return new_dict
70 |
71 |
72 | def one_hot(x, v_bins):
73 | reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
74 | diffs = x[..., None] - reshaped_bins
75 | am = torch.argmin(torch.abs(diffs), dim=-1)
76 | return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
77 |
78 |
79 | def batched_gather(data, inds, dim=0, no_batch_dims=0):
80 | ranges = []
81 | for i, s in enumerate(data.shape[:no_batch_dims]):
82 | r = torch.arange(s)
83 | r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
84 | ranges.append(r)
85 |
86 | remaining_dims = [
87 | slice(None) for _ in range(len(data.shape) - no_batch_dims)
88 | ]
89 | remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
90 | ranges.extend(remaining_dims)
91 | return data[ranges]
92 |
93 |
94 | # With tree_map, a poor man's JAX tree_map
95 | def dict_map(fn, dic, leaf_type):
96 | new_dict = {}
97 | for k, v in dic.items():
98 | if type(v) is dict:
99 | new_dict[k] = dict_map(fn, v, leaf_type)
100 | else:
101 | new_dict[k] = tree_map(fn, v, leaf_type)
102 |
103 | return new_dict
104 |
105 |
106 | def tree_map(fn, tree, leaf_type):
107 | if isinstance(tree, dict):
108 | return dict_map(fn, tree, leaf_type)
109 | elif isinstance(tree, list):
110 | return [tree_map(fn, x, leaf_type) for x in tree]
111 | elif isinstance(tree, tuple):
112 | return tuple([tree_map(fn, x, leaf_type) for x in tree])
113 | elif isinstance(tree, leaf_type):
114 | return fn(tree)
115 | else:
116 | return tree
117 |
118 | tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
119 |
--------------------------------------------------------------------------------
/assets/12l_md_templates.md:
--------------------------------------------------------------------------------
1 | ### Comparison of 48-layer and 12-layer AlphaFlow-MD+Templates
2 | | |48l (base)|48l (distilled)|12l (base)|12l (distilled)|
3 | |:-|:-:|:-:|:-:|:-:|
4 | | Pariwise RMSD | 2.18 | 1.73 | 1.94 | 1.40
5 | | Pairwise RMSD $r$ | 0.94 | 0.92 | 0.81 | 0.76
6 | | All-atom RMSF | 1.31 | 1.00 | 1.01 | 0.76
7 | | Global RMSF $r$ | 0.91 | 0.89 | 0.78 | 0.74
8 | | Per-target RMSF $r$ | 0.90 | 0.88 | 0.89 | 0.86
9 | | Root mean $\mathcal{W}_2$-dist | 1.95 | 2.18 | 2.26 | 2.43
10 | | MD PCA $\mathcal{W}_2$-dist | 1.25 | 1.41 | 1.40 | 1.56
11 | | Joint PCA $\mathcal{W}_2$-dist | 1.58 | 1.68 | 1.78 | 1.90
12 | | % PC-sim > 0.5 | 44 | 43 | 46 | 39
13 | | Weak contacts $J$ | 0.62 | 0.51 | 0.60 | 0.56
14 | | Transient contacts $J$ | 0.47 | 0.42 | 0.36 | 0.24
15 | | Exposed residue $J$ | 0.50 | 0.47 | 0.47 | 0.44
16 | | Exposed MI matrix $\rho$ | 0.25 | 0.18 | 0.21 | 0.13
17 | | **Runtime (s)** | 38 | 3.8 | 15.2 | 1.56
--------------------------------------------------------------------------------
/assets/6uof_A_animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bjing2016/alphaflow/02dc03763a016949326c2c741e6e33094f9250fd/assets/6uof_A_animation.gif
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | parser = argparse.ArgumentParser()
3 | parser.add_argument('--input_csv', type=str, default='splits/transporters_only.csv')
4 | parser.add_argument('--templates_dir', type=str, default=None)
5 | parser.add_argument('--msa_dir', type=str, default='./alignment_dir')
6 | parser.add_argument('--mode', choices=['alphafold', 'esmfold'], default='alphafold')
7 | parser.add_argument('--samples', type=int, default=10)
8 | parser.add_argument('--steps', type=int, default=10)
9 | parser.add_argument('--outpdb', type=str, default='./outpdb/default')
10 | parser.add_argument('--weights', type=str, default=None)
11 | parser.add_argument('--ckpt', type=str, default=None)
12 | parser.add_argument('--original_weights', action='store_true')
13 | parser.add_argument('--pdb_id', nargs='*', default=[])
14 | parser.add_argument('--subsample', type=int, default=None)
15 | parser.add_argument('--resample', action='store_true')
16 | parser.add_argument('--tmax', type=float, default=1.0)
17 | parser.add_argument('--no_diffusion', action='store_true', default=False)
18 | parser.add_argument('--self_cond', action='store_true', default=False)
19 | parser.add_argument('--noisy_first', action='store_true', default=False)
20 | parser.add_argument('--runtime_json', type=str, default=None)
21 | parser.add_argument('--no_overwrite', action='store_true', default=False)
22 | args = parser.parse_args()
23 |
24 | import torch, tqdm, os, wandb, json, time
25 | import pandas as pd
26 | import pytorch_lightning as pl
27 | import numpy as np
28 | from collections import defaultdict
29 | from alphaflow.data.data_modules import collate_fn
30 | from alphaflow.model.wrapper import AlphaFoldWrapper, ESMFoldWrapper
31 | from alphaflow.utils.tensor_utils import tensor_tree_map
32 | import alphaflow.utils.protein as protein
33 | from alphaflow.data.inference import AlphaFoldCSVDataset, CSVDataset
34 | from collections import defaultdict
35 | from openfold.utils.import_weights import import_jax_weights_
36 | from alphaflow.config import model_config
37 |
38 | from alphaflow.utils.logging import get_logger
39 | logger = get_logger(__name__)
40 | torch.set_float32_matmul_precision("high")
41 |
42 | config = model_config(
43 | 'initial_training',
44 | train=True,
45 | low_prec=True
46 | )
47 | schedule = np.linspace(args.tmax, 0, args.steps+1)
48 | if args.tmax != 1.0:
49 | schedule = np.array([1.0] + list(schedule))
50 | loss_cfg = config.loss
51 | data_cfg = config.data
52 | data_cfg.common.use_templates = False
53 | data_cfg.common.max_recycling_iters = 0
54 |
55 | if args.subsample: # https://elifesciences.org/articles/75751#s3
56 | data_cfg.predict.max_msa_clusters = args.subsample // 2
57 | data_cfg.predict.max_extra_msa = args.subsample
58 |
59 | @torch.no_grad()
60 | def main():
61 |
62 | valset = {
63 | 'alphafold': AlphaFoldCSVDataset,
64 | 'esmfold': CSVDataset,
65 | }[args.mode](
66 | data_cfg,
67 | args.input_csv,
68 | msa_dir=args.msa_dir,
69 | templates_dir=args.templates_dir,
70 | )
71 | # valset[0]
72 | logger.info("Loading the model")
73 | model_class = {'alphafold': AlphaFoldWrapper, 'esmfold': ESMFoldWrapper}[args.mode]
74 |
75 | if args.weights:
76 | ckpt = torch.load(args.weights, map_location='cpu')
77 | model = model_class(**ckpt['hyper_parameters'], training=False)
78 | model.model.load_state_dict(ckpt['params'], strict=False)
79 | model = model.cuda()
80 |
81 |
82 | elif args.original_weights:
83 | model = model_class(config, None, training=False)
84 | if args.mode == 'esmfold':
85 | path = "esmfold_3B_v1.pt"
86 | model_data = torch.load(path, map_location='cpu')
87 | model_state = model_data["model"]
88 | model.model.load_state_dict(model_state, strict=False)
89 | model = model.to(torch.float).cuda()
90 |
91 | elif args.mode == 'alphafold':
92 | import_jax_weights_(model.model, 'params_model_1.npz', version='model_3')
93 | model = model.cuda()
94 |
95 | else:
96 | model = model_class.load_from_checkpoint(args.ckpt, map_location='cpu')
97 | model.load_ema_weights()
98 | model = model.cuda()
99 | model.eval()
100 |
101 | logger.info("Model has been loaded")
102 |
103 | results = defaultdict(list)
104 | os.makedirs(args.outpdb, exist_ok=True)
105 | runtime = defaultdict(list)
106 | for i, item in enumerate(valset):
107 | if args.pdb_id and item['name'] not in args.pdb_id:
108 | continue
109 | if args.no_overwrite and os.path.exists(f'{args.outpdb}/{item["name"]}.pdb'):
110 | continue
111 | result = []
112 | for j in tqdm.trange(args.samples):
113 | if args.subsample or args.resample:
114 | item = valset[i] # resample MSA
115 |
116 | batch = collate_fn([item])
117 | batch = tensor_tree_map(lambda x: x.cuda(), batch)
118 | start = time.time()
119 | prots = model.inference(batch, as_protein=True, noisy_first=args.noisy_first,
120 | no_diffusion=args.no_diffusion, schedule=schedule, self_cond=args.self_cond)
121 | runtime[item['name']].append(time.time() - start)
122 | result.append(prots[-1])
123 |
124 |
125 |
126 | with open(f'{args.outpdb}/{item["name"]}.pdb', 'w') as f:
127 | f.write(protein.prots_to_pdb(result))
128 |
129 | if args.runtime_json:
130 | with open(args.runtime_json, 'w') as f:
131 | f.write(json.dumps(dict(runtime)))
132 | if __name__ == "__main__":
133 | main()
134 |
--------------------------------------------------------------------------------
/scripts/add_msa_info.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import os, tqdm
3 | from collections import defaultdict
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument('--openfold_dir', type=str, required=True)
7 | parser.add_argument('--incsv', type=str, default='pdb_chains.csv')
8 | parser.add_argument('--outcsv', type=str, default='pdb_chains_msa.csv')
9 | args = parser.parse_args()
10 |
11 | msas = os.listdir(f'{args.openfold_dir}/pdb')
12 |
13 | msa_queries = {}
14 | for pdb_id in tqdm.tqdm(msas):
15 | a3m_paths = [
16 | f'{args.openfold_dir}/pdb/{pdb_id}/a3m/bfd_uniclust_hits.a3m',
17 | f'{args.openfold_dir}/pdb/{pdb_id}/a3m/mgnify_hits.a3m',
18 | f'{args.openfold_dir}/pdb/{pdb_id}/a3m/uniref90_hits.a3m',
19 | ]
20 | for a3m_path in a3m_paths:
21 | if os.path.exists(a3m_path):
22 | break
23 | with open(a3m_path) as f:
24 | _ = next(f)
25 | msa_queries[pdb_id] = next(f).strip()
26 |
27 | msa_queries = pd.Series(msa_queries)
28 |
29 | df = pd.read_csv(args.incsv)
30 | msa_id = [None]*len(df)
31 |
32 | freqs = defaultdict(int)
33 | done, skipped = 0, 0
34 | for seqres, sub_df in tqdm.tqdm(df.groupby('seqres')):
35 | freqs[len(sub_df)] += 1
36 | found = list(filter(lambda n: n in msa_queries.index, sub_df.name))
37 | if not found:
38 | skipped += 1
39 | continue
40 | done += 1
41 | if len(found) == 1:
42 | if seqres == msa_queries.loc[found[0]]:
43 | for idx in sub_df.index: msa_id[idx] = found[0]
44 | else:
45 | print('Mismatch', found[0])
46 | if len(found) != 1:
47 | print('Found multiple', found)
48 | for pdb_id in found:
49 | if seqres == msa_queries.loc[pdb_id]:
50 | for idx in sub_df.index: msa_id[idx] = found[0]
51 | print('Match', pdb_id)
52 | else:
53 | print('Mismatch', pdb_id)
54 |
55 | df['msa_id'] = msa_id
56 | df[~df.msa_id.isnull()].set_index('name').to_csv(args.outcsv)
57 |
58 |
--------------------------------------------------------------------------------
/scripts/analyze_ensembles.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | parser = argparse.ArgumentParser()
3 | parser.add_argument('--atlas_dir', type=str, required=True)
4 | parser.add_argument('--pdbdir', type=str, required=True)
5 | parser.add_argument('--pdb_id', nargs='*', default=[])
6 | parser.add_argument('--bb_only', action='store_true')
7 | parser.add_argument('--ca_only', action='store_true')
8 | parser.add_argument('--num_workers', type=int, default=1)
9 |
10 | args = parser.parse_args()
11 | from sklearn.decomposition import PCA
12 | import mdtraj, pickle, tqdm, os
13 | import pandas as pd
14 | import numpy as np
15 | import matplotlib.pyplot as plt
16 | from multiprocessing import Pool
17 | from scipy.optimize import linear_sum_assignment
18 |
19 | def get_pca(xyz):
20 | traj_reshaped = xyz.reshape(xyz.shape[0], -1)
21 | pca = PCA(n_components=min(traj_reshaped.shape))
22 | coords = pca.fit_transform(traj_reshaped)
23 | return pca, coords
24 |
25 | def get_rmsds(traj1, traj2, broadcast=False):
26 | n_atoms = traj1.shape[1]
27 | traj1 = traj1.reshape(traj1.shape[0], n_atoms * 3)
28 | traj2 = traj2.reshape(traj2.shape[0], n_atoms * 3)
29 | if broadcast:
30 | traj1, traj2 = traj1[:,None], traj2[None]
31 | distmat = np.square(traj1 - traj2).sum(-1)**0.5 / n_atoms**0.5 * 10
32 | return distmat
33 |
34 | def condense_sidechain_sasas(sasas, top):
35 | assert top.n_residues > 1
36 |
37 | if top.n_atoms != sasas.shape[1]:
38 | raise exception.DataInvalid(
39 | f"The number of atoms in top ({top.n_atoms}) didn't match the "
40 | f"number of SASAs provided ({sasas.shape[1]}). Make sure you "
41 | f"computed atom-level SASAs (mode='atom') and that you've passed "
42 | "the correct topology file and array of SASAs"
43 | )
44 |
45 | sc_mask = np.array([a.name not in ['CA', 'C', 'N', 'O', 'OXT'] for a in top.atoms])
46 | res_id = np.array([a.residue.index for a in top.atoms])
47 |
48 | rsd_sasas = np.zeros((sasas.shape[0], top.n_residues), dtype='float32')
49 |
50 | for i in range(top.n_residues):
51 | rsd_sasas[:, i] = sasas[:, sc_mask & (res_id == i)].sum(1)
52 | return rsd_sasas
53 |
54 | def sasa_mi(sasa):
55 | N, L = sasa.shape
56 | joint_probs = np.zeros((L, L, 2, 2))
57 |
58 | joint_probs[:,:,1,1] = (sasa[:,:,None] & sasa[:,None,:]).mean(0)
59 | joint_probs[:,:,1,0] = (sasa[:,:,None] & ~sasa[:,None,:]).mean(0)
60 | joint_probs[:,:,0,1] = (~sasa[:,:,None] & sasa[:,None,:]).mean(0)
61 | joint_probs[:,:,0,0] = (~sasa[:,:,None] & ~sasa[:,None,:]).mean(0)
62 |
63 | marginal_probs = np.stack([1-sasa.mean(0), sasa.mean(0)], -1)
64 | indep_probs = marginal_probs[None,:,None,:] * marginal_probs[:,None,:,None]
65 | mi = np.nansum(joint_probs * np.log(joint_probs / indep_probs), (-1, -2))
66 | mi[np.arange(L), np.arange(L)] = 0
67 | return mi
68 |
69 |
70 | def get_mean_covar(xyz):
71 | mean = xyz.mean(0)
72 | xyz = xyz - mean
73 | covar = (xyz[...,None] * xyz[...,None,:]).mean(0)
74 | return mean, covar
75 |
76 |
77 | def sqrtm(M):
78 | D, P = np.linalg.eig(M)
79 | out = (P * np.sqrt(D[:,None])) @ np.linalg.inv(P)
80 | return out
81 |
82 |
83 | def get_wasserstein(distmat, p=2):
84 | assert distmat.shape[0] == distmat.shape[1]
85 | distmat = distmat ** p
86 | row_ind, col_ind = linear_sum_assignment(distmat)
87 | return distmat[row_ind, col_ind].mean() ** (1/p)
88 |
89 | def align_tops(top1, top2):
90 | names1 = [repr(a) for a in top1.atoms]
91 | names2 = [repr(a) for a in top2.atoms]
92 |
93 | intersection = [nam for nam in names1 if nam in names2]
94 |
95 | mask1 = [names1.index(nam) for nam in intersection]
96 | mask2 = [names2.index(nam) for nam in intersection]
97 | return mask1, mask2
98 |
99 | def main(name):
100 | print('Analyzing', name)
101 | out = {}
102 | ref_aa = mdtraj.load(f'{args.atlas_dir}/{name}/{name}.pdb')
103 | topfile = f'{args.atlas_dir}/{name}/{name}.pdb'
104 | print('Loading reference trajectory')
105 | traj_aa = mdtraj.load(f'{args.atlas_dir}/{name}/{name}_prod_R1_fit.xtc', top=topfile) \
106 | + mdtraj.load(f'{args.atlas_dir}/{name}/{name}_prod_R2_fit.xtc', top=topfile) \
107 | + mdtraj.load(f'{args.atlas_dir}/{name}/{name}_prod_R3_fit.xtc', top=topfile)
108 | print(f'Loaded {traj_aa.n_frames} reference frames')
109 |
110 | print('Loading AF2 conformers')
111 | aftraj_aa = mdtraj.load(f'{args.pdbdir}/{name}.pdb')
112 |
113 | print(f'Loaded {aftraj_aa.n_frames} AF2 conformers')
114 | print(f'Reference has {traj_aa.n_atoms} atoms')
115 | print(f'Crystal has {ref_aa.n_atoms} atoms')
116 | print(f'AF has {aftraj_aa.n_atoms} atoms')
117 |
118 | print('Removing hydrogens')
119 |
120 | traj_aa.atom_slice([a.index for a in traj_aa.top.atoms if a.element.symbol != 'H'], True)
121 | ref_aa.atom_slice([a.index for a in ref_aa.top.atoms if a.element.symbol != 'H'], True)
122 | aftraj_aa.atom_slice([a.index for a in aftraj_aa.top.atoms if a.element.symbol != 'H'], True)
123 |
124 | print(f'Reference has {traj_aa.n_atoms} atoms')
125 | print(f'Crystal has {ref_aa.n_atoms} atoms')
126 | print(f'AF has {aftraj_aa.n_atoms} atoms')
127 |
128 | if args.bb_only:
129 | print('Removing sidechains')
130 | aftraj_aa.atom_slice([a.index for a in aftraj_aa.top.atoms if a.name in ['CA', 'C', 'N', 'O', 'OXT']], True)
131 | print(f'AF has {aftraj_aa.n_atoms} atoms')
132 |
133 | elif args.ca_only:
134 | print('Removing sidechains')
135 | aftraj_aa.atom_slice([a.index for a in aftraj_aa.top.atoms if a.name == 'CA'], True)
136 | print(f'AF has {aftraj_aa.n_atoms} atoms')
137 |
138 |
139 | refmask, afmask = align_tops(traj_aa.top, aftraj_aa.top)
140 | traj_aa.atom_slice(refmask, True)
141 | ref_aa.atom_slice(refmask, True)
142 | aftraj_aa.atom_slice(afmask, True)
143 |
144 | print(f'Aligned on {aftraj_aa.n_atoms} atoms')
145 |
146 | np.random.seed(137)
147 | RAND1 = np.random.randint(0, traj_aa.n_frames, aftraj_aa.n_frames)
148 | RAND2 = np.random.randint(0, traj_aa.n_frames, aftraj_aa.n_frames)
149 | RAND1K = np.random.randint(0, traj_aa.n_frames, 1000)
150 |
151 | traj_aa.superpose(ref_aa)
152 | aftraj_aa.superpose(ref_aa)
153 |
154 | out['ca_mask'] = ca_mask = [a.index for a in traj_aa.top.atoms if a.name == 'CA']
155 | traj = traj_aa.atom_slice(ca_mask, False)
156 | ref = ref_aa.atom_slice(ca_mask, False)
157 | aftraj = aftraj_aa.atom_slice(ca_mask, False)
158 | print(f'Sliced {aftraj.n_atoms} C-alphas')
159 |
160 | traj.superpose(ref)
161 | aftraj.superpose(ref)
162 |
163 |
164 | n_atoms = aftraj.n_atoms
165 |
166 | print(f'Doing PCA')
167 |
168 | ref_pca, ref_coords = get_pca(traj.xyz)
169 | af_coords_ref_pca = ref_pca.transform(aftraj.xyz.reshape(aftraj.n_frames, -1))
170 | seed_coords_ref_pca = ref_pca.transform(ref.xyz.reshape(1, -1))
171 |
172 | af_pca, af_coords = get_pca(aftraj.xyz)
173 | ref_coords_af_pca = af_pca.transform(traj.xyz.reshape(traj.n_frames, -1))
174 | seed_coords_af_pca = af_pca.transform(ref.xyz.reshape(1, -1))
175 |
176 | joint_pca, _ = get_pca(np.concatenate([traj[RAND1].xyz, aftraj.xyz]))
177 | af_coords_joint_pca = joint_pca.transform(aftraj.xyz.reshape(aftraj.n_frames, -1))
178 | ref_coords_joint_pca = joint_pca.transform(traj.xyz.reshape(traj.n_frames, -1))
179 | seed_coords_joint_pca = joint_pca.transform(ref.xyz.reshape(1, -1))
180 |
181 | out['ref_variance'] = ref_pca.explained_variance_ / n_atoms * 100
182 | out['af_variance'] = af_pca.explained_variance_ / n_atoms * 100
183 | out['joint_variance'] = joint_pca.explained_variance_ / n_atoms * 100
184 |
185 | out['af_rmsf'] = mdtraj.rmsf(aftraj_aa, ref_aa) * 10
186 | out['ref_rmsf'] = mdtraj.rmsf(traj_aa, ref_aa) * 10
187 |
188 | print(f'Computing atomic EMD')
189 | ref_mean, ref_covar = get_mean_covar(traj_aa[RAND1K].xyz)
190 | af_mean, af_covar = get_mean_covar(aftraj_aa.xyz)
191 | out['emd_mean'] = (np.square(ref_mean - af_mean).sum(-1) ** 0.5) * 10
192 | try:
193 | out['emd_var'] = (np.trace(ref_covar + af_covar - 2*sqrtm(ref_covar @ af_covar), axis1=1,axis2=2) ** 0.5) * 10
194 | except:
195 | out['emd_var'] = np.trace(ref_covar) ** 0.5 * 10
196 |
197 |
198 | print(f'Analyzing SASA')
199 | sasa_thresh = 0.02
200 | af_sasa = mdtraj.shrake_rupley(aftraj_aa, probe_radius=0.28)
201 | af_sasa = condense_sidechain_sasas(af_sasa, aftraj_aa.top)
202 | ref_sasa = mdtraj.shrake_rupley(traj_aa[RAND1K], probe_radius=0.28)
203 | ref_sasa = condense_sidechain_sasas(ref_sasa, traj_aa.top)
204 | crystal_sasa = mdtraj.shrake_rupley(ref_aa, probe_radius=0.28)
205 | out['crystal_sasa'] = condense_sidechain_sasas(crystal_sasa, ref_aa.top)
206 |
207 | out['ref_sa_prob'] = (ref_sasa > sasa_thresh).mean(0)
208 | out['af_sa_prob'] = (af_sasa > sasa_thresh).mean(0)
209 | out['ref_mi_mat'] = sasa_mi(ref_sasa > sasa_thresh)
210 | out['af_mi_mat'] = sasa_mi(af_sasa > sasa_thresh)
211 |
212 | ref_distmat = np.linalg.norm(traj[RAND1].xyz[:,None,:] - traj[RAND1].xyz[:,:,None], axis=-1)
213 | af_distmat = np.linalg.norm(aftraj.xyz[:,None,:] - aftraj.xyz[:,:,None], axis=-1)
214 |
215 | out['ref_contact_prob'] = (ref_distmat < 0.8).mean(0)
216 | out['af_contact_prob'] = (af_distmat < 0.8).mean(0)
217 | out['crystal_distmat'] = np.linalg.norm(ref.xyz[0,None,:] - ref.xyz[0,:,None], axis=-1)
218 |
219 | out['ref_mean_pairwise_rmsd'] = get_rmsds(traj[RAND1].xyz, traj[RAND2].xyz, broadcast=True).mean()
220 | out['af_mean_pairwise_rmsd'] = get_rmsds(aftraj.xyz, aftraj.xyz, broadcast=True).mean()
221 |
222 | out['ref_rms_pairwise_rmsd'] = np.square(get_rmsds(traj[RAND1].xyz, traj[RAND2].xyz, broadcast=True)).mean() ** 0.5
223 | out['af_rms_pairwise_rmsd'] = np.square(get_rmsds(aftraj.xyz, aftraj.xyz, broadcast=True)).mean() ** 0.5
224 |
225 | out['ref_self_mean_pairwise_rmsd'] = get_rmsds(traj[RAND1].xyz, traj[RAND1].xyz, broadcast=True).mean()
226 | out['ref_self_rms_pairwise_rmsd'] = np.square(get_rmsds(traj[RAND1].xyz, traj[RAND1].xyz, broadcast=True)).mean() ** 0.5
227 |
228 | out['cosine_sim'] = (ref_pca.components_[0] * af_pca.components_[0]).sum()
229 |
230 |
231 | def get_emd(ref_coords1, ref_coords2, af_coords, seed_coords, K=None):
232 | if len(ref_coords1.shape) == 3:
233 | ref_coords1 = ref_coords1.reshape(ref_coords1.shape[0], -1)
234 | ref_coords2 = ref_coords2.reshape(ref_coords2.shape[0], -1)
235 | af_coords = af_coords.reshape(af_coords.shape[0], -1)
236 | seed_coords = seed_coords.reshape(seed_coords.shape[0], -1)
237 | if K is not None:
238 | ref_coords1 = ref_coords1[:,:K]
239 | ref_coords2 = ref_coords2[:,:K]
240 | af_coords = af_coords[:,:K]
241 | seed_coords = seed_coords[:,:K]
242 | emd = {}
243 | emd['ref|ref mean'] = (np.square(ref_coords1 - ref_coords1.mean(0)).sum(-1)).mean()**0.5 / n_atoms ** 0.5 * 10
244 |
245 | distmat = np.square(ref_coords1[:,None] - ref_coords2[None]).sum(-1)
246 | distmat = distmat ** 0.5 / n_atoms ** 0.5 * 10
247 | emd['ref|ref2'] = get_wasserstein(distmat)
248 | emd['ref mean|ref2 mean'] = np.square(ref_coords1.mean(0) - ref_coords2.mean(0)).sum() ** 0.5 / n_atoms ** 0.5 * 10
249 |
250 | distmat = np.square(ref_coords1[:,None] - af_coords[None]).sum(-1)
251 | distmat = distmat ** 0.5 / n_atoms ** 0.5 * 10
252 | emd['ref|af'] = get_wasserstein(distmat)
253 | emd['ref mean|af mean'] = np.square(ref_coords1.mean(0) - af_coords.mean(0)).sum() ** 0.5 / n_atoms ** 0.5 * 10
254 |
255 | emd['ref|seed'] = (np.square(ref_coords1 - seed_coords).sum(-1)).mean()**0.5 / n_atoms ** 0.5 * 10
256 | emd['ref mean|seed'] = (np.square(ref_coords1.mean(0) - seed_coords).sum(-1)).mean()**0.5 / n_atoms ** 0.5 * 10
257 |
258 | emd['af|seed'] = (np.square(af_coords - seed_coords).sum(-1)).mean()**0.5 / n_atoms ** 0.5 * 10
259 | emd['af|af mean'] = (np.square(af_coords - af_coords.mean(0)).sum(-1)).mean()**0.5 / n_atoms ** 0.5 * 10
260 | emd['af mean|seed'] = (np.square(af_coords.mean(0) - seed_coords).sum(-1)).mean()**0.5 / n_atoms ** 0.5 * 10
261 | return emd
262 |
263 | K=2
264 | out[f'EMD,ref'] = get_emd(ref_coords[RAND1], ref_coords[RAND2], af_coords_ref_pca, seed_coords_ref_pca, K=K)
265 | out[f'EMD,af2'] = get_emd(ref_coords_af_pca[RAND1], ref_coords_af_pca[RAND2], af_coords, seed_coords_af_pca, K=K)
266 | out[f'EMD,joint'] = get_emd(ref_coords_joint_pca[RAND1], ref_coords_joint_pca[RAND2], af_coords_joint_pca, seed_coords_joint_pca, K=K)
267 | return name, out
268 |
269 |
270 | if args.pdb_id:
271 | pdb_id = args.pdb_id
272 | else:
273 | pdb_id = [nam.split('.')[0] for nam in os.listdir(args.pdbdir) if '.pdb' in nam]
274 |
275 | if args.num_workers > 1:
276 | p = Pool(args.num_workers)
277 | p.__enter__()
278 | __map__ = p.imap
279 | else:
280 | __map__ = map
281 | out = dict(tqdm.tqdm(__map__(main, pdb_id), total=len(pdb_id)))
282 | if args.num_workers > 1:
283 | p.__exit__(None, None, None)
284 |
285 | with open(f"{args.pdbdir}/out.pkl", 'wb') as f:
286 | f.write(pickle.dumps(out))
--------------------------------------------------------------------------------
/scripts/cluster_chains.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser()
4 | parser.add_argument('--chains', type=str, default='pdb_mmcif.csv')
5 | parser.add_argument('--out', type=str, default='pdb_clusters')
6 | parser.add_argument('--thresh', type=float, default=0.4)
7 | parser.add_argument('--mmseqs_path', type=str, default='mmseqs')
8 | args = parser.parse_args()
9 |
10 | import pandas as pd
11 | import os, json, tqdm, random, subprocess, pickle
12 | from collections import defaultdict
13 | from functools import partial
14 | import numpy as np
15 | from Bio.Seq import Seq
16 | from Bio.SeqRecord import SeqRecord
17 | from Bio import SeqIO
18 | from collections import defaultdict
19 |
20 | def main():
21 | df = pd.read_csv(args.chains, index_col='name')
22 |
23 | sequences = [SeqRecord(Seq(row.seqres), id=name) for name, row in tqdm.tqdm(df.iterrows())]
24 | SeqIO.write(sequences, ".in.fasta", "fasta")
25 | cmd = [args.mmseqs_path, 'easy-cluster', '.in.fasta', '.out', '.tmp', '--min-seq-id', str(args.thresh), '--alignment-mode', '1']
26 | subprocess.run(cmd)#, stdout=open('/dev/null', 'w'))
27 | f = open('.out_cluster.tsv')
28 | clusters = []
29 | for line in f:
30 | a, b = line.strip().split()
31 | if a == b:
32 | clusters.append([])
33 | clusters[-1].append(b)
34 | subprocess.run(['rm', '-r', '.in.fasta', '.tmp', '.out_all_seqs.fasta', '.out_rep_seq.fasta', '.out_cluster.tsv'])
35 |
36 | with open(args.out, 'w') as f:
37 | for clus in clusters:
38 | f.write(' '.join(clus))
39 | f.write('\n')
40 |
41 |
42 | if __name__ == "__main__":
43 | main()
--------------------------------------------------------------------------------
/scripts/download_atlas.sh:
--------------------------------------------------------------------------------
1 | for name in $(cat `dirname $0`/../splits/atlas.csv | grep -v name | awk -F ',' {'print $1'}); do
2 | wget https://www.dsimb.inserm.fr/ATLAS/database/ATLAS/${name}/${name}_protein.zip
3 | mkdir ${name}
4 | unzip ${name}_protein.zip -d ${name}
5 | rm ${name}_protein.zip
6 | done
--------------------------------------------------------------------------------
/scripts/mmseqs_query.py:
--------------------------------------------------------------------------------
1 | # https://github.com/sokrypton/ColabFold/blob/main/colabfold/colabfold.py
2 |
3 | import argparse
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--split', type=str, required=True)
6 | parser.add_argument('--outdir', type=str, default='./alignment_dir')
7 | args = parser.parse_args()
8 |
9 | from typing import Tuple, List
10 | import os
11 | import requests
12 | import random
13 | import time
14 | import logging
15 | import tarfile
16 | import pandas as pd
17 | from tqdm import tqdm
18 | logger = logging.getLogger(__name__)
19 |
20 | TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'
21 | def run_mmseqs2(x, prefix, use_env=True, use_filter=True,
22 | use_templates=False, filter=None, use_pairing=False, pairing_strategy="greedy",
23 | host_url="https://api.colabfold.com",
24 | user_agent: str = "") -> Tuple[List[str], List[str]]:
25 | submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa"
26 |
27 | headers = {}
28 | if user_agent != "":
29 | headers['User-Agent'] = user_agent
30 | else:
31 | logger.warning("No user agent specified. Please set a user agent (e.g., 'toolname/version contact@email') to help us debug in case of problems. This warning will become an error in the future.")
32 |
33 | def submit(seqs, mode, N=101):
34 | n, query = N, ""
35 | for seq in seqs:
36 | query += f">{n}\n{seq}\n"
37 | n += 1
38 |
39 | while True:
40 | error_count = 0
41 | try:
42 | # https://requests.readthedocs.io/en/latest/user/advanced/#advanced
43 | # "good practice to set connect timeouts to slightly larger than a multiple of 3"
44 | res = requests.post(f'{host_url}/{submission_endpoint}', data={ 'q': query, 'mode': mode }, timeout=6.02, headers=headers)
45 | except requests.exceptions.Timeout:
46 | logger.warning("Timeout while submitting to MSA server. Retrying...")
47 | continue
48 | except Exception as e:
49 | error_count += 1
50 | logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
51 | logger.warning(f"Error: {e}")
52 | time.sleep(5)
53 | if error_count > 5:
54 | raise
55 | continue
56 | break
57 |
58 | try:
59 | out = res.json()
60 | except ValueError:
61 | logger.error(f"Server didn't reply with json: {res.text}")
62 | out = {"status":"ERROR"}
63 | return out
64 |
65 | def status(ID):
66 | while True:
67 | error_count = 0
68 | try:
69 | res = requests.get(f'{host_url}/ticket/{ID}', timeout=6.02, headers=headers)
70 | except requests.exceptions.Timeout:
71 | logger.warning("Timeout while fetching status from MSA server. Retrying...")
72 | continue
73 | except Exception as e:
74 | error_count += 1
75 | logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
76 | logger.warning(f"Error: {e}")
77 | time.sleep(5)
78 | if error_count > 5:
79 | raise
80 | continue
81 | break
82 | try:
83 | out = res.json()
84 | except ValueError:
85 | logger.error(f"Server didn't reply with json: {res.text}")
86 | out = {"status":"ERROR"}
87 | return out
88 |
89 | def download(ID, path):
90 | error_count = 0
91 | while True:
92 | try:
93 | res = requests.get(f'{host_url}/result/download/{ID}', timeout=6.02, headers=headers)
94 | except requests.exceptions.Timeout:
95 | logger.warning("Timeout while fetching result from MSA server. Retrying...")
96 | continue
97 | except Exception as e:
98 | error_count += 1
99 | logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
100 | logger.warning(f"Error: {e}")
101 | time.sleep(5)
102 | if error_count > 5:
103 | raise
104 | continue
105 | break
106 | with open(path,"wb") as out: out.write(res.content)
107 |
108 | # process input x
109 | seqs = [x] if isinstance(x, str) else x
110 |
111 | # compatibility to old option
112 | if filter is not None:
113 | use_filter = filter
114 |
115 | # setup mode
116 | if use_filter:
117 | mode = "env" if use_env else "all"
118 | else:
119 | mode = "env-nofilter" if use_env else "nofilter"
120 |
121 | if use_pairing:
122 | use_templates = False
123 | use_env = False
124 | mode = ""
125 | # greedy is default, complete was the previous behavior
126 | if pairing_strategy == "greedy":
127 | mode = "pairgreedy"
128 | elif pairing_strategy == "complete":
129 | mode = "paircomplete"
130 |
131 | # define path
132 | path = f"{prefix}_{mode}"
133 | if not os.path.isdir(path): os.mkdir(path)
134 |
135 | # call mmseqs2 api
136 | tar_gz_file = f'{path}/out.tar.gz'
137 | N,REDO = 101,True
138 |
139 | # deduplicate and keep track of order
140 | seqs_unique = []
141 | #TODO this might be slow for large sets
142 | [seqs_unique.append(x) for x in seqs if x not in seqs_unique]
143 | Ms = [N + seqs_unique.index(seq) for seq in seqs]
144 | # lets do it!
145 | if not os.path.isfile(tar_gz_file):
146 | TIME_ESTIMATE = 150 * len(seqs_unique)
147 | with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:
148 | while REDO:
149 | pbar.set_description("SUBMIT")
150 |
151 | # Resubmit job until it goes through
152 | out = submit(seqs_unique, mode, N)
153 | while out["status"] in ["UNKNOWN", "RATELIMIT"]:
154 | sleep_time = 5 + random.randint(0, 5)
155 | logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}")
156 | # resubmit
157 | time.sleep(sleep_time)
158 | out = submit(seqs_unique, mode, N)
159 |
160 | if out["status"] == "ERROR":
161 | raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')
162 |
163 | if out["status"] == "MAINTENANCE":
164 | raise Exception(f'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.')
165 |
166 | # wait for job to finish
167 | ID,TIME = out["id"],0
168 | pbar.set_description(out["status"])
169 | while out["status"] in ["UNKNOWN","RUNNING","PENDING"]:
170 | t = 5 + random.randint(0,5)
171 | logger.error(f"Sleeping for {t}s. Reason: {out['status']}")
172 | time.sleep(t)
173 | out = status(ID)
174 | pbar.set_description(out["status"])
175 | if out["status"] == "RUNNING":
176 | TIME += t
177 | pbar.update(n=t)
178 | #if TIME > 900 and out["status"] != "COMPLETE":
179 | # # something failed on the server side, need to resubmit
180 | # N += 1
181 | # break
182 |
183 | if out["status"] == "COMPLETE":
184 | if TIME < TIME_ESTIMATE:
185 | pbar.update(n=(TIME_ESTIMATE-TIME))
186 | REDO = False
187 |
188 | if out["status"] == "ERROR":
189 | REDO = False
190 | raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')
191 |
192 | # Download results
193 | download(ID, tar_gz_file)
194 |
195 | # prep list of a3m files
196 | if use_pairing:
197 | a3m_files = [f"{path}/pair.a3m"]
198 | else:
199 | a3m_files = [f"{path}/uniref.a3m"]
200 | if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")
201 |
202 | # extract a3m files
203 | if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):
204 | with tarfile.open(tar_gz_file) as tar_gz:
205 | tar_gz.extractall(path)
206 |
207 | # templates
208 | if use_templates:
209 | templates = {}
210 | #print("seq\tpdb\tcid\tevalue")
211 | for line in open(f"{path}/pdb70.m8","r"):
212 | p = line.rstrip().split()
213 | M,pdb,qid,e_value = p[0],p[1],p[2],p[10]
214 | M = int(M)
215 | if M not in templates: templates[M] = []
216 | templates[M].append(pdb)
217 | #if len(templates[M]) <= 20:
218 | # print(f"{int(M)-N}\t{pdb}\t{qid}\t{e_value}")
219 |
220 | template_paths = {}
221 | for k,TMPL in templates.items():
222 | TMPL_PATH = f"{prefix}_{mode}/templates_{k}"
223 | if not os.path.isdir(TMPL_PATH):
224 | os.mkdir(TMPL_PATH)
225 | TMPL_LINE = ",".join(TMPL[:20])
226 | response = None
227 | while True:
228 | error_count = 0
229 | try:
230 | # https://requests.readthedocs.io/en/latest/user/advanced/#advanced
231 | # "good practice to set connect timeouts to slightly larger than a multiple of 3"
232 | response = requests.get(f"{host_url}/template/{TMPL_LINE}", stream=True, timeout=6.02, headers=headers)
233 | except requests.exceptions.Timeout:
234 | logger.warning("Timeout while submitting to template server. Retrying...")
235 | continue
236 | except Exception as e:
237 | error_count += 1
238 | logger.warning(f"Error while fetching result from template server. Retrying... ({error_count}/5)")
239 | logger.warning(f"Error: {e}")
240 | time.sleep(5)
241 | if error_count > 5:
242 | raise
243 | continue
244 | break
245 | with tarfile.open(fileobj=response.raw, mode="r|gz") as tar:
246 | tar.extractall(path=TMPL_PATH)
247 | os.symlink("pdb70_a3m.ffindex", f"{TMPL_PATH}/pdb70_cs219.ffindex")
248 | with open(f"{TMPL_PATH}/pdb70_cs219.ffdata", "w") as f:
249 | f.write("")
250 | template_paths[k] = TMPL_PATH
251 |
252 | # gather a3m lines
253 | a3m_lines = {}
254 | for a3m_file in a3m_files:
255 | update_M,M = True,None
256 | for line in open(a3m_file,"r"):
257 | if len(line) > 0:
258 | if "\x00" in line:
259 | line = line.replace("\x00","")
260 | update_M = True
261 | if line.startswith(">") and update_M:
262 | M = int(line[1:].rstrip())
263 | update_M = False
264 | if M not in a3m_lines: a3m_lines[M] = []
265 | a3m_lines[M].append(line)
266 |
267 | # return results
268 |
269 | a3m_lines = ["".join(a3m_lines[n]) for n in Ms]
270 |
271 | if use_templates:
272 | template_paths_ = []
273 | for n in Ms:
274 | if n not in template_paths:
275 | template_paths_.append(None)
276 | #print(f"{n-N}\tno_templates_found")
277 | else:
278 | template_paths_.append(template_paths[n])
279 | template_paths = template_paths_
280 |
281 |
282 | return (a3m_lines, template_paths) if use_templates else a3m_lines
283 |
284 | df = pd.read_csv(args.split, index_col='name')
285 | os.makedirs(args.outdir, exist_ok=True)
286 |
287 | msas = run_mmseqs2(list(df.seqres), prefix='/tmp/', user_agent='bjing2016/alphaflow')
288 | os.system('rm -r /tmp/_env')
289 |
290 | for name, msa in zip(df.index, msas):
291 | os.makedirs(f'{args.outdir}/{name}/a3m/', exist_ok=True)
292 | with open(f'{args.outdir}/{name}/a3m/{name}.a3m', 'w') as f:
293 | f.write(msa)
294 |
295 |
--------------------------------------------------------------------------------
/scripts/mmseqs_search_helper.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | parser = argparse.ArgumentParser()
3 | parser.add_argument('--split', type=str, required=True)
4 | parser.add_argument('--db_dir', type=str, default='./dbbase')
5 | parser.add_argument('--outdir', type=str, default='./alignment_dir')
6 | args = parser.parse_args()
7 |
8 | import pandas as pd
9 | import subprocess, os
10 |
11 | df = pd.read_csv(args.split)
12 | os.makedirs(args.outdir, exist_ok=True)
13 | with open('/tmp/tmp.fasta', 'w') as f:
14 | for _, row in df.iterrows():
15 | f.write(f'>{row["name"]}\n{row.seqres}\n')
16 |
17 | cmd = f'python -m scripts.mmseqs_search /tmp/tmp.fasta {args.db_dir} {args.outdir}'
18 | os.system(cmd)
19 |
20 | for name in os.listdir(args.outdir):
21 | if '.a3m' not in name:
22 | continue
23 | with open(f'{args.outdir}/{name}') as f:
24 | pdb_id = next(f).strip()[1:]
25 | cmd = f'mkdir -p {args.outdir}/{pdb_id}/a3m'
26 | print(cmd)
27 | os.system(cmd)
28 | cmd = f'mv {args.outdir}/{name} {args.outdir}/{pdb_id}/a3m/{pdb_id}.a3m'
29 | os.system(cmd)
30 | print(cmd)
--------------------------------------------------------------------------------
/scripts/prep_atlas.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser()
4 | parser.add_argument('--split', type=str, default='splits/atlas.csv')
5 | parser.add_argument('--atlas_dir', type=str, required=True)
6 | parser.add_argument('--outdir', type=str, default='./data_atlas')
7 | parser.add_argument('--num_workers', type=int, default=1)
8 | args = parser.parse_args()
9 |
10 | import mdtraj, os, tempfile, tqdm
11 | from alphaflow.utils import protein
12 | from openfold.data.data_pipeline import make_protein_features
13 | import pandas as pd
14 | from multiprocessing import Pool
15 | import numpy as np
16 |
17 | os.makedirs(args.outdir, exist_ok=True)
18 |
19 | df = pd.read_csv(args.split, index_col='name')
20 |
21 | def main():
22 | jobs = []
23 | for name in df.index:
24 | #if os.path.exists(f'{args.outdir}/{name}.npz'): continue
25 | jobs.append(name)
26 |
27 | if args.num_workers > 1:
28 | p = Pool(args.num_workers)
29 | p.__enter__()
30 | __map__ = p.imap
31 | else:
32 | __map__ = map
33 | for _ in tqdm.tqdm(__map__(do_job, jobs), total=len(jobs)):
34 | pass
35 | if args.num_workers > 1:
36 | p.__exit__(None, None, None)
37 |
38 | def do_job(name):
39 | traj = mdtraj.load(f'{args.atlas_dir}/{name}/{name}_prod_R1_fit.xtc', top=f'{args.atlas_dir}/{name}/{name}.pdb') \
40 | + mdtraj.load(f'{args.atlas_dir}/{name}/{name}_prod_R2_fit.xtc', top=f'{args.atlas_dir}/{name}/{name}.pdb') \
41 | + mdtraj.load(f'{args.atlas_dir}/{name}/{name}_prod_R3_fit.xtc', top=f'{args.atlas_dir}/{name}/{name}.pdb')
42 | ref = mdtraj.load(f'{args.atlas_dir}/{name}/{name}.pdb')
43 | traj = ref + traj
44 | f, temp_path = tempfile.mkstemp(); os.close(f)
45 | positions_stacked = []
46 | for i in tqdm.trange(0, len(traj), 100):
47 | traj[i].save_pdb(temp_path)
48 |
49 | with open(temp_path) as f:
50 | prot = protein.from_pdb_string(f.read())
51 | pdb_feats = make_protein_features(prot, name)
52 | positions_stacked.append(pdb_feats['all_atom_positions'])
53 |
54 |
55 | pdb_feats['all_atom_positions'] = np.stack(positions_stacked)
56 | print({key: pdb_feats[key].shape for key in pdb_feats})
57 | np.savez(f"{args.outdir}/{name}.npz", **pdb_feats)
58 | os.unlink(temp_path)
59 |
60 | main()
61 |
--------------------------------------------------------------------------------
/scripts/print_analysis.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tqdm, sys, pickle, warnings
3 | import pandas as pd
4 | import scipy.stats
5 |
6 | paths = sys.argv[1:]
7 |
8 | def correlations(a, b, prefix=''):
9 | return {
10 | prefix + 'pearson': scipy.stats.pearsonr(a, b)[0],
11 | prefix + 'spearman': scipy.stats.spearmanr(a, b)[0],
12 | prefix + 'kendall': scipy.stats.kendalltau(a, b)[0],
13 | }
14 |
15 |
16 | def analyze_data(data):
17 | mi_mats = {}
18 | df = []
19 | for name, out in data.items():
20 | item = {
21 | 'name': name,
22 | 'md_pairwise': out['ref_mean_pairwise_rmsd'],
23 | 'af_pairwise': out['af_mean_pairwise_rmsd'],
24 | 'cosine_sim': abs(out['cosine_sim']),
25 | 'emd_mean': np.square(out['emd_mean']).mean() ** 0.5,
26 | 'emd_var': np.square(out['emd_var']).mean() ** 0.5,
27 | } | correlations(out['af_rmsf'], out['ref_rmsf'], prefix='rmsf_')
28 | if 'EMD,ref' not in out:
29 | out['EMD,ref'] = out['EMD-2,ref']
30 | out['EMD,af2'] = out['EMD-2,af2']
31 | out['EMD,joint'] = out['EMD-2,joint']
32 | for emd_dict, emd_key in [
33 | (out['EMD,ref'], 'ref'),
34 | (out['EMD,joint'], 'joint')
35 | ]:
36 | item.update({
37 | emd_key + 'emd': emd_dict['ref|af'],
38 | emd_key + 'emd_tr': emd_dict['ref mean|af mean'],
39 | emd_key + 'emd_int': (emd_dict['ref|af']**2 - emd_dict['ref mean|af mean']**2)**0.5,
40 | })
41 |
42 | try:
43 | crystal_contact_mask = out['crystal_distmat'] < 0.8
44 | ref_transient_mask = (~crystal_contact_mask) & (out['ref_contact_prob'] > 0.1)
45 | af_transient_mask = (~crystal_contact_mask) & (out['af_contact_prob'] > 0.1)
46 | ref_weak_mask = crystal_contact_mask & (out['ref_contact_prob'] < 0.9)
47 | af_weak_mask = crystal_contact_mask & (out['af_contact_prob'] < 0.9)
48 | item.update({
49 | 'weak_contacts_iou': (ref_weak_mask & af_weak_mask).sum() / (ref_weak_mask | af_weak_mask).sum(),
50 | 'transient_contacts_iou': (ref_transient_mask & af_transient_mask).sum() / (ref_transient_mask | af_transient_mask).sum()
51 | })
52 | except:
53 | item.update({
54 | 'weak_contacts_iou': np.nan,
55 | 'transient_contacts_iou': np.nan,
56 | })
57 | sasa_thresh = 0.02
58 | buried_mask = out['crystal_sasa'][0] < sasa_thresh
59 | ref_sa_mask = (out['ref_sa_prob'] > 0.1) & buried_mask
60 | af_sa_mask = (out['af_sa_prob'] > 0.1) & buried_mask
61 |
62 | item.update({
63 | 'num_sasa': ref_sa_mask.sum(),
64 | 'sasa_iou': (ref_sa_mask & af_sa_mask).sum() / (ref_sa_mask | af_sa_mask).sum(),
65 | })
66 | item.update(correlations(out['ref_mi_mat'].flatten(), out['af_mi_mat'].flatten(), prefix='exposon_mi_'))
67 |
68 | df.append(item)
69 | df = pd.DataFrame(df).set_index('name')#.join(val_df)
70 | all_ref_rmsf = np.concatenate([data[name]['ref_rmsf'] for name in df.index])
71 | all_af_rmsf = np.concatenate([data[name]['af_rmsf'] for name in df.index])
72 | return all_ref_rmsf, all_af_rmsf, df, data
73 |
74 | datas = {}
75 |
76 | for path in tqdm.tqdm(paths):
77 | with open(path, 'rb') as f:
78 | data = pickle.load(f)
79 | with warnings.catch_warnings():
80 | warnings.simplefilter("ignore")
81 | datas[path] = analyze_data(data)
82 |
83 | new_df = []
84 | for key in datas:
85 | ref_rmsf, af_rmsf, df, data = datas[key]
86 | new_df.append({
87 | 'path': key,
88 | 'count': len(df),
89 | 'MD pairwise RMSD': df.md_pairwise.median(),
90 | 'Pairwise RMSD': df.af_pairwise.median(),
91 | 'Pairwise RMSD r': scipy.stats.pearsonr(df.md_pairwise, df.af_pairwise)[0],
92 | 'MD RMSF': np.median(ref_rmsf),
93 | 'RMSF': np.median(af_rmsf),
94 | 'Global RMSF r': scipy.stats.pearsonr(ref_rmsf, af_rmsf)[0],
95 | 'Per target RMSF r': df.rmsf_pearson.median(),
96 | 'RMWD': np.sqrt(df.emd_mean**2 + df.emd_var**2).median(),
97 | 'RMWD trans': df.emd_mean.median(),
98 | 'RMWD var': df.emd_var.median(),
99 | 'MD PCA W2': df.refemd.median(),
100 | 'Joint PCA W2': df.jointemd.median(),
101 | 'PC sim > 0.5 %': (df.cosine_sim > 0.5).mean() * 100,
102 | 'Weak contacts J': df.weak_contacts_iou.median(),
103 | 'Weak contacts nans': df.weak_contacts_iou.isna().mean(),
104 | 'Transient contacts J': df.transient_contacts_iou.median(),
105 | 'Transient contacts nans': df.transient_contacts_iou.isna().mean(),
106 | 'Exposed residue J': df.sasa_iou.median(),
107 | 'Exposed MI matrix rho': df.exposon_mi_spearman.median(),
108 | })
109 |
110 | new_df = pd.DataFrame(new_df).set_index('path')
111 | print(new_df.round(2).T)
--------------------------------------------------------------------------------
/scripts/unpack_mmcif.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser()
4 | parser.add_argument('--mmcif_dir', type=str, required=True)
5 | parser.add_argument('--outdir', type=str, default='./data')
6 | parser.add_argument('--outcsv', type=str, default='./pdb_mmcif.csv')
7 | parser.add_argument('--num_workers', type=int, default=15)
8 | args = parser.parse_args()
9 |
10 | import warnings, tqdm, os, io, logging
11 | import pandas as pd
12 | import numpy as np
13 | from multiprocessing import Pool
14 | from alphaflow.data.data_pipeline import DataPipeline
15 | from openfold.data import mmcif_parsing
16 |
17 | pipeline = DataPipeline(template_featurizer=None)
18 |
19 | def main():
20 | dirs = os.listdir(args.mmcif_dir)
21 | files = [os.listdir(f"{args.mmcif_dir}/{dir}") for dir in dirs]
22 | files = sum(files, [])
23 | if args.num_workers > 1:
24 | p = Pool(args.num_workers)
25 | p.__enter__()
26 | __map__ = p.imap
27 | else:
28 | __map__ = map
29 | infos = list(tqdm.tqdm(__map__(unpack_mmcif, files), total=len(files)))
30 | if args.num_workers > 1:
31 | p.__exit__(None, None, None)
32 | info = []
33 | for inf in infos:
34 | info.extend(inf)
35 | df = pd.DataFrame(info).set_index('name')
36 | df.to_csv(args.outcsv)
37 |
38 | def unpack_mmcif(name):
39 | path = f"{args.mmcif_dir}/{name[1:3]}/{name}"
40 |
41 | with open(path, 'r') as f:
42 | mmcif_string = f.read()
43 |
44 |
45 | mmcif = mmcif_parsing.parse(
46 | file_id=name[:-4], mmcif_string=mmcif_string
47 | )
48 | if mmcif.mmcif_object is None:
49 | logging.info(f"Could not parse {name}. Skipping...")
50 | return []
51 | else:
52 | mmcif = mmcif.mmcif_object
53 |
54 | out = []
55 | for chain, seq in mmcif.chain_to_seqres.items():
56 | out.append({
57 | "name": f"{name[:-4]}_{chain}",
58 | "release_date": mmcif.header["release_date"],
59 | "seqres": seq,
60 | "resolution": mmcif.header["resolution"],
61 | })
62 |
63 | data = pipeline.process_mmcif(mmcif=mmcif, chain_id=chain)
64 | out_dir = f"{args.outdir}/{name[1:3]}"
65 | os.makedirs(out_dir, exist_ok=True)
66 | out_path = f"{out_dir}/{name[:-4]}_{chain}.npz"
67 | np.savez(out_path, **data)
68 |
69 |
70 | return out
71 |
72 | if __name__ == "__main__":
73 | main()
--------------------------------------------------------------------------------
/splits/atlas_val.csv:
--------------------------------------------------------------------------------
1 | name,seqres,release_date,msa_id
2 | 6irx_A,VYWDLDIQTNAVIRERAPADHLPPHPEIELQRAQLTTKLRQHYHELCSQREGIEPPRESFNRWLLERKVVDKGLDPLLPSECDPVISPSMFREIMNDIPIRLSRIKYKEEARKLLFKYAEAAKKMIDSRNVTPESRKVVKWNVEDTMNWLRRDHSASKEDYMDRLENLRKQCGPHVASVAKDSVEGICSKIYHISAEYVRRIRQAHLTLLKECNISVDGTESAEVQDRLVYCYPVRLSIPAPPQTRVELHFENDIACLRFKGEMVKVSRGHFNKLELLYRYSCIDDPRFEKFLSRVWCLIKRYQVMFGSGVNEGSGLQGSLPVPVFEALNKQFGVTFECFASPLNCYFKQFCSAFPDIDGFFGSRGPFLSFSPASGSFEANPPFCEELMDAMVTHFEDLLGRSSEPLSFIIFVPEWRDPPTPALTRMEASRFRRHQMTVPAFEHEYRSGSQHICKREEIYYKAIHGTAVIFLQNNAGFAKWEPTTERIQELLAAYK,2018-12-05,6irx_A
3 | 6cka_B,MLYIDEFKEAIDKGYILGDTVAIVRKNGKIFDYVLPHEKVRDDEVVTVERVEEVMVELDKLEHHHHHH,2018-11-14,6cka_B
4 | 6hj6_A,ETGMPLSCSKNNSHHYIFVGNNSGLELTLTNTSLLNHKFCNLSDAHKRAQYDMALMSIVSTFHLSIPNFNQYEAMSCDFNGDNITVQYNLSHASAVDAANHCGTVANGILETFHKFFWSNNIKDAYQLPNQGKLAHCYSTSYQFLIIQNTTWEDHCTFSRPTGTKHHHHHH,2018-10-10,6hj6_A
5 | 6dgk_B,ASRYLTDMTLEEMSRDWFMLMPKQKVAGSLCIRMDQAIMDKNIILKANFSVIFDRLETLILLRAFTEEGAIVGEISPLPSLPGHTDEDVKNAVGVLIGGLEANDNTVRVSETLQRFAWRS,2018-08-15,6dgk_B
6 | 6fc0_B,GPHMQPSAALQSLRSARFLPGIVQDIYPPGIKSPNPALNEAVQKKGRIFKYDVQFLLQFQNVFTEKPSPDFDQQVKALIGD,2018-06-27,6fc0_B
7 | 6dlm_A,GSPRSYLLKELADLSQHLVRLLERLVRESERVVEVLERGEVDEEELKRLEDLHRELEKAVREVRETHREIRERSR,2018-12-19,6dlm_A
8 | 6mdw_A,GPGRPQESLSLVDASWELVDPTPDLQALFVQFNDQFFWGQLEAVEVKWSVRMTLCAGICSYEGKGGMCSIRLSEPLLKLRPRKDLVETLLHEMIHAYLFVTNNDKDREGHGPEFCKHMHRINSLTGANITVYHTFHDEVDEYRRHWWRCNGPCQHRPPYYGYVKRATNREPSAHDYWWAEHQKTCGGTYIKIKE,2019-04-10,6mdw_A
9 | 6bwq_A,TADAVLMIEANLDDQTGEGLGYVMNQLLTAGAYDVFFTPIQMKKDRPATKLTVLGNVNDKDLLTKLILQETTTIGVRYQTWQRTIMQRHFLTVATPYGDVQVKVATYQDIEKKMPEYADCAQLAQQFHIPFRTVYQAALVAVDQLDEEA,2018-06-20,6bwq_A
10 | 5ydn_A,MNYATVNDLCARYTRTRLDILTRPKTADGQPDDAVAEQALADASAFIDGYLAARFVLPLTVVPSLLKRQCCVVAWFYLNESQPTEQITATYRDTVRWLEQVRDGKTDPG,2018-07-11,5ydn_A
11 | 5w82_E,HMMPSEDYAIWYARATIAALQAAEYRLAMPSASYTAWFTDAVSDKLDKISESLNTLVECVIDKRLAVSVPEPLPVRVENKVQVEVEDEVRVRVENKVDVEVKN,2018-12-26,5w82_E
12 | 6cb7_A,HMDKLRVLYDEFVTISKDNLERETGLSASDVDMDFDLNIFMTLVPVLAAAVCAITPTIEDDKIVTMMKYCSYQSFSFWFLKSGAVVKSVYNKLDYVKKEKFVATFRDMLLNVQTLISLNSMY,2018-12-12,6cb7_A
13 | 5yrv_I,MARVSDYPLANKHPEWVKTATNKTLDDFTLENVLSNKVTAQDMRITPETLRLQASIAKDAGRDRLAMNFERAAELTAVPDDRILEIYNALRPYRSTKEELLAIADDLESRYQAKICAAFVREAATLYVERKKLKGDD,2018-09-19,5yrv_I
14 | 5z51_A,PTLWPQREALKSALQYPALAGPVFDALTVEGFTHPEYAAVRAAIDTAGGTSAGLSGAQWLDMVRQQTTSTVTSALISELGVEAIQVDDDKLPRYIAGVLARLQEVWLGRQIAEVKSKLQRMSPIEQGDEYHALFGDLVAMEAYRRSLLEQASGDDLHHHHHH,2018-11-28,5z51_A
15 | 6a9a_A,MISDSMTVEEIRLHLGLALKEKDFVVDKTGVKTIEIIGASFVADEPFIFGALNDEYIQRELEWYKSKSLFVKDIPGETPKIWQQVASSKGEINSNYGWAIWSEDNYAQYDMCLAELGQNPDSRRGIMIYTRPSMQFDYNKDGMSDFMSTNTVQYLIRDKKINAVVNMRSNDVVFGFRNNYAWQKYVLDKLVSDLNAGDSTRQYKAGSIIWNVGSLHVYSRHFYLVDHWWKTGETHISKKDYVGKYA,2019-01-02,6a9a_A
16 | 6atg_C,GHMNQRNINELKIFVEKAKYYSIKLDAIYNECTGAYNDIMTYSEGTFSDQSKVNQAISIFKKDNKIVNKFKELEKIIEEYKPMFLSKLIDDFAIELDQAVDNDVSNARHVADSYKKLRKSVVLAYIESFDVISSKFVDSKFVEASKKFVNKAKEFVEENDLIALECIVKTIGDMVNDREINSRSRYNNFYKKEADFLGAAVELEGAYKAIKQTL,2018-09-12,6atg_C
17 | 5w4a_B,MDTNKREIVEFLGIRTYFFPNLALYAVNNDELLVSDPNKANSFAAYVFGASDKKPSVDDIVQILFPSGSDSGTILTSMDTLLALGPDFLTEFKKRNQDLARFNLTHDLSILAQGDEDAAKKKLNLMGRKAKLQKTEAAKILAILIKTINSEENYEKFTELSELCGLDLDFDAYVFTKILGLEDEDTADEVEVIRDNFLNRLDQTKPKLADIIRNGP,2018-06-13,5w4a_B
18 | 5zmo_A,GSREAPKTFHRRVGDVRPARRAMGPALHRPVLLLWAIGQAVARAPRLQPWSTTRDAVAPLMEKYGQVEDGVDGVRYPFWALVRDDLWCVEQAEELTLTSRGRRPTLESLNAVDPSAGLREDDYNLLRSQPEAAASAAAGLIARYFHLLPAGLLEDFGLHELLAGRWPDALRP,2018-09-26,5zmo_A
19 | 6bk4_A,ARKDNLAKTIAETAKIREVLIIQNVLNCFNDDQVRSDFLNGENGAKKLENTELELLEKFFIETQTRRPETADDVSFIATAQKSAELFYSTINARPKSFGEVSFEKLRSLFQQIQDSGYLDKYYL,2018-11-07,6bk4_A
20 | 6eu8_A,MDKKDDVKETIKKSKTIDATKYEMWTYVNLETGQTETHRDFSEWHVMKNGKLLETIPAKGSEADIKIKWHIAIHRFDIRTNEGEAIATKETEFSKVTGLPAGDYKKDVEIKDKMLVGFNMADMMKSKFTVAGMAKVNPVLKTWIVENPMGKAPVLSKSVFVVKFKDGSYAKIKFTDATNDKQEKGHVSFNYEFQPK,2018-10-24,6eu8_A
21 | 6mbg_A,SNATGERFHPLLGRRVLDPAPGARCFTGTVSSDRPAYLGEHWVYDAIVVLGVTYLEMALAAASRLLPGNAADTLIVEDVTVWSPLVLRAGAPARLRLRSEDERFEIHSAEPERSDDESAWTRHATGRIARRQLSPDATGQLPRLDGEAVELDAYYERMRIYYGPRLRNIRHLERRGREAIGHVCLQGEEAQESASYELHPALLDACFQCVFALIYAHESHREPFVPLGCARIELRARGVREVRVHLRLHPPRSTDHNQTHTADLRLFDMEGRLVASVDALQLKRA,2018-09-19,6mbg_A
22 | 6bm5_A,GSKKPRTPPVTLEAARDNDFAFDWQAYTPPVAHRLGVQEVEASIETLRNYIDWTPFFMTWSLAGKYPRILEDEVVGVEAQRLFKDANDMLDKLSAEKTLNPRGVVGLFPANRVGDDIEIYRDETRTHVINVSHHLRQQTEKTGFANYCLADFVAPKLSGKADYIGAFAVTGGLEEDALADAFEAQHDDYNKIMVKALADRLAEAFAEYLHERVRKVYWGYAPNENLSNEELIRENYQGIRPAPGYPACPEHTEKATIWELLEVEKHTGMKLTESFAMWPGASVSGWYFSHPDSKYYAVAQIQRDQVEDYARRKGMSVTEVERWLAPNLGYDAD,2018-05-23,6bm5_A
23 | 6f45_D,MAVQGPWVGSSYVAETGQNWASLAANELRVTERPFWISSFIGRSKEEIWEWTGENHSFNKDWLIGELRNRGGTPVVINIRAHQVSYTPGAPLFEFPGDLPNAYITLNIYADIYGRGGTGGVAYLGGNPGGDCIHNWIGNRLRINNQGWICGGGGGGGGFRVGHTEAGGGGGRPLGAGGVSSLNLNGDNATLGAPGRGYQLGNDYAGNGGDVGNPGSASSAEMGGGAAGRAVVGTSPQWINVGNIAGSWL,2018-08-01,6f45_D
24 | 6fub_B,METGNKYIEKRAIDLSRERDPNFFDNPGIPVPECFWFMFKNNVRQDAGTCYSSWKMDMKVGPNWVHIKSDDNCNLSGDFPPGWIVLGKKRPGF,2018-06-13,6fub_B
25 | 6gfx_C,MEAGRPRPVLRSVNSREPSQVIFCNRSPRVVLPVWLNFDGEPQPYPTLPPGTGRRIHSYRGHLWLFRDAGTHDGLLVNQTELFVPSLNVDGQPIFANITLPVYTLKERCLQVVRSLVKPENYRRLDIVRSLYEDLEDHPNVQKDLERLTQERIAHQRMGD,2018-07-11,6gfx_C
26 | 6a02_A,MAATNMTDNVTLNNDKISGQAWQAMRDIGMSRFELFNGRTQKAEQLAAQAEKLLNDDSTDWNLYVKSDKKAPVEGDHYIRINSSITVAEDYLPAGQKNDAINKANQKMKEGDKKGTIEALKLAGVSVIENQELIPLQQTRKDVTTALSLMNEGKYYQAGLLLKSAQDGIVVDSQSVQLEHHHHHH,2019-02-06,6a02_A
27 | 6c0h_A,SGMKSAKEPTIYQDVDIIRRIQELMVLCSLLPPDGKLREALELALALHEEPALARITPLTNLHPFATKAWLETLWLGEGVSSEEKELVAWQNKSENMGPAIRELKNAEQQSGITLVARLTS,2018-09-05,6c0h_A
28 | 6dnm_A,SHMVVLGGDRDFWLQVGIDPIQIMTGTATFYTLRCYLDDRPIFLGRNGRISVFGSERALARYLADEHDHDLSDLSTYDDIRTAATDGSLAVAVTDDNVYVLSGLVDDFADGPDAVDREQLDLAVELLRDIGDYSEDSAVDKALETTRPLGQLVAYVLDPHSVGKPTAPYAAAVREWEKLERFVESRLRRE,2019-01-23,6dnm_A
29 | 6as3_A,HHHHHHMEKKLSDAQVALVAAWRKYPDLRESLEEAASILSLIVFQAETLSDQANELANYIRRQGLEEAEGACRNIDIMRAKWVEVCGEVNQYGIRVYGDAIDRDVD,2018-08-29,6as3_A
30 | 6bn0_A,QPYNPCKPQEVIDTKCMGPKDCLYPNPDSCTTYIQCVPLDEVGNAKPVVKPCPKGLQWNDNVGKKWCDYPNLSTCPVKT,2018-08-22,6bn0_A
31 | 5wfy_A,MILKILNEIASIGSTKQKQAILEKNKDNELLKRVYRLTYSRGLQYYIKKWPKPGIATQSFGMLTLTDMLDFIEFTLATRKLTGNAAIEELTGYITDGKKDDVEVLRRVMMRDLECGASVSIANKVWPGLEHHHHHH,2018-09-26,5wfy_A
32 | 6hem_A,GPDFSKHLKEETIQIITKASHEHEDKSPETVLQSAIKLEYARLVKLAQEDTPPETDYRLHHVVVYFIQNQAPKKIIEKTLLEQFGDRNLSFDERCHNIMKVAQAKLEMIKPEEVNLEEYEEWHQDYRKFRETTMYLIIGLENFQRESYIDSLLFLICAYQNNKELLSKGLYRGHDEELISHYRRECLLKLNEQAAELFESGEDREVNNGLIIMNEFIVPFLPLLLVDEMEEKDILAVEDMRNRWCSYLGQEMEPHLQEKLTDFLPKLLDCSMEIKSFHEPPKLPSYSTHELCERFARIMLSLS,2019-03-27,6hem_A
33 | 5z1n_A,GPLSGLKKLIPEEGRELIGSVKKIIKRVSNEEKANEMEKNILKILIKVFFYIDSKAIQIGDLAKVDRALRDGFNHLDRAFRYYGVKKAADLVVILEKASTALKEAEQETVTLLTPFFRPHNIQLIRNTFAFLGSLDFFTKVWDDLEIEDDLFLLISALNKYTQIELIY,2018-10-17,5z1n_A
34 | 5zlq_A,GPMGKIALQLKATLENITNLRPVGEDFRWYLKMKCGNCGEISDKWQYIRLMDSVALKGGRGSASMVQKCKLCARENSIEILSSTIKPYNAEDNENFKTIVEFECRGLEPVDFQPQAGFAAEGVESGTAFSDINLQEKDWTDYDEKAQESVGIYEVTHQFVKC,2018-10-10,5zlq_A
35 | 5naz_A,SVAHGFLITRHSQTTDAPQCPQGTLQVYEGFSLLYVQGNKRAHGQDLGTAGSCLRRFSTMPFMFCNINNVCNFASRNDYSYWLSTPEPMPMSMQPLKGQSIQPFISRCAVCEAPAVVIAVHSQTIQIPHCPQGWDSLWIGYSFMMHTSAGAEGSGQALASPGSCLEEFRSAPFIECHGRGTCNYYANSYSFWLATVDVSDMFSKPQSETLKAGDLRTRISRCQVCMKRT,2018-09-12,5naz_A
36 | 6e33_A,GKVKKRLPQAKRACAKCQKDNKKCDDARPCQRCIKAKTDCIDLPRKKRPTGVRRGPYKKLS,2018-10-24,6e33_A
37 | 5ok6_A,MDREPQHEELPGLDSQWRQIENGESGRERPLRAGESWFLVEKHWYKQWEAYVQGGDQDSSTFPGCINNATLFQDEINWRLKEGLVEGEDYVLLPAAAWHYLVSWYGLEHGQPPIERKVIELPNIQKVEVYPVELLLVRHNDLGKSHTVQFSHTDSIGLVLRTARERFLVEPQEDTRLWAKNSEGSLDRLYDTHITVLDAALETGQLIIMETRKKDGTWPSAQLEHHHHHH,2018-08-08,5ok6_A
38 | 6idx_A,GPGSMPPPSDIVKVAIEWPGANAQLLEIDQKRPLASIIKEVCDGWSLPNPEYYTLRYADGPQLYITEQTRSDIKNGTILQLAISPSRAARQLMERTQSSNMETRLDAMKELAKLSADVTFATEFINMDGIIVLTRLVESGTKLLSHYSEMLAFTLTAFLELMDHGIVSWDMVSITFIKQIAGYVSQPMVDVSILQRSLAILESMVLNSQSLYQKIAEEITVGQLISHLQVSNQEIQTYAIALINALFLKAPEDKRQDMANAFAQKHLRSIILNHVIRGNRPIKTEMAHQLYVLQVLTFNLLEERMMTKMDPNDQAQRDIIFELRRIAFDAESDPSNAPGSGTEKRKAMYTKDYKMLGFTNHINPAMDFTQTPPGMLALDNMLYLAKVHQDTYIRIVLENSSREDKHECPFGRSAIELTKMLCEILQVGELPNEGRNDYHPMFFTHDRAFEELFGICIQLLNKTWKEMRATAEDFNKVMQVVREQITRALPSKPNSLDQFKSKLRSLSYSEILRLRQSERMSQDD,2019-01-23,6idx_A
39 | 6e5y_A,SNAMTTILPNLPTGQKVGIAFSGGLDTSAALLWMRQKGAVPYAYTANLGQPDEPDYDEIPRRAMQYGAEAARLVDCRAQLVAEGIAALQAGAFHISTAGLTYFNTTPIGRAVTGTMLVAAMKEDGVNIWGDGSTFKGNDIERFYRYGLLTNPDLKIYKPWLDQTFIDELGGRAEMSEYMRQAGFDYKMSAEKAYSTDSNMLGATHEAKDLELLSAGIRIVQPIMGVAFWQDSVQIKAEEVTVRFEEGQPVALNGVEYADPVELLLEANRIGGRHGLGMSDQIENRIIEAKSRGIYEAPGLALLFIAYERLVTGIHNEDTIEQYRENGRKLGRLLYQGRWFDPQAIMLRETAQRWVARAITGEVTLELRRGNDYSLLNTESANLTYAPERLSMEKVENAPFTPADRIGQLTMRNLDIVDTREKLFTYVKTGLLAPSAGSALPQIKDGKK,2018-08-01,6e5y_A
40 | 6crk_G,MASNNTASIAQARKLVEQLKMEANIDRIKVSKAAADLMAYCEAHAKEDPLLTPVPASENPFREKKFFSAIL,2018-10-24,6crk_G
41 |
--------------------------------------------------------------------------------
/splits/pdb_test.json:
--------------------------------------------------------------------------------
1 | {
2 | "Q15758_1": [
3 | "6gct_A",
4 | "6mp6_A",
5 | "6mpb_A",
6 | "6rvx_A",
7 | "7bcq_A",
8 | "7bcs_A",
9 | "7bct_A"
10 | ],
11 | "P41440_3": [
12 | "7xpz_A",
13 | "7xq0_A",
14 | "7xq1_A",
15 | "7xq2_A",
16 | "8goe_A",
17 | "8gof_A",
18 | "8hii_A",
19 | "8hij_A",
20 | "8hik_A"
21 | ],
22 | "Q16602_1": [
23 | "6e3y_R",
24 | "6uun_R",
25 | "6uus_R",
26 | "6uva_R",
27 | "7knt_R",
28 | "7knu_R"
29 | ],
30 | "Q01650_1": [
31 | "6irs_B",
32 | "6irt_B",
33 | "6jmq_A",
34 | "7dsk_B",
35 | "7dsl_B",
36 | "7dsn_B",
37 | "7dsq_B"
38 | ],
39 | "Q9SVY5_1": [
40 | "6j5t_B",
41 | "6j5u_B",
42 | "6j5w_B",
43 | "6j6i_B"
44 | ],
45 | "Q8H9R8_1": [
46 | "7sxk_a",
47 | "7sya_a"
48 | ],
49 | "Q24560_1": [
50 | "6tis_B",
51 | "6tiu_B",
52 | "6tiy_B",
53 | "6tiz_B",
54 | "7quc_B",
55 | "7qud_B",
56 | "7qup_10B"
57 | ],
58 | "Q8U338_1": [
59 | "7tr6_C",
60 | "7tr8_C",
61 | "7tr9_C",
62 | "7tra_B"
63 | ],
64 | "O92284_1": [
65 | "6lgw_E",
66 | "6lgx_A"
67 | ],
68 | "P38786_1": [
69 | "6agb_I",
70 | "6ah3_I",
71 | "6w6v_I",
72 | "7c79_I",
73 | "7c7a_I"
74 | ],
75 | "O43172_1": [
76 | "6ahd_K",
77 | "6qw6_4B",
78 | "6qx9_4B"
79 | ],
80 | "B1H241_1": [
81 | "6nmg_A",
82 | "6nmj_A",
83 | "6tyl_A",
84 | "6ukt_A"
85 | ],
86 | "A0QUE0_1": [
87 | "6yxu_H",
88 | "6yys_H",
89 | "6z11_H"
90 | ],
91 | "Q18BV5_1": [
92 | "6t9m_AAA",
93 | "6tsb_AAA"
94 | ],
95 | "G1SLW8_1": [
96 | "6w2s_8",
97 | "6w2t_8",
98 | "6yam_u"
99 | ],
100 | "P68040_1": [
101 | "7cpu_Sg",
102 | "7cpv_Sg",
103 | "7ls1_I3",
104 | "7ls2_I3"
105 | ],
106 | "Q9BU64_1": [
107 | "7pb8_O",
108 | "7pkn_O",
109 | "7r5s_O",
110 | "7xho_O"
111 | ],
112 | "Q02256_1": [
113 | "6n8m_Y",
114 | "6n8n_Y",
115 | "6n8o_Y",
116 | "6rzz_s",
117 | "6s05_s"
118 | ],
119 | "Q1D6H4_1": [
120 | "7s20_A",
121 | "7s21_A",
122 | "7s2t_A",
123 | "7s4q_A"
124 | ],
125 | "A0A7J7G2E3_1": [
126 | "7v4i_A",
127 | "7v4j_A",
128 | "7v4k_A",
129 | "7v4l_A"
130 | ],
131 | "A0A498U580_1": [
132 | "6xgp_A",
133 | "6xgq_A"
134 | ],
135 | "Q5FVI6_1": [
136 | "6vq6_G",
137 | "6vq7_G",
138 | "6vq8_G"
139 | ],
140 | "P73953_1": [
141 | "7cye_A",
142 | "7cyf_A",
143 | "7egk_A",
144 | "7egl_A"
145 | ],
146 | "P46976_2": [
147 | "7q0b_E",
148 | "7q0s_E",
149 | "7q12_E",
150 | "7q13_E",
151 | "7zbn_E",
152 | "8cvx_E",
153 | "8cvy_E",
154 | "8cvz_E"
155 | ],
156 | "Q02939_1": [
157 | "7k01_2",
158 | "7ml0_2",
159 | "7ml1_2",
160 | "7ml2_2",
161 | "7ml4_2",
162 | "7o4i_2",
163 | "7o4j_2",
164 | "7o4k_2",
165 | "7o4l_2",
166 | "7o72_2",
167 | "7o73_2",
168 | "7o75_2",
169 | "7zs9_2",
170 | "7zsa_2"
171 | ],
172 | "Q8DL32_1": [
173 | "6hum_A",
174 | "6khi_A",
175 | "6khj_A",
176 | "6l7o_A",
177 | "6l7p_A",
178 | "6nbx_A",
179 | "6nby_A",
180 | "6tjv_A"
181 | ],
182 | "A8IVW2_1": [
183 | "7n6g_3M",
184 | "7sqc_1F"
185 | ],
186 | "F2NWD3_1": [
187 | "6wxw_A",
188 | "6wxx_A",
189 | "6wxy_B",
190 | "6xl1_A"
191 | ],
192 | "P03887_1": [
193 | "7dgz_1",
194 | "7qsd_H",
195 | "7qsk_H",
196 | "7qsl_H",
197 | "7qsm_H",
198 | "7qsn_H",
199 | "7qso_H"
200 | ],
201 | "Q9VJY9_3": [
202 | "7w0a_B",
203 | "7w0b_B",
204 | "7w0c_B"
205 | ],
206 | "P53044_1": [
207 | "8dar_H",
208 | "8das_H",
209 | "8dat_H",
210 | "8dau_H",
211 | "8dav_H",
212 | "8daw_H"
213 | ],
214 | "P82251_1": [
215 | "6li9_B",
216 | "6lid_B",
217 | "6yup_D",
218 | "6yv1_A"
219 | ],
220 | "Q9KN86_1": [
221 | "6u2a_A",
222 | "6ue4_A"
223 | ],
224 | "P20429_1": [
225 | "6wvj_A",
226 | "6wvk_A",
227 | "6zfb_U"
228 | ],
229 | "P13866_1": [
230 | "7sl8_A",
231 | "7sla_A",
232 | "7wmv_A"
233 | ],
234 | "Q5Y388_3": [
235 | "7fd2_B",
236 | "7wc2_B",
237 | "7wco_K"
238 | ],
239 | "P78371_1": [
240 | "6qb8_B",
241 | "7nvl_B",
242 | "7nvm_B",
243 | "7nvn_B",
244 | "7nvo_B",
245 | "7trg_E",
246 | "7ttn_E",
247 | "7ttt_E",
248 | "7tub_E",
249 | "7wu7_B"
250 | ],
251 | "P16140_1": [
252 | "7fde_B",
253 | "7tmm_B",
254 | "7tmo_B",
255 | "7tmp_B",
256 | "7tmq_B",
257 | "7tmr_B"
258 | ],
259 | "Q9H1D9_1": [
260 | "7a6h_P",
261 | "7ae1_P",
262 | "7ae3_P",
263 | "7aea_P",
264 | "7ast_Z",
265 | "7d58_P",
266 | "7d59_P",
267 | "7dn3_P",
268 | "7du2_P",
269 | "7fji_P",
270 | "7fjj_P"
271 | ],
272 | "P41586_2": [
273 | "6lpb_R",
274 | "6m1h_A",
275 | "6m1i_A",
276 | "6p9y_R",
277 | "8e3x_R"
278 | ],
279 | "Q83VS8_1": [
280 | "7e8r_A",
281 | "7edb_A"
282 | ],
283 | "P20042_1": [
284 | "6ybv_s",
285 | "6zmw_s",
286 | "6zp4_4",
287 | "7a09_4"
288 | ],
289 | "P03910_1": [
290 | "7dgz_4",
291 | "7qsd_M",
292 | "7qsk_M",
293 | "7qsl_M",
294 | "7qsm_M",
295 | "7qsn_M",
296 | "7qso_M"
297 | ],
298 | "Q9F5Q9_1": [
299 | "6p7r_A",
300 | "6p7t_A",
301 | "6pb9_A"
302 | ],
303 | "Q8DKY0_1": [
304 | "6hum_D",
305 | "6khi_D",
306 | "6khj_D",
307 | "6l7o_D",
308 | "6l7p_D",
309 | "6nbq_D",
310 | "6nbx_D",
311 | "6nby_D"
312 | ],
313 | "P0CB42_1": [
314 | "6ima_A",
315 | "6imc_A",
316 | "6ksf_A",
317 | "6l94_A"
318 | ],
319 | "A0A172JI16_1": [
320 | "7s00_c",
321 | "7s01_c",
322 | "7um0_c"
323 | ],
324 | "Q7K740_1": [
325 | "6mb3_E",
326 | "6mhg_E",
327 | "7v05_X"
328 | ],
329 | "F0NH89_1": [
330 | "7pq2_AAA",
331 | "7pq3_AAA",
332 | "7pq6_AAA",
333 | "7pqa_AAA"
334 | ],
335 | "P47871_3": [
336 | "6lmk_R",
337 | "6lml_R",
338 | "6whc_R",
339 | "6wpw_R",
340 | "7v35_R"
341 | ],
342 | "Q9UU79_1": [
343 | "8esq_r",
344 | "8esr_r",
345 | "8etc_r",
346 | "8etg_r",
347 | "8eth_r",
348 | "8eti_r",
349 | "8etj_r",
350 | "8eup_r",
351 | "8euy_r",
352 | "8ev3_r"
353 | ],
354 | "G0S654_1": [
355 | "6pdw_B",
356 | "6pdy_A",
357 | "6pe0_A"
358 | ],
359 | "Q92847_3": [
360 | "6ko5_A",
361 | "7f9y_R",
362 | "7f9z_R",
363 | "7na7_R",
364 | "7na8_R",
365 | "7w2z_R"
366 | ],
367 | "P33993_1": [
368 | "6xtx_7",
369 | "7pfo_7",
370 | "7plo_7"
371 | ],
372 | "Q4GZ99_1": [
373 | "6hiw_CO",
374 | "6hiy_CO",
375 | "6sga_CO",
376 | "6sgb_CO",
377 | "7pua_CO",
378 | "7pub_CO"
379 | ],
380 | "A0A1D5AKC8_2": [
381 | "7xxa_B",
382 | "7xxg_B",
383 | "7xxj_B"
384 | ],
385 | "P01871_3": [
386 | "6kxs_A",
387 | "7k0c_A",
388 | "7qdo_A",
389 | "7xt6_B"
390 | ],
391 | "A0A3A8DDU9_1": [
392 | "7ecw_A",
393 | "7elm_A",
394 | "7eln_A",
395 | "7eqg_B",
396 | "7we6_A"
397 | ],
398 | "Q8DJD9_1": [
399 | "6hum_H",
400 | "6khi_H",
401 | "6khj_H",
402 | "6l7o_H",
403 | "6l7p_H",
404 | "6nbq_H",
405 | "6nbx_H",
406 | "6nby_H",
407 | "6tjv_H"
408 | ],
409 | "P53985_1": [
410 | "6lyy_A",
411 | "6lz0_A",
412 | "7cko_A",
413 | "7ckr_A",
414 | "7da5_A",
415 | "7yr5_A"
416 | ],
417 | "P0A1J1_2": [
418 | "6k3i_AA",
419 | "6k9q_A",
420 | "7cbm_DA",
421 | "7cgb_DL",
422 | "7cgo_DA",
423 | "7e80_DA",
424 | "7e82_DA"
425 | ],
426 | "Q09E27_1": [
427 | "6ptq_A",
428 | "6ptx_A",
429 | "6pu2_A",
430 | "7jr5_A",
431 | "7jri_A"
432 | ],
433 | "O43614_3": [
434 | "7l1u_R",
435 | "7l1v_R",
436 | "7sqo_R",
437 | "7sr8_R",
438 | "7xrr_A"
439 | ],
440 | "O18640_1": [
441 | "6xu6_Ag",
442 | "6xu8_Ag"
443 | ],
444 | "Q4Q703_1": [
445 | "7aih_BM",
446 | "7am2_BM",
447 | "7ane_BM"
448 | ],
449 | "P61221_1": [
450 | "6zme_CI",
451 | "6zvj_1",
452 | "7a09_J"
453 | ],
454 | "P32893_1": [
455 | "7e2d_K",
456 | "7e8t_K",
457 | "7u05_C"
458 | ],
459 | "P00960_1": [
460 | "7eiv_A",
461 | "7qcf_A"
462 | ],
463 | "P15056_1": [
464 | "6uan_B",
465 | "7mfd_A",
466 | "7mff_A",
467 | "7zr0_K",
468 | "7zr5_K"
469 | ],
470 | "P46957_1": [
471 | "6p1h_B",
472 | "6v93_F",
473 | "7kc0_B",
474 | "7s0t_F"
475 | ],
476 | "Q9LVM1_1": [
477 | "7n58_A",
478 | "7n59_A",
479 | "7n5a_A",
480 | "7n5b_A"
481 | ],
482 | "Q12874_1": [
483 | "6ahd_w",
484 | "6qx9_A3",
485 | "7evo_C",
486 | "7onb_N",
487 | "7q3l_9",
488 | "7q4o_9",
489 | "7q4p_9"
490 | ],
491 | "G1SVB0_1": [
492 | "7o7y_Ag",
493 | "7o7z_Ag",
494 | "7o81_Ag",
495 | "7oyd_HH",
496 | "7syn_I",
497 | "7syp_I",
498 | "7syq_I",
499 | "7syr_I",
500 | "7sys_I",
501 | "7syv_I",
502 | "7syw_I",
503 | "7syx_I",
504 | "7zjw_SS",
505 | "7zjx_SS"
506 | ],
507 | "O05000_1": [
508 | "7a23_I",
509 | "7a24_I",
510 | "7aqq_N",
511 | "7ar7_N",
512 | "7ar8_N",
513 | "7arb_N"
514 | ],
515 | "P32325_2": [
516 | "7p5z_G",
517 | "7pt6_9",
518 | "7pt7_9",
519 | "7v3v_I"
520 | ],
521 | "Q06631_1": [
522 | "6lqp_RY",
523 | "6lqu_RY",
524 | "6zqb_JK",
525 | "6zqc_JK",
526 | "7suk_5"
527 | ],
528 | "B6VNP3_1": [
529 | "6j0b_A",
530 | "6j0c_A",
531 | "6j0n_h"
532 | ],
533 | "O34693_1": [
534 | "7aqc_R",
535 | "7aqd_R",
536 | "7as8_0",
537 | "7as9_0",
538 | "7asa_0",
539 | "7ope_0"
540 | ],
541 | "Q57YI7_1": [
542 | "6hix_AR",
543 | "6yxx_AR",
544 | "6yxy_AR"
545 | ],
546 | "P41895_2": [
547 | "7mei_Q",
548 | "7mk9_Q",
549 | "7ml0_Q",
550 | "7ml1_Q",
551 | "7ml2_Q",
552 | "7ml4_Q",
553 | "7o4i_Q",
554 | "7o4j_Q",
555 | "7o72_Q",
556 | "7o73_Q",
557 | "7o75_Q",
558 | "7zs9_Q",
559 | "7zsa_Q"
560 | ],
561 | "P37468_2": [
562 | "6ifs_A",
563 | "6ift_A",
564 | "6ifw_A",
565 | "7v2l_W",
566 | "7v2m_U",
567 | "7v2n_U",
568 | "7v2o_U",
569 | "7v2p_U",
570 | "7v2q_U"
571 | ],
572 | "A0A1D5AKC8_3": [
573 | "7xxa_A",
574 | "7xxg_A",
575 | "7xxj_A"
576 | ],
577 | "A1L314_2": [
578 | "6sb3_A",
579 | "8a1d_A",
580 | "8a1s_A"
581 | ],
582 | "A0A8C7BTF2_1": [
583 | "7f5r_A",
584 | "7w8s_A",
585 | "7wa1_A",
586 | "7wa3_A"
587 | ],
588 | "P47110_1": [
589 | "6p1h_C",
590 | "6v93_G",
591 | "7kc0_C",
592 | "7s0t_G"
593 | ],
594 | "Q8DMR6_1": [
595 | "6hum_B",
596 | "6khi_B",
597 | "6khj_B",
598 | "6l7o_B",
599 | "6l7p_B",
600 | "6nbq_B",
601 | "6nbx_B",
602 | "6nby_B",
603 | "6tjv_B"
604 | ],
605 | "A0QT50_1": [
606 | "6hb0_A",
607 | "6hbd_A",
608 | "6hbm_A",
609 | "6hyh_A"
610 | ],
611 | "Q9JJ59_1": [
612 | "7v5c_A",
613 | "7v5d_A",
614 | "7vfi_A"
615 | ],
616 | "P41091_1": [
617 | "6o81_S",
618 | "6o85_S",
619 | "6ybv_t",
620 | "6zmw_t",
621 | "6zp4_Y",
622 | "7a09_Y",
623 | "7f66_S",
624 | "7f67_S",
625 | "7qp7_t"
626 | ],
627 | "P51787_3": [
628 | "6uzz_A",
629 | "6v00_A",
630 | "6v01_A",
631 | "7xni_A",
632 | "7xnk_A",
633 | "7xnl_A",
634 | "7xnn_B"
635 | ],
636 | "Q9LT15_1": [
637 | "6h7d_A",
638 | "7aaq_A",
639 | "7aar_A"
640 | ],
641 | "P32241_2": [
642 | "6vn7_R",
643 | "8e3y_R",
644 | "8e3z_R"
645 | ],
646 | "P69801_1": [
647 | "6k1h_B",
648 | "7dyr_B"
649 | ],
650 | "Q38854_1": [
651 | "7bzx_A",
652 | "7bzy_11"
653 | ],
654 | "P17694_1": [
655 | "7dgz_B",
656 | "7qsd_D",
657 | "7qsk_D",
658 | "7qsl_D",
659 | "7qsm_D",
660 | "7qsn_D",
661 | "7qso_D"
662 | ],
663 | "Q586R9_1": [
664 | "6hix_AK",
665 | "6yxx_AK",
666 | "6yxy_AK"
667 | ],
668 | "P18708_3": [
669 | "6ip2_A",
670 | "6mdo_A",
671 | "6mdp_A"
672 | ],
673 | "Q8N0V3_1": [
674 | "7pnx_a",
675 | "7pny_a",
676 | "7pnz_a",
677 | "7po0_a",
678 | "8csr_6",
679 | "8css_6",
680 | "8cst_6",
681 | "8csu_6"
682 | ],
683 | "P29992_1": [
684 | "6oij_A",
685 | "7rkf_A",
686 | "7try_A"
687 | ],
688 | "Q03431_2": [
689 | "6fj3_A",
690 | "6nbf_R",
691 | "6nbh_R",
692 | "6nbi_R",
693 | "7vvj_R",
694 | "7vvk_R",
695 | "7vvl_R",
696 | "7vvm_R",
697 | "7vvn_R",
698 | "8ha0_R",
699 | "8haf_R",
700 | "8hao_I"
701 | ]
702 | }
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from alphaflow.utils.parsing import parse_train_args
2 | args = parse_train_args()
3 |
4 | from alphaflow.utils.logging import get_logger
5 | logger = get_logger(__name__)
6 | import torch, tqdm, os, wandb
7 | import pandas as pd
8 |
9 | from functools import partial
10 | import pytorch_lightning as pl
11 | from pytorch_lightning.callbacks import ModelCheckpoint
12 | from openfold.utils.exponential_moving_average import ExponentialMovingAverage
13 | from alphaflow.model.wrapper import ESMFoldWrapper, AlphaFoldWrapper
14 | from openfold.utils.import_weights import import_jax_weights_
15 |
16 | torch.set_float32_matmul_precision("high")
17 | from alphaflow.config import model_config
18 | from alphaflow.data.data_modules import OpenFoldSingleDataset, OpenFoldBatchCollator, OpenFoldDataset
19 | from alphaflow.data.inference import CSVDataset, AlphaFoldCSVDataset
20 |
21 | config = model_config(
22 | 'initial_training',
23 | train=True,
24 | low_prec=True
25 | )
26 |
27 | loss_cfg = config.loss
28 | data_cfg = config.data
29 | data_cfg.common.use_templates = False
30 | data_cfg.common.max_recycling_iters = 0
31 |
32 | def load_clusters(path):
33 | cluster_size = []
34 | with open(args.pdb_clusters) as f:
35 | for line in f:
36 | names = line.split()
37 | for name in names:
38 | cluster_size.append({'name': name, 'cluster_size': len(names)})
39 | return pd.DataFrame(cluster_size).set_index('name')
40 |
41 | def main():
42 |
43 | if args.wandb:
44 | wandb.init(
45 | entity=os.environ["WANDB_ENTITY"],
46 | settings=wandb.Settings(start_method="fork"),
47 | project="alphaflow",
48 | name=args.run_name,
49 | config=args,
50 | )
51 |
52 | logger.info("Loading the chains dataframe")
53 | pdb_chains = pd.read_csv(args.pdb_chains, index_col='name')
54 |
55 | if args.filter_chains:
56 | clusters = load_clusters(args.pdb_clusters)
57 | pdb_chains = pdb_chains.join(clusters)
58 | pdb_chains = pdb_chains[pdb_chains.release_date < args.train_cutoff]
59 |
60 | trainset = OpenFoldSingleDataset(
61 | data_dir = args.train_data_dir,
62 | alignment_dir = args.train_msa_dir,
63 | pdb_chains = pdb_chains,
64 | config = data_cfg,
65 | mode = 'train',
66 | subsample_pos = args.sample_train_confs,
67 | first_as_template = args.first_as_template,
68 | )
69 | if args.normal_validate:
70 | val_pdb_chains = pd.read_csv(args.val_csv, index_col='name')
71 | valset = OpenFoldSingleDataset(
72 | data_dir = args.train_data_dir,
73 | alignment_dir = args.train_msa_dir,
74 | pdb_chains = val_pdb_chains,
75 | config = data_cfg,
76 | mode = 'train',
77 | subsample_pos = args.sample_val_confs,
78 | num_confs = args.num_val_confs,
79 | first_as_template = args.first_as_template,
80 | )
81 | else:
82 | valset = AlphaFoldCSVDataset(
83 | data_cfg,
84 | args.val_csv,
85 | mmcif_dir=args.mmcif_dir,
86 | data_dir=args.train_data_dir,
87 | msa_dir=args.val_msa_dir,
88 | )
89 | if args.filter_chains:
90 | trainset = OpenFoldDataset([trainset], [1.0], args.train_epoch_len)
91 |
92 | val_loader = torch.utils.data.DataLoader(
93 | valset,
94 | batch_size=args.batch_size,
95 | collate_fn=OpenFoldBatchCollator(),
96 | num_workers=args.num_workers,
97 | )
98 | train_loader = torch.utils.data.DataLoader(
99 | trainset,
100 | batch_size=args.batch_size,
101 | collate_fn=OpenFoldBatchCollator(),
102 | num_workers=args.num_workers,
103 | shuffle=not args.filter_chains,
104 | )
105 |
106 |
107 | trainer = pl.Trainer(
108 | accelerator="gpu",
109 | max_epochs=args.epochs,
110 | limit_train_batches=args.limit_batches or 1.0,
111 | limit_val_batches=args.limit_batches or 1.0,
112 | num_sanity_val_steps=0,
113 | enable_progress_bar=not args.wandb,
114 | gradient_clip_val=args.grad_clip,
115 | callbacks=[ModelCheckpoint(
116 | dirpath=os.environ["MODEL_DIR"],
117 | save_top_k=-1,
118 | every_n_epochs=args.ckpt_freq,
119 | )],
120 | accumulate_grad_batches=args.accumulate_grad,
121 | check_val_every_n_epoch=args.val_freq,
122 | logger=False,
123 | )
124 | if args.mode == 'esmfold':
125 | model = ESMFoldWrapper(config, args)
126 | if args.ckpt is None:
127 | logger.info("Loading the model")
128 | path = "esmfold_3B_v1.pt"
129 | model_data = torch.load(path)
130 | model_state = model_data["model"]
131 | model.esmfold.load_state_dict(model_state, strict=False)
132 | logger.info("Model has been loaded")
133 |
134 | if not args.no_ema:
135 | model.ema = ExponentialMovingAverage(
136 | model=model.esmfold, decay=config.ema.decay
137 | ) # need to initialize EMA this way at the beginning
138 | elif args.mode == 'alphafold':
139 | model = AlphaFoldWrapper(config, args)
140 | if args.ckpt is None:
141 | logger.info("Loading the model")
142 | import_jax_weights_(model.esmfold, 'params_model_1.npz', version='model_3')
143 | if not args.no_ema:
144 | model.ema = ExponentialMovingAverage(
145 | model=model.model, decay=config.ema.decay
146 | ) # need to initialize EMA this way at the beginning
147 |
148 | if args.restore_weights_only:
149 | model.load_state_dict(torch.load(args.ckpt, map_location='cpu')['state_dict'], strict=False)
150 | args.ckpt = None
151 | if not args.no_ema:
152 | model.ema = ExponentialMovingAverage(
153 | model=model.model, decay=config.ema.decay
154 | ) # need to initialize EMA this way at the beginning
155 |
156 |
157 | if args.validate:
158 | trainer.validate(model, val_loader, ckpt_path=args.ckpt)
159 | else:
160 | trainer.fit(model, train_loader, val_loader, ckpt_path=args.ckpt)
161 |
162 | if __name__ == "__main__":
163 | main()
--------------------------------------------------------------------------------