├── README.md ├── data.py ├── engine_train.py ├── environment.yml ├── example ├── 1gl5.fasta ├── 1gl5_af │ ├── 1gl5_double.csv │ └── 1gl5_single.csv ├── 1gl5_esm │ ├── 1gl5_double.csv │ └── 1gl5_single.csv └── 1gl5_msa │ └── 1gl5.a3m ├── figs └── model.jpeg ├── main_train.py ├── metrics.py ├── misc.py ├── modeling ├── backbone.py ├── criterion.py ├── module.py ├── mutate_everything.py └── utils.py ├── openfold ├── .github │ └── workflows │ │ ├── docker-image.yml │ │ └── undefined_names.yml ├── .gitignore ├── CITATION.cff ├── Dockerfile ├── LICENSE ├── README.md ├── deepspeed_config.json ├── environment.yml ├── imgs │ └── of_banner.png ├── lib │ └── openmm.patch ├── notebooks │ ├── OpenFold.ipynb │ └── environment.yml ├── openfold │ ├── __init__.py │ ├── config.py │ ├── data │ │ ├── __init__.py │ │ ├── data_modules.py │ │ ├── data_pipeline.py │ │ ├── data_transforms.py │ │ ├── errors.py │ │ ├── feature_pipeline.py │ │ ├── input_pipeline.py │ │ ├── mmcif_parsing.py │ │ ├── parsers.py │ │ ├── templates.py │ │ └── tools │ │ │ ├── __init__.py │ │ │ ├── hhblits.py │ │ │ ├── hhsearch.py │ │ │ ├── jackhmmer.py │ │ │ ├── kalign.py │ │ │ └── utils.py │ ├── model │ │ ├── __init__.py │ │ ├── dropout.py │ │ ├── embedders.py │ │ ├── evoformer.py │ │ ├── heads.py │ │ ├── model.py │ │ ├── msa.py │ │ ├── outer_product_mean.py │ │ ├── pair_transition.py │ │ ├── primitives.py │ │ ├── structure_module.py │ │ ├── template.py │ │ ├── torchscript.py │ │ ├── triangular_attention.py │ │ └── triangular_multiplicative_update.py │ ├── np │ │ ├── __init__.py │ │ ├── protein.py │ │ ├── relax │ │ │ ├── __init__.py │ │ │ ├── amber_minimize.py │ │ │ ├── cleanup.py │ │ │ ├── relax.py │ │ │ └── utils.py │ │ └── residue_constants.py │ ├── resources │ │ └── __init__.py │ └── utils │ │ ├── __init__.py │ │ ├── argparse.py │ │ ├── callbacks.py │ │ ├── checkpointing.py │ │ ├── chunk_utils.py │ │ ├── exponential_moving_average.py │ │ ├── feats.py │ │ ├── import_weights.py │ │ ├── kernel │ │ ├── __init__.py │ │ ├── attention_core.py │ │ └── csrc │ │ │ ├── compat.h │ │ │ ├── softmax_cuda.cpp │ │ │ ├── softmax_cuda_kernel.cu │ │ │ └── softmax_cuda_stub.cpp │ │ ├── logger.py │ │ ├── loss.py │ │ ├── lr_schedulers.py │ │ ├── precision_utils.py │ │ ├── rigid_utils.py │ │ ├── script_utils.py │ │ ├── seed.py │ │ ├── superimposition.py │ │ ├── suppress_output.py │ │ ├── tensor_utils.py │ │ ├── trace_utils.py │ │ └── validation_metrics.py ├── run_pretrained_openfold.py ├── scripts │ ├── activate_conda_env.sh │ ├── alignment_db_scripts │ │ ├── create_alignment_db.py │ │ └── unify_alignment_db_indices.py │ ├── build_deepspeed_config.py │ ├── colabfold_search.sh │ ├── convert_of_weights_to_jax.py │ ├── data_dir_to_fasta.py │ ├── deactivate_conda_env.sh │ ├── download_alphafold_dbs.sh │ ├── download_alphafold_params.sh │ ├── download_bfd.sh │ ├── download_cameo.py │ ├── download_colabfold_envdb.sh │ ├── download_mgnify.sh │ ├── download_mmseqs_dbs.sh │ ├── download_openfold_params.sh │ ├── download_openfold_params_gdrive.sh │ ├── download_openfold_params_huggingface.sh │ ├── download_pdb70.sh │ ├── download_pdb_mmcif.sh │ ├── download_roda_pdbs.sh │ ├── download_small_bfd.sh │ ├── download_uniclust30.sh │ ├── download_uniref30.sh │ ├── download_uniref90.sh │ ├── flatten_roda.sh │ ├── generate_alphafold_feature_dict.py │ ├── generate_chain_data_cache.py │ ├── generate_mmcif_cache.py │ ├── install_hh_suite.sh │ ├── install_third_party_dependencies.sh │ ├── precompute_alignments.py │ ├── precompute_alignments_mmseqs.py │ ├── prep_mmseqs_dbs.sh │ ├── prep_proteinnet_msas.py │ ├── run_unit_tests.sh │ ├── slurm_scripts │ │ └── run_uniclust30_search.sh │ ├── unpack_proteinnet.py │ ├── utils.py │ ├── vars.sh │ └── zero_to_fp32.py ├── setup.py ├── tests │ ├── __init__.py │ ├── compare_utils.py │ ├── config.py │ ├── data_utils.py │ ├── test_data_pipeline.py │ ├── test_data_transforms.py │ ├── test_embedders.py │ ├── test_evoformer.py │ ├── test_feats.py │ ├── test_import_weights.py │ ├── test_kernels.py │ ├── test_loss.py │ ├── test_model.py │ ├── test_msa.py │ ├── test_outer_product_mean.py │ ├── test_pair_transition.py │ ├── test_primitives.py │ ├── test_structure_module.py │ ├── test_template.py │ ├── test_triangular_attention.py │ ├── test_triangular_multiplicative_update.py │ └── test_utils.py ├── thread_sequence.py └── train_openfold.py └── test.py /environment.yml: -------------------------------------------------------------------------------- 1 | name: mutate_everything 2 | channels: 3 | - conda-forge 4 | - bioconda 5 | - pytorch 6 | dependencies: 7 | - conda-forge::python=3.7 8 | - conda-forge::setuptools=59.5.0 9 | - conda-forge::pip 10 | - conda-forge::openmm=7.5.1 11 | - conda-forge::pdbfixer 12 | - conda-forge::cudatoolkit==11.3.* 13 | - conda-forge::cudatoolkit-dev==11.3.* 14 | - bioconda::hmmer==3.3.2 15 | - bioconda::hhsuite==3.3.0 16 | - bioconda::kalign2==2.04 17 | - pytorch::pytorch=1.12.* 18 | - pip: 19 | - biopython==1.79 20 | - deepspeed==0.5.10 21 | - dm-tree==0.1.6 22 | - ml-collections==0.1.0 23 | - numpy==1.21.2 24 | - PyYAML==5.4.1 25 | - requests==2.26.0 26 | - scipy==1.7.1 27 | - tqdm==4.62.2 28 | - typing-extensions==3.10.0.2 29 | - wandb==0.12.21 30 | - modelcif==0.7 31 | - git+https://github.com/NVIDIA/dllogger.git 32 | - pandas 33 | - ipdb 34 | - easydict 35 | - matplotlib 36 | - fair-esm 37 | - einops 38 | - pandas 39 | - biopython 40 | - scikit-learn 41 | - gdown 42 | - awscli 43 | -------------------------------------------------------------------------------- /example/1gl5.fasta: -------------------------------------------------------------------------------- 1 | >1gl5 2 | SEIVVAMYDFQATEAHDLRLERGQEYIILEKNDLHWWRARDKYGSEGYIPSNYVTGKK 3 | -------------------------------------------------------------------------------- /figs/model.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/MutateEverything/c380ea9fa185b770df25127e8aabbccbd212a074/figs/model.jpeg -------------------------------------------------------------------------------- /modeling/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | def loss_single_double(pred: dict, ddg_dense1, ddg_dense2, batch, args, train) -> dict: 8 | # sample fewer destabilizing mutations 9 | stbl_ratio = args.double_subsample_destabilizing_ratio 10 | if stbl_ratio > 0 and train: 11 | for b in range(len(ddg_dense2)): 12 | ddg_dense_b = ddg_dense2[b] 13 | destbl_inds = ((ddg_dense_b != 999) & (ddg_dense_b > 0)).nonzero() 14 | n_destbl = len(destbl_inds) 15 | n_stbl = (ddg_dense_b < 0).sum().item() + 1 16 | if n_destbl < stbl_ratio * n_stbl: 17 | continue 18 | mask_inds = np.random.choice(n_destbl, n_destbl - int(stbl_ratio * n_stbl), replace=False) 19 | ddg_dense2[b][destbl_inds[mask_inds].split(1,dim=1)] = 999 20 | 21 | # mask unknown values 22 | unknown_mask1 = ddg_dense1 == 999 23 | unknown_mask2 = ddg_dense2 == 999 24 | 25 | losses = {} 26 | 27 | if unknown_mask1.all(): 28 | losses['loss1'] = 0. * pred['mut1_ddg'].sum() 29 | else: 30 | losses['loss1'] = F.huber_loss(pred['mut1_ddg'][~unknown_mask1], ddg_dense1[~unknown_mask1]) * args.lambda_single 31 | 32 | if args.multi_dec == 'epistasis': 33 | # unknown if any of (ddg_ij, ddg_i, ddg_j) are unknown 34 | unknown_mask2 |= (unknown_mask1[:,None,None,:,:] | unknown_mask1[:,:,:,None,None]) 35 | 36 | if ~unknown_mask2.all(): 37 | pos_mask = (ddg_dense2 <= 0)[~unknown_mask2] 38 | weight2 = torch.cat([n_b.new_ones(n_b) / n_b for n_b in (~unknown_mask2).flatten(1,-1).sum(1)]) 39 | weight2 *= 1 + pos_mask * (args.lambda_pos - 1) 40 | losses2 = F.huber_loss(pred['mut2_ddg'][~unknown_mask2], ddg_dense2[~unknown_mask2], reduction='none') 41 | loss2 = (losses2 * weight2).sum() / weight2.sum() 42 | losses['loss2'] = loss2 * args.lambda_double 43 | elif 'mut2_ddg' not in pred: 44 | losses['loss2'] = losses['loss1'] * 0. 45 | else: 46 | losses['loss2'] = 0. * pred['mut2_ddg'].sum() 47 | return losses 48 | -------------------------------------------------------------------------------- /modeling/mutate_everything.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from modeling.backbone import create_backbone 3 | from modeling.module import create_aa_expander, create_single_decoder, create_multi_decoder 4 | 5 | class MutateEverything(nn.Module): 6 | def __init__(self, args): 7 | super().__init__() 8 | self.args = args 9 | self.backbone = create_backbone(args) 10 | self.aa_expansion = create_aa_expander(args, self.backbone) 11 | self.single_decoder = create_single_decoder(args) 12 | self.multi_decoder = create_multi_decoder(args) 13 | 14 | def forward(self, x, batch): 15 | pred = {} 16 | pred.update(self.backbone(x, batch)) 17 | pred.update(self.aa_expansion(x, batch, pred)) 18 | pred.update(self.single_decoder(x, batch, pred)) 19 | pred.update(self.multi_decoder(x, batch, pred)) 20 | return pred 21 | -------------------------------------------------------------------------------- /modeling/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | def mem_inputs_to_device(batch, device, args): 5 | if args.backbone == 'af': 6 | x = [{k:v.to(device, non_blocking=True) for k,v in x.items()} 7 | for x in batch['af_inputs']] 8 | elif 'esm' in args.backbone: 9 | x = batch['tokens'].to(device, non_blocking=True) 10 | return x 11 | 12 | class FFNLayer(nn.Module): 13 | def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, 14 | activation="relu", normalize_before=False): 15 | super().__init__() 16 | # Implementation of Feedforward model 17 | self.linear1 = nn.Linear(d_model, dim_feedforward) 18 | self.dropout = nn.Dropout(dropout) 19 | self.linear2 = nn.Linear(dim_feedforward, d_model) 20 | 21 | self.norm = nn.LayerNorm(d_model) 22 | 23 | self.activation = _get_activation_fn(activation) 24 | self.normalize_before = normalize_before 25 | 26 | self._reset_parameters() 27 | 28 | def _reset_parameters(self): 29 | for p in self.parameters(): 30 | if p.dim() > 1: 31 | nn.init.xavier_uniform_(p) 32 | 33 | def with_pos_embed(self, tensor, pos): 34 | return tensor if pos is None else tensor + pos 35 | 36 | def forward_post(self, tgt): 37 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 38 | tgt = tgt + self.dropout(tgt2) 39 | tgt = self.norm(tgt) 40 | return tgt 41 | 42 | def forward_pre(self, tgt): 43 | tgt2 = self.norm(tgt) 44 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 45 | tgt = tgt + self.dropout(tgt2) 46 | return tgt 47 | 48 | def forward(self, tgt): 49 | if self.normalize_before: 50 | return self.forward_pre(tgt) 51 | return self.forward_post(tgt) 52 | 53 | 54 | def _get_activation_fn(activation): 55 | """Return an activation function given a string""" 56 | if activation == "relu": 57 | return F.relu 58 | if activation == "gelu": 59 | return F.gelu 60 | if activation == "glu": 61 | return F.glu 62 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 63 | -------------------------------------------------------------------------------- /openfold/.github/workflows/docker-image.yml: -------------------------------------------------------------------------------- 1 | name: Docker Image CI 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Build the Docker image 15 | run: docker build . --file Dockerfile --tag openfold:$(date +%s) -------------------------------------------------------------------------------- /openfold/.github/workflows/undefined_names.yml: -------------------------------------------------------------------------------- 1 | name: undefined_names 2 | on: [pull_request, push] 3 | jobs: 4 | undefined_names: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v2 8 | - uses: actions/setup-python@v2 9 | - run: pip install --upgrade pip 10 | - run: pip install flake8 11 | - run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 12 | -------------------------------------------------------------------------------- /openfold/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | __pycache__/ 3 | *.egg-info 4 | build 5 | dist 6 | 7 | # files from script downloads 8 | data 9 | openfold/resources/ 10 | tests/test_data/ 11 | -------------------------------------------------------------------------------- /openfold/CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | preferred-citation: 3 | authors: 4 | - family-names: "Ahdritz" 5 | given-names: "Gustaf" 6 | orcid: https://orcid.org/0000-0001-8283-5324 7 | - family-names: "Bouatta" 8 | given-names: "Nazim" 9 | orcid: https://orcid.org/0000-0002-6524-874X 10 | - family-names: "Kadyan" 11 | given-names: "Sachin" 12 | orcid: https://orcid.org/0000-0002-6079-7627 13 | - family-names: "Xia" 14 | given-names: "Qinghui" 15 | - family-names: "Gerecke" 16 | given-names: "William" 17 | orcid: https://orcid.org/0000-0002-9777-6192 18 | - family-names: "O'Donnell" 19 | given-names: "Timothy J" 20 | orcid: https://orcid.org/0000-0002-9949-069X 21 | - family-names: "Berenberg" 22 | given-names: "Daniel" 23 | orcid: https://orcid.org/0000-0003-4631-0947 24 | - family-names: "Fisk" 25 | given-names: "Ian" 26 | - family-names: "Zanichelli" 27 | given-names: "Niccolò" 28 | orcid: https://orcid.org/0000-0002-3093-3587 29 | - family-names: "Zhang" 30 | given-names: "Bo" 31 | orcid: https://orcid.org/0000-0002-9714-2827 32 | - family-names: "Nowaczynski" 33 | given-names: "Arkadiusz" 34 | orcid: https://orcid.org/0000-0002-3351-9584 35 | - family-names: "Wang" 36 | given-names: "Bei" 37 | orcid: https://orcid.org/0000-0003-4942-9652 38 | - family-names: "Stepniewska-Dziubinska" 39 | given-names: "Marta M" 40 | orcid: https://orcid.org/0000-0003-4942-9652 41 | - family-names: "Zhang" 42 | given-names: "Shang" 43 | orcid: https://orcid.org/0000-0003-0759-2080 44 | - family-names: "Ojewole" 45 | given-names: "Adegoke" 46 | orcid: https://orcid.org/0000-0003-2661-4388 47 | - family-names: "Guney" 48 | given-names: "Murat Efe" 49 | - family-names: "Biderman" 50 | given-names: "Stella" 51 | orcid: https://orcid.org/0000-0001-8228-1042 52 | - family-names: "Watkins" 53 | given-names: "Andrew M" 54 | orcid: https://orcid.org/0000-0003-1617-1720 55 | - family-names: "Ra" 56 | given-names: "Stephen" 57 | orcid: https://orcid.org/0000-0002-2820-0050 58 | - family-names: "Lorenzo" 59 | given-names: "Pablo Ribalta" 60 | orcid: https://orcid.org/0000-0002-3657-8053 61 | - family-names: "Nivon" 62 | given-names: "Lucas" 63 | - family-names: "Weitzner" 64 | given-names: "Brian" 65 | orcid: https://orcid.org/0000-0002-1909-0961 66 | - family-names: "Ban" 67 | given-names: "Yih-En" 68 | orcid: https://orcid.org/0000-0003-3698-3574 69 | - family-names: "Ban" 70 | given-names: "Yih-En Andrew" 71 | orcid: https://orcid.org/0000-0003-3698-3574 72 | - family-names: "Sorger" 73 | given-names: "Peter K" 74 | orcid: https://orcid.org/0000-0002-3364-1838 75 | - family-names: "Mostaque" 76 | given-names: "Emad" 77 | - family-names: "Zhang" 78 | given-names: "Zhao" 79 | orcid: https://orcid.org/0000-0001-5921-0035 80 | - family-names: "Bonneau" 81 | given-names: "Richard" 82 | orcid: https://orcid.org/0000-0003-4354-7906 83 | - family-names: "AlQuraishi" 84 | given-names: "Mohammed" 85 | orcid: https://orcid.org/0000-0001-6817-1322 86 | title: "OpenFold: Retraining AlphaFold2 yields new insights into its learning mechanisms and capacity for generalization" 87 | type: article 88 | doi: 10.1101/2022.11.20.517210 89 | doi: 10.1101/2022.11.20.517210 90 | date-released: 2021-11-12 91 | url: "https://doi.org/10.1101/2022.11.20.517210" 92 | -------------------------------------------------------------------------------- /openfold/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu18.04 2 | 3 | # metainformation 4 | LABEL org.opencontainers.image.version = "1.0.0" 5 | LABEL org.opencontainers.image.authors = "Gustaf Ahdritz" 6 | LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfold" 7 | LABEL org.opencontainers.image.licenses = "Apache License 2.0" 8 | LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04" 9 | 10 | RUN apt-key del 7fa2af80 11 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub 12 | RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub 13 | 14 | RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git 15 | RUN wget -P /tmp \ 16 | "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \ 17 | && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \ 18 | && rm /tmp/Miniconda3-latest-Linux-x86_64.sh 19 | ENV PATH /opt/conda/bin:$PATH 20 | 21 | COPY environment.yml /opt/openfold/environment.yml 22 | 23 | # installing into the base environment since the docker container wont do anything other than run openfold 24 | RUN conda env update -n base --file /opt/openfold/environment.yml && conda clean --all 25 | 26 | COPY openfold /opt/openfold/openfold 27 | COPY scripts /opt/openfold/scripts 28 | COPY run_pretrained_openfold.py /opt/openfold/run_pretrained_openfold.py 29 | COPY train_openfold.py /opt/openfold/train_openfold.py 30 | COPY setup.py /opt/openfold/setup.py 31 | COPY lib/openmm.patch /opt/openfold/lib/openmm.patch 32 | RUN wget -q -P /opt/openfold/openfold/resources \ 33 | https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt 34 | RUN patch -p0 -d /opt/conda/lib/python3.7/site-packages/ < /opt/openfold/lib/openmm.patch 35 | WORKDIR /opt/openfold 36 | RUN python3 setup.py install 37 | -------------------------------------------------------------------------------- /openfold/deepspeed_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": false, 4 | "min_loss_scale": 1 5 | }, 6 | "amp": { 7 | "enabled": false, 8 | "opt_level": "O2" 9 | }, 10 | "bfloat16": { 11 | "enabled": true 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "cpu_offload": true, 16 | "contiguous_gradients": true 17 | }, 18 | "activation_checkpointing": { 19 | "partition_activations": true, 20 | "cpu_checkpointing": false, 21 | "profile": false 22 | }, 23 | "gradient_clipping": 0.1 24 | } 25 | -------------------------------------------------------------------------------- /openfold/environment.yml: -------------------------------------------------------------------------------- 1 | name: openfold_venv 2 | channels: 3 | - conda-forge 4 | - bioconda 5 | - pytorch 6 | dependencies: 7 | - conda-forge::python=3.7 8 | - conda-forge::setuptools=59.5.0 9 | - conda-forge::pip 10 | - conda-forge::openmm=7.5.1 11 | - conda-forge::pdbfixer 12 | - conda-forge::cudatoolkit==11.3.* 13 | - bioconda::hmmer==3.3.2 14 | - bioconda::hhsuite==3.3.0 15 | - bioconda::kalign2==2.04 16 | - pytorch::pytorch=1.12.* 17 | - pip: 18 | - biopython==1.79 19 | - deepspeed==0.5.10 20 | - dm-tree==0.1.6 21 | - ml-collections==0.1.0 22 | - numpy==1.21.2 23 | - PyYAML==5.4.1 24 | - requests==2.26.0 25 | - scipy==1.7.1 26 | - tqdm==4.62.2 27 | - typing-extensions==3.10.0.2 28 | - pytorch_lightning==1.5.10 29 | - wandb==0.12.21 30 | - modelcif==0.7 31 | - git+https://github.com/NVIDIA/dllogger.git 32 | -------------------------------------------------------------------------------- /openfold/imgs/of_banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/MutateEverything/c380ea9fa185b770df25127e8aabbccbd212a074/openfold/imgs/of_banner.png -------------------------------------------------------------------------------- /openfold/lib/openmm.patch: -------------------------------------------------------------------------------- 1 | Index: simtk/openmm/app/topology.py 2 | =================================================================== 3 | --- simtk.orig/openmm/app/topology.py 4 | +++ simtk/openmm/app/topology.py 5 | @@ -356,19 +356,35 @@ 6 | def isCyx(res): 7 | names = [atom.name for atom in res._atoms] 8 | return 'SG' in names and 'HG' not in names 9 | + # This function is used to prevent multiple di-sulfide bonds from being 10 | + # assigned to a given atom. This is a DeepMind modification. 11 | + def isDisulfideBonded(atom): 12 | + for b in self._bonds: 13 | + if (atom in b and b[0].name == 'SG' and 14 | + b[1].name == 'SG'): 15 | + return True 16 | + 17 | + return False 18 | 19 | cyx = [res for res in self.residues() if res.name == 'CYS' and isCyx(res)] 20 | atomNames = [[atom.name for atom in res._atoms] for res in cyx] 21 | for i in range(len(cyx)): 22 | sg1 = cyx[i]._atoms[atomNames[i].index('SG')] 23 | pos1 = positions[sg1.index] 24 | + candidate_distance, candidate_atom = 0.3*nanometers, None 25 | for j in range(i): 26 | sg2 = cyx[j]._atoms[atomNames[j].index('SG')] 27 | pos2 = positions[sg2.index] 28 | delta = [x-y for (x,y) in zip(pos1, pos2)] 29 | distance = sqrt(delta[0]*delta[0] + delta[1]*delta[1] + delta[2]*delta[2]) 30 | - if distance < 0.3*nanometers: 31 | - self.addBond(sg1, sg2) 32 | + if distance < candidate_distance and not isDisulfideBonded(sg2): 33 | + candidate_distance = distance 34 | + candidate_atom = sg2 35 | + # Assign bond to closest pair. 36 | + if candidate_atom: 37 | + self.addBond(sg1, candidate_atom) 38 | + 39 | + 40 | 41 | class Chain(object): 42 | """A Chain object represents a chain within a Topology.""" 43 | -------------------------------------------------------------------------------- /openfold/notebooks/environment.yml: -------------------------------------------------------------------------------- 1 | name: openfold_venv 2 | channels: 3 | - conda-forge 4 | - bioconda 5 | dependencies: 6 | - conda-forge::openmm=7.5.1 7 | - conda-forge::pdbfixer 8 | - bioconda::hmmer==3.3.2 9 | - bioconda::hhsuite==3.3.0 10 | - bioconda::kalign2==2.04 11 | - pip: 12 | - biopython==1.79 13 | - dm-tree==0.1.6 14 | - ml-collections==0.1.0 15 | - PyYAML==5.4.1 16 | - requests==2.26.0 17 | - typing-extensions==3.10.0.2 18 | -------------------------------------------------------------------------------- /openfold/openfold/__init__.py: -------------------------------------------------------------------------------- 1 | from . import model 2 | from . import utils 3 | from . import np 4 | from . import resources 5 | 6 | __all__ = ["model", "utils", "np", "data", "resources"] 7 | -------------------------------------------------------------------------------- /openfold/openfold/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/MutateEverything/c380ea9fa185b770df25127e8aabbccbd212a074/openfold/openfold/data/__init__.py -------------------------------------------------------------------------------- /openfold/openfold/data/errors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """General-purpose errors used throughout the data pipeline""" 17 | class Error(Exception): 18 | """Base class for exceptions.""" 19 | 20 | 21 | class MultipleChainsError(Error): 22 | """An error indicating that multiple chains were found for a given ID.""" 23 | -------------------------------------------------------------------------------- /openfold/openfold/data/feature_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import copy 17 | from typing import Mapping, Tuple, List, Optional, Dict, Sequence 18 | 19 | import ml_collections 20 | import numpy as np 21 | import torch 22 | 23 | from openfold.data import input_pipeline 24 | 25 | 26 | FeatureDict = Mapping[str, np.ndarray] 27 | TensorDict = Dict[str, torch.Tensor] 28 | 29 | 30 | def np_to_tensor_dict( 31 | np_example: Mapping[str, np.ndarray], 32 | features: Sequence[str], 33 | ) -> TensorDict: 34 | """Creates dict of tensors from a dict of NumPy arrays. 35 | 36 | Args: 37 | np_example: A dict of NumPy feature arrays. 38 | features: A list of strings of feature names to be returned in the dataset. 39 | 40 | Returns: 41 | A dictionary of features mapping feature names to features. Only the given 42 | features are returned, all other ones are filtered out. 43 | """ 44 | tensor_dict = { 45 | k: torch.tensor(v) for k, v in np_example.items() if k in features 46 | } 47 | 48 | return tensor_dict 49 | 50 | 51 | def make_data_config( 52 | config: ml_collections.ConfigDict, 53 | mode: str, 54 | num_res: int, 55 | ) -> Tuple[ml_collections.ConfigDict, List[str]]: 56 | cfg = copy.deepcopy(config) 57 | mode_cfg = cfg[mode] 58 | with cfg.unlocked(): 59 | if mode_cfg.crop_size is None: 60 | mode_cfg.crop_size = num_res 61 | 62 | feature_names = cfg.common.unsupervised_features 63 | 64 | if cfg.common.use_templates: 65 | feature_names += cfg.common.template_features 66 | 67 | if cfg[mode].supervised: 68 | feature_names += cfg.supervised.supervised_features 69 | 70 | return cfg, feature_names 71 | 72 | 73 | def np_example_to_features( 74 | np_example: FeatureDict, 75 | config: ml_collections.ConfigDict, 76 | mode: str, 77 | ): 78 | np_example = dict(np_example) 79 | num_res = int(np_example["seq_length"][0]) 80 | cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) 81 | 82 | if "deletion_matrix_int" in np_example: 83 | np_example["deletion_matrix"] = np_example.pop( 84 | "deletion_matrix_int" 85 | ).astype(np.float32) 86 | 87 | tensor_dict = np_to_tensor_dict( 88 | np_example=np_example, features=feature_names 89 | ) 90 | with torch.no_grad(): 91 | features = input_pipeline.process_tensors_from_config( 92 | tensor_dict, 93 | cfg.common, 94 | cfg[mode], 95 | ) 96 | 97 | if mode == "train": 98 | p = torch.rand(1).item() 99 | use_clamped_fape_value = float(p < cfg.supervised.clamp_prob) 100 | features["use_clamped_fape"] = torch.full( 101 | size=[cfg.common.max_recycling_iters + 1], 102 | fill_value=use_clamped_fape_value, 103 | dtype=torch.float32, 104 | ) 105 | else: 106 | features["use_clamped_fape"] = torch.full( 107 | size=[cfg.common.max_recycling_iters + 1], 108 | fill_value=0.0, 109 | dtype=torch.float32, 110 | ) 111 | 112 | return {k: v for k, v in features.items()} 113 | 114 | 115 | class FeaturePipeline: 116 | def __init__( 117 | self, 118 | config: ml_collections.ConfigDict, 119 | ): 120 | self.config = config 121 | 122 | def process_features( 123 | self, 124 | raw_features: FeatureDict, 125 | mode: str = "train", 126 | ) -> FeatureDict: 127 | return np_example_to_features( 128 | np_example=raw_features, 129 | config=self.config, 130 | mode=mode, 131 | ) 132 | -------------------------------------------------------------------------------- /openfold/openfold/data/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/MutateEverything/c380ea9fa185b770df25127e8aabbccbd212a074/openfold/openfold/data/tools/__init__.py -------------------------------------------------------------------------------- /openfold/openfold/data/tools/hhsearch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Library to run HHsearch from Python.""" 17 | import glob 18 | import logging 19 | import os 20 | import subprocess 21 | from typing import Sequence 22 | 23 | from openfold.data.tools import utils 24 | 25 | 26 | class HHSearch: 27 | """Python wrapper of the HHsearch binary.""" 28 | 29 | def __init__( 30 | self, 31 | *, 32 | binary_path: str, 33 | databases: Sequence[str], 34 | n_cpu: int = 2, 35 | maxseq: int = 1_000_000, 36 | ): 37 | """Initializes the Python HHsearch wrapper. 38 | 39 | Args: 40 | binary_path: The path to the HHsearch executable. 41 | databases: A sequence of HHsearch database paths. This should be the 42 | common prefix for the database files (i.e. up to but not including 43 | _hhm.ffindex etc.) 44 | n_cpu: The number of CPUs to use 45 | maxseq: The maximum number of rows in an input alignment. Note that this 46 | parameter is only supported in HHBlits version 3.1 and higher. 47 | 48 | Raises: 49 | RuntimeError: If HHsearch binary not found within the path. 50 | """ 51 | self.binary_path = binary_path 52 | self.databases = databases 53 | self.n_cpu = n_cpu 54 | self.maxseq = maxseq 55 | 56 | for database_path in self.databases: 57 | if not glob.glob(database_path + "_*"): 58 | logging.error( 59 | "Could not find HHsearch database %s", database_path 60 | ) 61 | raise ValueError( 62 | f"Could not find HHsearch database {database_path}" 63 | ) 64 | 65 | def query(self, a3m: str) -> str: 66 | """Queries the database using HHsearch using a given a3m.""" 67 | with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: 68 | input_path = os.path.join(query_tmp_dir, "query.a3m") 69 | hhr_path = os.path.join(query_tmp_dir, "output.hhr") 70 | with open(input_path, "w") as f: 71 | f.write(a3m) 72 | 73 | db_cmd = [] 74 | for db_path in self.databases: 75 | db_cmd.append("-d") 76 | db_cmd.append(db_path) 77 | cmd = [ 78 | self.binary_path, 79 | "-i", 80 | input_path, 81 | "-o", 82 | hhr_path, 83 | "-maxseq", 84 | str(self.maxseq), 85 | "-cpu", 86 | str(self.n_cpu), 87 | ] + db_cmd 88 | 89 | logging.info('Launching subprocess "%s"', " ".join(cmd)) 90 | process = subprocess.Popen( 91 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE 92 | ) 93 | with utils.timing("HHsearch query"): 94 | stdout, stderr = process.communicate() 95 | retcode = process.wait() 96 | 97 | if retcode: 98 | # Stderr is truncated to prevent proto size errors in Beam. 99 | raise RuntimeError( 100 | "HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n" 101 | % (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8")) 102 | ) 103 | 104 | with open(hhr_path) as f: 105 | hhr = f.read() 106 | return hhr 107 | -------------------------------------------------------------------------------- /openfold/openfold/data/tools/kalign.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A Python wrapper for Kalign.""" 17 | import os 18 | import subprocess 19 | from typing import Sequence 20 | 21 | from absl import logging 22 | 23 | from openfold.data.tools import utils 24 | 25 | 26 | def _to_a3m(sequences: Sequence[str]) -> str: 27 | """Converts sequences to an a3m file.""" 28 | names = ["sequence %d" % i for i in range(1, len(sequences) + 1)] 29 | a3m = [] 30 | for sequence, name in zip(sequences, names): 31 | a3m.append(u">" + name + u"\n") 32 | a3m.append(sequence + u"\n") 33 | return "".join(a3m) 34 | 35 | 36 | class Kalign: 37 | """Python wrapper of the Kalign binary.""" 38 | 39 | def __init__(self, *, binary_path: str): 40 | """Initializes the Python Kalign wrapper. 41 | 42 | Args: 43 | binary_path: The path to the Kalign binary. 44 | 45 | Raises: 46 | RuntimeError: If Kalign binary not found within the path. 47 | """ 48 | self.binary_path = binary_path 49 | 50 | def align(self, sequences: Sequence[str]) -> str: 51 | """Aligns the sequences and returns the alignment in A3M string. 52 | 53 | Args: 54 | sequences: A list of query sequence strings. The sequences have to be at 55 | least 6 residues long (Kalign requires this). Note that the order in 56 | which you give the sequences might alter the output slightly as 57 | different alignment tree might get constructed. 58 | 59 | Returns: 60 | A string with the alignment in a3m format. 61 | 62 | Raises: 63 | RuntimeError: If Kalign fails. 64 | ValueError: If any of the sequences is less than 6 residues long. 65 | """ 66 | logging.info("Aligning %d sequences", len(sequences)) 67 | 68 | for s in sequences: 69 | if len(s) < 6: 70 | raise ValueError( 71 | "Kalign requires all sequences to be at least 6 " 72 | "residues long. Got %s (%d residues)." % (s, len(s)) 73 | ) 74 | 75 | with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: 76 | input_fasta_path = os.path.join(query_tmp_dir, "input.fasta") 77 | output_a3m_path = os.path.join(query_tmp_dir, "output.a3m") 78 | 79 | with open(input_fasta_path, "w") as f: 80 | f.write(_to_a3m(sequences)) 81 | 82 | cmd = [ 83 | self.binary_path, 84 | "-i", 85 | input_fasta_path, 86 | "-o", 87 | output_a3m_path, 88 | "-format", 89 | "fasta", 90 | ] 91 | 92 | logging.info('Launching subprocess "%s"', " ".join(cmd)) 93 | process = subprocess.Popen( 94 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE 95 | ) 96 | 97 | with utils.timing("Kalign query"): 98 | stdout, stderr = process.communicate() 99 | retcode = process.wait() 100 | logging.info( 101 | "Kalign stdout:\n%s\n\nstderr:\n%s\n", 102 | stdout.decode("utf-8"), 103 | stderr.decode("utf-8"), 104 | ) 105 | 106 | if retcode: 107 | raise RuntimeError( 108 | "Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n" 109 | % (stdout.decode("utf-8"), stderr.decode("utf-8")) 110 | ) 111 | 112 | with open(output_a3m_path) as f: 113 | a3m = f.read() 114 | 115 | return a3m 116 | -------------------------------------------------------------------------------- /openfold/openfold/data/tools/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Common utilities for data pipeline tools.""" 17 | import contextlib 18 | import datetime 19 | import logging 20 | import shutil 21 | import tempfile 22 | import time 23 | from typing import Optional 24 | 25 | 26 | @contextlib.contextmanager 27 | def tmpdir_manager(base_dir: Optional[str] = None): 28 | """Context manager that deletes a temporary directory on exit.""" 29 | tmpdir = tempfile.mkdtemp(dir=base_dir) 30 | try: 31 | yield tmpdir 32 | finally: 33 | shutil.rmtree(tmpdir, ignore_errors=True) 34 | 35 | 36 | @contextlib.contextmanager 37 | def timing(msg: str): 38 | logging.info("Started %s", msg) 39 | tic = time.perf_counter() 40 | yield 41 | toc = time.perf_counter() 42 | logging.info("Finished %s in %.3f seconds", msg, toc - tic) 43 | 44 | 45 | def to_date(s: str): 46 | return datetime.datetime( 47 | year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10]) 48 | ) 49 | -------------------------------------------------------------------------------- /openfold/openfold/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/MutateEverything/c380ea9fa185b770df25127e8aabbccbd212a074/openfold/openfold/model/__init__.py -------------------------------------------------------------------------------- /openfold/openfold/model/dropout.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | import torch.nn as nn 18 | from functools import partialmethod 19 | from typing import Union, List 20 | 21 | 22 | class Dropout(nn.Module): 23 | """ 24 | Implementation of dropout with the ability to share the dropout mask 25 | along a particular dimension. 26 | 27 | If not in training mode, this module computes the identity function. 28 | """ 29 | 30 | def __init__(self, r: float, batch_dim: Union[int, List[int]]): 31 | """ 32 | Args: 33 | r: 34 | Dropout rate 35 | batch_dim: 36 | Dimension(s) along which the dropout mask is shared 37 | """ 38 | super(Dropout, self).__init__() 39 | 40 | self.r = r 41 | if type(batch_dim) == int: 42 | batch_dim = [batch_dim] 43 | self.batch_dim = batch_dim 44 | self.dropout = nn.Dropout(self.r) 45 | 46 | def forward(self, x: torch.Tensor) -> torch.Tensor: 47 | """ 48 | Args: 49 | x: 50 | Tensor to which dropout is applied. Can have any shape 51 | compatible with self.batch_dim 52 | """ 53 | shape = list(x.shape) 54 | if self.batch_dim is not None: 55 | for bd in self.batch_dim: 56 | shape[bd] = 1 57 | mask = x.new_ones(shape) 58 | mask = self.dropout(mask) 59 | x *= mask 60 | return x 61 | 62 | 63 | class DropoutRowwise(Dropout): 64 | """ 65 | Convenience class for rowwise dropout as described in subsection 66 | 1.11.6. 67 | """ 68 | 69 | __init__ = partialmethod(Dropout.__init__, batch_dim=-3) 70 | 71 | 72 | class DropoutColumnwise(Dropout): 73 | """ 74 | Convenience class for columnwise dropout as described in subsection 75 | 1.11.6. 76 | """ 77 | 78 | __init__ = partialmethod(Dropout.__init__, batch_dim=-2) 79 | -------------------------------------------------------------------------------- /openfold/openfold/model/outer_product_mean.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from functools import partial 17 | from typing import Optional 18 | 19 | import torch 20 | import torch.nn as nn 21 | 22 | from openfold.model.primitives import Linear 23 | from openfold.utils.chunk_utils import chunk_layer 24 | from openfold.utils.precision_utils import is_fp16_enabled 25 | 26 | 27 | class OuterProductMean(nn.Module): 28 | """ 29 | Implements Algorithm 10. 30 | """ 31 | 32 | def __init__(self, c_m, c_z, c_hidden, eps=1e-3): 33 | """ 34 | Args: 35 | c_m: 36 | MSA embedding channel dimension 37 | c_z: 38 | Pair embedding channel dimension 39 | c_hidden: 40 | Hidden channel dimension 41 | """ 42 | super(OuterProductMean, self).__init__() 43 | 44 | self.c_m = c_m 45 | self.c_z = c_z 46 | self.c_hidden = c_hidden 47 | self.eps = eps 48 | 49 | self.layer_norm = nn.LayerNorm(c_m) 50 | self.linear_1 = Linear(c_m, c_hidden) 51 | self.linear_2 = Linear(c_m, c_hidden) 52 | self.linear_out = Linear(c_hidden ** 2, c_z, init="final") 53 | 54 | def _opm(self, a, b): 55 | # [*, N_res, N_res, C, C] 56 | outer = torch.einsum("...bac,...dae->...bdce", a, b) 57 | 58 | # [*, N_res, N_res, C * C] 59 | outer = outer.reshape(outer.shape[:-2] + (-1,)) 60 | 61 | # [*, N_res, N_res, C_z] 62 | outer = self.linear_out(outer) 63 | 64 | return outer 65 | 66 | @torch.jit.ignore 67 | def _chunk(self, 68 | a: torch.Tensor, 69 | b: torch.Tensor, 70 | chunk_size: int 71 | ) -> torch.Tensor: 72 | # Since the "batch dim" in this case is not a true batch dimension 73 | # (in that the shape of the output depends on it), we need to 74 | # iterate over it ourselves 75 | a_reshape = a.reshape((-1,) + a.shape[-3:]) 76 | b_reshape = b.reshape((-1,) + b.shape[-3:]) 77 | out = [] 78 | for a_prime, b_prime in zip(a_reshape, b_reshape): 79 | outer = chunk_layer( 80 | partial(self._opm, b=b_prime), 81 | {"a": a_prime}, 82 | chunk_size=chunk_size, 83 | no_batch_dims=1, 84 | ) 85 | out.append(outer) 86 | 87 | # For some cursed reason making this distinction saves memory 88 | if(len(out) == 1): 89 | outer = out[0].unsqueeze(0) 90 | else: 91 | outer = torch.stack(out, dim=0) 92 | 93 | outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) 94 | 95 | return outer 96 | 97 | def _forward(self, 98 | m: torch.Tensor, 99 | mask: Optional[torch.Tensor] = None, 100 | chunk_size: Optional[int] = None, 101 | inplace_safe: bool = False, 102 | ) -> torch.Tensor: 103 | """ 104 | Args: 105 | m: 106 | [*, N_seq, N_res, C_m] MSA embedding 107 | mask: 108 | [*, N_seq, N_res] MSA mask 109 | Returns: 110 | [*, N_res, N_res, C_z] pair embedding update 111 | """ 112 | if mask is None: 113 | mask = m.new_ones(m.shape[:-1]) 114 | 115 | # [*, N_seq, N_res, C_m] 116 | ln = self.layer_norm(m) 117 | 118 | # [*, N_seq, N_res, C] 119 | mask = mask.unsqueeze(-1) 120 | a = self.linear_1(ln) 121 | a = a * mask 122 | 123 | b = self.linear_2(ln) 124 | b = b * mask 125 | 126 | del ln 127 | 128 | a = a.transpose(-2, -3) 129 | b = b.transpose(-2, -3) 130 | 131 | if chunk_size is not None: 132 | outer = self._chunk(a, b, chunk_size) 133 | else: 134 | outer = self._opm(a, b) 135 | 136 | # [*, N_res, N_res, 1] 137 | norm = torch.einsum("...abc,...adc->...bdc", mask, mask) 138 | norm = norm + self.eps 139 | 140 | # [*, N_res, N_res, C_z] 141 | if(inplace_safe): 142 | outer /= norm 143 | else: 144 | outer = outer / norm 145 | 146 | return outer 147 | 148 | def forward(self, 149 | m: torch.Tensor, 150 | mask: Optional[torch.Tensor] = None, 151 | chunk_size: Optional[int] = None, 152 | inplace_safe: bool = False, 153 | ) -> torch.Tensor: 154 | if(is_fp16_enabled()): 155 | with torch.cuda.amp.autocast(enabled=False): 156 | return self._forward(m.float(), mask, chunk_size, inplace_safe) 157 | else: 158 | return self._forward(m, mask, chunk_size, inplace_safe) 159 | 160 | -------------------------------------------------------------------------------- /openfold/openfold/model/pair_transition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from typing import Optional 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from openfold.model.primitives import Linear, LayerNorm 21 | from openfold.utils.chunk_utils import chunk_layer 22 | 23 | 24 | class PairTransition(nn.Module): 25 | """ 26 | Implements Algorithm 15. 27 | """ 28 | 29 | def __init__(self, c_z, n): 30 | """ 31 | Args: 32 | c_z: 33 | Pair transition channel dimension 34 | n: 35 | Factor by which c_z is multiplied to obtain hidden channel 36 | dimension 37 | """ 38 | super(PairTransition, self).__init__() 39 | 40 | self.c_z = c_z 41 | self.n = n 42 | 43 | self.layer_norm = LayerNorm(self.c_z) 44 | self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") 45 | self.relu = nn.ReLU() 46 | self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") 47 | 48 | def _transition(self, z, mask): 49 | # [*, N_res, N_res, C_z] 50 | z = self.layer_norm(z) 51 | 52 | # [*, N_res, N_res, C_hidden] 53 | z = self.linear_1(z) 54 | z = self.relu(z) 55 | 56 | # [*, N_res, N_res, C_z] 57 | z = self.linear_2(z) 58 | z = z * mask 59 | 60 | return z 61 | 62 | @torch.jit.ignore 63 | def _chunk(self, 64 | z: torch.Tensor, 65 | mask: torch.Tensor, 66 | chunk_size: int, 67 | ) -> torch.Tensor: 68 | return chunk_layer( 69 | self._transition, 70 | {"z": z, "mask": mask}, 71 | chunk_size=chunk_size, 72 | no_batch_dims=len(z.shape[:-2]), 73 | ) 74 | 75 | def forward(self, 76 | z: torch.Tensor, 77 | mask: Optional[torch.Tensor] = None, 78 | chunk_size: Optional[int] = None, 79 | ) -> torch.Tensor: 80 | """ 81 | Args: 82 | z: 83 | [*, N_res, N_res, C_z] pair embedding 84 | Returns: 85 | [*, N_res, N_res, C_z] pair embedding update 86 | """ 87 | # DISCREPANCY: DeepMind forgets to apply the mask in this module. 88 | if mask is None: 89 | mask = z.new_ones(z.shape[:-1]) 90 | 91 | # [*, N_res, N_res, 1] 92 | mask = mask.unsqueeze(-1) 93 | 94 | if chunk_size is not None: 95 | z = self._chunk(z, mask, chunk_size) 96 | else: 97 | z = self._transition(z=z, mask=mask) 98 | 99 | return z 100 | -------------------------------------------------------------------------------- /openfold/openfold/model/triangular_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from functools import partialmethod, partial 17 | import math 18 | from typing import Optional, List 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | from openfold.model.primitives import Linear, LayerNorm, Attention 24 | from openfold.utils.chunk_utils import chunk_layer 25 | from openfold.utils.tensor_utils import ( 26 | permute_final_dims, 27 | flatten_final_dims, 28 | ) 29 | 30 | 31 | class TriangleAttention(nn.Module): 32 | def __init__( 33 | self, c_in, c_hidden, no_heads, starting=True, inf=1e9 34 | ): 35 | """ 36 | Args: 37 | c_in: 38 | Input channel dimension 39 | c_hidden: 40 | Overall hidden channel dimension (not per-head) 41 | no_heads: 42 | Number of attention heads 43 | """ 44 | super(TriangleAttention, self).__init__() 45 | 46 | self.c_in = c_in 47 | self.c_hidden = c_hidden 48 | self.no_heads = no_heads 49 | self.starting = starting 50 | self.inf = inf 51 | 52 | self.layer_norm = LayerNorm(self.c_in) 53 | 54 | self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") 55 | 56 | self.mha = Attention( 57 | self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads 58 | ) 59 | 60 | @torch.jit.ignore 61 | def _chunk(self, 62 | x: torch.Tensor, 63 | biases: List[torch.Tensor], 64 | chunk_size: int, 65 | use_memory_efficient_kernel: bool = False, 66 | use_lma: bool = False, 67 | inplace_safe: bool = False, 68 | ) -> torch.Tensor: 69 | "triangle! triangle!" 70 | mha_inputs = { 71 | "q_x": x, 72 | "kv_x": x, 73 | "biases": biases, 74 | } 75 | 76 | return chunk_layer( 77 | partial( 78 | self.mha, 79 | use_memory_efficient_kernel=use_memory_efficient_kernel, 80 | use_lma=use_lma 81 | ), 82 | mha_inputs, 83 | chunk_size=chunk_size, 84 | no_batch_dims=len(x.shape[:-2]), 85 | _out=x if inplace_safe else None, 86 | ) 87 | 88 | def forward(self, 89 | x: torch.Tensor, 90 | mask: Optional[torch.Tensor] = None, 91 | chunk_size: Optional[int] = None, 92 | use_memory_efficient_kernel: bool = False, 93 | use_lma: bool = False, 94 | inplace_safe: bool = False, 95 | ) -> torch.Tensor: 96 | """ 97 | Args: 98 | x: 99 | [*, I, J, C_in] input tensor (e.g. the pair representation) 100 | Returns: 101 | [*, I, J, C_in] output tensor 102 | """ 103 | if mask is None: 104 | # [*, I, J] 105 | mask = x.new_ones( 106 | x.shape[:-1], 107 | ) 108 | 109 | if(not self.starting): 110 | x = x.transpose(-2, -3) 111 | mask = mask.transpose(-1, -2) 112 | 113 | # [*, I, J, C_in] 114 | x = self.layer_norm(x) 115 | 116 | # [*, I, 1, 1, J] 117 | mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] 118 | 119 | # [*, H, I, J] 120 | triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) 121 | 122 | # [*, 1, H, I, J] 123 | triangle_bias = triangle_bias.unsqueeze(-4) 124 | 125 | biases = [mask_bias, triangle_bias] 126 | 127 | if chunk_size is not None: 128 | x = self._chunk( 129 | x, 130 | biases, 131 | chunk_size, 132 | use_memory_efficient_kernel=use_memory_efficient_kernel, 133 | use_lma=use_lma, 134 | inplace_safe=inplace_safe, 135 | ) 136 | else: 137 | x = self.mha( 138 | q_x=x, 139 | kv_x=x, 140 | biases=biases, 141 | use_memory_efficient_kernel=use_memory_efficient_kernel, 142 | use_lma=use_lma 143 | ) 144 | 145 | if(not self.starting): 146 | x = x.transpose(-2, -3) 147 | 148 | return x 149 | 150 | 151 | # Implements Algorithm 13 152 | TriangleAttentionStartingNode = TriangleAttention 153 | 154 | 155 | class TriangleAttentionEndingNode(TriangleAttention): 156 | """ 157 | Implements Algorithm 14. 158 | """ 159 | __init__ = partialmethod(TriangleAttention.__init__, starting=False) 160 | -------------------------------------------------------------------------------- /openfold/openfold/np/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/MutateEverything/c380ea9fa185b770df25127e8aabbccbd212a074/openfold/openfold/np/__init__.py -------------------------------------------------------------------------------- /openfold/openfold/np/relax/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/MutateEverything/c380ea9fa185b770df25127e8aabbccbd212a074/openfold/openfold/np/relax/__init__.py -------------------------------------------------------------------------------- /openfold/openfold/np/relax/cleanup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations. 16 | 17 | fix_pdb uses a third-party tool. We also support fixing some additional edge 18 | cases like removing chains of length one (see clean_structure). 19 | """ 20 | import io 21 | 22 | import pdbfixer 23 | try: 24 | # openmm >= 7.6 25 | from openmm import app 26 | from openmm.app import element 27 | except ImportError: 28 | # openmm < 7.6 (requires DeepMind patch) 29 | from simtk.openmm import app 30 | from simtk.openmm.app import element 31 | 32 | 33 | def fix_pdb(pdbfile, alterations_info): 34 | """Apply pdbfixer to the contents of a PDB file; return a PDB string result. 35 | 36 | 1) Replaces nonstandard residues. 37 | 2) Removes heterogens (non protein residues) including water. 38 | 3) Adds missing residues and missing atoms within existing residues. 39 | 4) Adds hydrogens assuming pH=7.0. 40 | 5) KeepIds is currently true, so the fixer must keep the existing chain and 41 | residue identifiers. This will fail for some files in wider PDB that have 42 | invalid IDs. 43 | 44 | Args: 45 | pdbfile: Input PDB file handle. 46 | alterations_info: A dict that will store details of changes made. 47 | 48 | Returns: 49 | A PDB string representing the fixed structure. 50 | """ 51 | fixer = pdbfixer.PDBFixer(pdbfile=pdbfile) 52 | fixer.findNonstandardResidues() 53 | alterations_info["nonstandard_residues"] = fixer.nonstandardResidues 54 | fixer.replaceNonstandardResidues() 55 | _remove_heterogens(fixer, alterations_info, keep_water=False) 56 | fixer.findMissingResidues() 57 | alterations_info["missing_residues"] = fixer.missingResidues 58 | fixer.findMissingAtoms() 59 | alterations_info["missing_heavy_atoms"] = fixer.missingAtoms 60 | alterations_info["missing_terminals"] = fixer.missingTerminals 61 | fixer.addMissingAtoms(seed=0) 62 | fixer.addMissingHydrogens() 63 | out_handle = io.StringIO() 64 | app.PDBFile.writeFile( 65 | fixer.topology, fixer.positions, out_handle, keepIds=True 66 | ) 67 | return out_handle.getvalue() 68 | 69 | 70 | def clean_structure(pdb_structure, alterations_info): 71 | """Applies additional fixes to an OpenMM structure, to handle edge cases. 72 | 73 | Args: 74 | pdb_structure: An OpenMM structure to modify and fix. 75 | alterations_info: A dict that will store details of changes made. 76 | """ 77 | _replace_met_se(pdb_structure, alterations_info) 78 | _remove_chains_of_length_one(pdb_structure, alterations_info) 79 | 80 | 81 | def _remove_heterogens(fixer, alterations_info, keep_water): 82 | """Removes the residues that Pdbfixer considers to be heterogens. 83 | 84 | Args: 85 | fixer: A Pdbfixer instance. 86 | alterations_info: A dict that will store details of changes made. 87 | keep_water: If True, water (HOH) is not considered to be a heterogen. 88 | """ 89 | initial_resnames = set() 90 | for chain in fixer.topology.chains(): 91 | for residue in chain.residues(): 92 | initial_resnames.add(residue.name) 93 | fixer.removeHeterogens(keepWater=keep_water) 94 | final_resnames = set() 95 | for chain in fixer.topology.chains(): 96 | for residue in chain.residues(): 97 | final_resnames.add(residue.name) 98 | alterations_info["removed_heterogens"] = initial_resnames.difference( 99 | final_resnames 100 | ) 101 | 102 | 103 | def _replace_met_se(pdb_structure, alterations_info): 104 | """Replace the Se in any MET residues that were not marked as modified.""" 105 | modified_met_residues = [] 106 | for res in pdb_structure.iter_residues(): 107 | name = res.get_name_with_spaces().strip() 108 | if name == "MET": 109 | s_atom = res.get_atom("SD") 110 | if s_atom.element_symbol == "Se": 111 | s_atom.element_symbol = "S" 112 | s_atom.element = element.get_by_symbol("S") 113 | modified_met_residues.append(s_atom.residue_number) 114 | alterations_info["Se_in_MET"] = modified_met_residues 115 | 116 | 117 | def _remove_chains_of_length_one(pdb_structure, alterations_info): 118 | """Removes chains that correspond to a single amino acid. 119 | 120 | A single amino acid in a chain is both N and C terminus. There is no force 121 | template for this case. 122 | 123 | Args: 124 | pdb_structure: An OpenMM pdb_structure to modify and fix. 125 | alterations_info: A dict that will store details of changes made. 126 | """ 127 | removed_chains = {} 128 | for model in pdb_structure.iter_models(): 129 | valid_chains = [c for c in model.iter_chains() if len(c) > 1] 130 | invalid_chain_ids = [ 131 | c.chain_id for c in model.iter_chains() if len(c) <= 1 132 | ] 133 | model.chains = valid_chains 134 | for chain_id in invalid_chain_ids: 135 | model.chains_by_id.pop(chain_id) 136 | removed_chains[model.number] = invalid_chain_ids 137 | alterations_info["removed_chains"] = removed_chains 138 | -------------------------------------------------------------------------------- /openfold/openfold/np/relax/relax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Amber relaxation.""" 17 | from typing import Any, Dict, Sequence, Tuple 18 | from openfold.np import protein 19 | from openfold.np.relax import amber_minimize, utils 20 | import numpy as np 21 | 22 | 23 | class AmberRelaxation(object): 24 | """Amber relaxation.""" 25 | def __init__( 26 | self, 27 | *, 28 | max_iterations: int, 29 | tolerance: float, 30 | stiffness: float, 31 | exclude_residues: Sequence[int], 32 | max_outer_iterations: int, 33 | use_gpu: bool, 34 | ): 35 | """Initialize Amber Relaxer. 36 | 37 | Args: 38 | max_iterations: Maximum number of L-BFGS iterations. 0 means no max. 39 | tolerance: kcal/mol, the energy tolerance of L-BFGS. 40 | stiffness: kcal/mol A**2, spring constant of heavy atom restraining 41 | potential. 42 | exclude_residues: Residues to exclude from per-atom restraining. 43 | Zero-indexed. 44 | max_outer_iterations: Maximum number of violation-informed relax 45 | iterations. A value of 1 will run the non-iterative procedure used in 46 | CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes 47 | as soon as there are no violations, hence in most cases this causes no 48 | slowdown. In the worst case we do 20 outer iterations. 49 | use_gpu: Whether to run on GPU 50 | """ 51 | 52 | self._max_iterations = max_iterations 53 | self._tolerance = tolerance 54 | self._stiffness = stiffness 55 | self._exclude_residues = exclude_residues 56 | self._max_outer_iterations = max_outer_iterations 57 | self._use_gpu = use_gpu 58 | 59 | def process( 60 | self, *, prot: protein.Protein, cif_output: bool 61 | ) -> Tuple[str, Dict[str, Any], np.ndarray]: 62 | """Runs Amber relax on a prediction, adds hydrogens, returns PDB string.""" 63 | out = amber_minimize.run_pipeline( 64 | prot=prot, 65 | max_iterations=self._max_iterations, 66 | tolerance=self._tolerance, 67 | stiffness=self._stiffness, 68 | exclude_residues=self._exclude_residues, 69 | max_outer_iterations=self._max_outer_iterations, 70 | use_gpu=self._use_gpu, 71 | ) 72 | min_pos = out["pos"] 73 | start_pos = out["posinit"] 74 | rmsd = np.sqrt(np.sum((start_pos - min_pos) ** 2) / start_pos.shape[0]) 75 | debug_data = { 76 | "initial_energy": out["einit"], 77 | "final_energy": out["efinal"], 78 | "attempts": out["min_attempts"], 79 | "rmsd": rmsd, 80 | } 81 | pdb_str = amber_minimize.clean_protein(prot) 82 | min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos) 83 | min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors) 84 | utils.assert_equal_nonterminal_atom_types( 85 | protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask 86 | ) 87 | violations = out["structural_violations"][ 88 | "total_per_residue_violations_mask" 89 | ] 90 | 91 | min_pdb = protein.add_pdb_headers(prot, min_pdb) 92 | output_str = min_pdb 93 | if cif_output: 94 | # TODO the model cif will be missing some metadata like headers (PARENTs and 95 | # REMARK with some details of the run, like num of recycles) 96 | final_prot = protein.from_pdb_string(min_pdb) 97 | output_str = protein.to_modelcif(final_prot) 98 | 99 | return output_str, debug_data, violations 100 | -------------------------------------------------------------------------------- /openfold/openfold/np/relax/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utils for minimization.""" 17 | import io 18 | from openfold.np import residue_constants 19 | from Bio import PDB 20 | import numpy as np 21 | try: 22 | # openmm >= 7.6 23 | from openmm import app as openmm_app 24 | from openmm.app.internal.pdbstructure import PdbStructure 25 | except ImportError: 26 | # openmm < 7.6 (requires DeepMind patch) 27 | from simtk.openmm import app as openmm_app 28 | from simtk.openmm.app.internal.pdbstructure import PdbStructure 29 | 30 | 31 | def overwrite_pdb_coordinates(pdb_str: str, pos) -> str: 32 | pdb_file = io.StringIO(pdb_str) 33 | structure = PdbStructure(pdb_file) 34 | topology = openmm_app.PDBFile(structure).getTopology() 35 | with io.StringIO() as f: 36 | openmm_app.PDBFile.writeFile(topology, pos, f) 37 | return f.getvalue() 38 | 39 | 40 | def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str: 41 | """Overwrites the B-factors in pdb_str with contents of bfactors array. 42 | 43 | Args: 44 | pdb_str: An input PDB string. 45 | bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the 46 | B-factors are per residue; i.e. that the nonzero entries are identical in 47 | [0, i, :]. 48 | 49 | Returns: 50 | A new PDB string with the B-factors replaced. 51 | """ 52 | if bfactors.shape[-1] != residue_constants.atom_type_num: 53 | raise ValueError( 54 | f"Invalid final dimension size for bfactors: {bfactors.shape[-1]}." 55 | ) 56 | 57 | parser = PDB.PDBParser(QUIET=True) 58 | handle = io.StringIO(pdb_str) 59 | structure = parser.get_structure("", handle) 60 | 61 | curr_resid = ("", "", "") 62 | idx = -1 63 | for atom in structure.get_atoms(): 64 | atom_resid = atom.parent.get_id() 65 | if atom_resid != curr_resid: 66 | idx += 1 67 | if idx >= bfactors.shape[0]: 68 | raise ValueError( 69 | "Index into bfactors exceeds number of residues. " 70 | "B-factors shape: {shape}, idx: {idx}." 71 | ) 72 | curr_resid = atom_resid 73 | atom.bfactor = bfactors[idx, residue_constants.atom_order["CA"]] 74 | 75 | new_pdb = io.StringIO() 76 | pdb_io = PDB.PDBIO() 77 | pdb_io.set_structure(structure) 78 | pdb_io.save(new_pdb) 79 | return new_pdb.getvalue() 80 | 81 | 82 | def assert_equal_nonterminal_atom_types( 83 | atom_mask: np.ndarray, ref_atom_mask: np.ndarray 84 | ): 85 | """Checks that pre- and post-minimized proteins have same atom set.""" 86 | # Ignore any terminal OXT atoms which may have been added by minimization. 87 | oxt = residue_constants.atom_order["OXT"] 88 | no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool) 89 | no_oxt_mask[..., oxt] = False 90 | np.testing.assert_almost_equal( 91 | ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask] 92 | ) 93 | -------------------------------------------------------------------------------- /openfold/openfold/resources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/MutateEverything/c380ea9fa185b770df25127e8aabbccbd212a074/openfold/openfold/resources/__init__.py -------------------------------------------------------------------------------- /openfold/openfold/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/MutateEverything/c380ea9fa185b770df25127e8aabbccbd212a074/openfold/openfold/utils/__init__.py -------------------------------------------------------------------------------- /openfold/openfold/utils/argparse.py: -------------------------------------------------------------------------------- 1 | from argparse import HelpFormatter 2 | from operator import attrgetter 3 | 4 | class ArgparseAlphabetizer(HelpFormatter): 5 | """ 6 | Sorts the optional arguments of an argparse parser alphabetically 7 | """ 8 | 9 | @staticmethod 10 | def sort_actions(actions): 11 | return sorted(actions, key=attrgetter("option_strings")) 12 | 13 | # Formats the help message 14 | def add_arguments(self, actions): 15 | actions = ArgparseAlphabetizer.sort_actions(actions) 16 | super(ArgparseAlphabetizer, self).add_arguments(actions) 17 | 18 | # Formats the usage message 19 | def add_usage(self, usage, actions, groups, prefix=None): 20 | actions = ArgparseAlphabetizer.sort_actions(actions) 21 | args = usage, actions, groups, prefix 22 | super(ArgparseAlphabetizer, self).add_usage(*args) 23 | 24 | 25 | def remove_arguments(parser, args): 26 | for arg in args: 27 | for action in parser._actions: 28 | opts = vars(action)["option_strings"] 29 | if(arg in opts): 30 | parser._handle_conflict_resolve(None, [(arg, action)]) 31 | -------------------------------------------------------------------------------- /openfold/openfold/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.utilities import rank_zero_info 2 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 3 | 4 | class EarlyStoppingVerbose(EarlyStopping): 5 | """ 6 | The default EarlyStopping callback's verbose mode is too verbose. 7 | This class outputs a message only when it's getting ready to stop. 8 | """ 9 | def _evalute_stopping_criteria(self, *args, **kwargs): 10 | should_stop, reason = super()._evalute_stopping_criteria(*args, **kwargs) 11 | if(should_stop): 12 | rank_zero_info(f"{reason}\n") 13 | 14 | return should_stop, reason 15 | -------------------------------------------------------------------------------- /openfold/openfold/utils/checkpointing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | from typing import Any, Tuple, List, Callable, Optional 16 | 17 | deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None 18 | if(deepspeed_is_installed): 19 | import deepspeed 20 | 21 | import torch 22 | import torch.utils.checkpoint 23 | 24 | 25 | BLOCK_ARG = Any 26 | BLOCK_ARGS = List[BLOCK_ARG] 27 | 28 | 29 | def get_checkpoint_fn(): 30 | deepspeed_is_configured = ( 31 | deepspeed_is_installed and 32 | deepspeed.checkpointing.is_configured() 33 | ) 34 | if(deepspeed_is_configured): 35 | checkpoint = deepspeed.checkpointing.checkpoint 36 | else: 37 | checkpoint = torch.utils.checkpoint.checkpoint 38 | 39 | return checkpoint 40 | 41 | 42 | @torch.jit.ignore 43 | def checkpoint_blocks( 44 | blocks: List[Callable], 45 | args: BLOCK_ARGS, 46 | blocks_per_ckpt: Optional[int], 47 | ) -> BLOCK_ARGS: 48 | """ 49 | Chunk a list of blocks and run each chunk with activation 50 | checkpointing. We define a "block" as a callable whose only inputs are 51 | the outputs of the previous block. 52 | 53 | Implements Subsection 1.11.8 54 | 55 | Args: 56 | blocks: 57 | List of blocks 58 | args: 59 | Tuple of arguments for the first block. 60 | blocks_per_ckpt: 61 | Size of each chunk. A higher value corresponds to fewer 62 | checkpoints, and trades memory for speed. If None, no checkpointing 63 | is performed. 64 | Returns: 65 | The output of the final block 66 | """ 67 | def wrap(a): 68 | return (a,) if type(a) is not tuple else a 69 | 70 | def exec(b, a): 71 | for block in b: 72 | a = wrap(block(*a)) 73 | return a 74 | 75 | def chunker(s, e): 76 | def exec_sliced(*a): 77 | return exec(blocks[s:e], a) 78 | 79 | return exec_sliced 80 | 81 | # Avoids mishaps when the blocks take just one argument 82 | args = wrap(args) 83 | 84 | if blocks_per_ckpt is None or not torch.is_grad_enabled(): 85 | return exec(blocks, args) 86 | elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks): 87 | raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)") 88 | 89 | checkpoint = get_checkpoint_fn() 90 | 91 | for s in range(0, len(blocks), blocks_per_ckpt): 92 | e = s + blocks_per_ckpt 93 | args = checkpoint(chunker(s, e), *args) 94 | args = wrap(args) 95 | 96 | return args 97 | -------------------------------------------------------------------------------- /openfold/openfold/utils/exponential_moving_average.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | 6 | from openfold.utils.tensor_utils import tensor_tree_map 7 | 8 | 9 | class ExponentialMovingAverage: 10 | """ 11 | Maintains moving averages of parameters with exponential decay 12 | 13 | At each step, the stored copy `copy` of each parameter `param` is 14 | updated as follows: 15 | 16 | `copy = decay * copy + (1 - decay) * param` 17 | 18 | where `decay` is an attribute of the ExponentialMovingAverage object. 19 | """ 20 | 21 | def __init__(self, model: nn.Module, decay: float): 22 | """ 23 | Args: 24 | model: 25 | A torch.nn.Module whose parameters are to be tracked 26 | decay: 27 | A value (usually close to 1.) by which updates are 28 | weighted as part of the above formula 29 | """ 30 | super(ExponentialMovingAverage, self).__init__() 31 | 32 | clone_param = lambda t: t.clone().detach() 33 | self.params = tensor_tree_map(clone_param, model.state_dict()) 34 | self.decay = decay 35 | self.device = next(model.parameters()).device 36 | 37 | def to(self, device): 38 | self.params = tensor_tree_map(lambda t: t.to(device), self.params) 39 | self.device = device 40 | 41 | def _update_state_dict_(self, update, state_dict): 42 | with torch.no_grad(): 43 | for k, v in update.items(): 44 | stored = state_dict[k] 45 | if not isinstance(v, torch.Tensor): 46 | self._update_state_dict_(v, stored) 47 | else: 48 | diff = stored - v 49 | diff *= 1 - self.decay 50 | stored -= diff 51 | 52 | def update(self, model: torch.nn.Module) -> None: 53 | """ 54 | Updates the stored parameters using the state dict of the provided 55 | module. The module should have the same structure as that used to 56 | initialize the ExponentialMovingAverage object. 57 | """ 58 | self._update_state_dict_(model.state_dict(), self.params) 59 | 60 | def load_state_dict(self, state_dict: OrderedDict) -> None: 61 | for k in state_dict["params"].keys(): 62 | self.params[k] = state_dict["params"][k].clone() 63 | self.decay = state_dict["decay"] 64 | 65 | def state_dict(self) -> OrderedDict: 66 | return OrderedDict( 67 | { 68 | "params": self.params, 69 | "decay": self.decay, 70 | } 71 | ) 72 | -------------------------------------------------------------------------------- /openfold/openfold/utils/kernel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/MutateEverything/c380ea9fa185b770df25127e8aabbccbd212a074/openfold/openfold/utils/kernel/__init__.py -------------------------------------------------------------------------------- /openfold/openfold/utils/kernel/attention_core.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | from functools import reduce 16 | from operator import mul 17 | 18 | import torch 19 | 20 | attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda") 21 | 22 | 23 | SUPPORTED_DTYPES = [torch.float32, torch.bfloat16] 24 | 25 | 26 | class AttentionCoreFunction(torch.autograd.Function): 27 | @staticmethod 28 | def forward(ctx, q, k, v, bias_1=None, bias_2=None): 29 | if(bias_1 is None and bias_2 is not None): 30 | raise ValueError("bias_1 must be specified before bias_2") 31 | if(q.dtype not in SUPPORTED_DTYPES): 32 | raise ValueError("Unsupported datatype") 33 | 34 | q = q.contiguous() 35 | k = k.contiguous() 36 | 37 | # [*, H, Q, K] 38 | attention_logits = torch.matmul( 39 | q, k.transpose(-1, -2), 40 | ) 41 | 42 | if(bias_1 is not None): 43 | attention_logits += bias_1 44 | if(bias_2 is not None): 45 | attention_logits += bias_2 46 | 47 | attn_core_inplace_cuda.forward_( 48 | attention_logits, 49 | reduce(mul, attention_logits.shape[:-1]), 50 | attention_logits.shape[-1], 51 | ) 52 | 53 | o = torch.matmul(attention_logits, v) 54 | 55 | ctx.bias_1_shape = bias_1.shape if bias_1 is not None else None 56 | ctx.bias_2_shape = bias_2.shape if bias_2 is not None else None 57 | ctx.save_for_backward(q, k, v, attention_logits) 58 | 59 | return o 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | q, k, v, attention_logits = ctx.saved_tensors 64 | grad_q = grad_k = grad_v = grad_bias_1 = grad_bias_2 = None 65 | 66 | grad_v = torch.matmul( 67 | attention_logits.transpose(-1, -2), 68 | grad_output 69 | ) 70 | 71 | attn_core_inplace_cuda.backward_( 72 | attention_logits, 73 | grad_output.contiguous(), 74 | v.contiguous(), # v is implicitly transposed in the kernel 75 | reduce(mul, attention_logits.shape[:-1]), 76 | attention_logits.shape[-1], 77 | grad_output.shape[-1], 78 | ) 79 | 80 | if(ctx.bias_1_shape is not None): 81 | grad_bias_1 = torch.sum( 82 | attention_logits, 83 | dim=tuple(i for i,d in enumerate(ctx.bias_1_shape) if d == 1), 84 | keepdim=True, 85 | ) 86 | 87 | if(ctx.bias_2_shape is not None): 88 | grad_bias_2 = torch.sum( 89 | attention_logits, 90 | dim=tuple(i for i,d in enumerate(ctx.bias_2_shape) if d == 1), 91 | keepdim=True, 92 | ) 93 | 94 | grad_q = torch.matmul( 95 | attention_logits, k 96 | ) 97 | grad_k = torch.matmul( 98 | q.transpose(-1, -2), attention_logits, 99 | ).transpose(-1, -2) 100 | 101 | return grad_q, grad_k, grad_v, grad_bias_1, grad_bias_2 102 | 103 | attention_core = AttentionCoreFunction.apply 104 | -------------------------------------------------------------------------------- /openfold/openfold/utils/kernel/csrc/compat.h: -------------------------------------------------------------------------------- 1 | // modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h 2 | 3 | #ifndef TORCH_CHECK 4 | #define TORCH_CHECK AT_CHECK 5 | #endif 6 | 7 | #ifdef VERSION_GE_1_3 8 | #define DATA_PTR data_ptr 9 | #else 10 | #define DATA_PTR data 11 | #endif 12 | -------------------------------------------------------------------------------- /openfold/openfold/utils/kernel/csrc/softmax_cuda.cpp: -------------------------------------------------------------------------------- 1 | // Copyright 2021 AlQuraishi Laboratory 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp 16 | 17 | #include 18 | 19 | void attn_softmax_inplace_forward_( 20 | at::Tensor input, 21 | long long rows, int cols 22 | ); 23 | void attn_softmax_inplace_backward_( 24 | at::Tensor output, 25 | at::Tensor d_ov, 26 | at::Tensor values, 27 | long long rows, 28 | int cols_output, 29 | int cols_values 30 | ); 31 | 32 | 33 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 34 | m.def( 35 | "forward_", 36 | &attn_softmax_inplace_forward_, 37 | "Softmax forward (CUDA)" 38 | ); 39 | m.def( 40 | "backward_", 41 | &attn_softmax_inplace_backward_, 42 | "Softmax backward (CUDA)" 43 | ); 44 | } 45 | -------------------------------------------------------------------------------- /openfold/openfold/utils/kernel/csrc/softmax_cuda_stub.cpp: -------------------------------------------------------------------------------- 1 | // Copyright 2021 AlQuraishi Laboratory 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp 16 | 17 | #include 18 | 19 | void attn_softmax_inplace_forward_( 20 | at::Tensor input, 21 | long long rows, int cols 22 | ) 23 | { 24 | throw std::runtime_error("attn_softmax_inplace_forward_ not implemented on CPU"); 25 | }; 26 | void attn_softmax_inplace_backward_( 27 | at::Tensor output, 28 | at::Tensor d_ov, 29 | at::Tensor values, 30 | long long rows, 31 | int cols_output, 32 | int cols_values 33 | ) 34 | { 35 | throw std::runtime_error("attn_softmax_inplace_backward_ not implemented on CPU"); 36 | }; -------------------------------------------------------------------------------- /openfold/openfold/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import operator 16 | import time 17 | 18 | import dllogger as logger 19 | from dllogger import JSONStreamBackend, StdOutBackend, Verbosity 20 | import numpy as np 21 | from pytorch_lightning import Callback 22 | import torch.cuda.profiler as profiler 23 | 24 | 25 | def is_main_process(): 26 | return int(os.getenv("LOCAL_RANK", "0")) == 0 27 | 28 | 29 | class PerformanceLoggingCallback(Callback): 30 | def __init__(self, log_file, global_batch_size, warmup_steps: int = 0, profile: bool = False): 31 | logger.init(backends=[JSONStreamBackend(Verbosity.VERBOSE, log_file), StdOutBackend(Verbosity.VERBOSE)]) 32 | self.warmup_steps = warmup_steps 33 | self.global_batch_size = global_batch_size 34 | self.step = 0 35 | self.profile = profile 36 | self.timestamps = [] 37 | 38 | def do_step(self): 39 | self.step += 1 40 | if self.profile and self.step == self.warmup_steps: 41 | profiler.start() 42 | if self.step > self.warmup_steps: 43 | self.timestamps.append(time.time()) 44 | 45 | def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): 46 | self.do_step() 47 | 48 | def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): 49 | self.do_step() 50 | 51 | def process_performance_stats(self, deltas): 52 | def _round3(val): 53 | return round(val, 3) 54 | 55 | throughput_imgps = _round3(self.global_batch_size / np.mean(deltas)) 56 | timestamps_ms = 1000 * deltas 57 | stats = { 58 | f"throughput": throughput_imgps, 59 | f"latency_mean": _round3(timestamps_ms.mean()), 60 | } 61 | for level in [90, 95, 99]: 62 | stats.update({f"latency_{level}": _round3(np.percentile(timestamps_ms, level))}) 63 | 64 | return stats 65 | 66 | def _log(self): 67 | if is_main_process(): 68 | diffs = list(map(operator.sub, self.timestamps[1:], self.timestamps[:-1])) 69 | deltas = np.array(diffs) 70 | stats = self.process_performance_stats(deltas) 71 | logger.log(step=(), data=stats) 72 | logger.flush() 73 | 74 | def on_train_end(self, trainer, pl_module): 75 | if self.profile: 76 | profiler.stop() 77 | self._log() 78 | 79 | def on_epoch_end(self, trainer, pl_module): 80 | self._log() 81 | -------------------------------------------------------------------------------- /openfold/openfold/utils/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler): 5 | """ Implements the learning rate schedule defined in the AlphaFold 2 6 | supplement. A linear warmup is followed by a plateau at the maximum 7 | learning rate and then exponential decay. 8 | 9 | Note that the initial learning rate of the optimizer in question is 10 | ignored; use this class' base_lr parameter to specify the starting 11 | point of the warmup. 12 | """ 13 | def __init__(self, 14 | optimizer, 15 | last_epoch: int = -1, 16 | verbose: bool = False, 17 | base_lr: float = 0., 18 | max_lr: float = 0.001, 19 | warmup_no_steps: int = 1000, 20 | start_decay_after_n_steps: int = 50000, 21 | decay_every_n_steps: int = 50000, 22 | decay_factor: float = 0.95, 23 | ): 24 | step_counts = { 25 | "warmup_no_steps": warmup_no_steps, 26 | "start_decay_after_n_steps": start_decay_after_n_steps, 27 | } 28 | 29 | for k,v in step_counts.items(): 30 | if(v < 0): 31 | raise ValueError(f"{k} must be nonnegative") 32 | 33 | if(warmup_no_steps > start_decay_after_n_steps): 34 | raise ValueError( 35 | "warmup_no_steps must not exceed start_decay_after_n_steps" 36 | ) 37 | 38 | self.optimizer = optimizer 39 | self.last_epoch = last_epoch 40 | self.verbose = verbose 41 | self.base_lr = base_lr 42 | self.max_lr = max_lr 43 | self.warmup_no_steps = warmup_no_steps 44 | self.start_decay_after_n_steps = start_decay_after_n_steps 45 | self.decay_every_n_steps = decay_every_n_steps 46 | self.decay_factor = decay_factor 47 | 48 | super(AlphaFoldLRScheduler, self).__init__( 49 | optimizer, 50 | last_epoch=last_epoch, 51 | verbose=verbose, 52 | ) 53 | 54 | def state_dict(self): 55 | state_dict = { 56 | k:v for k,v in self.__dict__.items() if k not in ["optimizer"] 57 | } 58 | 59 | return state_dict 60 | 61 | def load_state_dict(self, state_dict): 62 | self.__dict__.update(state_dict) 63 | 64 | def get_lr(self): 65 | if(not self._get_lr_called_within_step): 66 | raise RuntimeError( 67 | "To get the last learning rate computed by the scheduler, use " 68 | "get_last_lr()" 69 | ) 70 | 71 | step_no = self.last_epoch 72 | 73 | if(step_no <= self.warmup_no_steps): 74 | lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr 75 | elif(step_no > self.start_decay_after_n_steps): 76 | steps_since_decay = step_no - self.start_decay_after_n_steps 77 | exp = (steps_since_decay // self.decay_every_n_steps) + 1 78 | lr = self.max_lr * (self.decay_factor ** exp) 79 | else: # plateau 80 | lr = self.max_lr 81 | 82 | return [lr for group in self.optimizer.param_groups] 83 | -------------------------------------------------------------------------------- /openfold/openfold/utils/precision_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | 16 | import torch 17 | 18 | def is_fp16_enabled(): 19 | # Autocast world 20 | fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16 21 | fp16_enabled = fp16_enabled and torch.is_autocast_enabled() 22 | 23 | return fp16_enabled 24 | -------------------------------------------------------------------------------- /openfold/openfold/utils/seed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import random 4 | import numpy as np 5 | from pytorch_lightning.utilities.seed import seed_everything 6 | 7 | from openfold.utils.suppress_output import SuppressLogging 8 | 9 | 10 | def seed_globally(seed=None): 11 | if("PL_GLOBAL_SEED" not in os.environ): 12 | if(seed is None): 13 | seed = random.randint(0, np.iinfo(np.uint32).max) 14 | os.environ["PL_GLOBAL_SEED"] = str(seed) 15 | logging.info(f'os.environ["PL_GLOBAL_SEED"] set to {seed}') 16 | 17 | # seed_everything is a bit log-happy 18 | with SuppressLogging(logging.INFO): 19 | seed_everything(seed=None) 20 | -------------------------------------------------------------------------------- /openfold/openfold/utils/superimposition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from Bio.SVDSuperimposer import SVDSuperimposer 15 | import numpy as np 16 | import torch 17 | 18 | 19 | def _superimpose_np(reference, coords): 20 | """ 21 | Superimposes coordinates onto a reference by minimizing RMSD using SVD. 22 | 23 | Args: 24 | reference: 25 | [N, 3] reference array 26 | coords: 27 | [N, 3] array 28 | Returns: 29 | A tuple of [N, 3] superimposed coords and the final RMSD. 30 | """ 31 | sup = SVDSuperimposer() 32 | sup.set(reference, coords) 33 | sup.run() 34 | return sup.get_transformed(), sup.get_rms() 35 | 36 | 37 | def _superimpose_single(reference, coords): 38 | reference_np = reference.detach().cpu().numpy() 39 | coords_np = coords.detach().cpu().numpy() 40 | superimposed, rmsd = _superimpose_np(reference_np, coords_np) 41 | return coords.new_tensor(superimposed), coords.new_tensor(rmsd) 42 | 43 | 44 | def superimpose(reference, coords, mask): 45 | """ 46 | Superimposes coordinates onto a reference by minimizing RMSD using SVD. 47 | 48 | Args: 49 | reference: 50 | [*, N, 3] reference tensor 51 | coords: 52 | [*, N, 3] tensor 53 | mask: 54 | [*, N] tensor 55 | Returns: 56 | A tuple of [*, N, 3] superimposed coords and [*] final RMSDs. 57 | """ 58 | def select_unmasked_coords(coords, mask): 59 | return torch.masked_select( 60 | coords, 61 | (mask > 0.)[..., None], 62 | ).reshape(-1, 3) 63 | 64 | batch_dims = reference.shape[:-2] 65 | flat_reference = reference.reshape((-1,) + reference.shape[-2:]) 66 | flat_coords = coords.reshape((-1,) + reference.shape[-2:]) 67 | flat_mask = mask.reshape((-1,) + mask.shape[-1:]) 68 | superimposed_list = [] 69 | rmsds = [] 70 | for r, c, m in zip(flat_reference, flat_coords, flat_mask): 71 | r_unmasked_coords = select_unmasked_coords(r, m) 72 | c_unmasked_coords = select_unmasked_coords(c, m) 73 | superimposed, rmsd = _superimpose_single( 74 | r_unmasked_coords, 75 | c_unmasked_coords 76 | ) 77 | 78 | # This is very inelegant, but idk how else to invert the masking 79 | # procedure. 80 | count = 0 81 | superimposed_full_size = torch.zeros_like(r) 82 | for i, unmasked in enumerate(m): 83 | if(unmasked): 84 | superimposed_full_size[i] = superimposed[count] 85 | count += 1 86 | 87 | superimposed_list.append(superimposed_full_size) 88 | rmsds.append(rmsd) 89 | 90 | superimposed_stacked = torch.stack(superimposed_list, dim=0) 91 | rmsds_stacked = torch.stack(rmsds, dim=0) 92 | 93 | superimposed_reshaped = superimposed_stacked.reshape( 94 | batch_dims + coords.shape[-2:] 95 | ) 96 | rmsds_reshaped = rmsds_stacked.reshape( 97 | batch_dims 98 | ) 99 | 100 | return superimposed_reshaped, rmsds_reshaped 101 | -------------------------------------------------------------------------------- /openfold/openfold/utils/suppress_output.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | 5 | class SuppressStdout: 6 | def __enter__(self): 7 | self.stdout = sys.stdout 8 | dev_null = open("/dev/null", "w") 9 | sys.stdout = dev_null 10 | 11 | def __exit__(self, typ, value, traceback): 12 | fp = sys.stdout 13 | sys.stdout = self.stdout 14 | fp.close() 15 | 16 | 17 | class SuppressLogging: 18 | def __init__(self, level): 19 | self.level = level 20 | 21 | def __enter__(self): 22 | logging.disable(self.level) 23 | 24 | def __exit__(self, typ, value, traceback): 25 | logging.disable(logging.NOTSET) 26 | 27 | -------------------------------------------------------------------------------- /openfold/openfold/utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from functools import partial 17 | import logging 18 | from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | 24 | def add(m1, m2, inplace): 25 | # The first operation in a checkpoint can't be in-place, but it's 26 | # nice to have in-place addition during inference. Thus... 27 | if(not inplace): 28 | m1 = m1 + m2 29 | else: 30 | m1 += m2 31 | 32 | return m1 33 | 34 | 35 | def permute_final_dims(tensor: torch.Tensor, inds: List[int]): 36 | zero_index = -1 * len(inds) 37 | first_inds = list(range(len(tensor.shape[:zero_index]))) 38 | return tensor.permute(first_inds + [zero_index + i for i in inds]) 39 | 40 | 41 | def flatten_final_dims(t: torch.Tensor, no_dims: int): 42 | return t.reshape(t.shape[:-no_dims] + (-1,)) 43 | 44 | 45 | def masked_mean(mask, value, dim, eps=1e-4): 46 | mask = mask.expand(*value.shape) 47 | return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) 48 | 49 | 50 | def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): 51 | boundaries = torch.linspace( 52 | min_bin, max_bin, no_bins - 1, device=pts.device 53 | ) 54 | dists = torch.sqrt( 55 | torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) 56 | ) 57 | return torch.bucketize(dists, boundaries) 58 | 59 | 60 | def dict_multimap(fn, dicts): 61 | first = dicts[0] 62 | new_dict = {} 63 | for k, v in first.items(): 64 | all_v = [d[k] for d in dicts] 65 | if type(v) is dict: 66 | new_dict[k] = dict_multimap(fn, all_v) 67 | else: 68 | new_dict[k] = fn(all_v) 69 | 70 | return new_dict 71 | 72 | 73 | def one_hot(x, v_bins): 74 | reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) 75 | diffs = x[..., None] - reshaped_bins 76 | am = torch.argmin(torch.abs(diffs), dim=-1) 77 | return nn.functional.one_hot(am, num_classes=len(v_bins)).float() 78 | 79 | 80 | def batched_gather(data, inds, dim=0, no_batch_dims=0): 81 | ranges = [] 82 | for i, s in enumerate(data.shape[:no_batch_dims]): 83 | r = torch.arange(s) 84 | r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) 85 | ranges.append(r) 86 | 87 | remaining_dims = [ 88 | slice(None) for _ in range(len(data.shape) - no_batch_dims) 89 | ] 90 | remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds 91 | ranges.extend(remaining_dims) 92 | return data[ranges] 93 | 94 | 95 | # With tree_map, a poor man's JAX tree_map 96 | def dict_map(fn, dic, leaf_type): 97 | new_dict = {} 98 | for k, v in dic.items(): 99 | if type(v) is dict: 100 | new_dict[k] = dict_map(fn, v, leaf_type) 101 | else: 102 | new_dict[k] = tree_map(fn, v, leaf_type) 103 | 104 | return new_dict 105 | 106 | 107 | def tree_map(fn, tree, leaf_type): 108 | if isinstance(tree, dict): 109 | return dict_map(fn, tree, leaf_type) 110 | elif isinstance(tree, list): 111 | return [tree_map(fn, x, leaf_type) for x in tree] 112 | elif isinstance(tree, tuple): 113 | return tuple([tree_map(fn, x, leaf_type) for x in tree]) 114 | elif isinstance(tree, leaf_type): 115 | return fn(tree) 116 | else: 117 | print(type(tree)) 118 | raise ValueError("Not supported") 119 | 120 | 121 | tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) 122 | -------------------------------------------------------------------------------- /openfold/openfold/utils/validation_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | 16 | 17 | def drmsd(structure_1, structure_2, mask=None): 18 | def prep_d(structure): 19 | d = structure[..., :, None, :] - structure[..., None, :, :] 20 | d = d ** 2 21 | d = torch.sqrt(torch.sum(d, dim=-1)) 22 | return d 23 | 24 | d1 = prep_d(structure_1) 25 | d2 = prep_d(structure_2) 26 | 27 | drmsd = d1 - d2 28 | drmsd = drmsd ** 2 29 | if(mask is not None): 30 | drmsd = drmsd * (mask[..., None] * mask[..., None, :]) 31 | drmsd = torch.sum(drmsd, dim=(-1, -2)) 32 | n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1) 33 | drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.) 34 | drmsd = torch.sqrt(drmsd) 35 | 36 | return drmsd 37 | 38 | 39 | def drmsd_np(structure_1, structure_2, mask=None): 40 | structure_1 = torch.tensor(structure_1) 41 | structure_2 = torch.tensor(structure_2) 42 | if(mask is not None): 43 | mask = torch.tensor(mask) 44 | 45 | return drmsd(structure_1, structure_2, mask) 46 | 47 | 48 | def gdt(p1, p2, mask, cutoffs): 49 | n = torch.sum(mask, dim=-1) 50 | 51 | p1 = p1.float() 52 | p2 = p2.float() 53 | distances = torch.sqrt(torch.sum((p1 - p2)**2, dim=-1)) 54 | scores = [] 55 | for c in cutoffs: 56 | score = torch.sum((distances <= c) * mask, dim=-1) / n 57 | score = torch.mean(score) 58 | scores.append(score) 59 | 60 | return sum(scores) / len(scores) 61 | 62 | 63 | def gdt_ts(p1, p2, mask): 64 | return gdt(p1, p2, mask, [1., 2., 4., 8.]) 65 | 66 | 67 | def gdt_ha(p1, p2, mask): 68 | return gdt(p1, p2, mask, [0.5, 1., 2., 4.]) 69 | 70 | -------------------------------------------------------------------------------- /openfold/scripts/activate_conda_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source scripts/vars.sh 4 | 5 | source lib/conda/etc/profile.d/conda.sh 6 | conda activate $ENV_NAME 7 | -------------------------------------------------------------------------------- /openfold/scripts/alignment_db_scripts/create_alignment_db.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | 6 | def main(args): 7 | db_path = os.path.join(args.output_db_path, f"{args.output_db_name}.db") 8 | index_path = os.path.join( 9 | args.output_db_path, f"{args.output_db_name}.index" 10 | ) 11 | db_fp = open(db_path, "wb") 12 | index = {} 13 | db_offset = 0 14 | for chain_alignment_dir in os.listdir(args.alignment_dir): 15 | cad_path = os.path.join(args.alignment_dir, chain_alignment_dir) 16 | for f in os.listdir(cad_path): 17 | f_path = os.path.join(cad_path, f) 18 | with open(f_path, "rb") as fp: 19 | file_bytes = fp.read() 20 | 21 | l = len(file_bytes) 22 | file_list = index.setdefault(chain_alignment_dir, []) 23 | file_list.append((f, db_offset, l)) 24 | 25 | db_fp.write(file_bytes) 26 | db_offset += l 27 | 28 | db_fp.close() 29 | 30 | with open(index_path, "w") as fp: 31 | json.dump(index, fp) 32 | 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument( 38 | "alignment_dir", type=str, 39 | help="""Path to precomputed alignment directory, with one subdirectory 40 | per chain.""" 41 | ) 42 | parser.add_argument("output_db_path", type=str) 43 | parser.add_argument("output_db_name", type=str) 44 | 45 | args = parser.parse_args() 46 | 47 | main(args) 48 | -------------------------------------------------------------------------------- /openfold/scripts/alignment_db_scripts/unify_alignment_db_indices.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | 6 | """ Unifies databases created with create_alignment_db.py """ 7 | 8 | 9 | def main(args): 10 | super_index = {} 11 | for f in os.listdir(args.alignment_db_dir): 12 | if(not os.path.splitext(f)[-1] == ".index"): 13 | continue 14 | 15 | with open(os.path.join(args.alignment_db_dir, f), "r") as fp: 16 | index = json.load(fp) 17 | 18 | db_name = f"{os.path.splitext(f)[0]}.db" 19 | 20 | for k in index: 21 | super_index[k] = { 22 | "db": db_name, 23 | "files": index[k], 24 | } 25 | 26 | with open(os.path.join(args.output_dir, "super.index"), "w") as fp: 27 | json.dump(super_index, fp) 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("alignment_db_dir", type=str, help="Path to directory containing alignment_dbs") 33 | parser.add_argument("output_dir", type=str, help="Path in which to output super index") 34 | 35 | args = parser.parse_args() 36 | 37 | main(args) 38 | -------------------------------------------------------------------------------- /openfold/scripts/colabfold_search.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copied from colabfold.mmseqs.com 3 | 4 | MMSEQS="$1" 5 | QUERY="$2" 6 | DBBASE="$3" 7 | BASE="$4" 8 | DB1="$5" 9 | DB2="$6" 10 | DB3="$7" 11 | USE_ENV="${8:-1}" 12 | USE_TEMPLATES="${9:-0}" 13 | FILTER="${10:-1}" 14 | INDEX=${11:-1} 15 | DB_LOAD_MODE="${12:-2}" 16 | EXPAND_EVAL=inf 17 | ALIGN_EVAL=10 18 | DIFF=3000 19 | QSC=-20.0 20 | MAX_ACCEPT=1000000 21 | if [ "${FILTER}" = "1" ]; then 22 | # 0.1 was not used in benchmarks due to POSIX shell bug in line above 23 | # EXPAND_EVAL=0.1 24 | ALIGN_EVAL=10 25 | QSC=0.8 26 | MAX_ACCEPT=100000 27 | fi 28 | if [ "${INDEX}" = "1" ]; then 29 | SEQ=".idx" 30 | ALN=".idx" 31 | IDX=".idx" 32 | else 33 | SEQ="_seq" 34 | ALN="_aln" 35 | IDX="" 36 | export MMSEQS_IGNORE_INDEX=1 37 | fi 38 | export MMSEQS_CALL_DEPTH=1 39 | SEARCH_PARAM="--num-iterations 3 --db-load-mode ${DB_LOAD_MODE} -a -s 8 -e 0.1 --max-seqs 10000" 40 | FILTER_PARAM="--filter-msa ${FILTER} --filter-min-enable 1000 --diff ${DIFF} --qid 0.0,0.2,0.4,0.6,0.8,1.0 --qsc 0 --max-seq-id 0.95" 41 | EXPAND_PARAM="--expansion-mode 0 -e ${EXPAND_EVAL} --expand-filter-clusters ${FILTER} --max-seq-id 0.95" 42 | mkdir -p "${BASE}" 43 | "${MMSEQS}" createdb "${QUERY}" "${BASE}/qdb" 44 | "${MMSEQS}" search "${BASE}/qdb" "${DBBASE}/${DB1}" "${BASE}/res" "${BASE}/tmp" $SEARCH_PARAM 45 | "${MMSEQS}" expandaln "${BASE}/qdb" "${DBBASE}/${DB1}${SEQ}" "${BASE}/res" "${DBBASE}/${DB1}${ALN}" "${BASE}/res_exp" --db-load-mode ${DB_LOAD_MODE} ${EXPAND_PARAM} 46 | "${MMSEQS}" mvdb "${BASE}/tmp/latest/profile_1" "${BASE}/prof_res" 47 | "${MMSEQS}" lndb "${BASE}/qdb_h" "${BASE}/prof_res_h" 48 | "${MMSEQS}" align "${BASE}/prof_res" "${DBBASE}/${DB1}${SEQ}" "${BASE}/res_exp" "${BASE}/res_exp_realign" --db-load-mode ${DB_LOAD_MODE} -e ${ALIGN_EVAL} --max-accept ${MAX_ACCEPT} --alt-ali 10 -a 49 | "${MMSEQS}" filterresult "${BASE}/qdb" "${DBBASE}/${DB1}${SEQ}" "${BASE}/res_exp_realign" "${BASE}/res_exp_realign_filter" --db-load-mode ${DB_LOAD_MODE} --qid 0 --qsc $QSC --diff 0 --max-seq-id 1.0 --filter-min-enable 100 50 | "${MMSEQS}" result2msa "${BASE}/qdb" "${DBBASE}/${DB1}${SEQ}" "${BASE}/res_exp_realign_filter" "${BASE}/uniref.a3m" --msa-format-mode 6 --db-load-mode ${DB_LOAD_MODE} ${FILTER_PARAM} 51 | "${MMSEQS}" rmdb "${BASE}/res_exp_realign" 52 | "${MMSEQS}" rmdb "${BASE}/res_exp" 53 | "${MMSEQS}" rmdb "${BASE}/res" 54 | "${MMSEQS}" rmdb "${BASE}/res_exp_realign_filter" 55 | if [ "${USE_TEMPLATES}" = "1" ]; then 56 | "${MMSEQS}" search "${BASE}/prof_res" "${DBBASE}/${DB2}" "${BASE}/res_pdb" "${BASE}/tmp" --db-load-mode ${DB_LOAD_MODE} -s 7.5 -a -e 0.1 57 | "${MMSEQS}" convertalis "${BASE}/prof_res" "${DBBASE}/${DB2}${IDX}" "${BASE}/res_pdb" "${BASE}/${DB2}.m8" --format-output query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,cigar --db-load-mode ${DB_LOAD_MODE} 58 | "${MMSEQS}" rmdb "${BASE}/res_pdb" 59 | fi 60 | if [ "${USE_ENV}" = "1" ]; then 61 | "${MMSEQS}" search "${BASE}/prof_res" "${DBBASE}/${DB3}" "${BASE}/res_env" "${BASE}/tmp" $SEARCH_PARAM 62 | "${MMSEQS}" expandaln "${BASE}/prof_res" "${DBBASE}/${DB3}${SEQ}" "${BASE}/res_env" "${DBBASE}/${DB3}${ALN}" "${BASE}/res_env_exp" -e ${EXPAND_EVAL} --expansion-mode 0 --db-load-mode ${DB_LOAD_MODE} 63 | "${MMSEQS}" align "${BASE}/tmp/latest/profile_1" "${DBBASE}/${DB3}${SEQ}" "${BASE}/res_env_exp" "${BASE}/res_env_exp_realign" --db-load-mode ${DB_LOAD_MODE} -e ${ALIGN_EVAL} --max-accept ${MAX_ACCEPT} --alt-ali 10 -a 64 | "${MMSEQS}" filterresult "${BASE}/qdb" "${DBBASE}/${DB3}${SEQ}" "${BASE}/res_env_exp_realign" "${BASE}/res_env_exp_realign_filter" --db-load-mode ${DB_LOAD_MODE} --qid 0 --qsc $QSC --diff 0 --max-seq-id 1.0 --filter-min-enable 100 65 | "${MMSEQS}" result2msa "${BASE}/qdb" "${DBBASE}/${DB3}${SEQ}" "${BASE}/res_env_exp_realign_filter" "${BASE}/bfd.mgnify30.metaeuk30.smag30.a3m" --msa-format-mode 6 --db-load-mode ${DB_LOAD_MODE} ${FILTER_PARAM} 66 | "${MMSEQS}" rmdb "${BASE}/res_env_exp_realign_filter" 67 | "${MMSEQS}" rmdb "${BASE}/res_env_exp_realign" 68 | "${MMSEQS}" rmdb "${BASE}/res_env_exp" 69 | "${MMSEQS}" rmdb "${BASE}/res_env" 70 | fi 71 | "${MMSEQS}" rmdb "${BASE}/qdb" 72 | "${MMSEQS}" rmdb "${BASE}/qdb_h" 73 | "${MMSEQS}" rmdb "${BASE}/res" 74 | rm -f -- "${BASE}/prof_res"* 75 | rm -rf -- "${BASE}/tmp" 76 | -------------------------------------------------------------------------------- /openfold/scripts/convert_of_weights_to_jax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be 16 | # used to run inference using DeepMind's JAX code. 17 | import argparse 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from openfold.config import model_config 23 | from openfold.model.model import AlphaFold 24 | from openfold.utils.import_weights import ( 25 | Param, 26 | ParamType, 27 | generate_translation_dict, 28 | process_translation_dict, 29 | ) 30 | from openfold.utils.tensor_utils import tree_map 31 | 32 | 33 | def reshape_fn(of_param, af_weight): 34 | transformations = { 35 | ParamType.LinearWeight: lambda w: w.transpose(-1, -2), 36 | ParamType.LinearWeightMHA: lambda w: w.transpose(-1, -2).reshape(af_weight.shape), 37 | ParamType.LinearMHAOutputWeight: lambda w: w.transpose(-1, -2).reshape(af_weight.shape), 38 | ParamType.LinearBiasMHA: lambda w: w.reshape(af_weight.shape), 39 | ParamType.LinearWeightOPM: lambda w: w.transpose(-1, -2).reshape(af_weight.shape), 40 | ParamType.Other: lambda w: w, 41 | } 42 | 43 | if(of_param.stacked): 44 | of_weight = torch.stack([torch.Tensor(p) for p in of_param.param]) 45 | else: 46 | of_weight = torch.Tensor(of_param.param) 47 | 48 | return transformations[of_param.param_type](of_weight) 49 | 50 | 51 | def transfer(of_dict, af_weight_template): 52 | for k in of_dict: 53 | if(type(of_dict[k]) == dict): 54 | transfer(of_dict[k], af_weight_template[k]) 55 | else: 56 | reshaped = reshape_fn(of_dict[k], af_weight_template[k]) 57 | reshaped = reshaped.detach().numpy() 58 | np.copyto(af_weight_template[k], reshaped) 59 | 60 | 61 | def main(args): 62 | d = torch.load(args.of_pt_path) 63 | 64 | config = model_config(args.config_preset) 65 | model = AlphaFold(config) 66 | model.load_state_dict(d) 67 | 68 | translation = generate_translation_dict(model, args.config_preset) 69 | translation = process_translation_dict(translation) 70 | 71 | af_weight_template = np.load(args.template_npz_path) 72 | af_weight_template = {k:v for k,v in af_weight_template.items() if k in translation} 73 | zero = lambda n: n * 0 74 | af_weight_template = tree_map(zero, af_weight_template, np.ndarray) 75 | 76 | transfer(translation, af_weight_template) 77 | 78 | np.savez(args.out_path, **af_weight_template) 79 | 80 | 81 | if __name__ == "__main__": 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument( 84 | "of_pt_path", type=str, help="Path to OpenFold .pt checkpoint file" 85 | ) 86 | parser.add_argument( 87 | "config_preset", type=str, help="The corresponding config preset" 88 | ) 89 | parser.add_argument( 90 | "out_path", type=str, help="Path for output .npz file" 91 | ) 92 | parser.add_argument( 93 | "--template_npz_path", 94 | type=str, 95 | default="openfold/resources/params/params_model_1_ptm.npz", 96 | help="""Path to an AlphaFold checkpoint w/ a superset of the OF 97 | checkpoint's parameters. params_model_1_ptm.npz always works. 98 | """ 99 | ) 100 | 101 | args = parser.parse_args() 102 | 103 | main(args) 104 | -------------------------------------------------------------------------------- /openfold/scripts/data_dir_to_fasta.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | from openfold.data import mmcif_parsing 6 | from openfold.np import protein, residue_constants 7 | 8 | 9 | def main(args): 10 | fasta = [] 11 | for fname in os.listdir(args.data_dir): 12 | basename, ext = os.path.splitext(fname) 13 | basename = basename.upper() 14 | fpath = os.path.join(args.data_dir, fname) 15 | if(ext == ".cif"): 16 | with open(fpath, 'r') as fp: 17 | mmcif_str = fp.read() 18 | 19 | mmcif = mmcif_parsing.parse( 20 | file_id=basename, mmcif_string=mmcif_str 21 | ) 22 | if(mmcif.mmcif_object is None): 23 | logging.warning(f'Failed to parse {fname}...') 24 | if(args.raise_errors): 25 | raise list(mmcif.errors.values())[0] 26 | else: 27 | continue 28 | 29 | mmcif = mmcif.mmcif_object 30 | for chain, seq in mmcif.chain_to_seqres.items(): 31 | chain_id = '_'.join([basename, chain]) 32 | fasta.append(f">{chain_id}") 33 | fasta.append(seq) 34 | elif(ext == ".core"): 35 | with open(fpath, 'r') as fp: 36 | core_str = fp.read() 37 | 38 | core_protein = protein.from_proteinnet_string(core_str) 39 | aatype = core_protein.aatype 40 | seq = ''.join([ 41 | residue_constants.restypes_with_x[aatype[i]] 42 | for i in range(len(aatype)) 43 | ]) 44 | fasta.append(f">{basename}") 45 | fasta.append(seq) 46 | 47 | 48 | with open(args.output_path, "w") as fp: 49 | fp.write('\n'.join(fasta)) 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument( 55 | "data_dir", type=str, 56 | help="Path to a directory containing mmCIF or .core files" 57 | ) 58 | parser.add_argument( 59 | "output_path", type=str, 60 | help="Path to output FASTA file" 61 | ) 62 | parser.add_argument( 63 | "--raise_errors", type=bool, default=False, 64 | help="Whether to crash on parsing errors" 65 | ) 66 | 67 | args = parser.parse_args() 68 | 69 | main(args) 70 | -------------------------------------------------------------------------------- /openfold/scripts/deactivate_conda_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | conda deactivate 4 | -------------------------------------------------------------------------------- /openfold/scripts/download_alphafold_dbs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips all required data for AlphaFold. 18 | # 19 | # Usage: bash download_all_data.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | DOWNLOAD_MODE="${2:-full_dbs}" # Default mode to full_dbs. 34 | if [[ "${DOWNLOAD_MODE}" != full_dbs && "${DOWNLOAD_MODE}" != reduced_dbs ]] 35 | then 36 | echo "DOWNLOAD_MODE ${DOWNLOAD_MODE} not recognized." 37 | exit 1 38 | fi 39 | 40 | SCRIPT_DIR="$(dirname "$(realpath "$0")")" 41 | 42 | if [[ "${DOWNLOAD_MODE}" = full_dbs ]] ; then 43 | echo "Downloading BFD..." 44 | bash "${SCRIPT_DIR}/download_bfd.sh" "${DOWNLOAD_DIR}" 45 | else 46 | echo "Downloading Small BFD..." 47 | bash "${SCRIPT_DIR}/download_small_bfd.sh" "${DOWNLOAD_DIR}" 48 | fi 49 | 50 | echo "Downloading MGnify..." 51 | bash "${SCRIPT_DIR}/download_mgnify.sh" "${DOWNLOAD_DIR}" 52 | 53 | echo "Downloading PDB70..." 54 | bash "${SCRIPT_DIR}/download_pdb70.sh" "${DOWNLOAD_DIR}" 55 | 56 | echo "Downloading PDB mmCIF files..." 57 | bash "${SCRIPT_DIR}/download_pdb_mmcif.sh" "${DOWNLOAD_DIR}" 58 | 59 | echo "Downloading Uniclust30..." 60 | bash "${SCRIPT_DIR}/download_uniclust30.sh" "${DOWNLOAD_DIR}" 61 | 62 | echo "Downloading Uniref90..." 63 | bash "${SCRIPT_DIR}/download_uniref90.sh" "${DOWNLOAD_DIR}" 64 | 65 | echo "All data downloaded." 66 | -------------------------------------------------------------------------------- /openfold/scripts/download_alphafold_params.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the AlphaFold parameters. 18 | # 19 | # Usage: bash download_alphafold_params.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}/params" 34 | SOURCE_URL="https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar" 35 | BASENAME=$(basename "${SOURCE_URL}") 36 | 37 | mkdir --parents "${ROOT_DIR}" 38 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" 39 | tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \ 40 | --directory="${ROOT_DIR}" --preserve-permissions 41 | rm "${ROOT_DIR}/${BASENAME}" 42 | -------------------------------------------------------------------------------- /openfold/scripts/download_bfd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the BFD database for AlphaFold. 18 | # 19 | # Usage: bash download_bfd.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}/bfd" 34 | # Mirror of: 35 | # https://bfd.mmseqs.com/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz. 36 | SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz" 37 | BASENAME=$(basename "${SOURCE_URL}") 38 | 39 | mkdir --parents "${ROOT_DIR}" 40 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" 41 | tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \ 42 | --directory="${ROOT_DIR}" 43 | rm "${ROOT_DIR}/${BASENAME}" 44 | -------------------------------------------------------------------------------- /openfold/scripts/download_cameo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import argparse 4 | import json 5 | import os 6 | import re 7 | import requests 8 | 9 | from openfold.data import mmcif_parsing 10 | 11 | 12 | VALID_PERIODS = [ 13 | "1-year", 14 | "6-months", 15 | "3-months", 16 | "1-month", 17 | "1-week", 18 | ] 19 | 20 | 21 | def generate_url(period, end_date): 22 | return '/'.join([ 23 | "https://www.cameo3d.org/", 24 | "modeling", 25 | "targets", 26 | period, 27 | "ajax", 28 | f"?to_date={end_date}", 29 | ]) 30 | 31 | 32 | def main(args): 33 | data_dir_path = os.path.join(args.output_dir, "data_dir") 34 | fasta_dir_path = os.path.join(args.output_dir, "fasta_dir") 35 | 36 | os.makedirs(data_dir_path, exist_ok=True) 37 | os.makedirs(fasta_dir_path, exist_ok=True) 38 | 39 | url = generate_url(args.period, args.end_date) 40 | raw_data = requests.get(url).text 41 | parsed_data = json.loads(raw_data) 42 | 43 | chain_data = parsed_data["aaData"] 44 | for chain in chain_data: 45 | pdb_id = chain["pdbid"] 46 | chain_id = chain["pdbid_chain"] 47 | 48 | pdb_url = f"https://files.rcsb.org/view/{pdb_id.upper()}.cif" 49 | pdb_file = requests.get(pdb_url).text 50 | 51 | parsed_cif = mmcif_parsing.parse( 52 | file_id=pdb_id, mmcif_string=pdb_file 53 | ) 54 | mmcif_object = parsed_cif.mmcif_object 55 | if(mmcif_object is None): 56 | raise list(parsed_cif.errors.values())[0] 57 | 58 | seq = mmcif_object.chain_to_seqres[chain_id] 59 | 60 | if(args.max_seqlen > 0 and len(seq) > args.max_seqlen): 61 | continue 62 | 63 | fasta_file = '\n'.join([ 64 | f">{pdb_id}_{chain_id}", 65 | seq, 66 | ]) 67 | 68 | fasta_filename = f"{pdb_id}_{chain_id}.fasta" 69 | with open(os.path.join(fasta_dir_path, fasta_filename), "w") as fp: 70 | fp.write(fasta_file) 71 | 72 | cif_filename = f"{pdb_id}.cif" 73 | with open(os.path.join(data_dir_path, cif_filename), "w") as fp: 74 | fp.write(pdb_file) 75 | 76 | 77 | if __name__ == '__main__': 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument( 80 | "period", type=str, 81 | help=f"""The length of the period from which to draw CAMEO proteins. 82 | Choose from {VALID_PERIODS}""" 83 | ) 84 | parser.add_argument( 85 | "end_date", type=str, 86 | help="The date marking the end of the period (YYYY-MM-DD)" 87 | ) 88 | parser.add_argument("output_dir") 89 | parser.add_argument( 90 | "--max_seqlen", type=int, default=700, 91 | help="The maximum length in residues of downloaded proteins (or -1)" 92 | ) 93 | 94 | args = parser.parse_args() 95 | 96 | if(args.period not in VALID_PERIODS): 97 | raise ValueError(f"Invalid period. Choose from {VALID_PERIODS}") 98 | 99 | date_regex = re.compile("^[0-9]{4}-[0-9]{2}-[0-9]{2}$") 100 | if(not date_regex.match(args.end_date)): 101 | raise ValueError(f"Invalid end_date: {args.end_date}. Use YYYY-MM-DD format") 102 | 103 | main(args) 104 | -------------------------------------------------------------------------------- /openfold/scripts/download_colabfold_envdb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 AlQuraishi Laboratory 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the BFD database for AlphaFold. 18 | # 19 | # Usage: bash download_bfd.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}" 34 | SOURCE_URL="http://wwwuser.gwdg.de/~compbiol/colabfold/colabfold_envdb_202108.tar.gz" 35 | BASENAME=$(basename "${SOURCE_URL}") 36 | 37 | mkdir --parents "${ROOT_DIR}" 38 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" -x 4 --check-certificate=false 39 | -------------------------------------------------------------------------------- /openfold/scripts/download_mgnify.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the MGnify database for AlphaFold. 18 | # 19 | # Usage: bash download_mgnify.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}/mgnify" 34 | # Mirror of: 35 | # ftp://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2018_12/mgy_clusters.fa.gz 36 | SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/mgy_clusters_2018_12.fa.gz" 37 | BASENAME=$(basename "${SOURCE_URL}") 38 | 39 | mkdir --parents "${ROOT_DIR}" 40 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" 41 | gunzip "${ROOT_DIR}/${BASENAME}" 42 | -------------------------------------------------------------------------------- /openfold/scripts/download_mmseqs_dbs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 AlQuraishi Laboratory 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips all required data for AlphaFold. 18 | # 19 | # Usage: bash download_all_data.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | DOWNLOAD_MODE="${2:-full_dbs}" # Default mode to full_dbs. 34 | if [[ "${DOWNLOAD_MODE}" != full_dbs && "${DOWNLOAD_MODE}" != reduced_dbs ]] 35 | then 36 | echo "DOWNLOAD_MODE ${DOWNLOAD_MODE} not recognized." 37 | exit 1 38 | fi 39 | 40 | SCRIPT_DIR="$(dirname "$(realpath "$0")")" 41 | 42 | echo "Downloading Uniref30..." 43 | bash "${SCRIPT_DIR}/download_uniref30.sh" "${DOWNLOAD_DIR}" 44 | 45 | echo "Downloading ColabFold's environmental database..." 46 | bash "${SCRIPT_DIR}/download_colabfold_envdb.sh" "${DOWNLOAD_DIR}" 47 | 48 | echo "All data downloaded." 49 | -------------------------------------------------------------------------------- /openfold/scripts/download_openfold_params.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads OpenFold parameters. 18 | # 19 | # Usage: bash download_openfold_params_huggingface.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aws &> /dev/null ; then 28 | echo "Error: aws could not be found. Please install aws." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="${1}/openfold_params" 33 | mkdir -p "${DOWNLOAD_DIR}" 34 | aws s3 cp --no-sign-request --region us-east-1 s3://openfold/openfold_params/ "${DOWNLOAD_DIR}" --recursive 35 | -------------------------------------------------------------------------------- /openfold/scripts/download_openfold_params_gdrive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips OpenFold parameters from Google Drive. Alternative to 18 | # the HuggingFace version. 19 | # 20 | # Usage: bash download_openfold_params_gdrive.sh /path/to/download/directory 21 | set -e 22 | 23 | if [[ $# -eq 0 ]]; then 24 | echo "Error: download directory must be provided as an input argument." 25 | exit 1 26 | fi 27 | 28 | FILE_ID="1GVzZA2nbdBbz6TKydvzquhfELJ3Movnb" 29 | FILENAME="openfold_params_07_22.tar.gz" 30 | 31 | download_from_gdrive() { 32 | FILE_ID="$1" 33 | OUT_DIR="$2" 34 | MSG=$(wget \ 35 | --quiet \ 36 | --save-cookies /tmp/cookies_$$.txt \ 37 | --keep-session-cookies \ 38 | --no-check-certificate \ 39 | "https://docs.google.com/uc?export=download&id=${FILE_ID}" \ 40 | -O- \ 41 | ) 42 | CONFIRM=$(echo $MSG | sed -rn "s/.*confirm=([0-9A-Za-z_]+).*/\1\n/p") 43 | FILENAME=$(echo $MSG | sed -e "s/.*\(.*\)<\/a> (.*/\1/") 44 | FILEPATH="${OUT_DIR}/${FILENAME}" 45 | wget \ 46 | --quiet \ 47 | --load-cookies /tmp/cookies_$$.txt \ 48 | "https://docs.google.com/uc?export=download&confirm=${CONFIRM}&id=${FILE_ID}" \ 49 | -O "${FILEPATH}" 50 | rm /tmp/cookies_$$.txt 51 | echo $FILEPATH 52 | } 53 | 54 | DOWNLOAD_DIR="$1" 55 | mkdir -p "${DOWNLOAD_DIR}" 56 | DOWNLOAD_PATH=$(download_from_gdrive $FILE_ID "${DOWNLOAD_DIR}") 57 | 58 | DOWNLOAD_FILENAME=$(basename "${DOWNLOAD_PATH}") 59 | if [[ $FILENAME != $DOWNLOAD_FILENAME ]]; then 60 | echo "Error: Downloaded filename ${DOWNLOAD_FILENAME} does not match expected filename ${FILENAME}" 61 | rm "${DOWNLOAD_PATH}" 62 | exit 63 | fi 64 | 65 | tar --extract --verbose --file="${DOWNLOAD_PATH}" \ 66 | --directory="${DOWNLOAD_DIR}" --preserve-permissions 67 | rm "${DOWNLOAD_PATH}" 68 | -------------------------------------------------------------------------------- /openfold/scripts/download_openfold_params_huggingface.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips OpenFold parameters. 18 | # 19 | # Usage: bash download_openfold_params_huggingface.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | URL="https://huggingface.co/nz/OpenFold" 28 | 29 | DOWNLOAD_DIR="${1}/openfold_params/" 30 | mkdir -p "${DOWNLOAD_DIR}" 31 | git clone $URL "${DOWNLOAD_DIR}" 32 | rm -rf "${DOWNLOAD_DIR}/.git" 33 | -------------------------------------------------------------------------------- /openfold/scripts/download_pdb70.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the PDB70 database for AlphaFold. 18 | # 19 | # Usage: bash download_pdb70.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}/pdb70" 34 | SOURCE_URL="http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/old-releases/pdb70_from_mmcif_200401.tar.gz" 35 | BASENAME=$(basename "${SOURCE_URL}") 36 | 37 | mkdir --parents "${ROOT_DIR}" 38 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" --check-certificate=false 39 | tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \ 40 | --directory="${ROOT_DIR}" 41 | rm "${ROOT_DIR}/${BASENAME}" 42 | -------------------------------------------------------------------------------- /openfold/scripts/download_pdb_mmcif.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads, unzips and flattens the PDB database for AlphaFold. 18 | # 19 | # Usage: bash download_pdb_mmcif.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | if ! command -v rsync &> /dev/null ; then 33 | echo "Error: rsync could not be found. Please install rsync." 34 | exit 1 35 | fi 36 | 37 | DOWNLOAD_DIR="$1" 38 | ROOT_DIR="${DOWNLOAD_DIR}/pdb_mmcif" 39 | RAW_DIR="${ROOT_DIR}/raw" 40 | MMCIF_DIR="${ROOT_DIR}/mmcif_files" 41 | 42 | echo "Running rsync to fetch all mmCIF files (note that the rsync progress estimate might be inaccurate)..." 43 | mkdir --parents "${RAW_DIR}" 44 | rsync --recursive --links --perms --times --compress --info=progress2 --delete --port=33444 \ 45 | rsync.rcsb.org::ftp_data/structures/divided/mmCIF/ \ 46 | "${RAW_DIR}" 47 | 48 | echo "Unzipping all mmCIF files..." 49 | find "${RAW_DIR}/" -type f -iname "*.gz" -exec gunzip {} + 50 | 51 | echo "Flattening all mmCIF files..." 52 | mkdir --parents "${MMCIF_DIR}" 53 | find "${RAW_DIR}" -type d -empty -delete # Delete empty directories. 54 | for subdir in "${RAW_DIR}"/*; do 55 | mv "${subdir}/"*.cif "${MMCIF_DIR}" 56 | done 57 | 58 | # Delete empty download directory structure. 59 | find "${RAW_DIR}" -type d -empty -delete 60 | 61 | aria2c "ftp://ftp.wwpdb.org/pub/pdb/data/status/obsolete.dat" --dir="${ROOT_DIR}" 62 | -------------------------------------------------------------------------------- /openfold/scripts/download_roda_pdbs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 AlQuraishi Laboratories 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads .cif files matching the RODA alignments. Outputs a list of 18 | # RODA alignments for which .cif files could not be found.. 19 | if [[ $# != 2 ]]; then 20 | echo "usage: ./download_roda_pdbs.sh " 21 | exit 1 22 | fi 23 | 24 | OUT_DIR=$1 25 | RODA_ALIGNMENT_DIR=$2 26 | 27 | if [[ -d $OUT_DIR ]]; then 28 | echo "${OUT_DIR} already exists. Download failed..." 29 | exit 1 30 | fi 31 | 32 | SERVER=snapshotrsync.rcsb.org # RCSB server name 33 | PORT=873 # port RCSB server is using 34 | 35 | rsync -rlpt -v -z --delete --port=$PORT $SERVER::20220103/pub/pdb/data/structures/divided/mmCIF/ $OUT_DIR 2>&1 > /dev/null 36 | 37 | for f in $(find $OUT_DIR -mindepth 2 -type f); do 38 | mv $f $OUT_DIR 39 | BASENAME=$(basename $f) 40 | gunzip "${OUT_DIR}/${BASENAME}" 41 | done 42 | 43 | find $OUT_DIR -mindepth 1 -type d,l -delete 44 | 45 | for d in $(find $RODA_ALIGNMENT_DIR -mindepth 1 -maxdepth 1 -type d); do 46 | BASENAME=$(basename $d) 47 | PDB_ID=$(echo $BASENAME | cut -d '_' -f 1) 48 | CIF_PATH="${OUT_DIR}/${PDB_ID}.cif" 49 | if [[ ! -f $CIF_PATH ]]; then 50 | echo $d 51 | fi 52 | done 53 | -------------------------------------------------------------------------------- /openfold/scripts/download_small_bfd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the Small BFD database for AlphaFold. 18 | # 19 | # Usage: bash download_small_bfd.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}/small_bfd" 34 | SOURCE_URL="https://storage.googleapis.com/alphafold-databases/reduced_dbs/bfd-first_non_consensus_sequences.fasta.gz" 35 | BASENAME=$(basename "${SOURCE_URL}") 36 | 37 | mkdir --parents "${ROOT_DIR}" 38 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" 39 | pushd "${ROOT_DIR}" 40 | gunzip "${ROOT_DIR}/${BASENAME}" 41 | popd 42 | -------------------------------------------------------------------------------- /openfold/scripts/download_uniclust30.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the Uniclust30 database for AlphaFold. 18 | # 19 | # Usage: bash download_uniclust30.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}/uniclust30" 34 | # Mirror of: 35 | # http://wwwuser.gwdg.de/~compbiol/uniclust/2018_08/uniclust30_2018_08_hhsuite.tar.gz 36 | SOURCE_URL="https://storage.googleapis.com/alphafold-databases/casp14_versions/uniclust30_2018_08_hhsuite.tar.gz" 37 | BASENAME=$(basename "${SOURCE_URL}") 38 | 39 | mkdir --parents "${ROOT_DIR}" 40 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" 41 | tar --extract --verbose --file="${ROOT_DIR}/${BASENAME}" \ 42 | --directory="${ROOT_DIR}" 43 | rm "${ROOT_DIR}/${BASENAME}" 44 | -------------------------------------------------------------------------------- /openfold/scripts/download_uniref30.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 AlQuraishi Laboratory 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the BFD database for AlphaFold. 18 | # 19 | # Usage: bash download_bfd.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}" 34 | SOURCE_URL="http://wwwuser.gwdg.de/~compbiol/colabfold/uniref30_2103.tar.gz" 35 | BASENAME=$(basename "${SOURCE_URL}") 36 | 37 | mkdir --parents "${ROOT_DIR}" 38 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" -x 4 --check-certificate=false 39 | gunzip "${ROOT_DIR}/${BASENAME}" 40 | -------------------------------------------------------------------------------- /openfold/scripts/download_uniref90.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips the UniRef90 database for AlphaFold. 18 | # 19 | # Usage: bash download_uniref90.sh /path/to/download/directory 20 | set -e 21 | 22 | if [[ $# -eq 0 ]]; then 23 | echo "Error: download directory must be provided as an input argument." 24 | exit 1 25 | fi 26 | 27 | if ! command -v aria2c &> /dev/null ; then 28 | echo "Error: aria2c could not be found. Please install aria2c (sudo apt install aria2)." 29 | exit 1 30 | fi 31 | 32 | DOWNLOAD_DIR="$1" 33 | ROOT_DIR="${DOWNLOAD_DIR}/uniref90" 34 | SOURCE_URL="ftp://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref90/uniref90.fasta.gz" 35 | BASENAME=$(basename "${SOURCE_URL}") 36 | 37 | mkdir --parents "${ROOT_DIR}" 38 | aria2c "${SOURCE_URL}" --dir="${ROOT_DIR}" 39 | gunzip "${ROOT_DIR}/${BASENAME}" 40 | 41 | -------------------------------------------------------------------------------- /openfold/scripts/flatten_roda.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # 3 | # Flattens a downloaded RODA database into the format expected by OpenFold 4 | # Args: 5 | # roda_dir: 6 | # The path to the database you want to flatten. E.g. "roda/pdb" 7 | # or "roda/uniclust30". Note that, to save space, this script 8 | # will empty this directory. 9 | # output_dir: 10 | # The directory in which to construct the reformatted data 11 | 12 | if [[ $# != 2 ]]; then 13 | echo "usage: ./flatten_roda.sh " 14 | exit 1 15 | fi 16 | 17 | RODA_DIR=$1 18 | OUTPUT_DIR=$2 19 | 20 | DATA_DIR="${OUTPUT_DIR}/data" 21 | ALIGNMENT_DIR="${OUTPUT_DIR}/alignments" 22 | 23 | mkdir -p "${DATA_DIR}" 24 | mkdir -p "${ALIGNMENT_DIR}" 25 | 26 | for chain_dir in $(ls "${RODA_DIR}"); do 27 | CHAIN_DIR_PATH="${RODA_DIR}/${chain_dir}" 28 | for subdir in $(ls "${CHAIN_DIR_PATH}"); do 29 | if [[ $subdir = "pdb" ]] || [[ $subdir = "cif" ]]; then 30 | mv "${CHAIN_DIR_PATH}/${subdir}"/* "${DATA_DIR}" 31 | else 32 | CHAIN_ALIGNMENT_DIR="${ALIGNMENT_DIR}/${chain_dir}" 33 | mkdir -p "${CHAIN_ALIGNMENT_DIR}" 34 | mv "${CHAIN_DIR_PATH}/${subdir}"/* "${CHAIN_ALIGNMENT_DIR}" 35 | fi 36 | done 37 | done 38 | 39 | NO_DATA_FILES=$(find "${DATA_DIR}" -type f | wc -l) 40 | if [[ $NO_DATA_FILES = 0 ]]; then 41 | rm -rf ${DATA_DIR} 42 | fi 43 | -------------------------------------------------------------------------------- /openfold/scripts/generate_alphafold_feature_dict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | 5 | from alphafold.data import pipeline, templates 6 | 7 | from scripts.utils import add_data_args 8 | 9 | 10 | def main(args): 11 | template_featurizer = templates.TemplateHitFeaturizer( 12 | mmcif_dir=args.mmcif_dir, 13 | max_template_date=args.max_template_date, 14 | max_hits=20, 15 | kalign_binary_path=args.kalign_binary_path, 16 | release_dates_path=None, 17 | obsolete_pdbs_path=args.obsolete_pdbs_path, 18 | ) 19 | 20 | data_pipeline = pipeline.DataPipeline( 21 | jackhmmer_binary_path=args.jackhmmer_binary_path, 22 | hhblits_binary_path=args.hhblits_binary_path, 23 | hhsearch_binary_path=args.hhsearch_binary_path, 24 | uniref90_database_path=args.uniref90_database_path, 25 | mgnify_database_path=args.mgnify_database_path, 26 | bfd_database_path=args.bfd_database_path, 27 | uniclust30_database_path=args.uniclust30_database_path, 28 | pdb70_database_path=args.pdb70_database_path, 29 | small_bfd_database_path=None, 30 | template_featurizer=template_featurizer, 31 | use_small_bfd=False, 32 | ) 33 | 34 | feature_dict = data_pipeline.process( 35 | input_fasta_path=args.fasta_path, 36 | msa_output_dir=args.output_dir, 37 | ) 38 | 39 | with open(os.path.join(args.output_dir, "feature_dict.pickle"), "wb") as fp: 40 | pickle.dump(feature_dict, fp, protocol=pickle.HIGHEST_PROTOCOL) 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("fasta_path", type=str) 45 | parser.add_argument("mmcif_dir", type=str) 46 | parser.add_argument("output_dir", type=str) 47 | add_data_args(parser) 48 | 49 | args = parser.parse_args() 50 | 51 | main(args) 52 | -------------------------------------------------------------------------------- /openfold/scripts/generate_chain_data_cache.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | import json 4 | import logging 5 | from multiprocessing import Pool 6 | import os 7 | 8 | import sys 9 | sys.path.append(".") # an innocent hack to get this to run from the top level 10 | 11 | from tqdm import tqdm 12 | 13 | from openfold.data.mmcif_parsing import parse 14 | from openfold.np import protein, residue_constants 15 | 16 | 17 | def parse_file( 18 | f, 19 | args, 20 | chain_cluster_size_dict 21 | ): 22 | file_id, ext = os.path.splitext(f) 23 | if(ext == ".cif"): 24 | with open(os.path.join(args.data_dir, f), "r") as fp: 25 | mmcif_string = fp.read() 26 | mmcif = parse(file_id=file_id, mmcif_string=mmcif_string) 27 | if mmcif.mmcif_object is None: 28 | logging.info(f"Could not parse {f}. Skipping...") 29 | return {} 30 | else: 31 | mmcif = mmcif.mmcif_object 32 | 33 | out = {} 34 | for chain_id, seq in mmcif.chain_to_seqres.items(): 35 | full_name = "_".join([file_id, chain_id]) 36 | out[full_name] = {} 37 | local_data = out[full_name] 38 | local_data["release_date"] = mmcif.header["release_date"] 39 | local_data["seq"] = seq 40 | local_data["resolution"] = mmcif.header["resolution"] 41 | 42 | if(chain_cluster_size_dict is not None): 43 | cluster_size = chain_cluster_size_dict.get( 44 | full_name.upper(), -1 45 | ) 46 | local_data["cluster_size"] = cluster_size 47 | elif(ext == ".pdb"): 48 | with open(os.path.join(args.data_dir, f), "r") as fp: 49 | pdb_string = fp.read() 50 | 51 | protein_object = protein.from_pdb_string(pdb_string, None) 52 | 53 | chain_dict = {} 54 | chain_dict["seq"] = residue_constants.aatype_to_str_sequence( 55 | protein_object.aatype, 56 | ) 57 | chain_dict["resolution"] = 0. 58 | 59 | if(chain_cluster_size_dict is not None): 60 | cluster_size = chain_cluster_size_dict.get( 61 | full_name.upper(), -1 62 | ) 63 | chain_dict["cluster_size"] = cluster_size 64 | 65 | out = {file_id: chain_dict} 66 | 67 | return out 68 | 69 | 70 | def main(args): 71 | chain_cluster_size_dict = None 72 | if(args.cluster_file is not None): 73 | chain_cluster_size_dict = {} 74 | with open(args.cluster_file, "r") as fp: 75 | clusters = [l.strip() for l in fp.readlines()] 76 | 77 | for cluster in clusters: 78 | chain_ids = cluster.split() 79 | cluster_len = len(chain_ids) 80 | for chain_id in chain_ids: 81 | chain_id = chain_id.upper() 82 | chain_cluster_size_dict[chain_id] = cluster_len 83 | 84 | accepted_exts = [".cif", ".pdb"] 85 | files = list(os.listdir(args.data_dir)) 86 | files = [f for f in files if os.path.splitext(f)[-1] in accepted_exts] 87 | fn = partial( 88 | parse_file, 89 | args=args, 90 | chain_cluster_size_dict=chain_cluster_size_dict, 91 | ) 92 | data = {} 93 | with Pool(processes=args.no_workers) as p: 94 | with tqdm(total=len(files)) as pbar: 95 | for d in p.imap_unordered(fn, files, chunksize=args.chunksize): 96 | data.update(d) 97 | pbar.update() 98 | 99 | with open(args.output_path, "w") as fp: 100 | fp.write(json.dumps(data, indent=4)) 101 | 102 | 103 | if __name__ == "__main__": 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument( 106 | "data_dir", type=str, help="Directory containing mmCIF or PDB files" 107 | ) 108 | parser.add_argument( 109 | "output_path", type=str, help="Path for .json output" 110 | ) 111 | parser.add_argument( 112 | "--cluster_file", type=str, default=None, 113 | help=( 114 | "Path to a cluster file (e.g. PDB40), one cluster " 115 | "({PROT1_ID}_{CHAIN_ID} {PROT2_ID}_{CHAIN_ID} ...) per line. " 116 | "Chains not in this cluster file will NOT be filtered by cluster " 117 | "size." 118 | ) 119 | ) 120 | parser.add_argument( 121 | "--no_workers", type=int, default=4, 122 | help="Number of workers to use for parsing" 123 | ) 124 | parser.add_argument( 125 | "--chunksize", type=int, default=10, 126 | help="How many files should be distributed to each worker at a time" 127 | ) 128 | 129 | args = parser.parse_args() 130 | 131 | main(args) 132 | -------------------------------------------------------------------------------- /openfold/scripts/generate_mmcif_cache.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | import json 4 | import logging 5 | from multiprocessing import Pool 6 | import os 7 | 8 | import sys 9 | sys.path.append(".") # an innocent hack to get this to run from the top level 10 | 11 | from tqdm import tqdm 12 | 13 | from openfold.data.mmcif_parsing import parse 14 | 15 | 16 | def parse_file(f, args): 17 | with open(os.path.join(args.mmcif_dir, f), "r") as fp: 18 | mmcif_string = fp.read() 19 | file_id = os.path.splitext(f)[0] 20 | mmcif = parse(file_id=file_id, mmcif_string=mmcif_string) 21 | if mmcif.mmcif_object is None: 22 | logging.info(f"Could not parse {f}. Skipping...") 23 | return {} 24 | else: 25 | mmcif = mmcif.mmcif_object 26 | 27 | local_data = {} 28 | local_data["release_date"] = mmcif.header["release_date"] 29 | 30 | chain_ids, seqs = list(zip(*mmcif.chain_to_seqres.items())) 31 | local_data["chain_ids"] = chain_ids 32 | local_data["seqs"] = seqs 33 | local_data["no_chains"] = len(chain_ids) 34 | 35 | local_data["resolution"] = mmcif.header["resolution"] 36 | 37 | return {file_id: local_data} 38 | 39 | 40 | def main(args): 41 | files = [f for f in os.listdir(args.mmcif_dir) if ".cif" in f] 42 | fn = partial(parse_file, args=args) 43 | data = {} 44 | with Pool(processes=args.no_workers) as p: 45 | with tqdm(total=len(files)) as pbar: 46 | for d in p.imap_unordered(fn, files, chunksize=args.chunksize): 47 | data.update(d) 48 | pbar.update() 49 | 50 | with open(args.output_path, "w") as fp: 51 | fp.write(json.dumps(data, indent=4)) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument( 57 | "mmcif_dir", type=str, help="Directory containing mmCIF files" 58 | ) 59 | parser.add_argument( 60 | "output_path", type=str, help="Path for .json output" 61 | ) 62 | parser.add_argument( 63 | "--no_workers", type=int, default=4, 64 | help="Number of workers to use for parsing" 65 | ) 66 | parser.add_argument( 67 | "--chunksize", type=int, default=10, 68 | help="How many files should be distributed to each worker at a time" 69 | ) 70 | 71 | args = parser.parse_args() 72 | 73 | main(args) 74 | -------------------------------------------------------------------------------- /openfold/scripts/install_hh_suite.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | git clone --branch v3.3.0 https://github.com/soedinglab/hh-suite.git /tmp/hh-suite \ 4 | && mkdir /tmp/hh-suite/build \ 5 | && pushd /tmp/hh-suite/build \ 6 | && cmake -DCMAKE_INSTALL_PREFIX=/opt/hhsuite .. \ 7 | && make -j 4 && make install \ 8 | && ln -sf /opt/hhsuite/bin/* /usr/bin \ 9 | && popd \ 10 | && rm -rf /tmp/hh-suite 11 | -------------------------------------------------------------------------------- /openfold/scripts/install_third_party_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CONDA_INSTALL_URL=${CONDA_INSTALL_URL:-"https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh"} 3 | 4 | source scripts/vars.sh 5 | 6 | # Install Miniconda locally 7 | rm -rf lib/conda 8 | rm -f /tmp/Miniconda3-latest-Linux-x86_64.sh 9 | wget -P /tmp \ 10 | "${CONDA_INSTALL_URL}" \ 11 | && bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p lib/conda \ 12 | && rm /tmp/Miniconda3-latest-Linux-x86_64.sh 13 | 14 | # Grab conda-only packages 15 | export PATH=lib/conda/bin:$PATH 16 | lib/conda/bin/python3 -m pip install nvidia-pyindex 17 | conda env create --name=${ENV_NAME} -f environment.yml 18 | source scripts/activate_conda_env.sh 19 | 20 | echo "Attempting to install FlashAttention" 21 | git clone https://github.com/HazyResearch/flash-attention 22 | CUR_DIR=$PWD 23 | cd flash-attention 24 | git checkout 5b838a8bef 25 | python3 setup.py install 26 | cd $CUR_DIR 27 | 28 | # Install DeepMind's OpenMM patch 29 | OPENFOLD_DIR=$PWD 30 | pushd lib/conda/envs/$ENV_NAME/lib/python3.7/site-packages/ \ 31 | && patch -p0 < $OPENFOLD_DIR/lib/openmm.patch \ 32 | && popd 33 | 34 | # Download folding resources 35 | wget --no-check-certificate -P openfold/resources \ 36 | https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt 37 | 38 | # Certain tests need access to this file 39 | mkdir -p tests/test_data/alphafold/common 40 | ln -rs openfold/resources/stereo_chemical_props.txt tests/test_data/alphafold/common 41 | 42 | echo "Downloading OpenFold parameters..." 43 | bash scripts/download_openfold_params.sh openfold/resources 44 | 45 | echo "Downloading AlphaFold parameters..." 46 | bash scripts/download_alphafold_params.sh openfold/resources 47 | 48 | # Decompress test data 49 | gunzip tests/test_data/sample_feats.pickle.gz 50 | -------------------------------------------------------------------------------- /openfold/scripts/prep_mmseqs_dbs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2021 AlQuraishi Laboratory 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Downloads and unzips all required data for AlphaFold. 18 | # 19 | # Usage: bash download_all_data.sh /path/to/download/directory 20 | set -e 21 | 22 | DOWNLOAD_DIR="$1" 23 | ROOT_DIR="${DOWNLOAD_DIR}/mmseqs_dbs" 24 | mkdir -p $ROOT_DIR 25 | 26 | for f in $(ls ${DOWNLOAD_DIR}/*.tar*) 27 | do 28 | tar --extract --verbose --file="${f}" \ 29 | --directory=$ROOT_DIR 30 | rm "${f}" 31 | BASENAME="$(basename ${f%%.*})" 32 | DB_NAME="${BASENAME}_db" 33 | OLD_PWD=$(pwd) 34 | cd $ROOT_DIR 35 | mmseqs tsv2exprofiledb "${BASENAME}" "${DB_NAME}" 36 | mmseqs createindex "${DB_NAME}" "${DOWNLOAD_DIR}/tmp/" 37 | cd "${OLD_PWD}" 38 | done 39 | 40 | 41 | -------------------------------------------------------------------------------- /openfold/scripts/prep_proteinnet_msas.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import shutil 5 | 6 | 7 | def main(args): 8 | count = 0 9 | max_count = args.max_count if args.max_count is not None else -1 10 | msas = sorted(f for f in os.listdir(args.msa_dir)) 11 | mmcifs = sorted(f for f in os.listdir(args.mmcif_dir)) 12 | mmcif_idx = 0 13 | for f in msas: 14 | if(count == max_count): 15 | break 16 | 17 | path = os.path.join(args.msa_dir, f) 18 | name = os.path.splitext(f)[0] 19 | spl = name.upper().split('_') 20 | if(len(spl) != 3): 21 | continue 22 | 23 | pdb_id, _, chain_id = spl 24 | 25 | while pdb_id > os.path.splitext(mmcifs[mmcif_idx])[0].upper(): 26 | mmcif_idx += 1 27 | 28 | # Only consider files with matching mmCIF files 29 | if(pdb_id == os.path.splitext(mmcifs[mmcif_idx])[0].upper()): 30 | dirname = os.path.join(args.out_dir, '_'.join([pdb_id, chain_id])) 31 | os.makedirs(dirname, exist_ok=True) 32 | dest = os.path.join(dirname, f) 33 | if(args.copy): 34 | shutil.copyfile(path, dest) 35 | else: 36 | os.rename(path, dest) 37 | 38 | count += 1 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser(description= 43 | "Converts raw ProteinNet MSAs into a format recognized by the parser" 44 | ) 45 | parser.add_argument( 46 | "msa_dir", type=str, help="Directory containing ProteinNet MSAs" 47 | ) 48 | parser.add_argument( 49 | "mmcif_dir", type=str, help="Directory containing PDB mmCIFs" 50 | ) 51 | parser.add_argument( 52 | "out_dir", type=str, 53 | help="Directory to which output should be saved" 54 | ) 55 | parser.add_argument( 56 | "--copy", type=bool, default=True, 57 | help="Whether to copy the MSAs to out_dir rather than moving them" 58 | ) 59 | parser.add_argument( 60 | "--max_count", type=int, default=None, 61 | help="A bound on the number of MSAs to process" 62 | ) 63 | 64 | args = parser.parse_args() 65 | 66 | main(args) 67 | -------------------------------------------------------------------------------- /openfold/scripts/run_unit_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES="0" 4 | 5 | python3 -m unittest "$@" || \ 6 | echo -e "\nTest(s) failed. Make sure you've installed all Python dependencies." 7 | -------------------------------------------------------------------------------- /openfold/scripts/slurm_scripts/run_uniclust30_search.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Generates uniclust30 all-against-all alignments on a SLURM cluster. 4 | # Thanks to Milot Mirdita for help & feedback on this script. 5 | 6 | set -e 7 | 8 | if [[ $# != 3 ]]; then 9 | echo "usage: ./run_uniclust30_search.sh " 10 | exit 11 | fi 12 | 13 | UNICLUST_PATH=$1 14 | SCRATCH_DIR_BN=$2 15 | OUT_DIR=$3 16 | 17 | CPUS_PER_TASK=4 18 | MAX_SIZE=10000000000 # 10GB 19 | 20 | SCRATCH_DIR="${SCRATCH_DIR_BN}_${SLURM_NODEID}" 21 | 22 | mkdir -p ${SCRATCH_DIR} 23 | mkdir -p ${OUT_DIR} 24 | 25 | # copy database to local ssd 26 | DB_BN=$(basename $UNICLUST_PATH) 27 | DB_DIR="/dev/shm/uniclust30" 28 | mkdir -p $DB_DIR 29 | cp ${UNICLUST_PATH}*.ff* $DB_DIR 30 | DB="${DB_DIR}/${DB_BN}" 31 | 32 | for f in $(ls $OUT_DIR/*.zip) 33 | do 34 | zipinfo -1 $f '*/' | awk -F/ '{print $(NF-1)}' >> ${DB_DIR}/already_searched.txt 35 | done 36 | 37 | python3 filter_ffindex.py ${DB}_a3m.ffindex ${DB_DIR}/already_searched.txt ${DB_DIR}/filtered_a3m.ffindex 38 | 39 | TARGET="${DB}_a3m_${SLURM_NODEID}.ffindex" 40 | split -n "l/$((SLURM_NODEID + 1))/${SLURM_JOB_NUM_NODES}" "${DB_DIR}/filtered_a3m.ffindex" > $TARGET 41 | 42 | open_sem() { 43 | mkfifo pipe-$$ 44 | exec 3<>pipe-$$ 45 | rm pipe-$$ 46 | local i=$1 47 | for ((;i>0;i--)); do 48 | printf %s 000 >&3 49 | done 50 | } 51 | 52 | # run the given command asynchronously and pop/push tokens 53 | run_with_lock() { 54 | local x 55 | # this read waits until there is something to read 56 | read -u 3 -n 3 x && ((0==x)) || exit $x 57 | ( 58 | ( "$@"; ) 59 | # push the return code of the command to the semaphore 60 | printf '%.3d' $? >&3 61 | )& 62 | } 63 | 64 | task() { 65 | dd if="${DB}_a3m.ffdata" ibs=1 skip="${OFF}" count="${LEN}" status=none | \ 66 | hhblits -i stdin \ 67 | -oa3m "${SCRATCH_DIR}/${KEY}/uniclust30.a3m" \ 68 | -v 0 \ 69 | -o /dev/null \ 70 | -cpu $CPUS_PER_TASK \ 71 | -d $DB \ 72 | -n 3 \ 73 | -e 0.001 74 | } 75 | 76 | zip_or_not() { 77 | SIZE=$(du -hbs $SCRATCH_DIR | sed 's/|/ /' | awk '{print $1}') 78 | #if [[ "$SIZE" -gt "$MAX_SIZE" ]] 79 | if [[ "2" -gt "1" ]] 80 | then 81 | wait 82 | RANDOM_NAME=$(cat /dev/urandom | tr -cd 'a-f0-9' | head -c 32) 83 | zip -r "${OUT_DIR}/${RANDOM_NAME}.zip" $SCRATCH_DIR 84 | find $SCRATCH_DIR -mindepth 1 -type d -exec rm -rf {} + 85 | fi 86 | } 87 | 88 | N=$(($(nproc) / ${CPUS_PER_TASK})) 89 | open_sem $N 90 | while read -r KEY OFF LEN; do 91 | PROT_DIR="${SCRATCH_DIR}/${KEY}" 92 | 93 | if [[ -d $PROT_DIR ]] 94 | then 95 | continue 96 | fi 97 | 98 | mkdir -p $PROT_DIR 99 | run_with_lock task "${KEY}" "${OFF}" "${LEN}" 100 | zip_or_not 101 | done < $TARGET 102 | 103 | wait 104 | 105 | zip_or_not 106 | 107 | wait 108 | -------------------------------------------------------------------------------- /openfold/scripts/unpack_proteinnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | 6 | def _write_file(args, file_in_progress): 7 | file_id = file_in_progress[1] 8 | fname = file_id.upper() + ".core" 9 | fpath = os.path.join(args.output_dir, fname) 10 | with open(fpath, "w") as fp: 11 | fp.write('\n'.join(file_in_progress)) 12 | 13 | 14 | def main(args): 15 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 16 | 17 | with open(args.proteinnet_file, "r") as fp: 18 | proteinnet_string = fp.readlines() 19 | 20 | file_in_progress = [] 21 | for line in proteinnet_string: 22 | if(line == "[ID]\n"): 23 | if(len(file_in_progress) > 0): 24 | _write_file(args, file_in_progress) 25 | file_in_progress = [] 26 | 27 | file_in_progress.append(line.strip()) 28 | 29 | if(len(file_in_progress) > 0): 30 | _write_file(args, file_in_progress) 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "proteinnet_file", type=str, 37 | help="Path to ProteinNet file to unpack" 38 | ) 39 | parser.add_argument( 40 | "output_dir", type=str, 41 | help="Path to directory in which to output .core files" 42 | ) 43 | 44 | args = parser.parse_args() 45 | 46 | main(args) 47 | -------------------------------------------------------------------------------- /openfold/scripts/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ctypes 3 | from datetime import date 4 | import sys 5 | 6 | 7 | def add_data_args(parser: argparse.ArgumentParser): 8 | parser.add_argument( 9 | '--uniref90_database_path', type=str, default=None, 10 | ) 11 | parser.add_argument( 12 | '--mgnify_database_path', type=str, default=None, 13 | ) 14 | parser.add_argument( 15 | '--pdb70_database_path', type=str, default=None, 16 | ) 17 | parser.add_argument( 18 | '--uniclust30_database_path', type=str, default=None, 19 | ) 20 | parser.add_argument( 21 | '--bfd_database_path', type=str, default=None, 22 | ) 23 | parser.add_argument( 24 | '--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer' 25 | ) 26 | parser.add_argument( 27 | '--hhblits_binary_path', type=str, default='/usr/bin/hhblits' 28 | ) 29 | parser.add_argument( 30 | '--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch' 31 | ) 32 | parser.add_argument( 33 | '--kalign_binary_path', type=str, default='/usr/bin/kalign' 34 | ) 35 | parser.add_argument( 36 | '--max_template_date', type=str, 37 | default=date.today().strftime("%Y-%m-%d"), 38 | ) 39 | parser.add_argument( 40 | '--obsolete_pdbs_path', type=str, default=None 41 | ) 42 | parser.add_argument( 43 | '--release_dates_path', type=str, default=None 44 | ) 45 | 46 | 47 | def get_nvidia_cc(): 48 | """ 49 | Returns a tuple containing the Compute Capability of the first GPU 50 | installed in the system (formatted as a tuple of strings) and an error 51 | message. When the former is provided, the latter is None, and vice versa. 52 | 53 | Adapted from script by Jan Schlüte t 54 | https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 55 | """ 56 | CUDA_SUCCESS = 0 57 | 58 | libnames = [ 59 | 'libcuda.so', 60 | 'libcuda.dylib', 61 | 'cuda.dll', 62 | '/usr/local/cuda/compat/libcuda.so', # For Docker 63 | ] 64 | for libname in libnames: 65 | try: 66 | cuda = ctypes.CDLL(libname) 67 | except OSError: 68 | continue 69 | else: 70 | break 71 | else: 72 | return None, "Could not load any of: " + ' '.join(libnames) 73 | 74 | nGpus = ctypes.c_int() 75 | cc_major = ctypes.c_int() 76 | cc_minor = ctypes.c_int() 77 | 78 | result = ctypes.c_int() 79 | device = ctypes.c_int() 80 | error_str = ctypes.c_char_p() 81 | 82 | result = cuda.cuInit(0) 83 | if result != CUDA_SUCCESS: 84 | cuda.cuGetErrorString(result, ctypes.byref(error_str)) 85 | if error_str.value: 86 | return None, error_str.value.decode() 87 | else: 88 | return None, "Unknown error: cuInit returned %d" % result 89 | result = cuda.cuDeviceGetCount(ctypes.byref(nGpus)) 90 | if result != CUDA_SUCCESS: 91 | cuda.cuGetErrorString(result, ctypes.byref(error_str)) 92 | return None, error_str.value.decode() 93 | 94 | if nGpus.value < 1: 95 | return None, "No GPUs detected" 96 | 97 | result = cuda.cuDeviceGet(ctypes.byref(device), 0) 98 | if result != CUDA_SUCCESS: 99 | cuda.cuGetErrorString(result, ctypes.byref(error_str)) 100 | return None, error_str.value.decode() 101 | 102 | if cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device) != CUDA_SUCCESS: 103 | return None, "Compute Capability not found" 104 | 105 | major = cc_major.value 106 | minor = cc_minor.value 107 | 108 | return (major, minor), None 109 | -------------------------------------------------------------------------------- /openfold/scripts/vars.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ENV_NAME=openfold_venv 4 | -------------------------------------------------------------------------------- /openfold/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import os 16 | from setuptools import setup, Extension, find_packages 17 | import subprocess 18 | 19 | import torch 20 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME 21 | 22 | from scripts.utils import get_nvidia_cc 23 | 24 | 25 | version_dependent_macros = [ 26 | '-DVERSION_GE_1_1', 27 | '-DVERSION_GE_1_3', 28 | '-DVERSION_GE_1_5', 29 | ] 30 | 31 | extra_cuda_flags = [ 32 | '-std=c++14', 33 | '-maxrregcount=50', 34 | '-U__CUDA_NO_HALF_OPERATORS__', 35 | '-U__CUDA_NO_HALF_CONVERSIONS__', 36 | '--expt-relaxed-constexpr', 37 | '--expt-extended-lambda' 38 | ] 39 | 40 | def get_cuda_bare_metal_version(cuda_dir): 41 | if cuda_dir==None or torch.version.cuda==None: 42 | print("CUDA is not found, cpu version is installed") 43 | return None, -1, 0 44 | else: 45 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) 46 | output = raw_output.split() 47 | release_idx = output.index("release") + 1 48 | release = output[release_idx].split(".") 49 | bare_metal_major = release[0] 50 | bare_metal_minor = release[1][0] 51 | 52 | return raw_output, bare_metal_major, bare_metal_minor 53 | 54 | compute_capabilities = set([ 55 | (3, 7), # K80, e.g. 56 | (5, 2), # Titan X 57 | (6, 1), # GeForce 1000-series 58 | ]) 59 | 60 | compute_capabilities.add((7, 0)) 61 | _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) 62 | if int(bare_metal_major) >= 11: 63 | compute_capabilities.add((8, 0)) 64 | 65 | compute_capability, _ = get_nvidia_cc() 66 | if compute_capability is not None: 67 | compute_capabilities = set([compute_capability]) 68 | 69 | cc_flag = [] 70 | for major, minor in list(compute_capabilities): 71 | cc_flag.extend([ 72 | '-gencode', 73 | f'arch=compute_{major}{minor},code=sm_{major}{minor}', 74 | ]) 75 | 76 | extra_cuda_flags += cc_flag 77 | 78 | if bare_metal_major != -1: 79 | modules = [CUDAExtension( 80 | name="attn_core_inplace_cuda", 81 | sources=[ 82 | "openfold/utils/kernel/csrc/softmax_cuda.cpp", 83 | "openfold/utils/kernel/csrc/softmax_cuda_kernel.cu", 84 | ], 85 | include_dirs=[ 86 | os.path.join( 87 | os.path.dirname(os.path.abspath(__file__)), 88 | 'openfold/utils/kernel/csrc/' 89 | ) 90 | ], 91 | extra_compile_args={ 92 | 'cxx': ['-O3'] + version_dependent_macros, 93 | 'nvcc': ( 94 | ['-O3', '--use_fast_math'] + 95 | version_dependent_macros + 96 | extra_cuda_flags 97 | ), 98 | } 99 | )] 100 | else: 101 | modules = [CppExtension( 102 | name="attn_core_inplace_cuda", 103 | sources=[ 104 | "openfold/utils/kernel/csrc/softmax_cuda.cpp", 105 | "openfold/utils/kernel/csrc/softmax_cuda_stub.cpp", 106 | ], 107 | extra_compile_args={ 108 | 'cxx': ['-O3'], 109 | } 110 | )] 111 | 112 | setup( 113 | name='openfold', 114 | version='1.0.1', 115 | description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2', 116 | author='Gustaf Ahdritz & DeepMind', 117 | author_email='gahdritz@gmail.com', 118 | license='Apache License, Version 2.0', 119 | url='https://github.com/aqlaboratory/openfold', 120 | packages=find_packages(exclude=["tests", "scripts"]), 121 | include_package_data=True, 122 | package_data={ 123 | "openfold": ['utils/kernel/csrc/*'], 124 | "": ["resources/stereo_chemical_props.txt"] 125 | }, 126 | ext_modules=modules, 127 | cmdclass={'build_ext': BuildExtension}, 128 | classifiers=[ 129 | 'License :: OSI Approved :: Apache Software License', 130 | 'Operating System :: POSIX :: Linux', 131 | 'Programming Language :: Python :: 3.7,' 132 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 133 | ], 134 | ) 135 | -------------------------------------------------------------------------------- /openfold/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jozhang97/MutateEverything/c380ea9fa185b770df25127e8aabbccbd212a074/openfold/tests/__init__.py -------------------------------------------------------------------------------- /openfold/tests/compare_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import importlib 4 | import pkgutil 5 | import sys 6 | import unittest 7 | 8 | import numpy as np 9 | 10 | from openfold.config import model_config 11 | from openfold.model.model import AlphaFold 12 | from openfold.utils.import_weights import import_jax_weights_ 13 | from tests.config import consts 14 | 15 | # Give JAX some GPU memory discipline 16 | # (by default it hogs 90% of GPU memory. This disables that behavior and also 17 | # forces it to proactively free memory that it allocates) 18 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" 19 | os.environ["JAX_PLATFORM_NAME"] = "gpu" 20 | 21 | 22 | def alphafold_is_installed(): 23 | return importlib.util.find_spec("alphafold") is not None 24 | 25 | 26 | def skip_unless_alphafold_installed(): 27 | return unittest.skipUnless(alphafold_is_installed(), "Requires AlphaFold") 28 | 29 | 30 | def import_alphafold(): 31 | """ 32 | If AlphaFold is installed using the provided setuptools script, this 33 | is necessary to expose all of AlphaFold's precious insides 34 | """ 35 | if "alphafold" in sys.modules: 36 | return sys.modules["alphafold"] 37 | module = importlib.import_module("alphafold") 38 | # Forcefully import alphafold's submodules 39 | submodules = pkgutil.walk_packages(module.__path__, prefix=("alphafold.")) 40 | for submodule_info in submodules: 41 | importlib.import_module(submodule_info.name) 42 | sys.modules["alphafold"] = module 43 | globals()["alphafold"] = module 44 | 45 | return module 46 | 47 | 48 | def get_alphafold_config(): 49 | config = alphafold.model.config.model_config("model_1_ptm") # noqa 50 | config.model.global_config.deterministic = True 51 | return config 52 | 53 | 54 | _param_path = "openfold/resources/params/params_model_1_ptm.npz" 55 | _model = None 56 | 57 | 58 | def get_global_pretrained_openfold(): 59 | global _model 60 | if _model is None: 61 | _model = AlphaFold(model_config("model_1_ptm")) 62 | _model = _model.eval() 63 | if not os.path.exists(_param_path): 64 | raise FileNotFoundError( 65 | """Cannot load pretrained parameters. Make sure to run the 66 | installation script before running tests.""" 67 | ) 68 | import_jax_weights_(_model, _param_path, version="model_1_ptm") 69 | _model = _model.cuda() 70 | 71 | return _model 72 | 73 | 74 | _orig_weights = None 75 | 76 | 77 | def _get_orig_weights(): 78 | global _orig_weights 79 | if _orig_weights is None: 80 | _orig_weights = np.load(_param_path) 81 | 82 | return _orig_weights 83 | 84 | 85 | def _remove_key_prefix(d, prefix): 86 | for k, v in list(d.items()): 87 | if k.startswith(prefix): 88 | d.pop(k) 89 | d[k[len(prefix) :]] = v 90 | 91 | 92 | def fetch_alphafold_module_weights(weight_path): 93 | orig_weights = _get_orig_weights() 94 | params = {k: v for k, v in orig_weights.items() if weight_path in k} 95 | if "/" in weight_path: 96 | spl = weight_path.split("/") 97 | spl = spl if len(spl[-1]) != 0 else spl[:-1] 98 | module_name = spl[-1] 99 | prefix = "/".join(spl[:-1]) + "/" 100 | _remove_key_prefix(params, prefix) 101 | 102 | try: 103 | params = alphafold.model.utils.flat_params_to_haiku(params) # noqa 104 | except: 105 | raise ImportError( 106 | "Make sure to call import_alphafold before running this function" 107 | ) 108 | return params 109 | -------------------------------------------------------------------------------- /openfold/tests/config.py: -------------------------------------------------------------------------------- 1 | import ml_collections as mlc 2 | 3 | consts = mlc.ConfigDict( 4 | { 5 | "batch_size": 2, 6 | "n_res": 11, 7 | "n_seq": 13, 8 | "n_templ": 3, 9 | "n_extra": 17, 10 | "n_heads_extra_msa": 8, 11 | "eps": 5e-4, 12 | # For compatibility with DeepMind's pretrained weights, it's easiest for 13 | # everyone if these take their real values. 14 | "c_m": 256, 15 | "c_z": 128, 16 | "c_s": 384, 17 | "c_t": 64, 18 | "c_e": 64, 19 | } 20 | ) 21 | 22 | config = mlc.ConfigDict( 23 | { 24 | "data": { 25 | "common": { 26 | "masked_msa": { 27 | "profile_prob": 0.1, 28 | "same_prob": 0.1, 29 | "uniform_prob": 0.1, 30 | }, 31 | } 32 | } 33 | } 34 | ) 35 | -------------------------------------------------------------------------------- /openfold/tests/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | from scipy.spatial.transform import Rotation 17 | 18 | 19 | def random_template_feats(n_templ, n, batch_size=None): 20 | b = [] 21 | if batch_size is not None: 22 | b.append(batch_size) 23 | batch = { 24 | "template_mask": np.random.randint(0, 2, (*b, n_templ)), 25 | "template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)), 26 | "template_pseudo_beta": np.random.rand(*b, n_templ, n, 3), 27 | "template_aatype": np.random.randint(0, 22, (*b, n_templ, n)), 28 | "template_all_atom_mask": np.random.randint( 29 | 0, 2, (*b, n_templ, n, 37) 30 | ), 31 | "template_all_atom_positions": 32 | np.random.rand(*b, n_templ, n, 37, 3) * 10, 33 | "template_torsion_angles_sin_cos": 34 | np.random.rand(*b, n_templ, n, 7, 2), 35 | "template_alt_torsion_angles_sin_cos": 36 | np.random.rand(*b, n_templ, n, 7, 2), 37 | "template_torsion_angles_mask": 38 | np.random.rand(*b, n_templ, n, 7), 39 | } 40 | batch = {k: v.astype(np.float32) for k, v in batch.items()} 41 | batch["template_aatype"] = batch["template_aatype"].astype(np.int64) 42 | return batch 43 | 44 | 45 | def random_extra_msa_feats(n_extra, n, batch_size=None): 46 | b = [] 47 | if batch_size is not None: 48 | b.append(batch_size) 49 | batch = { 50 | "extra_msa": np.random.randint(0, 22, (*b, n_extra, n)).astype( 51 | np.int64 52 | ), 53 | "extra_has_deletion": np.random.randint(0, 2, (*b, n_extra, n)).astype( 54 | np.float32 55 | ), 56 | "extra_deletion_value": np.random.rand(*b, n_extra, n).astype( 57 | np.float32 58 | ), 59 | "extra_msa_mask": np.random.randint(0, 2, (*b, n_extra, n)).astype( 60 | np.float32 61 | ), 62 | } 63 | return batch 64 | 65 | 66 | def random_affines_vector(dim): 67 | prod_dim = 1 68 | for d in dim: 69 | prod_dim *= d 70 | 71 | affines = np.zeros((prod_dim, 7)).astype(np.float32) 72 | 73 | for i in range(prod_dim): 74 | affines[i, :4] = Rotation.random(random_state=42).as_quat() 75 | affines[i, 4:] = np.random.rand( 76 | 3, 77 | ).astype(np.float32) 78 | 79 | return affines.reshape(*dim, 7) 80 | 81 | 82 | def random_affines_4x4(dim): 83 | prod_dim = 1 84 | for d in dim: 85 | prod_dim *= d 86 | 87 | affines = np.zeros((prod_dim, 4, 4)).astype(np.float32) 88 | 89 | for i in range(prod_dim): 90 | affines[i, :3, :3] = Rotation.random(random_state=42).as_matrix() 91 | affines[i, :3, 3] = np.random.rand( 92 | 3, 93 | ).astype(np.float32) 94 | 95 | affines[:, 3, 3] = 1 96 | 97 | return affines.reshape(*dim, 4, 4) 98 | -------------------------------------------------------------------------------- /openfold/tests/test_data_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pickle 16 | import shutil 17 | 18 | import torch 19 | import numpy as np 20 | import unittest 21 | 22 | from openfold.data.data_pipeline import DataPipeline 23 | from openfold.data.templates import TemplateHitFeaturizer 24 | from openfold.model.embedders import ( 25 | InputEmbedder, 26 | RecyclingEmbedder, 27 | TemplateAngleEmbedder, 28 | TemplatePairEmbedder, 29 | ) 30 | import tests.compare_utils as compare_utils 31 | 32 | if compare_utils.alphafold_is_installed(): 33 | alphafold = compare_utils.import_alphafold() 34 | import jax 35 | import haiku as hk 36 | 37 | 38 | class TestDataPipeline(unittest.TestCase): 39 | @compare_utils.skip_unless_alphafold_installed() 40 | def test_fasta_compare(self): 41 | # AlphaFold runs the alignments and feature processing at the same 42 | # time, taking forever. As such, we precompute AlphaFold's features 43 | # using scripts/generate_alphafold_feature_dict.py and the default 44 | # databases. 45 | with open("tests/test_data/alphafold_feature_dict.pickle", "rb") as fp: 46 | alphafold_feature_dict = pickle.load(fp) 47 | 48 | template_featurizer = TemplateHitFeaturizer( 49 | mmcif_dir="tests/test_data/mmcifs", 50 | max_template_date="2021-12-20", 51 | max_hits=20, 52 | kalign_binary_path=shutil.which("kalign"), 53 | _zero_center_positions=False, 54 | ) 55 | 56 | data_pipeline = DataPipeline( 57 | template_featurizer=template_featurizer, 58 | ) 59 | 60 | openfold_feature_dict = data_pipeline.process_fasta( 61 | "tests/test_data/short.fasta", 62 | "tests/test_data/alignments" 63 | ) 64 | 65 | openfold_feature_dict["template_all_atom_masks"] = openfold_feature_dict["template_all_atom_mask"] 66 | 67 | checked = [] 68 | 69 | # AlphaFold and OpenFold process their MSAs in slightly different 70 | # orders, which we compensate for below. 71 | m_a = alphafold_feature_dict["msa"] 72 | m_o = openfold_feature_dict["msa"] 73 | 74 | # The first row of both MSAs should be the same, no matter what 75 | self.assertTrue(np.all(m_a[0, :] == m_o[0, :])) 76 | 77 | # Each row of each MSA should appear exactly once somewhere in its 78 | # counterpart 79 | matching_rows = np.all((m_a[:, None, ...] == m_o[None, :, ...]), axis=-1) 80 | self.assertTrue( 81 | np.all( 82 | np.sum(matching_rows, axis=-1) == 1 83 | ) 84 | ) 85 | 86 | checked.append("msa") 87 | 88 | # The corresponding rows of the deletion matrix should also be equal 89 | matching_idx = np.argmax(matching_rows, axis=-1) 90 | rearranged_o_dmi = openfold_feature_dict["deletion_matrix_int"] 91 | rearranged_o_dmi = rearranged_o_dmi[matching_idx, :] 92 | self.assertTrue( 93 | np.all( 94 | alphafold_feature_dict["deletion_matrix_int"] == 95 | rearranged_o_dmi 96 | ) 97 | ) 98 | 99 | checked.append("deletion_matrix_int") 100 | 101 | # Remaining features have to be precisely equal 102 | for k, v in alphafold_feature_dict.items(): 103 | self.assertTrue( 104 | k in checked or np.all(v == openfold_feature_dict[k]) 105 | ) 106 | 107 | 108 | 109 | if __name__ == "__main__": 110 | unittest.main() 111 | -------------------------------------------------------------------------------- /openfold/tests/test_embedders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import numpy as np 17 | import unittest 18 | from openfold.model.embedders import ( 19 | InputEmbedder, 20 | RecyclingEmbedder, 21 | TemplateAngleEmbedder, 22 | TemplatePairEmbedder, 23 | ) 24 | 25 | 26 | class TestInputEmbedder(unittest.TestCase): 27 | def test_shape(self): 28 | tf_dim = 2 29 | msa_dim = 3 30 | c_z = 5 31 | c_m = 7 32 | relpos_k = 11 33 | 34 | b = 13 35 | n_res = 17 36 | n_clust = 19 37 | 38 | tf = torch.rand((b, n_res, tf_dim)) 39 | ri = torch.rand((b, n_res)) 40 | msa = torch.rand((b, n_clust, n_res, msa_dim)) 41 | 42 | ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k) 43 | 44 | msa_emb, pair_emb = ie(tf, ri, msa) 45 | self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m)) 46 | self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z)) 47 | 48 | 49 | class TestRecyclingEmbedder(unittest.TestCase): 50 | def test_shape(self): 51 | batch_size = 2 52 | n = 3 53 | c_z = 5 54 | c_m = 7 55 | min_bin = 0 56 | max_bin = 10 57 | no_bins = 9 58 | 59 | re = RecyclingEmbedder(c_m, c_z, min_bin, max_bin, no_bins) 60 | 61 | m_1 = torch.rand((batch_size, n, c_m)) 62 | z = torch.rand((batch_size, n, n, c_z)) 63 | x = torch.rand((batch_size, n, 3)) 64 | 65 | m_1, z = re(m_1, z, x) 66 | 67 | self.assertTrue(z.shape == (batch_size, n, n, c_z)) 68 | self.assertTrue(m_1.shape == (batch_size, n, c_m)) 69 | 70 | 71 | class TestTemplateAngleEmbedder(unittest.TestCase): 72 | def test_shape(self): 73 | template_angle_dim = 51 74 | c_m = 256 75 | batch_size = 4 76 | n_templ = 4 77 | n_res = 256 78 | 79 | tae = TemplateAngleEmbedder( 80 | template_angle_dim, 81 | c_m, 82 | ) 83 | 84 | x = torch.rand((batch_size, n_templ, n_res, template_angle_dim)) 85 | x = tae(x) 86 | 87 | self.assertTrue(x.shape == (batch_size, n_templ, n_res, c_m)) 88 | 89 | 90 | class TestTemplatePairEmbedder(unittest.TestCase): 91 | def test_shape(self): 92 | batch_size = 2 93 | n_templ = 3 94 | n_res = 5 95 | template_pair_dim = 7 96 | c_t = 11 97 | 98 | tpe = TemplatePairEmbedder( 99 | template_pair_dim, 100 | c_t, 101 | ) 102 | 103 | x = torch.rand((batch_size, n_templ, n_res, n_res, template_pair_dim)) 104 | x = tpe(x) 105 | 106 | self.assertTrue(x.shape == (batch_size, n_templ, n_res, n_res, c_t)) 107 | 108 | 109 | if __name__ == "__main__": 110 | unittest.main() 111 | -------------------------------------------------------------------------------- /openfold/tests/test_import_weights.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import numpy as np 17 | import unittest 18 | 19 | from openfold.config import model_config 20 | from openfold.model.model import AlphaFold 21 | from openfold.utils.import_weights import import_jax_weights_ 22 | 23 | 24 | class TestImportWeights(unittest.TestCase): 25 | def test_import_jax_weights_(self): 26 | npz_path = "openfold/resources/params/params_model_1_ptm.npz" 27 | 28 | c = model_config("model_1_ptm") 29 | c.globals.blocks_per_ckpt = None 30 | model = AlphaFold(c) 31 | 32 | import_jax_weights_( 33 | model, 34 | npz_path, 35 | ) 36 | 37 | data = np.load(npz_path) 38 | prefix = "alphafold/alphafold_iteration/" 39 | 40 | test_pairs = [ 41 | # Normal linear weight 42 | ( 43 | torch.as_tensor( 44 | data[ 45 | prefix + "structure_module/initial_projection//weights" 46 | ] 47 | ).transpose(-1, -2), 48 | model.structure_module.linear_in.weight, 49 | ), 50 | # Normal layer norm param 51 | ( 52 | torch.as_tensor( 53 | data[prefix + "evoformer/prev_pair_norm//offset"], 54 | ), 55 | model.recycling_embedder.layer_norm_z.bias, 56 | ), 57 | # From a stack 58 | ( 59 | torch.as_tensor( 60 | data[ 61 | prefix 62 | + ( 63 | "evoformer/evoformer_iteration/outer_product_mean/" 64 | "left_projection//weights" 65 | ) 66 | ][1].transpose(-1, -2) 67 | ), 68 | model.evoformer.blocks[1].core.outer_product_mean.linear_1.weight, 69 | ), 70 | ] 71 | 72 | for w_alpha, w_repro in test_pairs: 73 | self.assertTrue(torch.all(w_alpha == w_repro)) 74 | -------------------------------------------------------------------------------- /openfold/tests/test_kernels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import unittest 5 | 6 | from openfold.model.primitives import _attention 7 | from openfold.utils.kernel.attention_core import attention_core 8 | from tests.config import consts 9 | 10 | 11 | class TestAttentionCore(unittest.TestCase): 12 | def test_attention_core_forward(self): 13 | n_res = consts.n_res 14 | h = consts.n_heads_extra_msa 15 | n_seq = consts.n_extra 16 | c = consts.c_e 17 | dtype = torch.float32 18 | 19 | q = torch.rand([n_seq, h, n_res, c], dtype=dtype).cuda() 20 | k = torch.rand([n_seq, h, n_res, c], dtype=dtype).cuda() 21 | v = torch.rand([n_seq, h, n_res, c], dtype=dtype).cuda() 22 | mask = torch.randint(0, 2, [n_seq, n_res]).cuda() 23 | mask_bias = (1e9 * mask - 1)[..., None, None, :].to(dtype) 24 | 25 | out_repro = attention_core(q, k, v, mask_bias, None) 26 | out_gt = _attention(q, k, v, [mask_bias]) 27 | 28 | self.assertTrue(torch.max(torch.abs(out_repro - out_gt)) < consts.eps) 29 | 30 | def test_attention_core_backward(self): 31 | n_res = consts.n_res 32 | h = consts.n_heads_extra_msa 33 | n_seq = consts.n_extra 34 | c = consts.c_e 35 | dtype = torch.float32 36 | 37 | q = torch.rand( 38 | [n_seq, h, n_res, c], dtype=dtype, requires_grad=True 39 | ).cuda() 40 | k = torch.rand( 41 | [n_seq, h, n_res, c], dtype=dtype, requires_grad=True 42 | ).cuda() 43 | v = torch.rand( 44 | [n_seq, h, n_res, c], dtype=dtype, requires_grad=True 45 | ).cuda() 46 | mask = torch.randint(0, 2, [n_seq, n_res]).cuda() 47 | mask_bias = (1e9 * mask - 1)[..., None, None, :].to(dtype) 48 | 49 | def clone(t): 50 | t = t.clone() 51 | if(t.requires_grad): 52 | t.retain_grad() 53 | return t 54 | 55 | q_repro = clone(q) 56 | k_repro = clone(k) 57 | v_repro = clone(v) 58 | out_repro = attention_core( 59 | q_repro, k_repro, v_repro, mask_bias, None 60 | ) 61 | 62 | loss_repro = torch.mean(out_repro) 63 | loss_repro.backward() 64 | 65 | q_gt = clone(q) 66 | k_gt = clone(k) 67 | v_gt = clone(v) 68 | out_gt = _attention( 69 | q_gt, k_gt, v_gt, [mask_bias] 70 | ) 71 | 72 | loss_gt = torch.mean(out_gt) 73 | loss_gt.backward() 74 | 75 | pairs = zip([q_repro, k_repro, v_repro], [q_gt, k_gt, v_gt]) 76 | for t_repro, t_gt in pairs: 77 | self.assertTrue( 78 | torch.max(torch.abs(t_repro.grad - t_gt.grad)) < consts.eps 79 | ) 80 | 81 | 82 | if __name__ == '__main__': 83 | unittest.main() 84 | 85 | -------------------------------------------------------------------------------- /openfold/tests/test_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pickle 16 | import torch 17 | import torch.nn as nn 18 | import numpy as np 19 | import unittest 20 | from openfold.config import model_config 21 | from openfold.data import data_transforms 22 | from openfold.model.model import AlphaFold 23 | import openfold.utils.feats as feats 24 | from openfold.utils.tensor_utils import tree_map, tensor_tree_map 25 | import tests.compare_utils as compare_utils 26 | from tests.config import consts 27 | from tests.data_utils import ( 28 | random_template_feats, 29 | random_extra_msa_feats, 30 | ) 31 | 32 | if compare_utils.alphafold_is_installed(): 33 | alphafold = compare_utils.import_alphafold() 34 | import jax 35 | import haiku as hk 36 | 37 | 38 | class TestModel(unittest.TestCase): 39 | def test_dry_run(self): 40 | n_seq = consts.n_seq 41 | n_templ = consts.n_templ 42 | n_res = consts.n_res 43 | n_extra_seq = consts.n_extra 44 | 45 | c = model_config("model_1") 46 | c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here 47 | c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up 48 | # deepspeed for this test 49 | 50 | model = AlphaFold(c) 51 | 52 | batch = {} 53 | tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,)) 54 | batch["target_feat"] = nn.functional.one_hot( 55 | tf, c.model.input_embedder.tf_dim 56 | ).float() 57 | batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1) 58 | batch["residue_index"] = torch.arange(n_res) 59 | batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim)) 60 | t_feats = random_template_feats(n_templ, n_res) 61 | batch.update({k: torch.tensor(v) for k, v in t_feats.items()}) 62 | extra_feats = random_extra_msa_feats(n_extra_seq, n_res) 63 | batch.update({k: torch.tensor(v) for k, v in extra_feats.items()}) 64 | batch["msa_mask"] = torch.randint( 65 | low=0, high=2, size=(n_seq, n_res) 66 | ).float() 67 | batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float() 68 | batch.update(data_transforms.make_atom14_masks(batch)) 69 | batch["no_recycling_iters"] = torch.tensor(2.) 70 | 71 | add_recycling_dims = lambda t: ( 72 | t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters) 73 | ) 74 | batch = tensor_tree_map(add_recycling_dims, batch) 75 | 76 | with torch.no_grad(): 77 | out = model(batch) 78 | 79 | @compare_utils.skip_unless_alphafold_installed() 80 | def test_compare(self): 81 | def run_alphafold(batch): 82 | config = compare_utils.get_alphafold_config() 83 | model = alphafold.model.modules.AlphaFold(config.model) 84 | return model( 85 | batch=batch, 86 | is_training=False, 87 | return_representations=True, 88 | ) 89 | 90 | f = hk.transform(run_alphafold) 91 | 92 | params = compare_utils.fetch_alphafold_module_weights("") 93 | 94 | with open("tests/test_data/sample_feats.pickle", "rb") as fp: 95 | batch = pickle.load(fp) 96 | 97 | out_gt = f.apply(params, jax.random.PRNGKey(42), batch) 98 | 99 | out_gt = out_gt["structure_module"]["final_atom_positions"] 100 | # atom37_to_atom14 doesn't like batches 101 | batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0] 102 | batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0] 103 | out_gt = alphafold.model.all_atom.atom37_to_atom14(out_gt, batch) 104 | out_gt = torch.as_tensor(np.array(out_gt.block_until_ready())) 105 | 106 | batch["no_recycling_iters"] = np.array([3., 3., 3., 3.,]) 107 | batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()} 108 | 109 | batch["aatype"] = batch["aatype"].long() 110 | batch["template_aatype"] = batch["template_aatype"].long() 111 | batch["extra_msa"] = batch["extra_msa"].long() 112 | batch["residx_atom37_to_atom14"] = batch[ 113 | "residx_atom37_to_atom14" 114 | ].long() 115 | batch["template_all_atom_mask"] = batch["template_all_atom_masks"] 116 | batch.update( 117 | data_transforms.atom37_to_torsion_angles("template_")(batch) 118 | ) 119 | 120 | # Move the recycling dimension to the end 121 | move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0) 122 | batch = tensor_tree_map(move_dim, batch) 123 | 124 | with torch.no_grad(): 125 | model = compare_utils.get_global_pretrained_openfold() 126 | out_repro = model(batch) 127 | 128 | out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro) 129 | 130 | out_repro = out_repro["sm"]["positions"][-1] 131 | out_repro = out_repro.squeeze(0) 132 | 133 | self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3) 134 | -------------------------------------------------------------------------------- /openfold/tests/test_outer_product_mean.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import numpy as np 17 | import unittest 18 | from openfold.model.outer_product_mean import OuterProductMean 19 | from openfold.utils.tensor_utils import tree_map 20 | import tests.compare_utils as compare_utils 21 | from tests.config import consts 22 | 23 | if compare_utils.alphafold_is_installed(): 24 | alphafold = compare_utils.import_alphafold() 25 | import jax 26 | import haiku as hk 27 | 28 | 29 | class TestOuterProductMean(unittest.TestCase): 30 | def test_shape(self): 31 | c = 31 32 | 33 | opm = OuterProductMean(consts.c_m, consts.c_z, c) 34 | 35 | m = torch.rand( 36 | (consts.batch_size, consts.n_seq, consts.n_res, consts.c_m) 37 | ) 38 | mask = torch.randint( 39 | 0, 2, size=(consts.batch_size, consts.n_seq, consts.n_res) 40 | ) 41 | m = opm(m, mask=mask, chunk_size=None) 42 | 43 | self.assertTrue( 44 | m.shape == 45 | (consts.batch_size, consts.n_res, consts.n_res, consts.c_z) 46 | ) 47 | 48 | @compare_utils.skip_unless_alphafold_installed() 49 | def test_opm_compare(self): 50 | def run_opm(msa_act, msa_mask): 51 | config = compare_utils.get_alphafold_config() 52 | c_evo = config.model.embeddings_and_evoformer.evoformer 53 | opm = alphafold.model.modules.OuterProductMean( 54 | c_evo.outer_product_mean, 55 | config.model.global_config, 56 | consts.c_z, 57 | ) 58 | act = opm(act=msa_act, mask=msa_mask) 59 | return act 60 | 61 | f = hk.transform(run_opm) 62 | 63 | n_res = consts.n_res 64 | n_seq = consts.n_seq 65 | c_m = consts.c_m 66 | 67 | msa_act = np.random.rand(n_seq, n_res, c_m).astype(np.float32) * 100 68 | msa_mask = np.random.randint(low=0, high=2, size=(n_seq, n_res)).astype( 69 | np.float32 70 | ) 71 | 72 | # Fetch pretrained parameters (but only from one block)] 73 | params = compare_utils.fetch_alphafold_module_weights( 74 | "alphafold/alphafold_iteration/evoformer/" 75 | + "evoformer_iteration/outer_product_mean" 76 | ) 77 | params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) 78 | 79 | out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready() 80 | out_gt = torch.as_tensor(np.array(out_gt)) 81 | 82 | model = compare_utils.get_global_pretrained_openfold() 83 | out_repro = ( 84 | model.evoformer.blocks[0].core 85 | .outer_product_mean( 86 | torch.as_tensor(msa_act).cuda(), 87 | chunk_size=4, 88 | mask=torch.as_tensor(msa_mask).cuda(), 89 | ) 90 | .cpu() 91 | ) 92 | 93 | # Even when correct, OPM has large, precision-related errors. It gets 94 | # a special pass from consts.eps. 95 | self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 5e-4) 96 | 97 | 98 | if __name__ == "__main__": 99 | unittest.main() 100 | -------------------------------------------------------------------------------- /openfold/tests/test_pair_transition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import numpy as np 17 | import unittest 18 | from openfold.model.pair_transition import PairTransition 19 | from openfold.utils.tensor_utils import tree_map 20 | import tests.compare_utils as compare_utils 21 | from tests.config import consts 22 | 23 | if compare_utils.alphafold_is_installed(): 24 | alphafold = compare_utils.import_alphafold() 25 | import jax 26 | import haiku as hk 27 | 28 | 29 | class TestPairTransition(unittest.TestCase): 30 | def test_shape(self): 31 | c_z = consts.c_z 32 | n = 4 33 | 34 | pt = PairTransition(c_z, n) 35 | 36 | batch_size = consts.batch_size 37 | n_res = consts.n_res 38 | 39 | z = torch.rand((batch_size, n_res, n_res, c_z)) 40 | mask = torch.randint(0, 2, size=(batch_size, n_res, n_res)) 41 | shape_before = z.shape 42 | z = pt(z, mask=mask, chunk_size=None) 43 | shape_after = z.shape 44 | 45 | self.assertTrue(shape_before == shape_after) 46 | 47 | @compare_utils.skip_unless_alphafold_installed() 48 | def test_compare(self): 49 | def run_pair_transition(pair_act, pair_mask): 50 | config = compare_utils.get_alphafold_config() 51 | c_e = config.model.embeddings_and_evoformer.evoformer 52 | pt = alphafold.model.modules.Transition( 53 | c_e.pair_transition, 54 | config.model.global_config, 55 | name="pair_transition", 56 | ) 57 | act = pt(act=pair_act, mask=pair_mask) 58 | return act 59 | 60 | f = hk.transform(run_pair_transition) 61 | 62 | n_res = consts.n_res 63 | 64 | pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32) 65 | pair_mask = np.ones((n_res, n_res)).astype(np.float32) # no mask 66 | 67 | # Fetch pretrained parameters (but only from one block)] 68 | params = compare_utils.fetch_alphafold_module_weights( 69 | "alphafold/alphafold_iteration/evoformer/evoformer_iteration/" 70 | + "pair_transition" 71 | ) 72 | params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) 73 | 74 | out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() 75 | out_gt = torch.as_tensor(np.array(out_gt.block_until_ready())) 76 | 77 | model = compare_utils.get_global_pretrained_openfold() 78 | out_repro = ( 79 | model.evoformer.blocks[0].core 80 | .pair_transition( 81 | torch.as_tensor(pair_act, dtype=torch.float32).cuda(), 82 | chunk_size=4, 83 | mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), 84 | ) 85 | .cpu() 86 | ) 87 | 88 | self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) 89 | 90 | 91 | if __name__ == "__main__": 92 | unittest.main() 93 | -------------------------------------------------------------------------------- /openfold/tests/test_primitives.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import numpy as np 17 | import unittest 18 | 19 | from openfold.model.primitives import ( 20 | Attention, 21 | ) 22 | from tests.config import consts 23 | 24 | 25 | class TestLMA(unittest.TestCase): 26 | def test_lma_vs_attention(self): 27 | batch_size = consts.batch_size 28 | c_hidden = 32 29 | n = 2**12 30 | no_heads = 4 31 | 32 | q = torch.rand(batch_size, n, c_hidden).cuda() 33 | kv = torch.rand(batch_size, n, c_hidden).cuda() 34 | 35 | bias = [torch.rand(no_heads, 1, n)] 36 | bias = [b.cuda() for b in bias] 37 | 38 | gating_fill = torch.rand(c_hidden * no_heads, c_hidden) 39 | o_fill = torch.rand(c_hidden, c_hidden * no_heads) 40 | 41 | a = Attention( 42 | c_hidden, c_hidden, c_hidden, c_hidden, no_heads 43 | ).cuda() 44 | 45 | with torch.no_grad(): 46 | l = a(q, kv, biases=bias, use_lma=True) 47 | real = a(q, kv, biases=bias) 48 | 49 | self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps) 50 | 51 | 52 | if __name__ == "__main__": 53 | unittest.main() 54 | -------------------------------------------------------------------------------- /openfold/tests/test_triangular_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import copy 15 | 16 | import torch 17 | import numpy as np 18 | import unittest 19 | from openfold.model.triangular_attention import TriangleAttention 20 | from openfold.utils.tensor_utils import tree_map 21 | 22 | import tests.compare_utils as compare_utils 23 | from tests.config import consts 24 | 25 | if compare_utils.alphafold_is_installed(): 26 | alphafold = compare_utils.import_alphafold() 27 | import jax 28 | import haiku as hk 29 | 30 | 31 | class TestTriangularAttention(unittest.TestCase): 32 | def test_shape(self): 33 | c_z = consts.c_z 34 | c = 12 35 | no_heads = 4 36 | starting = True 37 | 38 | tan = TriangleAttention(c_z, c, no_heads, starting) 39 | 40 | batch_size = consts.batch_size 41 | n_res = consts.n_res 42 | 43 | x = torch.rand((batch_size, n_res, n_res, c_z)) 44 | shape_before = x.shape 45 | x = tan(x, chunk_size=None) 46 | shape_after = x.shape 47 | 48 | self.assertTrue(shape_before == shape_after) 49 | 50 | def _tri_att_compare(self, starting=False): 51 | name = ( 52 | "triangle_attention_" 53 | + ("starting" if starting else "ending") 54 | + "_node" 55 | ) 56 | 57 | def run_tri_att(pair_act, pair_mask): 58 | config = compare_utils.get_alphafold_config() 59 | c_e = config.model.embeddings_and_evoformer.evoformer 60 | tri_att = alphafold.model.modules.TriangleAttention( 61 | c_e.triangle_attention_starting_node 62 | if starting 63 | else c_e.triangle_attention_ending_node, 64 | config.model.global_config, 65 | name=name, 66 | ) 67 | act = tri_att(pair_act=pair_act, pair_mask=pair_mask) 68 | return act 69 | 70 | f = hk.transform(run_tri_att) 71 | 72 | n_res = consts.n_res 73 | 74 | pair_act = np.random.rand(n_res, n_res, consts.c_z) * 100 75 | pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res)) 76 | 77 | # Fetch pretrained parameters (but only from one block)] 78 | params = compare_utils.fetch_alphafold_module_weights( 79 | "alphafold/alphafold_iteration/evoformer/evoformer_iteration/" 80 | + name 81 | ) 82 | params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) 83 | 84 | out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() 85 | out_gt = torch.as_tensor(np.array(out_gt)) 86 | 87 | model = compare_utils.get_global_pretrained_openfold() 88 | module = ( 89 | model.evoformer.blocks[0].core.tri_att_start 90 | if starting 91 | else model.evoformer.blocks[0].core.tri_att_end 92 | ) 93 | 94 | # To save memory, the full model transposes inputs outside of the 95 | # triangle attention module. We adjust the module here. 96 | module = copy.deepcopy(module) 97 | module.starting = starting 98 | 99 | out_repro = module( 100 | torch.as_tensor(pair_act, dtype=torch.float32).cuda(), 101 | mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), 102 | chunk_size=None, 103 | ).cpu() 104 | 105 | self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) 106 | 107 | @compare_utils.skip_unless_alphafold_installed() 108 | def test_tri_att_end_compare(self): 109 | self._tri_att_compare() 110 | 111 | @compare_utils.skip_unless_alphafold_installed() 112 | def test_tri_att_start_compare(self): 113 | self._tri_att_compare(starting=True) 114 | 115 | 116 | if __name__ == "__main__": 117 | unittest.main() 118 | -------------------------------------------------------------------------------- /openfold/tests/test_triangular_multiplicative_update.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import numpy as np 17 | import unittest 18 | from openfold.model.triangular_multiplicative_update import * 19 | from openfold.utils.tensor_utils import tree_map 20 | import tests.compare_utils as compare_utils 21 | from tests.config import consts 22 | 23 | if compare_utils.alphafold_is_installed(): 24 | alphafold = compare_utils.import_alphafold() 25 | import jax 26 | import haiku as hk 27 | 28 | 29 | class TestTriangularMultiplicativeUpdate(unittest.TestCase): 30 | def test_shape(self): 31 | c_z = consts.c_z 32 | c = 11 33 | 34 | tm = TriangleMultiplicationOutgoing( 35 | c_z, 36 | c, 37 | ) 38 | 39 | n_res = consts.c_z 40 | batch_size = consts.batch_size 41 | 42 | x = torch.rand((batch_size, n_res, n_res, c_z)) 43 | mask = torch.randint(0, 2, size=(batch_size, n_res, n_res)) 44 | shape_before = x.shape 45 | x = tm(x, mask) 46 | shape_after = x.shape 47 | 48 | self.assertTrue(shape_before == shape_after) 49 | 50 | def _tri_mul_compare(self, incoming=False): 51 | name = "triangle_multiplication_" + ( 52 | "incoming" if incoming else "outgoing" 53 | ) 54 | 55 | def run_tri_mul(pair_act, pair_mask): 56 | config = compare_utils.get_alphafold_config() 57 | c_e = config.model.embeddings_and_evoformer.evoformer 58 | tri_mul = alphafold.model.modules.TriangleMultiplication( 59 | c_e.triangle_multiplication_incoming 60 | if incoming 61 | else c_e.triangle_multiplication_outgoing, 62 | config.model.global_config, 63 | name=name, 64 | ) 65 | act = tri_mul(act=pair_act, mask=pair_mask) 66 | return act 67 | 68 | f = hk.transform(run_tri_mul) 69 | 70 | n_res = consts.n_res 71 | 72 | pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32) 73 | pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res)) 74 | pair_mask = pair_mask.astype(np.float32) 75 | 76 | # Fetch pretrained parameters (but only from one block)] 77 | params = compare_utils.fetch_alphafold_module_weights( 78 | "alphafold/alphafold_iteration/evoformer/evoformer_iteration/" 79 | + name 80 | ) 81 | params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) 82 | 83 | out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() 84 | out_gt = torch.as_tensor(np.array(out_gt)) 85 | 86 | model = compare_utils.get_global_pretrained_openfold() 87 | module = ( 88 | model.evoformer.blocks[0].core.tri_mul_in 89 | if incoming 90 | else model.evoformer.blocks[0].core.tri_mul_out 91 | ) 92 | out_repro = module( 93 | torch.as_tensor(pair_act, dtype=torch.float32).cuda(), 94 | mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), 95 | inplace_safe=True, _inplace_chunk_size=4, 96 | ).cpu() 97 | 98 | self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) 99 | 100 | @compare_utils.skip_unless_alphafold_installed() 101 | def test_tri_mul_out_compare(self): 102 | self._tri_mul_compare() 103 | 104 | @compare_utils.skip_unless_alphafold_installed() 105 | def test_tri_mul_in_compare(self): 106 | self._tri_mul_compare(incoming=True) 107 | 108 | def _tri_mul_inplace(self, incoming=False): 109 | n_res = consts.n_res 110 | 111 | pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32) 112 | pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res)) 113 | pair_mask = pair_mask.astype(np.float32) 114 | 115 | 116 | model = compare_utils.get_global_pretrained_openfold() 117 | module = ( 118 | model.evoformer.blocks[0].core.tri_mul_in 119 | if incoming 120 | else model.evoformer.blocks[0].core.tri_mul_out 121 | ) 122 | out_stock = module( 123 | torch.as_tensor(pair_act, dtype=torch.float32).cuda(), 124 | mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), 125 | inplace_safe=False, 126 | ).cpu() 127 | 128 | # This has to come second because inference mode is in-place 129 | out_inplace = module( 130 | torch.as_tensor(pair_act, dtype=torch.float32).cuda(), 131 | mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), 132 | inplace_safe=True, _inplace_chunk_size=2, 133 | ).cpu() 134 | 135 | self.assertTrue(torch.mean(torch.abs(out_stock - out_inplace)) < consts.eps) 136 | 137 | def test_tri_mul_out_inference(self): 138 | self._tri_mul_inplace() 139 | 140 | def test_tri_mul_in_inference(self): 141 | self._tri_mul_inplace(incoming=True) 142 | 143 | if __name__ == "__main__": 144 | unittest.main() 145 | --------------------------------------------------------------------------------