├── tests ├── __init__.py ├── test_loss_xyz.py ├── test_data │ ├── features.pkl │ ├── short.fasta │ └── alphafold_feature_dict.pickle ├── config.py ├── test_primitives.py ├── test_kernels.py ├── test_embedders.py ├── data_utils.py ├── test_pair_transition.py ├── compare_utils.py ├── test_outer_product_mean.py ├── test_triangular_attention.py ├── test_data_pipeline.py ├── test_permutation.py ├── test_triangular_multiplicative_update.py └── test_model.py ├── opencomplex ├── data │ ├── __init__.py │ ├── tools │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── hhsearch.py │ │ ├── kalign.py │ │ └── hhblits.py │ ├── errors.py │ └── feature_pipeline.py ├── config │ ├── __init__.py │ └── references.py ├── model │ ├── sm │ │ ├── __init__.py │ │ └── utils.py │ ├── __init__.py │ ├── dropout.py │ ├── pair_transition.py │ ├── triangular_attention.py │ ├── outer_product_mean.py │ └── torchscript.py ├── resources │ ├── __init__.py │ └── stereo_chemical_props_RNA.txt ├── utils │ ├── kernel │ │ ├── __init__.py │ │ ├── csrc │ │ │ ├── compat.h │ │ │ └── softmax_cuda.cpp │ │ └── attention_core.py │ ├── __init__.py │ ├── callbacks.py │ ├── seed.py │ ├── suppress_output.py │ ├── argparse.py │ ├── exponential_moving_average.py │ ├── complex_utils.py │ ├── logger.py │ ├── validation_metrics.py │ ├── lr_schedulers.py │ ├── checkpointing.py │ ├── superimposition.py │ ├── tensor_utils.py │ └── feats_rna.py ├── faster_alphafold │ ├── __init__.py │ ├── libths_faster_alphafold.so │ ├── faster_alphafold_config.py │ └── faster_alphafold.py ├── __init__.py ├── loss │ └── __init__.py └── np │ ├── __init__.py │ └── relax │ ├── __init__.py │ ├── utils.py │ ├── relax.py │ └── cleanup.py ├── example_data ├── filters │ ├── rna_filter.txt │ ├── complex_filter.txt │ └── protein_filter.txt ├── fasta │ ├── 1a4t_B.fasta │ ├── 1biv_B.fasta │ └── 4P9R_A.fasta ├── features │ ├── 1A4T │ │ └── features.pth │ ├── 1MNB │ │ └── features.pth │ ├── 2jpw_A │ │ └── features.pkl │ ├── 2mvi_A │ │ └── features.pkl │ └── 4P9R_A │ │ └── features.pkl └── scripts │ ├── infer_complex.sh │ ├── infer_rna.sh │ ├── infer_protein.sh │ ├── train_complex.sh │ └── train_protein.sh ├── img ├── cases.png └── logo.png ├── scripts ├── vars.sh ├── activate_conda_env.sh ├── run_unit_tests.sh ├── install_rmsa_petfold.sh ├── install_hh_suite.sh ├── install_third_party_dependencies.sh ├── utils.py ├── rna_extract_pkl_from_fas.py └── extract_pkl_from_fas.py ├── OpenComplex I - tech report.pdf ├── deepspeed_config.json ├── environment.yml ├── compute_metrics.py ├── lib └── openmm.patch ├── .gitignore ├── setup.py └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_loss_xyz.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /opencomplex/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /opencomplex/config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /opencomplex/model/sm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /opencomplex/resources/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /opencomplex/data/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /opencomplex/utils/kernel/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /example_data/filters/rna_filter.txt: -------------------------------------------------------------------------------- 1 | 4P9R_A -------------------------------------------------------------------------------- /opencomplex/faster_alphafold/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /example_data/filters/complex_filter.txt: -------------------------------------------------------------------------------- 1 | 1A4T 2 | 1MNB -------------------------------------------------------------------------------- /example_data/filters/protein_filter.txt: -------------------------------------------------------------------------------- 1 | 2jpw_A 2 | 2mvi_A -------------------------------------------------------------------------------- /example_data/fasta/1a4t_B.fasta: -------------------------------------------------------------------------------- 1 | >1a4t_B 2 | NAKTRRHERRRKLAIERDT -------------------------------------------------------------------------------- /example_data/fasta/1biv_B.fasta: -------------------------------------------------------------------------------- 1 | >1biv_B 2 | SGPRPRGTRGKGRRIRR -------------------------------------------------------------------------------- /img/cases.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ocx-lab/OpenComplex/HEAD/img/cases.png -------------------------------------------------------------------------------- /img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ocx-lab/OpenComplex/HEAD/img/logo.png -------------------------------------------------------------------------------- /scripts/vars.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ENV_NAME=opencomplex_venv 4 | CONDA_PATH=`conda info --base` 5 | -------------------------------------------------------------------------------- /tests/test_data/features.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ocx-lab/OpenComplex/HEAD/tests/test_data/features.pkl -------------------------------------------------------------------------------- /OpenComplex I - tech report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ocx-lab/OpenComplex/HEAD/OpenComplex I - tech report.pdf -------------------------------------------------------------------------------- /tests/test_data/short.fasta: -------------------------------------------------------------------------------- 1 | >query 2 | MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH 3 | -------------------------------------------------------------------------------- /example_data/features/1A4T/features.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ocx-lab/OpenComplex/HEAD/example_data/features/1A4T/features.pth -------------------------------------------------------------------------------- /example_data/features/1MNB/features.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ocx-lab/OpenComplex/HEAD/example_data/features/1MNB/features.pth -------------------------------------------------------------------------------- /example_data/features/2jpw_A/features.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ocx-lab/OpenComplex/HEAD/example_data/features/2jpw_A/features.pkl -------------------------------------------------------------------------------- /example_data/features/2mvi_A/features.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ocx-lab/OpenComplex/HEAD/example_data/features/2mvi_A/features.pkl -------------------------------------------------------------------------------- /example_data/features/4P9R_A/features.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ocx-lab/OpenComplex/HEAD/example_data/features/4P9R_A/features.pkl -------------------------------------------------------------------------------- /tests/test_data/alphafold_feature_dict.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ocx-lab/OpenComplex/HEAD/tests/test_data/alphafold_feature_dict.pickle -------------------------------------------------------------------------------- /scripts/activate_conda_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source scripts/vars.sh 4 | 5 | source $CONDA_PATH/etc/profile.d/conda.sh 6 | conda activate $ENV_NAME 7 | -------------------------------------------------------------------------------- /opencomplex/faster_alphafold/libths_faster_alphafold.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ocx-lab/OpenComplex/HEAD/opencomplex/faster_alphafold/libths_faster_alphafold.so -------------------------------------------------------------------------------- /opencomplex/__init__.py: -------------------------------------------------------------------------------- 1 | from . import model 2 | from . import utils 3 | from . import np 4 | from . import resources 5 | 6 | __all__ = ["model", "utils", "np", "data", "resources"] 7 | -------------------------------------------------------------------------------- /scripts/run_unit_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES="0" 4 | 5 | python3 -m unittest "$@" || \ 6 | echo -e "\nTest(s) failed. Make sure you've installed all Python dependencies." 7 | -------------------------------------------------------------------------------- /scripts/install_rmsa_petfold.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | git clone https://github.com/kad-ecoli/rMSA opencomplex/resources/RNA \ 4 | && cd rMSA \ 5 | && ./database/script/update.sh # Download RNAcentral and nt -------------------------------------------------------------------------------- /example_data/fasta/4P9R_A.fasta: -------------------------------------------------------------------------------- 1 | >4P9R_1|Chain A|RNA (189-MER)|Didymium iridis (5793) 2 | CAUCCGGUAUCCCAAGACAAUCUUCGGGUUGGGUUGGGAAGUAUCAUGGCUAAUCACCAUGAUGCAAUCGGGUUGAACACUUAAUUGGGUUAAAACGGUGGGGGACGAUCCCGUAACAUCCGUCCUAACGGCGACAGACUGCACGGCCCUGCCUCUUAGGUGUGUCCAAUGAACAGUCGUUCCGAAAGGAAG 3 | -------------------------------------------------------------------------------- /opencomplex/utils/kernel/csrc/compat.h: -------------------------------------------------------------------------------- 1 | // modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h 2 | 3 | #ifndef TORCH_CHECK 4 | #define TORCH_CHECK AT_CHECK 5 | #endif 6 | 7 | #ifdef VERSION_GE_1_3 8 | #define DATA_PTR data_ptr 9 | #else 10 | #define DATA_PTR data 11 | #endif 12 | -------------------------------------------------------------------------------- /scripts/install_hh_suite.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | git clone --branch v3.3.0 https://github.com/soedinglab/hh-suite.git /tmp/hh-suite \ 4 | && mkdir /tmp/hh-suite/build \ 5 | && pushd /tmp/hh-suite/build \ 6 | && cmake -DCMAKE_INSTALL_PREFIX=/opt/hhsuite .. \ 7 | && make -j 4 && make install \ 8 | && ln -sf /opt/hhsuite/bin/* /usr/bin \ 9 | && popd \ 10 | && rm -rf /tmp/hh-suite 11 | -------------------------------------------------------------------------------- /opencomplex/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import importlib as importlib 4 | 5 | _files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py")) 6 | __all__ = [ 7 | os.path.basename(f)[:-3] 8 | for f in _files 9 | if os.path.isfile(f) and not f.endswith("__init__.py") 10 | ] 11 | _modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__] 12 | for _m in _modules: 13 | globals()[_m[0]] = _m[1] 14 | 15 | # Avoid needlessly cluttering the global namespace 16 | del _files, _m, _modules 17 | -------------------------------------------------------------------------------- /opencomplex/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import importlib as importlib 4 | 5 | _files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py")) 6 | __all__ = [ 7 | os.path.basename(f)[:-3] 8 | for f in _files 9 | if os.path.isfile(f) and not f.endswith("__init__.py") 10 | ] 11 | _modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__] 12 | for _m in _modules: 13 | globals()[_m[0]] = _m[1] 14 | 15 | # Avoid needlessly cluttering the global namespace 16 | del _files, _m, _modules 17 | -------------------------------------------------------------------------------- /opencomplex/np/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import importlib as importlib 4 | 5 | _files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py")) 6 | __all__ = [ 7 | os.path.basename(f)[:-3] 8 | for f in _files 9 | if os.path.isfile(f) and not f.endswith("__init__.py") 10 | ] 11 | _modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__] 12 | for _m in _modules: 13 | globals()[_m[0]] = _m[1] 14 | 15 | # Avoid needlessly cluttering the global namespace 16 | del _files, _m, _modules 17 | -------------------------------------------------------------------------------- /opencomplex/np/relax/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import importlib as importlib 4 | 5 | _files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py")) 6 | __all__ = [ 7 | os.path.basename(f)[:-3] 8 | for f in _files 9 | if os.path.isfile(f) and not f.endswith("__init__.py") 10 | ] 11 | _modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__] 12 | for _m in _modules: 13 | globals()[_m[0]] = _m[1] 14 | 15 | # Avoid needlessly cluttering the global namespace 16 | del _files, _m, _modules 17 | -------------------------------------------------------------------------------- /deepspeed_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": false, 4 | "min_loss_scale": 1 5 | }, 6 | "amp": { 7 | "enabled": false, 8 | "opt_level": "O2" 9 | }, 10 | "bfloat16": { 11 | "enabled": true 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "cpu_offload": true, 16 | "contiguous_gradients": true 17 | }, 18 | "activation_checkpointing": { 19 | "partition_activations": true, 20 | "cpu_checkpointing": false, 21 | "profile": false 22 | }, 23 | "gradient_clipping": 0.1 24 | } 25 | -------------------------------------------------------------------------------- /opencomplex/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import importlib as importlib 4 | 5 | from . import kernel 6 | 7 | _files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py")) 8 | __all__ = [ 9 | os.path.basename(f)[:-3] 10 | for f in _files 11 | if os.path.isfile(f) and not f.endswith("__init__.py") 12 | ] + ["kernel"] 13 | _modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__] 14 | for _m in _modules: 15 | globals()[_m[0]] = _m[1] 16 | 17 | # Avoid needlessly cluttering the global namespace 18 | del _files, _m, _modules 19 | -------------------------------------------------------------------------------- /opencomplex/faster_alphafold/faster_alphafold_config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import ml_collections as mlc 3 | 4 | 5 | faster_alphafold_config = mlc.ConfigDict({ 6 | 'layer_norm': True, 7 | 'softmax': False, 8 | 'attention': True, 9 | 'outer_product_mean': True, 10 | 'triangle_multiplicative_update': True, 11 | }) if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else mlc.ConfigDict({ 12 | 'layer_norm': False, 13 | 'softmax': False, 14 | 'attention': False, 15 | 'outer_product_mean': False, 16 | 'triangle_multiplicative_update': False, 17 | }) 18 | -------------------------------------------------------------------------------- /opencomplex/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.utilities import rank_zero_info 2 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 3 | 4 | class EarlyStoppingVerbose(EarlyStopping): 5 | """ 6 | The default EarlyStopping callback's verbose mode is too verbose. 7 | This class outputs a message only when it's getting ready to stop. 8 | """ 9 | def _evalute_stopping_criteria(self, *args, **kwargs): 10 | should_stop, reason = super()._evalute_stopping_criteria(*args, **kwargs) 11 | if(should_stop): 12 | rank_zero_info(f"{reason}\n") 13 | 14 | return should_stop, reason 15 | -------------------------------------------------------------------------------- /opencomplex/utils/seed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import random 4 | import numpy as np 5 | from pytorch_lightning.utilities.seed import seed_everything 6 | 7 | from opencomplex.utils.suppress_output import SuppressLogging 8 | 9 | 10 | def seed_globally(seed=None): 11 | if("PL_GLOBAL_SEED" not in os.environ): 12 | if(seed is None): 13 | seed = random.randint(0, np.iinfo(np.uint32).max) 14 | os.environ["PL_GLOBAL_SEED"] = str(seed) 15 | logging.info(f'os.environ["PL_GLOBAL_SEED"] set to {seed}') 16 | 17 | # seed_everything is a bit log-happy 18 | with SuppressLogging(logging.INFO): 19 | seed_everything(seed=None) 20 | -------------------------------------------------------------------------------- /opencomplex/utils/suppress_output.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | 5 | class SuppressStdout: 6 | def __enter__(self): 7 | self.stdout = sys.stdout 8 | dev_null = open("/dev/null", "w") 9 | sys.stdout = dev_null 10 | 11 | def __exit__(self, typ, value, traceback): 12 | fp = sys.stdout 13 | sys.stdout = self.stdout 14 | fp.close() 15 | 16 | 17 | class SuppressLogging: 18 | def __init__(self, level): 19 | self.level = level 20 | 21 | def __enter__(self): 22 | logging.disable(self.level) 23 | 24 | def __exit__(self, typ, value, traceback): 25 | logging.disable(logging.NOTSET) 26 | 27 | -------------------------------------------------------------------------------- /tests/config.py: -------------------------------------------------------------------------------- 1 | import ml_collections as mlc 2 | 3 | consts = mlc.ConfigDict( 4 | { 5 | "batch_size": 2, 6 | "n_res": 11, 7 | "n_seq": 13, 8 | "n_templ": 3, 9 | "n_extra": 17, 10 | "eps": 5e-4, 11 | # For compatibility with DeepMind's pretrained weights, it's easiest for 12 | # everyone if these take their real values. 13 | "c_m": 256, 14 | "c_z": 128, 15 | "c_s": 384, 16 | "c_t": 64, 17 | "c_e": 64, 18 | } 19 | ) 20 | 21 | config = mlc.ConfigDict( 22 | { 23 | "data": { 24 | "common": { 25 | "masked_msa": { 26 | "profile_prob": 0.1, 27 | "same_prob": 0.1, 28 | "uniform_prob": 0.1, 29 | }, 30 | } 31 | } 32 | } 33 | ) 34 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: opencomplex_venv 2 | channels: 3 | - conda-forge 4 | - bioconda 5 | - pytorch 6 | dependencies: 7 | - conda-forge::python=3.9 8 | - conda-forge::setuptools=59.5.0 9 | - conda-forge::pip 10 | - conda-forge::openmm=7.5.1 11 | - conda-forge::pdbfixer 12 | - conda-forge::cudatoolkit==11.3.* 13 | - bioconda::hmmer==3.3.2 14 | - bioconda::hhsuite==3.3.0 15 | - bioconda::kalign2==2.04 16 | - pytorch::pytorch=1.12.* 17 | - pip: 18 | - biopython==1.79 19 | - deepspeed==0.5.10 20 | - dm-tree==0.1.8 21 | - ml-collections==0.1.0 22 | - numpy==1.21.2 23 | - PyYAML==5.4.1 24 | - requests==2.26.0 25 | - scipy==1.7.1 26 | - tqdm==4.62.2 27 | - typing-extensions==3.10.0.2 28 | - pytorch_lightning==1.5.10 29 | - wandb==0.12.21 30 | - einops==0.6.0 31 | - git+https://github.com/NVIDIA/dllogger.git 32 | -------------------------------------------------------------------------------- /opencomplex/data/errors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """General-purpose errors used throughout the data pipeline""" 17 | class Error(Exception): 18 | """Base class for exceptions.""" 19 | 20 | 21 | class MultipleChainsError(Error): 22 | """An error indicating that multiple chains were found for a given ID.""" 23 | -------------------------------------------------------------------------------- /example_data/scripts/infer_complex.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source scripts/activate_conda_env.sh 4 | 5 | python3 run_pretrained_opencomplex.py \ 6 | --features_dir example_data/features `# dir of generated features` \ 7 | --target_list_file example_data/filters/complex_filter.txt `# filter of target lists` \ 8 | --output_dir output/infer_result `# output directory` \ 9 | --use_gpu `# use gpu inference` \ 10 | --num_workers 1 `# number of parallel processes` \ 11 | --param_path /path/to/ckpt `# ckpt path` \ 12 | --config_preset "mix" `# config presets as in config.py` \ 13 | --complex_type "mix" `# protein, RNA, or mix (protein-RNA complex)` \ 14 | --skip_relaxation `# skip amber relaxation` \ 15 | --overwrite `# overwrite existing result` 16 | -------------------------------------------------------------------------------- /example_data/scripts/infer_rna.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source scripts/activate_conda_env.sh 4 | 5 | python3 run_pretrained_opencomplex.py \ 6 | --features_dir example_data/features `# dir of generated features` \ 7 | --target_list_file example_data/filters/rna_filter.txt `# filter of target lists` \ 8 | --output_dir output/infer_result `# output directory` \ 9 | --use_gpu `# use gpu inference` \ 10 | --num_workers 1 `# number of parallel processes` \ 11 | --param_path /path/to/ckpt `# ckpt path` \ 12 | --config_preset "RNA" `# config presets as in config.py` \ 13 | --complex_type "RNA" `# protein, RNA, or mix (protein-RNA complex)` \ 14 | --skip_relaxation `# skip amber relaxation` \ 15 | --overwrite `# overwrite existing result` -------------------------------------------------------------------------------- /example_data/scripts/infer_protein.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source scripts/activate_conda_env.sh 4 | 5 | python3 run_pretrained_opencomplex.py \ 6 | --features_dir example_data/features `# dir of generated features` \ 7 | --target_list_file example_data/filters/protein_filter.txt `# filter of target lists` \ 8 | --output_dir output/infer_result `# output directory` \ 9 | --use_gpu `# use gpu inference` \ 10 | --num_workers 1 `# number of parallel processes` \ 11 | --param_path /path/to/ckpt `# ckpt path` \ 12 | --config_preset "initial_training" `# config presets as in config.py` \ 13 | --complex_type "protein" `# protein, RNA, or mix (protein-RNA complex)` \ 14 | --skip_relaxation `# skip amber relaxation` \ 15 | --overwrite `# overwrite existing result` 16 | -------------------------------------------------------------------------------- /opencomplex/utils/argparse.py: -------------------------------------------------------------------------------- 1 | from argparse import HelpFormatter 2 | from operator import attrgetter 3 | 4 | class ArgparseAlphabetizer(HelpFormatter): 5 | """ 6 | Sorts the optional arguments of an argparse parser alphabetically 7 | """ 8 | 9 | @staticmethod 10 | def sort_actions(actions): 11 | return sorted(actions, key=attrgetter("option_strings")) 12 | 13 | # Formats the help message 14 | def add_arguments(self, actions): 15 | actions = ArgparseAlphabetizer.sort_actions(actions) 16 | super(ArgparseAlphabetizer, self).add_arguments(actions) 17 | 18 | # Formats the usage message 19 | def add_usage(self, usage, actions, groups, prefix=None): 20 | actions = ArgparseAlphabetizer.sort_actions(actions) 21 | args = usage, actions, groups, prefix 22 | super(ArgparseAlphabetizer, self).add_usage(*args) 23 | 24 | 25 | def remove_arguments(parser, args): 26 | for arg in args: 27 | for action in parser._actions: 28 | opts = vars(action)["option_strings"] 29 | if(arg in opts): 30 | parser._handle_conflict_resolve(None, [(arg, action)]) 31 | -------------------------------------------------------------------------------- /scripts/install_third_party_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source scripts/vars.sh 3 | $CONDA_PATH/bin/python3 -m pip install nvidia-pyindex 4 | conda install -y mamba -c conda-forge -n base 5 | 6 | conda create -y -n ${ENV_NAME} python=3.9 7 | mamba env update --file environment.yml 8 | 9 | source scripts/activate_conda_env.sh 10 | 11 | echo "Attempting to install FlashAttention" 12 | pip install git+https://github.com/HazyResearch/flash-attention.git@5b838a8bef78186196244a4156ec35bbb58c337d && echo "Installation successful" 13 | 14 | # Install DeepMind's OpenMM patch 15 | OPENCOMPLEX_DIR=$PWD 16 | pushd $CONDA_PATH/envs/$ENV_NAME/lib/python3.9/site-packages/ \ 17 | && patch -p0 < $OPENCOMPLEX_DIR/lib/openmm.patch \ 18 | && popd 19 | 20 | # Download folding resources 21 | wget --no-check-certificate -P opencomplex/resources \ 22 | https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt 23 | 24 | # Certain tests need access to this file 25 | mkdir -p tests/test_data/alphafold/common 26 | ln -rs opencomplex/resources/stereo_chemical_props.txt tests/test_data/alphafold/common 27 | -------------------------------------------------------------------------------- /opencomplex/model/sm/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 BAAI 2 | # Copyright 2021 AlQuraishi Laboratory 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from opencomplex.model.sm import structure_module_protein, structure_module_rna, structure_module_xyz 18 | from opencomplex.utils.complex_utils import ComplexType 19 | 20 | 21 | def create_structure_module(complex_type, *args, **kwargs): 22 | if complex_type == ComplexType.PROTEIN: 23 | sm = structure_module_protein.StructureModuleProtein 24 | elif complex_type == ComplexType.RNA: 25 | sm = structure_module_rna.StructureModuleRNA 26 | elif complex_type == ComplexType.MIX: 27 | sm = structure_module_xyz.StructureModuleXYZ 28 | else: 29 | raise ValueError("wrong complex type") 30 | 31 | return sm(*args, **kwargs) 32 | -------------------------------------------------------------------------------- /opencomplex/utils/kernel/csrc/softmax_cuda.cpp: -------------------------------------------------------------------------------- 1 | // Copyright 2021 AlQuraishi Laboratory 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp 16 | 17 | #include 18 | 19 | void attn_softmax_inplace_forward_( 20 | at::Tensor input, 21 | long long rows, int cols 22 | ); 23 | void attn_softmax_inplace_backward_( 24 | at::Tensor output, 25 | at::Tensor d_ov, 26 | at::Tensor values, 27 | long long rows, 28 | int cols_output, 29 | int cols_values 30 | ); 31 | 32 | 33 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 34 | m.def( 35 | "forward_", 36 | &attn_softmax_inplace_forward_, 37 | "Softmax forward (CUDA)" 38 | ); 39 | m.def( 40 | "backward_", 41 | &attn_softmax_inplace_backward_, 42 | "Softmax backward (CUDA)" 43 | ); 44 | } 45 | -------------------------------------------------------------------------------- /example_data/scripts/train_complex.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source scripts/activate_conda_env.sh 4 | 5 | python3 train_opencomplex.py \ 6 | --config_preset "mix" `# config presets defined in config.py` \ 7 | --complex_type "mix" `# protein, RNA or mix (protein-RNA complex)` \ 8 | --train_data_dir example_data/mmcif `# ground truth directory` \ 9 | --train_feature_dir example_data/features `# features of training sample` \ 10 | --train_filter_path example_data/filters/complex_filter.txt `# optinal filter of training sample` \ 11 | --val_data_dir example_data/mmcif `# optional ground truth directory of validation sample` \ 12 | --val_feature_dir example_data/features `# optional features of validation sample` \ 13 | --val_filter_path example_data/filters/complex_filter.txt `# optioanl filter of validation sample` \ 14 | --output_dir output/ `# output directory of checkpoints` \ 15 | --precision 32 `# bf16 has better speed but may slightly lower accuracy` \ 16 | --gpus 2 \ 17 | --replace_sampler_ddp=True \ 18 | --train_epoch_len 100 \ 19 | --max_epochs 10 \ 20 | --seed 4242022 `# in multi-gpu settings, the seed must be specified` \ 21 | --checkpoint_every_epoch \ 22 | --log_lr \ 23 | --wandb 24 | -------------------------------------------------------------------------------- /example_data/scripts/train_protein.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source scripts/activate_conda_env.sh 4 | 5 | python3 train_opencomplex.py \ 6 | --config_preset "initial_training" `# config presets defined in config.py` \ 7 | --complex_type "protein" `# protein, RNA or mix (protein-RNA complex)` \ 8 | --train_data_dir example_data/mmcif `# ground truth directory` \ 9 | --train_feature_dir example_data/features `# features of training sample` \ 10 | --train_filter_path example_data/filters/protein_filter.txt `# optinal filter of training sample` \ 11 | --val_data_dir example_data/mmcif `# optional ground truth directory of validation sample` \ 12 | --val_feature_dir example_data/features `# optional features of validation sample` \ 13 | --val_filter_path example_data/filters/protein_filter.txt `# optioanl filter of validation sample` \ 14 | --output_dir output/ `# output directory of checkpoints` \ 15 | --precision 32 `# bf16 has better speed but may slightly lower accuracy` \ 16 | --gpus 2 \ 17 | --replace_sampler_ddp=True \ 18 | --train_epoch_len 100 \ 19 | --max_epochs 10 \ 20 | --seed 4242022 `# in multi-gpu settings, the seed must be specified` \ 21 | --checkpoint_every_epoch \ 22 | --log_lr \ 23 | --wandb 24 | -------------------------------------------------------------------------------- /opencomplex/data/tools/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Common utilities for data pipeline tools.""" 17 | import contextlib 18 | import datetime 19 | import logging 20 | import shutil 21 | import tempfile 22 | import time 23 | from typing import Optional 24 | 25 | 26 | @contextlib.contextmanager 27 | def tmpdir_manager(base_dir: Optional[str] = None): 28 | """Context manager that deletes a temporary directory on exit.""" 29 | tmpdir = tempfile.mkdtemp(dir=base_dir) 30 | try: 31 | yield tmpdir 32 | finally: 33 | shutil.rmtree(tmpdir, ignore_errors=True) 34 | 35 | 36 | @contextlib.contextmanager 37 | def timing(msg: str): 38 | logging.info("Started %s", msg) 39 | tic = time.perf_counter() 40 | yield 41 | toc = time.perf_counter() 42 | logging.info("Finished %s in %.3f seconds", msg, toc - tic) 43 | 44 | 45 | def to_date(s: str): 46 | return datetime.datetime( 47 | year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10]) 48 | ) 49 | -------------------------------------------------------------------------------- /tests/test_primitives.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import numpy as np 17 | import unittest 18 | 19 | from opencomplex.model.primitives import ( 20 | Attention, 21 | ) 22 | from tests.config import consts 23 | 24 | 25 | class TestLMA(unittest.TestCase): 26 | def test_lma_vs_attention(self): 27 | batch_size = consts.batch_size 28 | c_hidden = 32 29 | n = 2**12 30 | no_heads = 4 31 | 32 | q = torch.rand(batch_size, n, c_hidden).cuda() 33 | kv = torch.rand(batch_size, n, c_hidden).cuda() 34 | 35 | bias = [torch.rand(no_heads, 1, n)] 36 | bias = [b.cuda() for b in bias] 37 | 38 | gating_fill = torch.rand(c_hidden * no_heads, c_hidden) 39 | o_fill = torch.rand(c_hidden, c_hidden * no_heads) 40 | 41 | a = Attention( 42 | c_hidden, c_hidden, c_hidden, c_hidden, no_heads 43 | ).cuda() 44 | 45 | with torch.no_grad(): 46 | l = a(q, kv, biases=bias, use_lma=True) 47 | real = a(q, kv, biases=bias) 48 | 49 | self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps) 50 | 51 | 52 | if __name__ == "__main__": 53 | unittest.main() 54 | -------------------------------------------------------------------------------- /compute_metrics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing as mp 3 | 4 | from functools import partial 5 | from tqdm import tqdm 6 | 7 | from opencomplex.utils import metric_tool 8 | 9 | 10 | def main(target_file, args): 11 | metric_tool.compute_metric( 12 | args.native_dir, 13 | target_file, 14 | mode="multimer" if args.multimer else "monomer", 15 | complex_type=args.complex_type, 16 | ) 17 | 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--prediction_dir", type=str, 23 | help="""Directory to predicted pdb files.""" 24 | ) 25 | parser.add_argument( 26 | "--native_dir", type=str, 27 | help="""Directory to mmcif files.""" 28 | ) 29 | parser.add_argument( 30 | "--target_list_file", type=str, 31 | help="""File path to target list.""" 32 | ) 33 | parser.add_argument( 34 | "--complex_type", type=str, 35 | default="protein", choices=["protein", "RNA", "mix"], 36 | help="""Complex type of predictions.""" 37 | ) 38 | parser.add_argument( 39 | "--multimer", action="store_true", 40 | help="""If the prediction has multiple chains.""" 41 | ) 42 | parser.add_argument( 43 | "--num_workers", type=int, default=8, 44 | help="""Number of workers to compute metrics in parallel.""" 45 | ) 46 | 47 | args = parser.parse_args() 48 | 49 | 50 | 51 | target_list = metric_tool.get_prediction_list(args.prediction_dir, args.target_list_file) 52 | worker = partial(main, args=args) 53 | with mp.Pool(args.num_workers) as p: 54 | list(tqdm(p.imap_unordered(worker, target_list), total=len(target_list))) 55 | p.join() 56 | 57 | metric_tool.summarize_metrics(args.prediction_dir, target_list_file=args.target_list_file) -------------------------------------------------------------------------------- /lib/openmm.patch: -------------------------------------------------------------------------------- 1 | Index: simtk/openmm/app/topology.py 2 | =================================================================== 3 | --- simtk.orig/openmm/app/topology.py 4 | +++ simtk/openmm/app/topology.py 5 | @@ -356,19 +356,35 @@ 6 | def isCyx(res): 7 | names = [atom.name for atom in res._atoms] 8 | return 'SG' in names and 'HG' not in names 9 | + # This function is used to prevent multiple di-sulfide bonds from being 10 | + # assigned to a given atom. This is a DeepMind modification. 11 | + def isDisulfideBonded(atom): 12 | + for b in self._bonds: 13 | + if (atom in b and b[0].name == 'SG' and 14 | + b[1].name == 'SG'): 15 | + return True 16 | + 17 | + return False 18 | 19 | cyx = [res for res in self.residues() if res.name == 'CYS' and isCyx(res)] 20 | atomNames = [[atom.name for atom in res._atoms] for res in cyx] 21 | for i in range(len(cyx)): 22 | sg1 = cyx[i]._atoms[atomNames[i].index('SG')] 23 | pos1 = positions[sg1.index] 24 | + candidate_distance, candidate_atom = 0.3*nanometers, None 25 | for j in range(i): 26 | sg2 = cyx[j]._atoms[atomNames[j].index('SG')] 27 | pos2 = positions[sg2.index] 28 | delta = [x-y for (x,y) in zip(pos1, pos2)] 29 | distance = sqrt(delta[0]*delta[0] + delta[1]*delta[1] + delta[2]*delta[2]) 30 | - if distance < 0.3*nanometers: 31 | - self.addBond(sg1, sg2) 32 | + if distance < candidate_distance and not isDisulfideBonded(sg2): 33 | + candidate_distance = distance 34 | + candidate_atom = sg2 35 | + # Assign bond to closest pair. 36 | + if candidate_atom: 37 | + self.addBond(sg1, candidate_atom) 38 | + 39 | + 40 | 41 | class Chain(object): 42 | """A Chain object represents a chain within a Topology.""" 43 | 44 | -------------------------------------------------------------------------------- /opencomplex/model/dropout.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | import torch.nn as nn 18 | from functools import partialmethod 19 | from typing import Union, List 20 | 21 | 22 | class Dropout(nn.Module): 23 | """ 24 | Implementation of dropout with the ability to share the dropout mask 25 | along a particular dimension. 26 | 27 | If not in training mode, this module computes the identity function. 28 | """ 29 | 30 | def __init__(self, r: float, batch_dim: Union[int, List[int]]): 31 | """ 32 | Args: 33 | r: 34 | Dropout rate 35 | batch_dim: 36 | Dimension(s) along which the dropout mask is shared 37 | """ 38 | super(Dropout, self).__init__() 39 | 40 | self.r = r 41 | if type(batch_dim) == int: 42 | batch_dim = [batch_dim] 43 | self.batch_dim = batch_dim 44 | self.dropout = nn.Dropout(self.r) 45 | 46 | def forward(self, x: torch.Tensor) -> torch.Tensor: 47 | """ 48 | Args: 49 | x: 50 | Tensor to which dropout is applied. Can have any shape 51 | compatible with self.batch_dim 52 | """ 53 | shape = list(x.shape) 54 | if self.batch_dim is not None: 55 | for bd in self.batch_dim: 56 | shape[bd] = 1 57 | mask = x.new_ones(shape) 58 | mask = self.dropout(mask) 59 | x *= mask 60 | return x 61 | 62 | 63 | class DropoutRowwise(Dropout): 64 | """ 65 | Convenience class for rowwise dropout as described in subsection 66 | 1.11.6. 67 | """ 68 | 69 | __init__ = partialmethod(Dropout.__init__, batch_dim=-3) 70 | 71 | 72 | class DropoutColumnwise(Dropout): 73 | """ 74 | Convenience class for columnwise dropout as described in subsection 75 | 1.11.6. 76 | """ 77 | 78 | __init__ = partialmethod(Dropout.__init__, batch_dim=-2) 79 | -------------------------------------------------------------------------------- /opencomplex/utils/exponential_moving_average.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | 6 | from opencomplex.utils.tensor_utils import tensor_tree_map 7 | 8 | 9 | class ExponentialMovingAverage: 10 | """ 11 | Maintains moving averages of parameters with exponential decay 12 | 13 | At each step, the stored copy `copy` of each parameter `param` is 14 | updated as follows: 15 | 16 | `copy = decay * copy + (1 - decay) * param` 17 | 18 | where `decay` is an attribute of the ExponentialMovingAverage object. 19 | """ 20 | 21 | def __init__(self, model: nn.Module, decay: float): 22 | """ 23 | Args: 24 | model: 25 | A torch.nn.Module whose parameters are to be tracked 26 | decay: 27 | A value (usually close to 1.) by which updates are 28 | weighted as part of the above formula 29 | """ 30 | super(ExponentialMovingAverage, self).__init__() 31 | 32 | clone_param = lambda t: t.clone().detach() 33 | self.params = tensor_tree_map(clone_param, model.state_dict()) 34 | self.decay = decay 35 | self.device = next(model.parameters()).device 36 | 37 | def to(self, device): 38 | self.params = tensor_tree_map(lambda t: t.to(device), self.params) 39 | self.device = device 40 | 41 | def _update_state_dict_(self, update, state_dict): 42 | with torch.no_grad(): 43 | for k, v in update.items(): 44 | stored = state_dict[k] 45 | if not isinstance(v, torch.Tensor): 46 | self._update_state_dict_(v, stored) 47 | else: 48 | diff = stored - v 49 | diff *= 1 - self.decay 50 | stored -= diff 51 | 52 | def update(self, model: torch.nn.Module) -> None: 53 | """ 54 | Updates the stored parameters using the state dict of the provided 55 | module. The module should have the same structure as that used to 56 | initialize the ExponentialMovingAverage object. 57 | """ 58 | self._update_state_dict_(model.state_dict(), self.params) 59 | 60 | def load_state_dict(self, state_dict: OrderedDict) -> None: 61 | for k in state_dict["params"].keys(): 62 | self.params[k] = state_dict["params"][k].clone() 63 | self.decay = state_dict["decay"] 64 | 65 | def state_dict(self) -> OrderedDict: 66 | return OrderedDict( 67 | { 68 | "params": self.params, 69 | "decay": self.decay, 70 | } 71 | ) 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | .vscode 141 | 142 | 143 | experiment/ 144 | 145 | output/ -------------------------------------------------------------------------------- /opencomplex/utils/complex_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 BAAI 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import enum 16 | import torch 17 | 18 | class CaseInsensitiveEnumMeta(enum.EnumMeta): 19 | def __getitem__(self, item): 20 | if isinstance(item, str): 21 | item = item.upper() 22 | return super().__getitem__(item) 23 | 24 | class ComplexType(enum.Enum, metaclass=CaseInsensitiveEnumMeta): 25 | PROTEIN = 1 26 | RNA = 2 27 | MIX = 3 28 | 29 | def determine_chain_type(complex_type, bio_id=None): 30 | if complex_type != ComplexType.MIX: 31 | return complex_type 32 | 33 | assert bio_id is not None 34 | 35 | if bio_id == 0: 36 | return ComplexType.PROTEIN 37 | else: 38 | return ComplexType.RNA 39 | 40 | def correct_rna_butype(butype): 41 | if butype.numel() > 0 and torch.max(butype) > 7: 42 | return butype - 20 43 | return butype 44 | 45 | def split_protein_rna_pos(bio_complex, complex_type=None): 46 | device = bio_complex["butype"].device 47 | 48 | protein_pos = [] 49 | rna_pos = [] 50 | if complex_type is None: 51 | complex_type = ComplexType.MIX if 'bio_id' in bio_complex else ComplexType.PROTEIN 52 | 53 | if complex_type == ComplexType.PROTEIN: 54 | protein_pos = torch.arange(0, bio_complex['butype'].shape[-1], device=device) 55 | elif complex_type == ComplexType.RNA: 56 | rna_pos = torch.arange(0, bio_complex['butype'].shape[-1], device=device) 57 | else: 58 | protein_pos = torch.where(bio_complex['bio_id'] == 0)[-1] 59 | rna_pos = torch.where(bio_complex['bio_id'] == 1)[-1] 60 | 61 | return protein_pos, rna_pos 62 | 63 | 64 | def complex_gather(protein_pos, rna_pos, protein_data, rna_data, dim): 65 | if protein_pos is None or len(protein_pos) == 0: 66 | return rna_data 67 | elif rna_pos is None or len(rna_pos) == 0: 68 | return protein_data 69 | 70 | n = protein_data.ndim 71 | if dim < 0: 72 | dim += n 73 | shape = protein_data.shape 74 | i = protein_data.new_zeros(shape[dim:], dtype=torch.bool) 75 | i[protein_pos, ...] = True 76 | for _ in range(dim): 77 | i = i.unsqueeze(0) 78 | i = i.tile(shape[:dim] + (1,)*(n-dim)) 79 | 80 | return torch.where(i, protein_data, rna_data) -------------------------------------------------------------------------------- /tests/test_kernels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import torch 4 | import unittest 5 | 6 | from opencomplex.model.primitives import _attention 7 | from opencomplex.utils.kernel.attention_core import attention_core 8 | from tests.config import consts 9 | 10 | 11 | class TestAttentionCore(unittest.TestCase): 12 | def test_attention_core_forward(self): 13 | n_res = consts.n_res 14 | h = consts.n_heads_extra_msa 15 | n_seq = consts.n_extra 16 | c = consts.c_e 17 | dtype = torch.float32 18 | 19 | q = torch.rand([n_seq, h, n_res, c], dtype=dtype).cuda() 20 | k = torch.rand([n_seq, h, n_res, c], dtype=dtype).cuda() 21 | v = torch.rand([n_seq, h, n_res, c], dtype=dtype).cuda() 22 | mask = torch.randint(0, 2, [n_seq, n_res]).cuda() 23 | mask_bias = (1e9 * mask - 1)[..., None, None, :].to(dtype) 24 | 25 | out_repro = attention_core(q, k, v, mask_bias, None) 26 | out_gt = _attention(q, k, v, [mask_bias]) 27 | 28 | self.assertTrue(torch.max(torch.abs(out_repro - out_gt)) < consts.eps) 29 | 30 | def test_attention_core_backward(self): 31 | n_res = consts.n_res 32 | h = consts.n_heads_extra_msa 33 | n_seq = consts.n_extra 34 | c = consts.c_e 35 | dtype = torch.float32 36 | 37 | q = torch.rand( 38 | [n_seq, h, n_res, c], dtype=dtype, requires_grad=True 39 | ).cuda() 40 | k = torch.rand( 41 | [n_seq, h, n_res, c], dtype=dtype, requires_grad=True 42 | ).cuda() 43 | v = torch.rand( 44 | [n_seq, h, n_res, c], dtype=dtype, requires_grad=True 45 | ).cuda() 46 | mask = torch.randint(0, 2, [n_seq, n_res]).cuda() 47 | mask_bias = (1e9 * mask - 1)[..., None, None, :].to(dtype) 48 | 49 | def clone(t): 50 | t = t.clone() 51 | if(t.requires_grad): 52 | t.retain_grad() 53 | return t 54 | 55 | q_repro = clone(q) 56 | k_repro = clone(k) 57 | v_repro = clone(v) 58 | out_repro = attention_core( 59 | q_repro, k_repro, v_repro, mask_bias, None 60 | ) 61 | 62 | loss_repro = torch.mean(out_repro) 63 | loss_repro.backward() 64 | 65 | q_gt = clone(q) 66 | k_gt = clone(k) 67 | v_gt = clone(v) 68 | out_gt = _attention( 69 | q_gt, k_gt, v_gt, [mask_bias] 70 | ) 71 | 72 | loss_gt = torch.mean(out_gt) 73 | loss_gt.backward() 74 | 75 | pairs = zip([q_repro, k_repro, v_repro], [q_gt, k_gt, v_gt]) 76 | for t_repro, t_gt in pairs: 77 | self.assertTrue( 78 | torch.max(torch.abs(t_repro.grad - t_gt.grad)) < consts.eps 79 | ) 80 | 81 | 82 | if __name__ == '__main__': 83 | unittest.main() 84 | 85 | -------------------------------------------------------------------------------- /opencomplex/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import operator 16 | import time 17 | 18 | import dllogger as logger 19 | from dllogger import JSONStreamBackend, StdOutBackend, Verbosity 20 | import numpy as np 21 | from pytorch_lightning import Callback 22 | import torch.cuda.profiler as profiler 23 | 24 | 25 | def is_main_process(): 26 | return int(os.getenv("LOCAL_RANK", "0")) == 0 27 | 28 | 29 | class PerformanceLoggingCallback(Callback): 30 | def __init__(self, log_file, global_batch_size, warmup_steps: int = 0, profile: bool = False): 31 | logger.init(backends=[JSONStreamBackend(Verbosity.VERBOSE, log_file), StdOutBackend(Verbosity.VERBOSE)]) 32 | self.warmup_steps = warmup_steps 33 | self.global_batch_size = global_batch_size 34 | self.step = 0 35 | self.profile = profile 36 | self.timestamps = [] 37 | 38 | def do_step(self): 39 | self.step += 1 40 | if self.profile and self.step == self.warmup_steps: 41 | profiler.start() 42 | if self.step > self.warmup_steps: 43 | self.timestamps.append(time.time()) 44 | 45 | def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): 46 | self.do_step() 47 | 48 | def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): 49 | self.do_step() 50 | 51 | def process_performance_stats(self, deltas): 52 | def _round3(val): 53 | return round(val, 3) 54 | 55 | throughput_imgps = _round3(self.global_batch_size / np.mean(deltas)) 56 | timestamps_ms = 1000 * deltas 57 | stats = { 58 | f"throughput": throughput_imgps, 59 | f"latency_mean": _round3(timestamps_ms.mean()), 60 | } 61 | for level in [90, 95, 99]: 62 | stats.update({f"latency_{level}": _round3(np.percentile(timestamps_ms, level))}) 63 | 64 | return stats 65 | 66 | def _log(self): 67 | if is_main_process(): 68 | diffs = list(map(operator.sub, self.timestamps[1:], self.timestamps[:-1])) 69 | deltas = np.array(diffs) 70 | stats = self.process_performance_stats(deltas) 71 | logger.log(step=(), data=stats) 72 | logger.flush() 73 | 74 | def on_train_end(self, trainer, pl_module): 75 | if self.profile: 76 | profiler.stop() 77 | self._log() 78 | 79 | def on_epoch_end(self, trainer, pl_module): 80 | self._log() 81 | -------------------------------------------------------------------------------- /opencomplex/utils/validation_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import numpy as np 16 | 17 | def drmsd(structure_1, structure_2, mask=None): 18 | def prep_d(structure): 19 | d = structure[..., :, None, :] - structure[..., None, :, :] 20 | d = d ** 2 21 | d = torch.sqrt(torch.sum(d, dim=-1)) 22 | return d 23 | 24 | d1 = prep_d(structure_1) 25 | d2 = prep_d(structure_2) 26 | 27 | drmsd = d1 - d2 28 | drmsd = drmsd ** 2 29 | if(mask is not None): 30 | drmsd = drmsd * (mask[..., None] * mask[..., None, :]) 31 | drmsd = torch.sum(drmsd, dim=(-1, -2)) 32 | n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1) 33 | 34 | v_ind = n>1 35 | drmsd = drmsd[v_ind] 36 | drmsd = 0 if len(drmsd) == 0 else torch.sqrt(drmsd / (n[v_ind] * (n[v_ind] - 1))).mean() 37 | 38 | return drmsd 39 | 40 | 41 | def drmsd_np(structure_1, structure_2, mask=None): 42 | structure_1 = torch.tensor(structure_1) 43 | structure_2 = torch.tensor(structure_2) 44 | if(mask is not None): 45 | mask = torch.tensor(mask) 46 | 47 | return drmsd(structure_1, structure_2, mask) 48 | 49 | 50 | def gdt(p1, p2, mask, cutoffs): 51 | n = torch.sum(mask, dim=-1) 52 | 53 | p1 = p1.float() 54 | p2 = p2.float() 55 | distances = torch.sqrt(torch.sum((p1 - p2)**2, dim=-1)) 56 | scores = [] 57 | for c in cutoffs: 58 | score = torch.sum((distances <= c) * mask, dim=-1) / n 59 | score = torch.mean(score) 60 | scores.append(score) 61 | 62 | return sum(scores) / len(scores) 63 | 64 | 65 | def gdt_ts(p1, p2, mask): 66 | return gdt(p1, p2, mask, [1., 2., 4., 8.]) 67 | 68 | 69 | def gdt_ha(p1, p2, mask): 70 | return gdt(p1, p2, mask, [0.5, 1., 2., 4.]) 71 | 72 | def tm_score(p1, p2, mask): 73 | p1 = p1.float() 74 | p2 = p2.float() 75 | distances = torch.sqrt(torch.sum((p1 - p2) ** 2, dim=-1)) 76 | l = len(distances) 77 | n = np.shape(p2)[0] 78 | clipped_n = max(n, 19) 79 | d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8 80 | score = [] 81 | for i in range(l): 82 | S = 1 / (1 + (distances[i] / d0) ** 2) 83 | score.append(S) 84 | tm = torch.sum(torch.tensor(score)) 85 | return tm / n 86 | 87 | def _rmsd(p1, p2, mask): 88 | p1 = p1.float() 89 | p2 = p2.float() 90 | n = np.shape(p2)[0] 91 | print(p1.shape, p2.shape) 92 | distances = torch.sqrt(torch.sum((p1 - p2)**2, dim=-1)) 93 | rmsd = torch.sum(distances) 94 | return rmsd / n 95 | -------------------------------------------------------------------------------- /opencomplex/utils/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler): 5 | """ Implements the learning rate schedule defined in the AlphaFold 2 6 | supplement. A linear warmup is followed by a plateau at the maximum 7 | learning rate and then exponential decay. 8 | 9 | Note that the initial learning rate of the optimizer in question is 10 | ignored; use this class' base_lr parameter to specify the starting 11 | point of the warmup. 12 | """ 13 | def __init__(self, 14 | optimizer, 15 | last_epoch: int = -1, 16 | verbose: bool = False, 17 | base_lr: float = 0., 18 | max_lr: float = 0.001, 19 | warmup_no_steps: int = 1000, 20 | start_decay_after_n_steps: int = 50000, 21 | decay_every_n_steps: int = 50000, 22 | decay_factor: float = 0.95, 23 | ): 24 | step_counts = { 25 | "warmup_no_steps": warmup_no_steps, 26 | "start_decay_after_n_steps": start_decay_after_n_steps, 27 | } 28 | 29 | for k,v in step_counts.items(): 30 | if(v < 0): 31 | raise ValueError(f"{k} must be nonnegative") 32 | 33 | if(warmup_no_steps > start_decay_after_n_steps): 34 | raise ValueError( 35 | "warmup_no_steps must not exceed start_decay_after_n_steps" 36 | ) 37 | 38 | self.optimizer = optimizer 39 | self.last_epoch = last_epoch 40 | self.verbose = verbose 41 | self.base_lr = base_lr 42 | self.max_lr = max_lr 43 | self.warmup_no_steps = warmup_no_steps 44 | self.start_decay_after_n_steps = start_decay_after_n_steps 45 | self.decay_every_n_steps = decay_every_n_steps 46 | self.decay_factor = decay_factor 47 | 48 | super(AlphaFoldLRScheduler, self).__init__( 49 | optimizer, 50 | last_epoch=last_epoch, 51 | verbose=verbose, 52 | ) 53 | 54 | def state_dict(self): 55 | state_dict = { 56 | k:v for k,v in self.__dict__.items() if k not in ["optimizer"] 57 | } 58 | 59 | return state_dict 60 | 61 | def load_state_dict(self, state_dict): 62 | self.__dict__.update(state_dict) 63 | 64 | def get_lr(self): 65 | if(not self._get_lr_called_within_step): 66 | raise RuntimeError( 67 | "To get the last learning rate computed by the scheduler, use " 68 | "get_last_lr()" 69 | ) 70 | 71 | step_no = self.last_epoch 72 | 73 | if(step_no <= self.warmup_no_steps): 74 | lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr 75 | elif(step_no > self.start_decay_after_n_steps): 76 | steps_since_decay = step_no - self.start_decay_after_n_steps 77 | exp = (steps_since_decay // self.decay_every_n_steps) + 1 78 | lr = self.max_lr * (self.decay_factor ** exp) 79 | else: # plateau 80 | lr = self.max_lr 81 | 82 | return [lr for group in self.optimizer.param_groups] 83 | -------------------------------------------------------------------------------- /opencomplex/utils/checkpointing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | from typing import Any, Tuple, List, Callable, Optional 16 | 17 | deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None 18 | if(deepspeed_is_installed): 19 | import deepspeed 20 | 21 | import torch 22 | import torch.utils.checkpoint 23 | 24 | 25 | BLOCK_ARG = Any 26 | BLOCK_ARGS = List[BLOCK_ARG] 27 | 28 | 29 | def get_checkpoint_fn(): 30 | deepspeed_is_configured = ( 31 | deepspeed_is_installed and 32 | deepspeed.checkpointing.is_configured() 33 | ) 34 | if(deepspeed_is_configured): 35 | checkpoint = deepspeed.checkpointing.checkpoint 36 | else: 37 | checkpoint = torch.utils.checkpoint.checkpoint 38 | 39 | return checkpoint 40 | 41 | 42 | @torch.jit.ignore 43 | def checkpoint_blocks( 44 | blocks: List[Callable], 45 | args: BLOCK_ARGS, 46 | blocks_per_ckpt: Optional[int], 47 | ) -> BLOCK_ARGS: 48 | """ 49 | Chunk a list of blocks and run each chunk with activation 50 | checkpointing. We define a "block" as a callable whose only inputs are 51 | the outputs of the previous block. 52 | 53 | Implements Subsection 1.11.8 54 | 55 | Args: 56 | blocks: 57 | List of blocks 58 | args: 59 | Tuple of arguments for the first block. 60 | blocks_per_ckpt: 61 | Size of each chunk. A higher value corresponds to fewer 62 | checkpoints, and trades memory for speed. If None, no checkpointing 63 | is performed. 64 | Returns: 65 | The output of the final block 66 | """ 67 | def wrap(a): 68 | return (a,) if type(a) is not tuple else a 69 | 70 | def exec(b, a): 71 | for block in b: 72 | a = wrap(block(*a)) 73 | return a 74 | 75 | def chunker(s, e): 76 | def exec_sliced(*a): 77 | return exec(blocks[s:e], a) 78 | 79 | return exec_sliced 80 | 81 | # Avoids mishaps when the blocks take just one argument 82 | args = wrap(args) 83 | 84 | if blocks_per_ckpt is None or not torch.is_grad_enabled(): 85 | return exec(blocks, args) 86 | elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks): 87 | raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)") 88 | 89 | checkpoint = get_checkpoint_fn() 90 | 91 | for s in range(0, len(blocks), blocks_per_ckpt): 92 | e = s + blocks_per_ckpt 93 | args = checkpoint(chunker(s, e), *args) 94 | args = wrap(args) 95 | 96 | return args 97 | -------------------------------------------------------------------------------- /opencomplex/model/pair_transition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 BAAI 2 | # Copyright 2021 AlQuraishi Laboratory 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from typing import Optional 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from opencomplex.model.primitives import Linear, LayerNorm 22 | from opencomplex.utils.chunk_utils import chunk_layer 23 | 24 | 25 | class PairTransition(nn.Module): 26 | """ 27 | Implements Algorithm 15. 28 | """ 29 | 30 | def __init__(self, c_z, n): 31 | """ 32 | Args: 33 | c_z: 34 | Pair transition channel dimension 35 | n: 36 | Factor by which c_z is multiplied to obtain hidden channel 37 | dimension 38 | """ 39 | super(PairTransition, self).__init__() 40 | 41 | self.c_z = c_z 42 | self.n = n 43 | 44 | self.layer_norm = LayerNorm(self.c_z) 45 | self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") 46 | self.relu = nn.ReLU() 47 | self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") 48 | 49 | def _transition(self, z, mask): 50 | # [*, N_res, N_res, C_z] 51 | z = self.layer_norm(z) 52 | 53 | # [*, N_res, N_res, C_hidden] 54 | z = self.linear_1(z) 55 | z = self.relu(z) 56 | 57 | # [*, N_res, N_res, C_z] 58 | z = self.linear_2(z) 59 | z = z * mask 60 | 61 | return z 62 | 63 | @torch.jit.ignore 64 | def _chunk(self, 65 | z: torch.Tensor, 66 | mask: torch.Tensor, 67 | chunk_size: int, 68 | ) -> torch.Tensor: 69 | return chunk_layer( 70 | self._transition, 71 | {"z": z, "mask": mask}, 72 | chunk_size=chunk_size, 73 | no_batch_dims=len(z.shape[:-2]), 74 | ) 75 | 76 | def forward(self, 77 | z: torch.Tensor, 78 | mask: Optional[torch.Tensor] = None, 79 | chunk_size: Optional[int] = None, 80 | ) -> torch.Tensor: 81 | """ 82 | Args: 83 | z: 84 | [*, N_res, N_res, C_z] pair embedding 85 | Returns: 86 | [*, N_res, N_res, C_z] pair embedding update 87 | """ 88 | # DISCREPANCY: DeepMind forgets to apply the mask in this module. 89 | if mask is None: 90 | mask = z.new_ones(z.shape[:-1]) 91 | 92 | # [*, N_res, N_res, 1] 93 | mask = mask.unsqueeze(-1) 94 | 95 | if chunk_size is not None: 96 | z = self._chunk(z, mask, chunk_size) 97 | else: 98 | z = self._transition(z=z, mask=mask) 99 | 100 | return z 101 | -------------------------------------------------------------------------------- /tests/test_embedders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import numpy as np 17 | import unittest 18 | from opencomplex.model.embedders import ( 19 | InputEmbedder, 20 | RecyclingEmbedder, 21 | TemplateAngleEmbedder, 22 | TemplatePairEmbedder, 23 | ) 24 | 25 | 26 | class TestInputEmbedder(unittest.TestCase): 27 | def test_shape(self): 28 | tf_dim = 2 29 | msa_dim = 3 30 | c_z = 5 31 | c_m = 7 32 | relpos_k = 11 33 | 34 | b = 13 35 | n_res = 17 36 | n_clust = 19 37 | 38 | tf = torch.rand((b, n_res, tf_dim)) 39 | ri = torch.rand((b, n_res)) 40 | msa = torch.rand((b, n_clust, n_res, msa_dim)) 41 | 42 | ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k) 43 | 44 | msa_emb, pair_emb = ie(tf, ri, msa) 45 | self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m)) 46 | self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z)) 47 | 48 | 49 | class TestRecyclingEmbedder(unittest.TestCase): 50 | def test_shape(self): 51 | batch_size = 2 52 | n = 3 53 | c_z = 5 54 | c_m = 7 55 | min_bin = 0 56 | max_bin = 10 57 | no_bins = 9 58 | 59 | re = RecyclingEmbedder(c_m, c_z, min_bin, max_bin, no_bins) 60 | 61 | m_1 = torch.rand((batch_size, n, c_m)) 62 | z = torch.rand((batch_size, n, n, c_z)) 63 | x = torch.rand((batch_size, n, 3)) 64 | 65 | m_1, z = re(m_1, z, x) 66 | 67 | self.assertTrue(z.shape == (batch_size, n, n, c_z)) 68 | self.assertTrue(m_1.shape == (batch_size, n, c_m)) 69 | 70 | 71 | class TestTemplateAngleEmbedder(unittest.TestCase): 72 | def test_shape(self): 73 | template_angle_dim = 51 74 | c_m = 256 75 | batch_size = 4 76 | n_templ = 4 77 | n_res = 256 78 | 79 | tae = TemplateAngleEmbedder( 80 | template_angle_dim, 81 | c_m, 82 | ) 83 | 84 | x = torch.rand((batch_size, n_templ, n_res, template_angle_dim)) 85 | x = tae(x) 86 | 87 | self.assertTrue(x.shape == (batch_size, n_templ, n_res, c_m)) 88 | 89 | 90 | class TestTemplatePairEmbedder(unittest.TestCase): 91 | def test_shape(self): 92 | batch_size = 2 93 | n_templ = 3 94 | n_res = 5 95 | template_pair_dim = 7 96 | c_t = 11 97 | 98 | tpe = TemplatePairEmbedder( 99 | template_pair_dim, 100 | c_t, 101 | ) 102 | 103 | x = torch.rand((batch_size, n_templ, n_res, n_res, template_pair_dim)) 104 | x = tpe(x) 105 | 106 | self.assertTrue(x.shape == (batch_size, n_templ, n_res, n_res, c_t)) 107 | 108 | 109 | if __name__ == "__main__": 110 | unittest.main() 111 | -------------------------------------------------------------------------------- /tests/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | from scipy.spatial.transform import Rotation 17 | 18 | 19 | def random_template_feats(n_templ, n, batch_size=None): 20 | b = [] 21 | if batch_size is not None: 22 | b.append(batch_size) 23 | batch = { 24 | "template_mask": np.random.randint(0, 2, (*b, n_templ)), 25 | "template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)), 26 | "template_pseudo_beta": np.random.rand(*b, n_templ, n, 3), 27 | "template_butype": np.random.randint(0, 22, (*b, n_templ, n)), 28 | "template_all_atom_mask": np.random.randint( 29 | 0, 2, (*b, n_templ, n, 37) 30 | ), 31 | "template_all_atom_positions": 32 | np.random.rand(*b, n_templ, n, 37, 3) * 10, 33 | "template_torsion_angles_sin_cos": 34 | np.random.rand(*b, n_templ, n, 7, 2), 35 | "template_alt_torsion_angles_sin_cos": 36 | np.random.rand(*b, n_templ, n, 7, 2), 37 | "template_torsion_angles_mask": 38 | np.random.rand(*b, n_templ, n, 7), 39 | } 40 | batch = {k: v.astype(np.float32) for k, v in batch.items()} 41 | batch["template_butype"] = batch["template_butype"].astype(np.int64) 42 | return batch 43 | 44 | 45 | def random_extra_msa_feats(n_extra, n, batch_size=None): 46 | b = [] 47 | if batch_size is not None: 48 | b.append(batch_size) 49 | batch = { 50 | "extra_msa": np.random.randint(0, 22, (*b, n_extra, n)).astype( 51 | np.int64 52 | ), 53 | "extra_has_deletion": np.random.randint(0, 2, (*b, n_extra, n)).astype( 54 | np.float32 55 | ), 56 | "extra_deletion_value": np.random.rand(*b, n_extra, n).astype( 57 | np.float32 58 | ), 59 | "extra_msa_mask": np.random.randint(0, 2, (*b, n_extra, n)).astype( 60 | np.float32 61 | ), 62 | } 63 | return batch 64 | 65 | 66 | def random_affines_vector(dim): 67 | prod_dim = 1 68 | for d in dim: 69 | prod_dim *= d 70 | 71 | affines = np.zeros((prod_dim, 7)).astype(np.float32) 72 | 73 | for i in range(prod_dim): 74 | affines[i, :4] = Rotation.random(random_state=42).as_quat() 75 | affines[i, 4:] = np.random.rand( 76 | 3, 77 | ).astype(np.float32) 78 | 79 | return affines.reshape(*dim, 7) 80 | 81 | 82 | def random_affines_4x4(dim): 83 | prod_dim = 1 84 | for d in dim: 85 | prod_dim *= d 86 | 87 | affines = np.zeros((prod_dim, 4, 4)).astype(np.float32) 88 | 89 | for i in range(prod_dim): 90 | affines[i, :3, :3] = Rotation.random(random_state=42).as_matrix() 91 | affines[i, :3, 3] = np.random.rand( 92 | 3, 93 | ).astype(np.float32) 94 | 95 | affines[:, 3, 3] = 1 96 | 97 | return affines.reshape(*dim, 4, 4) 98 | -------------------------------------------------------------------------------- /tests/test_pair_transition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import numpy as np 17 | import unittest 18 | from opencomplex.model.pair_transition import PairTransition 19 | from opencomplex.utils.tensor_utils import tree_map 20 | import tests.compare_utils as compare_utils 21 | from tests.config import consts 22 | 23 | if compare_utils.alphafold_is_installed(): 24 | alphafold = compare_utils.import_alphafold() 25 | import jax 26 | import haiku as hk 27 | 28 | 29 | class TestPairTransition(unittest.TestCase): 30 | def test_shape(self): 31 | c_z = consts.c_z 32 | n = 4 33 | 34 | pt = PairTransition(c_z, n) 35 | 36 | batch_size = consts.batch_size 37 | n_res = consts.n_res 38 | 39 | z = torch.rand((batch_size, n_res, n_res, c_z)) 40 | mask = torch.randint(0, 2, size=(batch_size, n_res, n_res)) 41 | shape_before = z.shape 42 | z = pt(z, mask=mask, chunk_size=None) 43 | shape_after = z.shape 44 | 45 | self.assertTrue(shape_before == shape_after) 46 | 47 | @compare_utils.skip_unless_alphafold_installed() 48 | def test_compare(self): 49 | def run_pair_transition(pair_act, pair_mask): 50 | config = compare_utils.get_alphafold_config() 51 | c_e = config.model.embeddings_and_evoformer.evoformer 52 | pt = alphafold.model.modules.Transition( 53 | c_e.pair_transition, 54 | config.model.global_config, 55 | name="pair_transition", 56 | ) 57 | act = pt(act=pair_act, mask=pair_mask) 58 | return act 59 | 60 | f = hk.transform(run_pair_transition) 61 | 62 | n_res = consts.n_res 63 | 64 | pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32) 65 | pair_mask = np.ones((n_res, n_res)).astype(np.float32) # no mask 66 | 67 | # Fetch pretrained parameters (but only from one block)] 68 | params = compare_utils.fetch_alphafold_module_weights( 69 | "alphafold/alphafold_iteration/evoformer/evoformer_iteration/" 70 | + "pair_transition" 71 | ) 72 | params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) 73 | 74 | out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() 75 | out_gt = torch.as_tensor(np.array(out_gt.block_until_ready())) 76 | 77 | model = compare_utils.get_global_pretrained_opencomplex() 78 | out_repro = ( 79 | model.evoformer.blocks[0].core 80 | .pair_transition( 81 | torch.as_tensor(pair_act, dtype=torch.float32).cuda(), 82 | chunk_size=4, 83 | mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), 84 | ) 85 | .cpu() 86 | ) 87 | 88 | self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) 89 | 90 | 91 | if __name__ == "__main__": 92 | unittest.main() 93 | -------------------------------------------------------------------------------- /opencomplex/np/relax/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utils for minimization.""" 17 | import io 18 | from opencomplex.np import residue_constants 19 | from Bio import PDB 20 | import numpy as np 21 | try: 22 | # openmm >= 7.6 23 | from openmm import app as openmm_app 24 | from openmm.app.internal.pdbstructure import PdbStructure 25 | except ImportError: 26 | # openmm < 7.6 (requires DeepMind patch) 27 | from simtk.openmm import app as openmm_app 28 | from simtk.openmm.app.internal.pdbstructure import PdbStructure 29 | 30 | 31 | def overwrite_pdb_coordinates(pdb_str: str, pos) -> str: 32 | pdb_file = io.StringIO(pdb_str) 33 | structure = PdbStructure(pdb_file) 34 | topology = openmm_app.PDBFile(structure).getTopology() 35 | with io.StringIO() as f: 36 | openmm_app.PDBFile.writeFile(topology, pos, f) 37 | return f.getvalue() 38 | 39 | 40 | def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str: 41 | """Overwrites the B-factors in pdb_str with contents of bfactors array. 42 | 43 | Args: 44 | pdb_str: An input PDB string. 45 | bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the 46 | B-factors are per residue; i.e. that the nonzero entries are identical in 47 | [0, i, :]. 48 | 49 | Returns: 50 | A new PDB string with the B-factors replaced. 51 | """ 52 | if bfactors.shape[-1] != residue_constants.atom_type_num: 53 | raise ValueError( 54 | f"Invalid final dimension size for bfactors: {bfactors.shape[-1]}." 55 | ) 56 | 57 | parser = PDB.PDBParser(QUIET=True) 58 | handle = io.StringIO(pdb_str) 59 | structure = parser.get_structure("", handle) 60 | 61 | curr_resid = ("", "", "") 62 | idx = -1 63 | for atom in structure.get_atoms(): 64 | atom_resid = atom.parent.get_id() 65 | if atom_resid != curr_resid: 66 | idx += 1 67 | if idx >= bfactors.shape[0]: 68 | raise ValueError( 69 | "Index into bfactors exceeds number of residues. " 70 | "B-factors shape: {shape}, idx: {idx}." 71 | ) 72 | curr_resid = atom_resid 73 | atom.bfactor = bfactors[idx, residue_constants.atom_order["CA"]] 74 | 75 | new_pdb = io.StringIO() 76 | pdb_io = PDB.PDBIO() 77 | pdb_io.set_structure(structure) 78 | pdb_io.save(new_pdb) 79 | return new_pdb.getvalue() 80 | 81 | 82 | def assert_equal_nonterminal_atom_types( 83 | atom_mask: np.ndarray, ref_atom_mask: np.ndarray 84 | ): 85 | """Checks that pre- and post-minimized proteins have same atom set.""" 86 | # Ignore any terminal OXT atoms which may have been added by minimization. 87 | oxt = residue_constants.atom_order["OXT"] 88 | no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool) 89 | no_oxt_mask[..., oxt] = False 90 | np.testing.assert_almost_equal( 91 | ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask] 92 | ) 93 | -------------------------------------------------------------------------------- /tests/compare_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import importlib 4 | import pkgutil 5 | import sys 6 | import unittest 7 | 8 | import numpy as np 9 | 10 | from opencomplex.config.config import model_config 11 | from opencomplex.model.model import AlphaFold 12 | from opencomplex.utils.import_weights import import_jax_weights_ 13 | from tests.config import consts 14 | 15 | # Give JAX some GPU memory discipline 16 | # (by default it hogs 90% of GPU memory. This disables that behavior and also 17 | # forces it to proactively free memory that it allocates) 18 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" 19 | os.environ["JAX_PLATFORM_NAME"] = "gpu" 20 | 21 | 22 | def alphafold_is_installed(): 23 | return importlib.util.find_spec("alphafold") is not None 24 | 25 | 26 | def skip_unless_alphafold_installed(): 27 | return unittest.skipUnless(alphafold_is_installed(), "Requires AlphaFold") 28 | 29 | 30 | def import_alphafold(): 31 | """ 32 | If AlphaFold is installed using the provided setuptools script, this 33 | is necessary to expose all of AlphaFold's precious insides 34 | """ 35 | if "alphafold" in sys.modules: 36 | return sys.modules["alphafold"] 37 | module = importlib.import_module("alphafold") 38 | # Forcefully import alphafold's submodules 39 | submodules = pkgutil.walk_packages(module.__path__, prefix=("alphafold.")) 40 | for submodule_info in submodules: 41 | importlib.import_module(submodule_info.name) 42 | sys.modules["alphafold"] = module 43 | globals()["alphafold"] = module 44 | 45 | return module 46 | 47 | 48 | def get_alphafold_config(): 49 | config = alphafold.model.config.model_config("model_1_ptm") # noqa 50 | config.model.global_config.deterministic = True 51 | return config 52 | 53 | 54 | _param_path = "opencomplex/resources/params/params_model_1_ptm.npz" 55 | _model = None 56 | 57 | 58 | def get_global_pretrained_opencomplex(): 59 | global _model 60 | if _model is None: 61 | _model = AlphaFold(model_config("model_1_ptm")) 62 | _model = _model.eval() 63 | if not os.path.exists(_param_path): 64 | raise FileNotFoundError( 65 | """Cannot load pretrained parameters. Make sure to run the 66 | installation script before running tests.""" 67 | ) 68 | import_jax_weights_(_model, _param_path, version="model_1_ptm") 69 | _model = _model.cuda() 70 | 71 | return _model 72 | 73 | 74 | _orig_weights = None 75 | 76 | 77 | def _get_orig_weights(): 78 | global _orig_weights 79 | if _orig_weights is None: 80 | _orig_weights = np.load(_param_path) 81 | 82 | return _orig_weights 83 | 84 | 85 | def _remove_key_prefix(d, prefix): 86 | for k, v in list(d.items()): 87 | if k.startswith(prefix): 88 | d.pop(k) 89 | d[k[len(prefix) :]] = v 90 | 91 | 92 | def fetch_alphafold_module_weights(weight_path): 93 | orig_weights = _get_orig_weights() 94 | params = {k: v for k, v in orig_weights.items() if weight_path in k} 95 | if "/" in weight_path: 96 | spl = weight_path.split("/") 97 | spl = spl if len(spl[-1]) != 0 else spl[:-1] 98 | module_name = spl[-1] 99 | prefix = "/".join(spl[:-1]) + "/" 100 | _remove_key_prefix(params, prefix) 101 | 102 | try: 103 | params = alphafold.model.utils.flat_params_to_haiku(params) # noqa 104 | except: 105 | raise ImportError( 106 | "Make sure to call import_alphafold before running this function" 107 | ) 108 | return params 109 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ctypes 3 | from datetime import date 4 | 5 | 6 | def add_data_args(parser: argparse.ArgumentParser): 7 | parser.add_argument( 8 | '--uniref90_database_path', type=str, default=None, 9 | ) 10 | parser.add_argument( 11 | '--mgnify_database_path', type=str, default=None, 12 | ) 13 | parser.add_argument( 14 | '--pdb70_database_path', type=str, default=None, 15 | ) 16 | parser.add_argument( 17 | '--uniclust30_database_path', type=str, default=None, 18 | ) 19 | parser.add_argument( 20 | '--bfd_database_path', type=str, default=None, 21 | ) 22 | parser.add_argument( 23 | '--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer' 24 | ) 25 | parser.add_argument( 26 | '--hhblits_binary_path', type=str, default='/usr/bin/hhblits' 27 | ) 28 | parser.add_argument( 29 | '--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch' 30 | ) 31 | parser.add_argument( 32 | '--kalign_binary_path', type=str, default='/usr/bin/kalign' 33 | ) 34 | parser.add_argument( 35 | '--max_template_date', type=str, 36 | default=date.today().strftime("%Y-%m-%d"), 37 | ) 38 | parser.add_argument( 39 | '--obsolete_pdbs_path', type=str, default=None 40 | ) 41 | parser.add_argument( 42 | '--release_dates_path', type=str, default=None 43 | ) 44 | 45 | 46 | def get_nvidia_cc(): 47 | """ 48 | Returns a tuple containing the Compute Capability of the first GPU 49 | installed in the system (formatted as a tuple of strings) and an error 50 | message. When the former is provided, the latter is None, and vice versa. 51 | 52 | Adapted from script by Jan Schlüte t 53 | https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 54 | """ 55 | CUDA_SUCCESS = 0 56 | 57 | libnames = [ 58 | 'libcuda.so', 59 | 'libcuda.dylib', 60 | 'cuda.dll', 61 | '/usr/local/cuda/compat/libcuda.so', # For Docker 62 | ] 63 | for libname in libnames: 64 | try: 65 | cuda = ctypes.CDLL(libname) 66 | except OSError: 67 | continue 68 | else: 69 | break 70 | else: 71 | return None, "Could not load any of: " + ' '.join(libnames) 72 | 73 | nGpus = ctypes.c_int() 74 | cc_major = ctypes.c_int() 75 | cc_minor = ctypes.c_int() 76 | 77 | result = ctypes.c_int() 78 | device = ctypes.c_int() 79 | error_str = ctypes.c_char_p() 80 | 81 | result = cuda.cuInit(0) 82 | if result != CUDA_SUCCESS: 83 | cuda.cuGetErrorString(result, ctypes.byref(error_str)) 84 | if error_str.value: 85 | return None, error_str.value.decode() 86 | else: 87 | return None, "Unknown error: cuInit returned %d" % result 88 | result = cuda.cuDeviceGetCount(ctypes.byref(nGpus)) 89 | if result != CUDA_SUCCESS: 90 | cuda.cuGetErrorString(result, ctypes.byref(error_str)) 91 | return None, error_str.value.decode() 92 | 93 | if nGpus.value < 1: 94 | return None, "No GPUs detected" 95 | 96 | result = cuda.cuDeviceGet(ctypes.byref(device), 0) 97 | if result != CUDA_SUCCESS: 98 | cuda.cuGetErrorString(result, ctypes.byref(error_str)) 99 | return None, error_str.value.decode() 100 | 101 | if cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device) != CUDA_SUCCESS: 102 | return None, "Compute Capability not found" 103 | 104 | major = cc_major.value 105 | minor = cc_minor.value 106 | 107 | return (major, minor), None 108 | -------------------------------------------------------------------------------- /opencomplex/utils/superimposition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from Bio.SVDSuperimposer import SVDSuperimposer 15 | import numpy as np 16 | import torch 17 | 18 | 19 | def _superimpose_np(reference, coords): 20 | """ 21 | Superimposes coordinates onto a reference by minimizing RMSD using SVD. 22 | 23 | Args: 24 | reference: 25 | [N, 3] reference array 26 | coords: 27 | [N, 3] array 28 | Returns: 29 | A tuple of [N, 3] superimposed coords and the final RMSD. 30 | """ 31 | sup = SVDSuperimposer() 32 | sup.set(reference, coords) 33 | sup.run() 34 | return sup.get_transformed(), sup.get_rms() 35 | 36 | 37 | def _superimpose_single(reference, coords): 38 | reference_np = reference.detach().cpu().numpy() 39 | coords_np = coords.detach().cpu().numpy() 40 | superimposed, rmsd = _superimpose_np(reference_np, coords_np) 41 | return coords.new_tensor(superimposed), coords.new_tensor(rmsd) 42 | 43 | 44 | def superimpose(reference, coords, mask): 45 | """ 46 | Superimposes coordinates onto a reference by minimizing RMSD using SVD. 47 | 48 | Args: 49 | reference: 50 | [*, N, 3] reference tensor 51 | coords: 52 | [*, N, 3] tensor 53 | mask: 54 | [*, N] tensor 55 | Returns: 56 | A tuple of [*, N, 3] superimposed coords and [*] final RMSDs. 57 | """ 58 | def select_unmasked_coords(coords, mask): 59 | return torch.masked_select( 60 | coords, 61 | (mask > 0.)[..., None], 62 | ).reshape(-1, 3) 63 | 64 | batch_dims = reference.shape[:-2] 65 | flat_reference = reference.reshape((-1,) + reference.shape[-2:]) 66 | flat_coords = coords.reshape((-1,) + reference.shape[-2:]) 67 | flat_mask = mask.reshape((-1,) + mask.shape[-1:]) 68 | superimposed_list = [] 69 | rmsds = [] 70 | for r, c, m in zip(flat_reference, flat_coords, flat_mask): 71 | r_unmasked_coords = select_unmasked_coords(r, m) 72 | c_unmasked_coords = select_unmasked_coords(c, m) 73 | superimposed, rmsd = _superimpose_single( 74 | r_unmasked_coords, 75 | c_unmasked_coords 76 | ) 77 | 78 | # This is very inelegant, but idk how else to invert the masking 79 | # procedure. 80 | count = 0 81 | superimposed_full_size = torch.zeros_like(r) 82 | for i, unmasked in enumerate(m): 83 | if(unmasked): 84 | superimposed_full_size[i] = superimposed[count] 85 | count += 1 86 | 87 | superimposed_list.append(superimposed_full_size) 88 | rmsds.append(rmsd) 89 | 90 | superimposed_stacked = torch.stack(superimposed_list, dim=0) 91 | rmsds_stacked = torch.stack(rmsds, dim=0) 92 | 93 | superimposed_reshaped = superimposed_stacked.reshape( 94 | batch_dims + coords.shape[-2:] 95 | ) 96 | rmsds_reshaped = rmsds_stacked.reshape( 97 | batch_dims 98 | ) 99 | 100 | return superimposed_reshaped, rmsds_reshaped 101 | -------------------------------------------------------------------------------- /opencomplex/utils/kernel/attention_core.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | from functools import reduce 16 | from operator import mul 17 | 18 | import torch 19 | 20 | attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda") 21 | 22 | 23 | SUPPORTED_DTYPES = [torch.float32, torch.bfloat16] 24 | 25 | 26 | class AttentionCoreFunction(torch.autograd.Function): 27 | @staticmethod 28 | def forward(ctx, q, k, v, bias_1=None, bias_2=None): 29 | if(bias_1 is None and bias_2 is not None): 30 | raise ValueError("bias_1 must be specified before bias_2") 31 | if(q.dtype not in SUPPORTED_DTYPES): 32 | raise ValueError("Unsupported datatype") 33 | 34 | q = q.contiguous() 35 | k = k.contiguous() 36 | 37 | # [*, H, Q, K] 38 | attention_logits = torch.matmul( 39 | q, k.transpose(-1, -2), 40 | ) 41 | 42 | if(bias_1 is not None): 43 | attention_logits += bias_1 44 | if(bias_2 is not None): 45 | attention_logits += bias_2 46 | 47 | attn_core_inplace_cuda.forward_( 48 | attention_logits, 49 | reduce(mul, attention_logits.shape[:-1]), 50 | attention_logits.shape[-1], 51 | ) 52 | 53 | o = torch.matmul(attention_logits, v) 54 | 55 | ctx.bias_1_shape = bias_1.shape if bias_1 is not None else None 56 | ctx.bias_2_shape = bias_2.shape if bias_2 is not None else None 57 | ctx.save_for_backward(q, k, v, attention_logits) 58 | 59 | return o 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | q, k, v, attention_logits = ctx.saved_tensors 64 | grad_q = grad_k = grad_v = grad_bias_1 = grad_bias_2 = None 65 | 66 | grad_v = torch.matmul( 67 | attention_logits.transpose(-1, -2), 68 | grad_output 69 | ) 70 | 71 | attn_core_inplace_cuda.backward_( 72 | attention_logits, 73 | grad_output.contiguous(), 74 | v.contiguous(), # v is implicitly transposed in the kernel 75 | reduce(mul, attention_logits.shape[:-1]), 76 | attention_logits.shape[-1], 77 | grad_output.shape[-1], 78 | ) 79 | 80 | if(ctx.bias_1_shape is not None): 81 | grad_bias_1 = torch.sum( 82 | attention_logits, 83 | dim=tuple(i for i,d in enumerate(ctx.bias_1_shape) if d == 1), 84 | keepdim=True, 85 | ) 86 | 87 | if(ctx.bias_2_shape is not None): 88 | grad_bias_2 = torch.sum( 89 | attention_logits, 90 | dim=tuple(i for i,d in enumerate(ctx.bias_2_shape) if d == 1), 91 | keepdim=True, 92 | ) 93 | 94 | grad_q = torch.matmul( 95 | attention_logits, k 96 | ) 97 | grad_k = torch.matmul( 98 | q.transpose(-1, -2), attention_logits, 99 | ).transpose(-1, -2) 100 | 101 | return grad_q, grad_k, grad_v, grad_bias_1, grad_bias_2 102 | 103 | attention_core = AttentionCoreFunction.apply 104 | -------------------------------------------------------------------------------- /tests/test_outer_product_mean.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import numpy as np 17 | import unittest 18 | from opencomplex.model.outer_product_mean import OuterProductMean 19 | from opencomplex.utils.tensor_utils import tree_map 20 | import tests.compare_utils as compare_utils 21 | from tests.config import consts 22 | 23 | if compare_utils.alphafold_is_installed(): 24 | alphafold = compare_utils.import_alphafold() 25 | import jax 26 | import haiku as hk 27 | 28 | 29 | class TestOuterProductMean(unittest.TestCase): 30 | def test_shape(self): 31 | c = 31 32 | 33 | opm = OuterProductMean(consts.c_m, consts.c_z, c) 34 | 35 | m = torch.rand( 36 | (consts.batch_size, consts.n_seq, consts.n_res, consts.c_m) 37 | ) 38 | mask = torch.randint( 39 | 0, 2, size=(consts.batch_size, consts.n_seq, consts.n_res) 40 | ) 41 | m = opm(m, mask=mask, chunk_size=None) 42 | 43 | self.assertTrue( 44 | m.shape == 45 | (consts.batch_size, consts.n_res, consts.n_res, consts.c_z) 46 | ) 47 | 48 | @compare_utils.skip_unless_alphafold_installed() 49 | def test_opm_compare(self): 50 | def run_opm(msa_act, msa_mask): 51 | config = compare_utils.get_alphafold_config() 52 | c_evo = config.model.embeddings_and_evoformer.evoformer 53 | opm = alphafold.model.modules.OuterProductMean( 54 | c_evo.outer_product_mean, 55 | config.model.global_config, 56 | consts.c_z, 57 | ) 58 | act = opm(act=msa_act, mask=msa_mask) 59 | return act 60 | 61 | f = hk.transform(run_opm) 62 | 63 | n_res = consts.n_res 64 | n_seq = consts.n_seq 65 | c_m = consts.c_m 66 | 67 | msa_act = np.random.rand(n_seq, n_res, c_m).astype(np.float32) * 100 68 | msa_mask = np.random.randint(low=0, high=2, size=(n_seq, n_res)).astype( 69 | np.float32 70 | ) 71 | 72 | # Fetch pretrained parameters (but only from one block)] 73 | params = compare_utils.fetch_alphafold_module_weights( 74 | "alphafold/alphafold_iteration/evoformer/" 75 | + "evoformer_iteration/outer_product_mean" 76 | ) 77 | params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) 78 | 79 | out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready() 80 | out_gt = torch.as_tensor(np.array(out_gt)) 81 | 82 | model = compare_utils.get_global_pretrained_opencomplex() 83 | out_repro = ( 84 | model.evoformer.blocks[0].core 85 | .outer_product_mean( 86 | torch.as_tensor(msa_act).cuda(), 87 | chunk_size=4, 88 | mask=torch.as_tensor(msa_mask).cuda(), 89 | ) 90 | .cpu() 91 | ) 92 | 93 | # Even when correct, OPM has large, precision-related errors. It gets 94 | # a special pass from consts.eps. 95 | self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 5e-4) 96 | 97 | 98 | if __name__ == "__main__": 99 | unittest.main() 100 | -------------------------------------------------------------------------------- /opencomplex/data/feature_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 BAAI 2 | # Copyright 2021 AlQuraishi Laboratory 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import copy 18 | from typing import Mapping, Tuple, List, Optional, Dict, Sequence 19 | 20 | import ml_collections 21 | import numpy as np 22 | import torch 23 | 24 | from opencomplex.data import input_pipeline 25 | 26 | 27 | FeatureDict = Mapping[str, np.ndarray] 28 | TensorDict = Dict[str, torch.Tensor] 29 | 30 | 31 | def np_to_tensor_dict( 32 | np_example: Mapping[str, np.ndarray], 33 | features: Sequence[str], 34 | ) -> TensorDict: 35 | """Creates dict of tensors from a dict of NumPy arrays. 36 | 37 | Args: 38 | np_example: A dict of NumPy feature arrays. 39 | features: A list of strings of feature names to be returned in the dataset. 40 | 41 | Returns: 42 | A dictionary of features mapping feature names to features. Only the given 43 | features are returned, all other ones are filtered out. 44 | """ 45 | tensor_dict = { 46 | k: torch.tensor(v) if isinstance(v, np.ndarray) else v 47 | for k, v in np_example.items() if k in features 48 | } 49 | 50 | return tensor_dict 51 | 52 | 53 | def make_data_config( 54 | config: ml_collections.ConfigDict, 55 | mode: str, 56 | num_res: int, 57 | ) -> Tuple[ml_collections.ConfigDict, List[str]]: 58 | cfg = copy.deepcopy(config) 59 | mode_cfg = cfg[mode] 60 | with cfg.unlocked(): 61 | if mode_cfg.crop_size is None: 62 | mode_cfg.crop_size = num_res 63 | 64 | feature_names = cfg.common.unsupervised_features 65 | 66 | if cfg.common.use_templates: 67 | feature_names += cfg.common.template_features 68 | 69 | if cfg[mode].supervised: 70 | feature_names += cfg.supervised.supervised_features 71 | 72 | return cfg, feature_names 73 | 74 | 75 | def np_example_to_features( 76 | np_example: FeatureDict, 77 | config: ml_collections.ConfigDict, 78 | mode: str, 79 | ): 80 | np_example = dict(np_example) 81 | num_res = int(np_example["seq_length"][0]) 82 | cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) 83 | 84 | if "deletion_matrix_int" in np_example: 85 | np_example["deletion_matrix"] = np_example.pop( 86 | "deletion_matrix_int" 87 | ).astype(np.float32) 88 | 89 | tensor_dict = np_to_tensor_dict( 90 | np_example=np_example, features=feature_names 91 | ) 92 | with torch.no_grad(): 93 | features = input_pipeline.process_tensors_from_config( 94 | tensor_dict, 95 | cfg.common, 96 | cfg[mode], 97 | ) 98 | 99 | return {k: v for k, v in features.items()} 100 | 101 | 102 | class FeaturePipeline: 103 | def __init__( 104 | self, 105 | config: ml_collections.ConfigDict, 106 | ): 107 | self.config = config 108 | 109 | def process_features( 110 | self, 111 | raw_features: FeatureDict, 112 | mode: str = "train", 113 | ) -> FeatureDict: 114 | return np_example_to_features( 115 | np_example=raw_features, 116 | config=self.config, 117 | mode=mode, 118 | ) 119 | -------------------------------------------------------------------------------- /opencomplex/np/relax/relax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Amber relaxation.""" 17 | from typing import Any, Dict, Sequence, Tuple 18 | from opencomplex.np import protein 19 | from opencomplex.np.relax import amber_minimize, utils 20 | import numpy as np 21 | 22 | 23 | class AmberRelaxation(object): 24 | """Amber relaxation.""" 25 | def __init__( 26 | self, 27 | *, 28 | max_iterations: int, 29 | tolerance: float, 30 | stiffness: float, 31 | exclude_residues: Sequence[int], 32 | max_outer_iterations: int, 33 | use_gpu: bool, 34 | ): 35 | """Initialize Amber Relaxer. 36 | 37 | Args: 38 | max_iterations: Maximum number of L-BFGS iterations. 0 means no max. 39 | tolerance: kcal/mol, the energy tolerance of L-BFGS. 40 | stiffness: kcal/mol A**2, spring constant of heavy atom restraining 41 | potential. 42 | exclude_residues: Residues to exclude from per-atom restraining. 43 | Zero-indexed. 44 | max_outer_iterations: Maximum number of violation-informed relax 45 | iterations. A value of 1 will run the non-iterative procedure used in 46 | CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes 47 | as soon as there are no violations, hence in most cases this causes no 48 | slowdown. In the worst case we do 20 outer iterations. 49 | use_gpu: Whether to run on GPU 50 | """ 51 | 52 | self._max_iterations = max_iterations 53 | self._tolerance = tolerance 54 | self._stiffness = stiffness 55 | self._exclude_residues = exclude_residues 56 | self._max_outer_iterations = max_outer_iterations 57 | self._use_gpu = use_gpu 58 | 59 | def process( 60 | self, *, prot: protein.Protein 61 | ) -> Tuple[str, Dict[str, Any], np.ndarray]: 62 | """Runs Amber relax on a prediction, adds hydrogens, returns PDB string.""" 63 | out = amber_minimize.run_pipeline( 64 | prot=prot, 65 | max_iterations=self._max_iterations, 66 | tolerance=self._tolerance, 67 | stiffness=self._stiffness, 68 | exclude_residues=self._exclude_residues, 69 | max_outer_iterations=self._max_outer_iterations, 70 | use_gpu=self._use_gpu, 71 | ) 72 | min_pos = out["pos"] 73 | start_pos = out["posinit"] 74 | rmsd = np.sqrt(np.sum((start_pos - min_pos) ** 2) / start_pos.shape[0]) 75 | debug_data = { 76 | "initial_energy": out["einit"], 77 | "final_energy": out["efinal"], 78 | "attempts": out["min_attempts"], 79 | "rmsd": rmsd, 80 | } 81 | pdb_str = amber_minimize.clean_protein(prot) 82 | min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos) 83 | min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors) 84 | utils.assert_equal_nonterminal_atom_types( 85 | protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask 86 | ) 87 | violations = out["structural_violations"][ 88 | "total_per_residue_violations_mask" 89 | ] 90 | 91 | min_pdb = protein.add_pdb_headers(prot, min_pdb) 92 | 93 | return min_pdb, debug_data, violations 94 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 BAAI 2 | # Copyright 2021 AlQuraishi Laboratory 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import os 17 | from setuptools import setup, Extension, find_packages 18 | import subprocess 19 | 20 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME 21 | 22 | from scripts.utils import get_nvidia_cc 23 | 24 | 25 | version_dependent_macros = [ 26 | '-DVERSION_GE_1_1', 27 | '-DVERSION_GE_1_3', 28 | '-DVERSION_GE_1_5', 29 | ] 30 | 31 | extra_cuda_flags = [ 32 | '-std=c++14', 33 | '-maxrregcount=50', 34 | '-U__CUDA_NO_HALF_OPERATORS__', 35 | '-U__CUDA_NO_HALF_CONVERSIONS__', 36 | '--expt-relaxed-constexpr', 37 | '--expt-extended-lambda' 38 | ] 39 | 40 | def get_cuda_bare_metal_version(cuda_dir): 41 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) 42 | output = raw_output.split() 43 | release_idx = output.index("release") + 1 44 | release = output[release_idx].split(".") 45 | bare_metal_major = release[0] 46 | bare_metal_minor = release[1][0] 47 | 48 | return raw_output, bare_metal_major, bare_metal_minor 49 | 50 | compute_capabilities = set([ 51 | (3, 7), # K80, e.g. 52 | (5, 2), # Titan X 53 | (6, 1), # GeForce 1000-series 54 | ]) 55 | 56 | compute_capabilities.add((7, 0)) 57 | _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) 58 | if int(bare_metal_major) >= 11: 59 | compute_capabilities.add((8, 0)) 60 | 61 | compute_capability, _ = get_nvidia_cc() 62 | if compute_capability is not None: 63 | compute_capabilities = set([compute_capability]) 64 | 65 | cc_flag = [] 66 | for major, minor in list(compute_capabilities): 67 | cc_flag.extend([ 68 | '-gencode', 69 | f'arch=compute_{major}{minor},code=sm_{major}{minor}', 70 | ]) 71 | 72 | extra_cuda_flags += cc_flag 73 | 74 | 75 | setup( 76 | name='opencomplex', 77 | version='1.0.0', 78 | description='A platform for protein and RNA complex structure predicition.', 79 | author='BAAI', 80 | author_email='jingchengyu.94@gmail.com', 81 | license='Apache License, Version 2.0', 82 | packages=find_packages(exclude=["tests", "scripts"]), 83 | include_package_data=True, 84 | package_data={ 85 | "opencomplex": ['utils/kernel/csrc/*'], 86 | "": ["resources/stereo_chemical_props.txt"] 87 | }, 88 | ext_modules=[CUDAExtension( 89 | name="attn_core_inplace_cuda", 90 | sources=[ 91 | "opencomplex/utils/kernel/csrc/softmax_cuda.cpp", 92 | "opencomplex/utils/kernel/csrc/softmax_cuda_kernel.cu", 93 | ], 94 | include_dirs=[ 95 | os.path.join( 96 | os.path.dirname(os.path.abspath(__file__)), 97 | 'opencomplex/utils/kernel/csrc/' 98 | ) 99 | ], 100 | extra_compile_args={ 101 | 'cxx': ['-O3'] + version_dependent_macros, 102 | 'nvcc': ( 103 | ['-O3', '--use_fast_math'] + 104 | version_dependent_macros + 105 | extra_cuda_flags 106 | ), 107 | } 108 | )], 109 | cmdclass={'build_ext': BuildExtension}, 110 | classifiers=[ 111 | 'License :: OSI Approved :: Apache Software License', 112 | 'Operating System :: POSIX :: Linux', 113 | 'Programming Language :: Python :: 3.9,' 114 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 115 | ], 116 | ) 117 | -------------------------------------------------------------------------------- /opencomplex/data/tools/hhsearch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Library to run HHsearch from Python.""" 17 | import glob 18 | import logging 19 | import os 20 | import subprocess 21 | from typing import Sequence 22 | 23 | from opencomplex.data.tools import utils 24 | 25 | 26 | class HHSearch: 27 | """Python wrapper of the HHsearch binary.""" 28 | 29 | def __init__( 30 | self, 31 | *, 32 | binary_path: str, 33 | databases: Sequence[str], 34 | n_cpu: int = 2, 35 | maxseq: int = 1_000_000, 36 | ): 37 | """Initializes the Python HHsearch wrapper. 38 | 39 | Args: 40 | binary_path: The path to the HHsearch executable. 41 | databases: A sequence of HHsearch database paths. This should be the 42 | common prefix for the database files (i.e. up to but not including 43 | _hhm.ffindex etc.) 44 | n_cpu: The number of CPUs to use 45 | maxseq: The maximum number of rows in an input alignment. Note that this 46 | parameter is only supported in HHBlits version 3.1 and higher. 47 | 48 | Raises: 49 | RuntimeError: If HHsearch binary not found within the path. 50 | """ 51 | self.binary_path = binary_path 52 | self.databases = databases 53 | self.n_cpu = n_cpu 54 | self.maxseq = maxseq 55 | 56 | for database_path in self.databases: 57 | if not glob.glob(database_path + "_*"): 58 | logging.error( 59 | "Could not find HHsearch database %s", database_path 60 | ) 61 | raise ValueError( 62 | f"Could not find HHsearch database {database_path}" 63 | ) 64 | 65 | def query(self, a3m: str) -> str: 66 | """Queries the database using HHsearch using a given a3m.""" 67 | with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: 68 | input_path = os.path.join(query_tmp_dir, "query.a3m") 69 | hhr_path = os.path.join(query_tmp_dir, "output.hhr") 70 | with open(input_path, "w") as f: 71 | f.write(a3m) 72 | 73 | db_cmd = [] 74 | for db_path in self.databases: 75 | db_cmd.append("-d") 76 | db_cmd.append(db_path) 77 | cmd = [ 78 | self.binary_path, 79 | "-i", 80 | input_path, 81 | "-o", 82 | hhr_path, 83 | "-maxseq", 84 | str(self.maxseq), 85 | "-cpu", 86 | str(self.n_cpu), 87 | ] + db_cmd 88 | 89 | logging.info('Launching subprocess "%s"', " ".join(cmd)) 90 | process = subprocess.Popen( 91 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE 92 | ) 93 | with utils.timing("HHsearch query"): 94 | stdout, stderr = process.communicate() 95 | retcode = process.wait() 96 | 97 | if retcode: 98 | # Stderr is truncated to prevent proto size errors in Beam. 99 | raise RuntimeError( 100 | "HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n" 101 | % (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8")) 102 | ) 103 | 104 | with open(hhr_path) as f: 105 | hhr = f.read() 106 | return hhr 107 | -------------------------------------------------------------------------------- /tests/test_triangular_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import numpy as np 17 | import unittest 18 | from opencomplex.model.triangular_attention import TriangleAttention 19 | from opencomplex.utils.tensor_utils import tree_map 20 | 21 | import tests.compare_utils as compare_utils 22 | from tests.config import consts 23 | 24 | if compare_utils.alphafold_is_installed(): 25 | alphafold = compare_utils.import_alphafold() 26 | import jax 27 | import haiku as hk 28 | 29 | 30 | class TestTriangularAttention(unittest.TestCase): 31 | def test_shape(self): 32 | c_z = consts.c_z 33 | c = 12 34 | no_heads = 4 35 | starting = True 36 | 37 | tan = TriangleAttention(c_z, c, no_heads, starting) 38 | 39 | batch_size = consts.batch_size 40 | n_res = consts.n_res 41 | 42 | x = torch.rand((batch_size, n_res, n_res, c_z)) 43 | shape_before = x.shape 44 | x = tan(x, chunk_size=None) 45 | shape_after = x.shape 46 | 47 | self.assertTrue(shape_before == shape_after) 48 | 49 | def _tri_att_compare(self, starting=False): 50 | name = ( 51 | "triangle_attention_" 52 | + ("starting" if starting else "ending") 53 | + "_node" 54 | ) 55 | 56 | def run_tri_att(pair_act, pair_mask): 57 | config = compare_utils.get_alphafold_config() 58 | c_e = config.model.embeddings_and_evoformer.evoformer 59 | tri_att = alphafold.model.modules.TriangleAttention( 60 | c_e.triangle_attention_starting_node 61 | if starting 62 | else c_e.triangle_attention_ending_node, 63 | config.model.global_config, 64 | name=name, 65 | ) 66 | act = tri_att(pair_act=pair_act, pair_mask=pair_mask) 67 | return act 68 | 69 | f = hk.transform(run_tri_att) 70 | 71 | n_res = consts.n_res 72 | 73 | pair_act = np.random.rand(n_res, n_res, consts.c_z) * 100 74 | pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res)) 75 | 76 | # Fetch pretrained parameters (but only from one block)] 77 | params = compare_utils.fetch_alphafold_module_weights( 78 | "alphafold/alphafold_iteration/evoformer/evoformer_iteration/" 79 | + name 80 | ) 81 | params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) 82 | 83 | out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() 84 | out_gt = torch.as_tensor(np.array(out_gt)) 85 | 86 | model = compare_utils.get_global_pretrained_opencomplex() 87 | module = ( 88 | model.evoformer.blocks[0].core.tri_att_start 89 | if starting 90 | else model.evoformer.blocks[0].core.tri_att_end 91 | ) 92 | out_repro = module( 93 | torch.as_tensor(pair_act, dtype=torch.float32).cuda(), 94 | mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), 95 | chunk_size=None, 96 | ).cpu() 97 | 98 | self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) 99 | 100 | @compare_utils.skip_unless_alphafold_installed() 101 | def test_tri_att_end_compare(self): 102 | self._tri_att_compare() 103 | 104 | @compare_utils.skip_unless_alphafold_installed() 105 | def test_tri_att_start_compare(self): 106 | self._tri_att_compare(starting=True) 107 | 108 | 109 | if __name__ == "__main__": 110 | unittest.main() 111 | -------------------------------------------------------------------------------- /tests/test_data_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pickle 16 | import shutil 17 | 18 | import torch 19 | import numpy as np 20 | import unittest 21 | 22 | from opencomplex.data.data_pipeline import DataPipeline 23 | from opencomplex.data.templates import TemplateHitFeaturizer 24 | from opencomplex.model.embedders import ( 25 | InputEmbedder, 26 | RecyclingEmbedder, 27 | TemplateAngleEmbedder, 28 | TemplatePairEmbedder, 29 | ) 30 | import tests.compare_utils as compare_utils 31 | 32 | if compare_utils.alphafold_is_installed(): 33 | alphafold = compare_utils.import_alphafold() 34 | import jax 35 | import haiku as hk 36 | 37 | 38 | class TestDataPipeline(unittest.TestCase): 39 | @compare_utils.skip_unless_alphafold_installed() 40 | def test_fasta_compare(self): 41 | # AlphaFold runs the alignments and feature processing at the same 42 | # time, taking forever. As such, we precompute AlphaFold's features 43 | # using scripts/generate_alphafold_feature_dict.py and the default 44 | # databases. 45 | with open("tests/test_data/alphafold_feature_dict.pickle", "rb") as fp: 46 | alphafold_feature_dict = pickle.load(fp) 47 | 48 | template_featurizer = TemplateHitFeaturizer( 49 | mmcif_dir="tests/test_data/mmcifs", 50 | max_template_date="2021-12-20", 51 | max_hits=20, 52 | kalign_binary_path=shutil.which("kalign"), 53 | _zero_center_positions=False, 54 | ) 55 | 56 | data_pipeline = DataPipeline( 57 | template_featurizer=template_featurizer, 58 | ) 59 | 60 | opencomplex_feature_dict = data_pipeline.process_fasta( 61 | "tests/test_data/short.fasta", 62 | "tests/test_data/alignments" 63 | ) 64 | 65 | opencomplex_feature_dict["template_all_atom_masks"] = opencomplex_feature_dict["template_all_atom_mask"] 66 | 67 | checked = [] 68 | 69 | # AlphaFold and opencomplex process their MSAs in slightly different 70 | # orders, which we compensate for below. 71 | m_a = alphafold_feature_dict["msa"] 72 | m_o = opencomplex_feature_dict["msa"] 73 | 74 | # The first row of both MSAs should be the same, no matter what 75 | self.assertTrue(np.all(m_a[0, :] == m_o[0, :])) 76 | 77 | # Each row of each MSA should appear exactly once somewhere in its 78 | # counterpart 79 | matching_rows = np.all((m_a[:, None, ...] == m_o[None, :, ...]), axis=-1) 80 | self.assertTrue( 81 | np.all( 82 | np.sum(matching_rows, axis=-1) == 1 83 | ) 84 | ) 85 | 86 | checked.append("msa") 87 | 88 | # The corresponding rows of the deletion matrix should also be equal 89 | matching_idx = np.argmax(matching_rows, axis=-1) 90 | rearranged_o_dmi = opencomplex_feature_dict["deletion_matrix_int"] 91 | rearranged_o_dmi = rearranged_o_dmi[matching_idx, :] 92 | self.assertTrue( 93 | np.all( 94 | alphafold_feature_dict["deletion_matrix_int"] == 95 | rearranged_o_dmi 96 | ) 97 | ) 98 | 99 | checked.append("deletion_matrix_int") 100 | 101 | # Remaining features have to be precisely equal 102 | for k, v in alphafold_feature_dict.items(): 103 | self.assertTrue( 104 | k in checked or np.all(v == opencomplex_feature_dict[k]) 105 | ) 106 | 107 | 108 | 109 | if __name__ == "__main__": 110 | unittest.main() 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![header](img/logo.png) 2 | # OpenComplex 3 | OpenComplex is an open-source platform for developing protein and RNA complex models. 4 | Based on DeepMind's [Alphafold 2](https://github.com/deepmind/alphafold) and AQ Laboratory's [OpenFold](https://github.com/aqlaboratory/openfold), OpenComplex support almost all features from Alphafold 2 and OpenFold, and introduces the following new features: 5 | * Reimplemented Alphafold-Multimer models. 6 | * RNA and protein-RNA complex models with high precision. 7 | * Kernel fusion and optimization on >=Ampere GPUs, brings 16% 8 | 9 | ![Figure 1. OpenComplex inference result of RNA and protein-RNA complex.](img/cases.png) 10 | 11 | We will release training results and pretrained parameters soon. 12 | 13 | ## Installation (Linux) 14 | 15 | All Python dependencies are specified in `environment.yml`. For producing sequence 16 | alignments, you'll also need `kalign`, the [HH-suite](https://github.com/soedinglab/hh-suite), 17 | and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (nightly build)} 18 | installed on on your system. 19 | Finally, some download scripts require `aria2c` and `aws`. 20 | 21 | For convenience, we provide a script that installs Miniconda locally, creates a 22 | `conda` virtual environment, installs all Python dependencies, and downloads 23 | useful resources, including both sets of model parameters. Run: 24 | 25 | ```bash 26 | scripts/install_third_party_dependencies.sh 27 | ``` 28 | 29 | To activate the environment, run: 30 | 31 | ```bash 32 | source scripts/activate_conda_env.sh 33 | ``` 34 | 35 | With the environment active, compile CUDA kernels with 36 | 37 | ```bash 38 | python3 setup.py install 39 | ``` 40 | 41 | To install the HH-suite to `/usr/bin`, run 42 | 43 | ```bash 44 | scripts/install_hh_suite.sh 45 | ``` 46 | 47 | ## Usage 48 | 49 | ### Data preparation 50 | 51 | To run feature generation pipeline from `.fasta` to `feature.pkl` on DeepMind's MSA and template database, run e.g.: 52 | ```bash 53 | python ./scripts/extract_pkl_from_fas.py ./example_data/fasta/ ./example_data/features/ 54 | ``` 55 | where `example_data` is the directory containing example fasta . If `jackhmmer`, 56 | `hhblits`, `hhsearch` and `kalign` are available at the default path of 57 | `/usr/bin`, their `binary_path` command-line arguments can be dropped. 58 | If you've already computed alignments for the query, you have the option to 59 | skip the expensive alignment computation here with 60 | `--use_precomputed_alignments`. 61 | 62 | ### Train and Inference 63 | 64 | See example bash scripts in example_data/scripts 65 | 66 | ## Testing 67 | 68 | To run unit tests, use 69 | 70 | ```bash 71 | scripts/run_unit_tests.sh 72 | ``` 73 | 74 | The script is a thin wrapper around Python's `unittest` suite, and recognizes 75 | `unittest` arguments. E.g., to run a specific test verbosely: 76 | 77 | ```bash 78 | scripts/run_unit_tests.sh -v tests.test_model 79 | ``` 80 | 81 | Certain tests require that AlphaFold (v2.0.1) be installed in the same Python 82 | environment. These run components of AlphaFold and OpenFold side by side and 83 | ensure that output activations are adequately similar. For most modules, we 84 | target a maximum pointwise difference of `1e-4`. 85 | 86 | ## Citation 87 | 88 | If you find our open-sourced code & models helpful to your research, please also consider star🌟 and cite📑 this repo. Thank you for your support! 89 | ``` 90 | @misc{OpenComplex_code, 91 | author={Jingcheng, Yu and Zhaoming, Chen and Zhaoqun, Li and Mingliang, Zeng and Wenjun, Lin and He, Huang and Qiwei, Ye}, 92 | title={Code of OpenComplex}, 93 | year={2022}, 94 | howpublished = {\url{https://github.com/baaihealth/OpenComplex}} 95 | } 96 | ``` 97 | It is recommended to also cite OpenFold and AlphaFold. 98 | 99 | 100 | ## License and Disclaimer 101 | 102 | Copyright 2022 BAAI. 103 | 104 | Extended from AlphaFold and OpenFold, OpenComplex is licensed under 105 | the permissive Apache Licence, Version 2.0. 106 | 107 | ## Contributing 108 | 109 | If you encounter problems using OpenComplex, feel free to create an issue! We also 110 | welcome pull requests from the community. 111 | 112 | ## Contact Information 113 | For help or issues using the repos, please submit a GitHub issue. 114 | 115 | For other communications, please contact Qiwei Ye (qwye@baai.ac.cn). 116 | -------------------------------------------------------------------------------- /opencomplex/data/tools/kalign.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A Python wrapper for Kalign.""" 17 | import os 18 | import subprocess 19 | from typing import Sequence 20 | 21 | from absl import logging 22 | 23 | from opencomplex.data.tools import utils 24 | 25 | 26 | def _to_a3m(sequences: Sequence[str]) -> str: 27 | """Converts sequences to an a3m file.""" 28 | names = ["sequence %d" % i for i in range(1, len(sequences) + 1)] 29 | a3m = [] 30 | for sequence, name in zip(sequences, names): 31 | a3m.append(u">" + name + u"\n") 32 | a3m.append(sequence + u"\n") 33 | return "".join(a3m) 34 | 35 | 36 | class Kalign: 37 | """Python wrapper of the Kalign binary.""" 38 | 39 | def __init__(self, *, binary_path: str): 40 | """Initializes the Python Kalign wrapper. 41 | 42 | Args: 43 | binary_path: The path to the Kalign binary. 44 | 45 | Raises: 46 | RuntimeError: If Kalign binary not found within the path. 47 | """ 48 | self.binary_path = binary_path 49 | 50 | def align(self, sequences: Sequence[str]) -> str: 51 | """Aligns the sequences and returns the alignment in A3M string. 52 | 53 | Args: 54 | sequences: A list of query sequence strings. The sequences have to be at 55 | least 6 residues long (Kalign requires this). Note that the order in 56 | which you give the sequences might alter the output slightly as 57 | different alignment tree might get constructed. 58 | 59 | Returns: 60 | A string with the alignment in a3m format. 61 | 62 | Raises: 63 | RuntimeError: If Kalign fails. 64 | ValueError: If any of the sequences is less than 6 residues long. 65 | """ 66 | logging.info("Aligning %d sequences", len(sequences)) 67 | 68 | for s in sequences: 69 | if len(s) < 6: 70 | raise ValueError( 71 | "Kalign requires all sequences to be at least 6 " 72 | "residues long. Got %s (%d residues)." % (s, len(s)) 73 | ) 74 | 75 | with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: 76 | input_fasta_path = os.path.join(query_tmp_dir, "input.fasta") 77 | output_a3m_path = os.path.join(query_tmp_dir, "output.a3m") 78 | 79 | with open(input_fasta_path, "w") as f: 80 | f.write(_to_a3m(sequences)) 81 | 82 | cmd = [ 83 | self.binary_path, 84 | "-i", 85 | input_fasta_path, 86 | "-o", 87 | output_a3m_path, 88 | "-format", 89 | "fasta", 90 | ] 91 | 92 | logging.info('Launching subprocess "%s"', " ".join(cmd)) 93 | process = subprocess.Popen( 94 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE 95 | ) 96 | 97 | with utils.timing("Kalign query"): 98 | stdout, stderr = process.communicate() 99 | retcode = process.wait() 100 | logging.info( 101 | "Kalign stdout:\n%s\n\nstderr:\n%s\n", 102 | stdout.decode("utf-8"), 103 | stderr.decode("utf-8"), 104 | ) 105 | 106 | if retcode: 107 | raise RuntimeError( 108 | "Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n" 109 | % (stdout.decode("utf-8"), stderr.decode("utf-8")) 110 | ) 111 | 112 | with open(output_a3m_path) as f: 113 | a3m = f.read() 114 | 115 | return a3m 116 | -------------------------------------------------------------------------------- /opencomplex/model/triangular_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 BAAI 2 | # Copyright 2021 AlQuraishi Laboratory 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from functools import partialmethod, partial 18 | import math 19 | from typing import Optional, List 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | from opencomplex.model.primitives import Linear, LayerNorm, Attention 25 | from opencomplex.utils.chunk_utils import chunk_layer 26 | from opencomplex.utils.tensor_utils import ( 27 | permute_final_dims, 28 | flatten_final_dims, 29 | ) 30 | 31 | 32 | class TriangleAttention(nn.Module): 33 | def __init__( 34 | self, c_in, c_hidden, no_heads, starting=True, inf=1e9 35 | ): 36 | """ 37 | Args: 38 | c_in: 39 | Input channel dimension 40 | c_hidden: 41 | Overall hidden channel dimension (not per-head) 42 | no_heads: 43 | Number of attention heads 44 | """ 45 | super(TriangleAttention, self).__init__() 46 | 47 | self.c_in = c_in 48 | self.c_hidden = c_hidden 49 | self.no_heads = no_heads 50 | self.starting = starting 51 | self.inf = inf 52 | 53 | self.layer_norm = LayerNorm(self.c_in) 54 | 55 | self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") 56 | 57 | self.mha = Attention( 58 | self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads 59 | ) 60 | 61 | @torch.jit.ignore 62 | def _chunk(self, 63 | x: torch.Tensor, 64 | biases: List[torch.Tensor], 65 | chunk_size: int, 66 | use_memory_efficient_kernel: bool = False, 67 | use_lma: bool = False, 68 | inplace_safe: bool = False, 69 | ) -> torch.Tensor: 70 | "triangle! triangle!" 71 | mha_inputs = { 72 | "q_x": x, 73 | "kv_x": x, 74 | "biases": biases, 75 | } 76 | 77 | return chunk_layer( 78 | partial( 79 | self.mha, 80 | use_memory_efficient_kernel=use_memory_efficient_kernel, 81 | use_lma=use_lma 82 | ), 83 | mha_inputs, 84 | chunk_size=chunk_size, 85 | no_batch_dims=len(x.shape[:-2]), 86 | _out=x if inplace_safe else None, 87 | ) 88 | 89 | def forward(self, 90 | x: torch.Tensor, 91 | mask: Optional[torch.Tensor] = None, 92 | chunk_size: Optional[int] = None, 93 | use_memory_efficient_kernel: bool = False, 94 | use_lma: bool = False, 95 | inplace_safe: bool = False, 96 | ) -> torch.Tensor: 97 | """ 98 | Args: 99 | x: 100 | [*, I, J, C_in] input tensor (e.g. the pair representation) 101 | Returns: 102 | [*, I, J, C_in] output tensor 103 | """ 104 | if mask is None: 105 | # [*, I, J] 106 | mask = x.new_ones( 107 | x.shape[:-1], 108 | ) 109 | 110 | if(not self.starting): 111 | x = x.transpose(-2, -3) 112 | mask = mask.transpose(-1, -2) 113 | 114 | # [*, I, J, C_in] 115 | x = self.layer_norm(x) 116 | 117 | # [*, I, 1, 1, J] 118 | mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] 119 | 120 | # [*, H, I, J] 121 | triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) 122 | 123 | # [*, 1, H, I, J] 124 | triangle_bias = triangle_bias.unsqueeze(-4) 125 | 126 | biases = [mask_bias, triangle_bias] 127 | 128 | if chunk_size is not None: 129 | x = self._chunk( 130 | x, 131 | biases, 132 | chunk_size, 133 | use_memory_efficient_kernel=use_memory_efficient_kernel, 134 | use_lma=use_lma, 135 | inplace_safe=inplace_safe, 136 | ) 137 | else: 138 | x = self.mha( 139 | q_x=x, 140 | kv_x=x, 141 | biases=biases, 142 | use_memory_efficient_kernel=use_memory_efficient_kernel, 143 | use_lma=use_lma 144 | ) 145 | 146 | if(not self.starting): 147 | x = x.transpose(-2, -3) 148 | 149 | return x 150 | 151 | 152 | # Implements Algorithm 13 153 | TriangleAttentionStartingNode = TriangleAttention 154 | 155 | 156 | class TriangleAttentionEndingNode(TriangleAttention): 157 | """ 158 | Implements Algorithm 14. 159 | """ 160 | __init__ = partialmethod(TriangleAttention.__init__, starting=False) 161 | -------------------------------------------------------------------------------- /tests/test_permutation.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from opencomplex.utils.rigid_utils import Rotation, Rigid 6 | from opencomplex.utils import permutation 7 | 8 | C = 0.7071067811865475 9 | ROT_135 = [ 10 | [-C, -C, 0], 11 | [C, -C, 0], 12 | [0, 0, 1]] 13 | TRANS = [1., 2., 3.] 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | R = Rigid(Rotation(torch.tensor(ROT_135, device=device)), torch.tensor(TRANS, device=device)) 17 | N_res = 11 18 | eps = 1e-3 19 | 20 | class TestMultichainPermutationAlignment(unittest.TestCase): 21 | def test_get_transform(self): 22 | pred = torch.rand(N_res, 3, device=device) 23 | mask = torch.rand(N_res, device=device) >= 0.1 24 | gt = R.apply(pred) 25 | 26 | rot, tran = permutation.get_transform(pred, gt, mask, device) 27 | 28 | pred_transformed = gt @ rot + tran 29 | diff = torch.abs(pred_transformed - pred) 30 | 31 | # The result may have 1e-3 order of error due to tf32 on Ampere GPU 32 | # https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere 33 | self.assertTrue(torch.all(diff < eps)) 34 | 35 | def test_split_chains(self): 36 | asym_id = torch.tensor([1,1,1,2,2,2,2,3,4,4,5]) 37 | chains = permutation.split_chains(asym_id) 38 | true_chains = [(0, 3), (3, 7), (7, 8), (8, 10), (10, 11)] 39 | 40 | self.assertTrue(chains == true_chains) 41 | 42 | def test_perm_to_idx(self): 43 | chains = [(0, 3), (3, 5), (5, 6), (6, 8)] 44 | perm = [3, 2, 1, 0] 45 | idx = permutation.perm_to_idx(perm, chains) 46 | idx_true = [6, 7, 5, 3, 4, 0, 1, 2] 47 | 48 | self.assertTrue(idx == idx_true) 49 | 50 | def test_multichain_permutation_alignment(self): 51 | batch = {} 52 | out = {} 53 | 54 | batch_size = 2 55 | chain_lens = [10, 10, 10, 8, 8, 15, 15] 56 | entitys = [1, 1, 1, 2, 2, 3, 3] 57 | syms = [1, 2, 3, 1, 2, 1, 2] 58 | st = 0 59 | N_res = sum(chain_lens) 60 | chains = [] 61 | asym_id = [] 62 | entity_id = [] 63 | sym_id = [] 64 | for i, l in enumerate(chain_lens): 65 | ed = st + l 66 | chains.append((st, ed)) 67 | for j in range(l): 68 | asym_id.append(i + 1) 69 | entity_id.append(entitys[i]) 70 | sym_id.append(syms[i]) 71 | st = ed 72 | 73 | asym_id = [asym_id] * batch_size 74 | entity_id = [entity_id] * batch_size 75 | sym_id = [sym_id] * batch_size 76 | 77 | batch['butype'] = torch.zeros([batch_size, N_res], device=device) 78 | batch['asym_id'] = torch.tensor(asym_id, device=device) 79 | batch['sym_id'] = torch.tensor(sym_id, device=device) 80 | batch['entity_id'] = torch.tensor(entity_id, device=device) 81 | 82 | batch['anchor_asym_id'] = 5 83 | 84 | best_perm = [1, 2, 0, 4, 3, 5, 6] 85 | best_perm_idx = permutation.perm_to_idx(best_perm, chains) 86 | best_perm_idx = torch.tensor([best_perm_idx] * batch_size, device=device) 87 | 88 | out['final_atom_positions'] = torch.rand([batch_size, N_res, 37, 3], device=device) * 100 89 | out['final_atom_mask'] = torch.rand([batch_size, N_res, 37], device=device) >= 0.1 90 | 91 | batch['all_atom_positions'] = permutation.apply_permutation_core( 92 | R.apply(out['final_atom_positions']), best_perm_idx, dims=-3) 93 | 94 | perm_idx = permutation.multichain_permutation_alignment(batch, out) 95 | self.assertTrue(torch.all(perm_idx == best_perm_idx)) 96 | 97 | 98 | def test_apply_permutation_core(self): 99 | # test permute with 1 dimention 100 | x = torch.arange(0, 20, device=device).view(2, 1, 5, 2) 101 | idx = torch.tensor([[1, 2, 0, 4, 3], [3, 0, 2, 4, 1]], device=device) 102 | y = permutation.apply_permutation_core(x, idx, -2) 103 | 104 | y_true = torch.tensor( 105 | [[[[2, 3],[4, 5],[0, 1],[8,9],[6,7]]], 106 | [[[16,17],[10,11],[14,15],[18,19],[12,13]]]], 107 | device=device) 108 | 109 | self.assertTrue(torch.all(y == y_true)) 110 | 111 | # test permute with 2 dimentions 112 | x = torch.arange(0, 9, device=device, 113 | dtype=torch.float32).view(1, 3, 3) 114 | x.requires_grad = True 115 | idx = torch.tensor([[1, 2, 0]], device=device) 116 | y = permutation.apply_permutation_core(x, idx, [-1, -2]) 117 | 118 | y_true = torch.tensor( 119 | [[4,5,3], 120 | [7,8,6], 121 | [1,2,0]], 122 | device=device, 123 | dtype=torch.float32) 124 | self.assertTrue(torch.all(y == y_true)) 125 | 126 | # test if gradient can be proped back to x correctly 127 | d = y ** 2 128 | external_grad = torch.arange(1, 10, device=device).view(d.shape) 129 | d.backward(gradient=external_grad) 130 | 131 | g_x_true = torch.tensor( 132 | [[0, 14, 32], 133 | [18, 8, 20], 134 | [72, 56, 80]], 135 | dtype=torch.float32, 136 | device=device 137 | ) 138 | 139 | self.assertTrue(torch.all(g_x_true == x.grad)) 140 | 141 | 142 | if __name__ == '__main__': 143 | unittest.main() -------------------------------------------------------------------------------- /opencomplex/model/outer_product_mean.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 BAAI 2 | # Copyright 2021 AlQuraishi Laboratory 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from functools import partial 18 | from typing import Optional 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | from opencomplex.model.primitives import Linear, LayerNorm 24 | from opencomplex.utils.chunk_utils import chunk_layer 25 | from opencomplex.faster_alphafold.faster_alphafold_config import faster_alphafold_config 26 | 27 | 28 | class OuterProductMean(nn.Module): 29 | """ 30 | Implements Algorithm 10. 31 | """ 32 | 33 | def __init__(self, c_m, c_z, c_hidden, eps=1e-3): 34 | """ 35 | Args: 36 | c_m: 37 | MSA embedding channel dimension 38 | c_z: 39 | Pair embedding channel dimension 40 | c_hidden: 41 | Hidden channel dimension 42 | """ 43 | super(OuterProductMean, self).__init__() 44 | 45 | self.c_m = c_m 46 | self.c_z = c_z 47 | self.c_hidden = c_hidden 48 | self.eps = eps 49 | 50 | self.layer_norm = LayerNorm(c_m) 51 | if faster_alphafold_config.outer_product_mean: 52 | linear_1 = Linear(c_m, c_hidden) 53 | linear_2 = Linear(c_m, c_hidden) 54 | self.linear_12 = Linear(c_m, c_hidden * 2) 55 | self.linear_12.weight.data.copy_(torch.cat((linear_1.weight.data, linear_2.weight.data), 0)) 56 | else: 57 | self.linear_1 = Linear(c_m, c_hidden) 58 | self.linear_2 = Linear(c_m, c_hidden) 59 | self.linear_out = Linear(c_hidden ** 2, c_z, init="final") 60 | 61 | def _opm(self, a, b): 62 | # [*, N_res, N_res, C, C] 63 | outer = torch.einsum("...bac,...dae->...bdce", a, b) 64 | 65 | # [*, N_res, N_res, C * C] 66 | outer = outer.reshape(outer.shape[:-2] + (-1,)) 67 | 68 | # [*, N_res, N_res, C_z] 69 | outer = self.linear_out(outer) 70 | 71 | return outer 72 | 73 | @torch.jit.ignore 74 | def _chunk(self, 75 | a: torch.Tensor, 76 | b: torch.Tensor, 77 | chunk_size: int 78 | ) -> torch.Tensor: 79 | # Since the "batch dim" in this case is not a true batch dimension 80 | # (in that the shape of the output depends on it), we need to 81 | # iterate over it ourselves 82 | a_reshape = a.reshape((-1,) + a.shape[-3:]) 83 | b_reshape = b.reshape((-1,) + b.shape[-3:]) 84 | out = [] 85 | for a_prime, b_prime in zip(a_reshape, b_reshape): 86 | outer = chunk_layer( 87 | partial(self._opm, b=b_prime), 88 | {"a": a_prime}, 89 | chunk_size=chunk_size, 90 | no_batch_dims=1, 91 | ) 92 | out.append(outer) 93 | 94 | # For some cursed reason making this distinction saves memory 95 | if(len(out) == 1): 96 | outer = out[0].unsqueeze(0) 97 | else: 98 | outer = torch.stack(out, dim=0) 99 | 100 | outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) 101 | 102 | return outer 103 | 104 | def forward(self, 105 | m: torch.Tensor, 106 | mask: Optional[torch.Tensor] = None, 107 | chunk_size: Optional[int] = None, 108 | inplace_safe: bool = False, 109 | ) -> torch.Tensor: 110 | """ 111 | Args: 112 | m: 113 | [*, N_seq, N_res, C_m] MSA embedding 114 | mask: 115 | [*, N_seq, N_res] MSA mask 116 | Returns: 117 | [*, N_res, N_res, C_z] pair embedding update 118 | """ 119 | if mask is None: 120 | mask = m.new_ones(m.shape[:-1]) 121 | 122 | # [*, N_seq, N_res, C_m] 123 | ln = self.layer_norm(m) 124 | 125 | # [*, N_seq, N_res, C] 126 | mask = mask.unsqueeze(-1) 127 | if faster_alphafold_config.outer_product_mean: 128 | a, b = self.linear_12(ln).chunk(2, dim=-1) 129 | else: 130 | a = self.linear_1(ln) 131 | b = self.linear_2(ln) 132 | 133 | a = a * mask 134 | b = b * mask 135 | 136 | del ln 137 | 138 | a = a.transpose(-2, -3) 139 | b = b.transpose(-2, -3) 140 | 141 | if chunk_size is not None: 142 | outer = self._chunk(a, b, chunk_size) 143 | else: 144 | outer = self._opm(a, b) 145 | 146 | # [*, N_res, N_res, 1] 147 | norm = torch.einsum("...abc,...adc->...bdc", mask, mask) 148 | norm = norm + self.eps 149 | 150 | # [*, N_res, N_res, C_z] 151 | if(inplace_safe): 152 | outer /= norm 153 | else: 154 | outer = outer / norm 155 | 156 | return outer 157 | -------------------------------------------------------------------------------- /tests/test_triangular_multiplicative_update.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import numpy as np 17 | import unittest 18 | from opencomplex.model.triangular_multiplicative_update import * 19 | from opencomplex.utils.tensor_utils import tree_map 20 | import tests.compare_utils as compare_utils 21 | from tests.config import consts 22 | 23 | if compare_utils.alphafold_is_installed(): 24 | alphafold = compare_utils.import_alphafold() 25 | import jax 26 | import haiku as hk 27 | 28 | 29 | class TestTriangularMultiplicativeUpdate(unittest.TestCase): 30 | def test_shape(self): 31 | c_z = consts.c_z 32 | c = 11 33 | 34 | tm = TriangleMultiplicationOutgoing( 35 | c_z, 36 | c, 37 | ) 38 | 39 | n_res = consts.c_z 40 | batch_size = consts.batch_size 41 | 42 | x = torch.rand((batch_size, n_res, n_res, c_z)) 43 | mask = torch.randint(0, 2, size=(batch_size, n_res, n_res)) 44 | shape_before = x.shape 45 | x = tm(x, mask) 46 | shape_after = x.shape 47 | 48 | self.assertTrue(shape_before == shape_after) 49 | 50 | def _tri_mul_compare(self, incoming=False): 51 | name = "triangle_multiplication_" + ( 52 | "incoming" if incoming else "outgoing" 53 | ) 54 | 55 | def run_tri_mul(pair_act, pair_mask): 56 | config = compare_utils.get_alphafold_config() 57 | c_e = config.model.embeddings_and_evoformer.evoformer 58 | tri_mul = alphafold.model.modules.TriangleMultiplication( 59 | c_e.triangle_multiplication_incoming 60 | if incoming 61 | else c_e.triangle_multiplication_outgoing, 62 | config.model.global_config, 63 | name=name, 64 | ) 65 | act = tri_mul(act=pair_act, mask=pair_mask) 66 | return act 67 | 68 | f = hk.transform(run_tri_mul) 69 | 70 | n_res = consts.n_res 71 | 72 | pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32) 73 | pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res)) 74 | pair_mask = pair_mask.astype(np.float32) 75 | 76 | # Fetch pretrained parameters (but only from one block)] 77 | params = compare_utils.fetch_alphafold_module_weights( 78 | "alphafold/alphafold_iteration/evoformer/evoformer_iteration/" 79 | + name 80 | ) 81 | params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) 82 | 83 | out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() 84 | out_gt = torch.as_tensor(np.array(out_gt)) 85 | 86 | model = compare_utils.get_global_pretrained_opencomplex() 87 | module = ( 88 | model.evoformer.blocks[0].core.tri_mul_in 89 | if incoming 90 | else model.evoformer.blocks[0].core.tri_mul_out 91 | ) 92 | out_repro = module( 93 | torch.as_tensor(pair_act, dtype=torch.float32).cuda(), 94 | mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), 95 | _inplace=True, _inplace_chunk_size=4, 96 | ).cpu() 97 | 98 | self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) 99 | 100 | @compare_utils.skip_unless_alphafold_installed() 101 | def test_tri_mul_out_compare(self): 102 | self._tri_mul_compare() 103 | 104 | @compare_utils.skip_unless_alphafold_installed() 105 | def test_tri_mul_in_compare(self): 106 | self._tri_mul_compare(incoming=True) 107 | 108 | def _tri_mul_inplace(self, incoming=False): 109 | n_res = consts.n_res 110 | 111 | pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32) 112 | pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res)) 113 | pair_mask = pair_mask.astype(np.float32) 114 | 115 | 116 | model = compare_utils.get_global_pretrained_opencomplex() 117 | module = ( 118 | model.evoformer.blocks[0].core.tri_mul_in 119 | if incoming 120 | else model.evoformer.blocks[0].core.tri_mul_out 121 | ) 122 | out_stock = module( 123 | torch.as_tensor(pair_act, dtype=torch.float32).cuda(), 124 | mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), 125 | _inplace=False, 126 | ).cpu() 127 | 128 | # This has to come second because inference mode is in-place 129 | out_inplace = module( 130 | torch.as_tensor(pair_act, dtype=torch.float32).cuda(), 131 | mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), 132 | _inplace=True, _inplace_chunk_size=2, 133 | ).cpu() 134 | 135 | self.assertTrue(torch.mean(torch.abs(out_stock - out_inplace)) < consts.eps) 136 | 137 | def test_tri_mul_out_inference(self): 138 | self._tri_mul_inplace() 139 | 140 | def test_tri_mul_in_inference(self): 141 | self._tri_mul_inplace(incoming=True) 142 | 143 | if __name__ == "__main__": 144 | unittest.main() 145 | -------------------------------------------------------------------------------- /opencomplex/faster_alphafold/faster_alphafold.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | from torch import nn, autograd 5 | 6 | from opencomplex.faster_alphafold.faster_alphafold_config import faster_alphafold_config 7 | 8 | path='opencomplex/faster_alphafold/libths_faster_alphafold.so' 9 | try: 10 | torch.ops.load_library(path) 11 | except Exception as e: 12 | logging.warn("Failed to load faster alphafold library...") 13 | logging.warn(e) 14 | for k in faster_alphafold_config: 15 | faster_alphafold_config[k] = False 16 | 17 | 18 | class LayerNormFunction(autograd.Function): 19 | @staticmethod 20 | def forward(ctx, input_tensor, gamma, beta, residual=None): 21 | input_tensor = input_tensor.contiguous() 22 | if residual is not None: 23 | residual = residual.contiguous() 24 | 25 | # result[] = output, mean, var_rsqrt, input_tensor if residual is None else input_add_residual 26 | result = torch.ops.FasterAlphaFold.LayerNorm_forward(input_tensor, gamma, beta, residual) 27 | ctx.save_for_backward(gamma, result[1], result[2], result[3]) 28 | ctx.add_residual = residual is not None 29 | return result[0] 30 | 31 | @staticmethod 32 | def backward(ctx, grad_out): 33 | grad_out = grad_out.contiguous() 34 | gamma, mean, var_rsqrt, input_add_residual = ctx.saved_tensors 35 | 36 | #grad[] = grad_in, grad_gamma, grad_beta, [grad_residual] 37 | grad = torch.ops.FasterAlphaFold.LayerNorm_backward(grad_out, gamma, mean, var_rsqrt, input_add_residual, ctx.add_residual) 38 | return grad[0], grad[1], grad[2], None if not ctx.add_residual else grad[3] 39 | 40 | def FasterLayerNorm(input_tensor, weight, bias, residual=None): 41 | return LayerNormFunction.apply(input_tensor, weight, bias, residual) 42 | 43 | 44 | class MatMulFunction(autograd.Function): 45 | @staticmethod 46 | def forward(ctx, input_A, input_B, transpose_a = False, transpose_b = False, scale = 1.0): 47 | input_A = input_A.contiguous() 48 | input_B = input_B.contiguous() 49 | matmul_out = torch.ops.FasterAlphaFold.MatMul_forward( 50 | input_A, input_B, transpose_a, transpose_b, scale) 51 | ctx.transpose_a = transpose_a 52 | ctx.transpose_b = transpose_b 53 | ctx.scale = scale 54 | ctx.save_for_backward(input_A, input_B) 55 | return matmul_out 56 | 57 | @staticmethod 58 | def backward(ctx, grad_out): 59 | input_A, input_B = ctx.saved_tensors 60 | grad_A, grad_B = torch.ops.FasterAlphaFold.MatMul_backward( 61 | grad_out, input_A, input_B, ctx.transpose_a, ctx.transpose_b, ctx.scale) 62 | return grad_A, grad_B, None, None, None 63 | 64 | def faster_matmul(input_A, input_B, transpose_a = False, transpose_b = False, scale = 1.0): 65 | return MatMulFunction.apply(input_A, input_B, transpose_a, transpose_b, scale) 66 | 67 | 68 | class SoftmaxFunction(autograd.Function): 69 | @staticmethod 70 | def forward(ctx, input_tensor, mask_tensor=None, head_num=1): 71 | input_tensor = input_tensor.contiguous() 72 | softmax_out = torch.ops.FasterAlphaFold.Softmax_forward(input_tensor, mask_tensor, head_num) 73 | ctx.save_for_backward(softmax_out) 74 | return softmax_out 75 | 76 | @staticmethod 77 | def backward(ctx, grad_out): 78 | grad_out = grad_out.contiguous() 79 | (softmax_out,) = ctx.saved_tensors 80 | grad_in = torch.ops.FasterAlphaFold.Softmax_backward(grad_out, softmax_out) 81 | return grad_in, None, None 82 | 83 | def faster_softmax(input_tensor, mask_tensor=None, head_num=1): 84 | return SoftmaxFunction.apply(input_tensor, mask_tensor, head_num) 85 | 86 | 87 | class TriangleUpdateABGFunction(autograd.Function): 88 | @staticmethod 89 | def forward(ctx, input_tensor, mask_tensor, c_z, c_hidden): 90 | a, b, g = torch.ops.FasterAlphaFold.TriangleUpdateABG_forward(input_tensor, mask_tensor, c_z, c_hidden) 91 | ctx.save_for_backward(input_tensor, mask_tensor) 92 | ctx.c_z = c_z 93 | ctx.c_hidden = c_hidden 94 | return a, b, g 95 | 96 | @staticmethod 97 | def backward(ctx, grad_a, grad_b, grad_g): 98 | input_tensor, mask_tensor = ctx.saved_tensors 99 | input_grad, mask_grad = torch.ops.FasterAlphaFold.TriangleUpdateABG_backward( 100 | grad_a.contiguous(), grad_b.contiguous(), grad_g.contiguous(), 101 | input_tensor, mask_tensor) 102 | return input_grad, mask_grad, None, None 103 | 104 | class TriangleUpdateABG(nn.Module): 105 | def __init__(self, c_z, c_hidden): 106 | super(TriangleUpdateABG, self).__init__() 107 | self.c_z = c_z 108 | self.c_hidden = c_hidden 109 | 110 | def forward(self, input_tensor, mask_tensor): 111 | return TriangleUpdateABGFunction.apply(input_tensor, mask_tensor, self.c_z, self.c_hidden) 112 | 113 | def extra_repr(self): 114 | return 'TriangleUpdateABG c_z={}, c_hidden={}'.format(self.c_z, self.c_hidden) 115 | 116 | 117 | class AttentionSplitQKVFunction(autograd.Function): 118 | @staticmethod 119 | def forward(ctx, input_tensor, head_num=1): 120 | input_tensor = input_tensor.contiguous() 121 | q_out, k_out, v_out = torch.ops.FasterAlphaFold.AttentionSplitQKV_forward(input_tensor, head_num) 122 | return q_out, k_out, v_out 123 | 124 | @staticmethod 125 | def backward(ctx, grad_q, grad_k, grad_v): 126 | grad_qkv = torch.ops.FasterAlphaFold.AttentionSplitQKV_backward(grad_q.contiguous(), grad_k.contiguous(), grad_v.contiguous()) 127 | return grad_qkv, None 128 | 129 | def attention_split_qkv(input_tensor, head_num=1): 130 | return AttentionSplitQKVFunction.apply(input_tensor, head_num) 131 | -------------------------------------------------------------------------------- /opencomplex/np/relax/cleanup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations. 16 | 17 | fix_pdb uses a third-party tool. We also support fixing some additional edge 18 | cases like removing chains of length one (see clean_structure). 19 | """ 20 | import io 21 | 22 | import pdbfixer 23 | try: 24 | # openmm >= 7.6 25 | from openmm import app 26 | from openmm.app import element 27 | except ImportError: 28 | # openmm < 7.6 (requires DeepMind patch) 29 | from simtk.openmm import app 30 | from simtk.openmm.app import element 31 | 32 | 33 | def fix_pdb(pdbfile, alterations_info): 34 | """Apply pdbfixer to the contents of a PDB file; return a PDB string result. 35 | 36 | 1) Replaces nonstandard residues. 37 | 2) Removes heterogens (non protein residues) including water. 38 | 3) Adds missing residues and missing atoms within existing residues. 39 | 4) Adds hydrogens assuming pH=7.0. 40 | 5) KeepIds is currently true, so the fixer must keep the existing chain and 41 | residue identifiers. This will fail for some files in wider PDB that have 42 | invalid IDs. 43 | 44 | Args: 45 | pdbfile: Input PDB file handle. 46 | alterations_info: A dict that will store details of changes made. 47 | 48 | Returns: 49 | A PDB string representing the fixed structure. 50 | """ 51 | fixer = pdbfixer.PDBFixer(pdbfile=pdbfile) 52 | fixer.findNonstandardResidues() 53 | alterations_info["nonstandard_residues"] = fixer.nonstandardResidues 54 | fixer.replaceNonstandardResidues() 55 | _remove_heterogens(fixer, alterations_info, keep_water=False) 56 | fixer.findMissingResidues() 57 | alterations_info["missing_residues"] = fixer.missingResidues 58 | fixer.findMissingAtoms() 59 | alterations_info["missing_heavy_atoms"] = fixer.missingAtoms 60 | alterations_info["missing_terminals"] = fixer.missingTerminals 61 | fixer.addMissingAtoms(seed=0) 62 | fixer.addMissingHydrogens() 63 | out_handle = io.StringIO() 64 | app.PDBFile.writeFile( 65 | fixer.topology, fixer.positions, out_handle, keepIds=True 66 | ) 67 | return out_handle.getvalue() 68 | 69 | 70 | def clean_structure(pdb_structure, alterations_info): 71 | """Applies additional fixes to an OpenMM structure, to handle edge cases. 72 | 73 | Args: 74 | pdb_structure: An OpenMM structure to modify and fix. 75 | alterations_info: A dict that will store details of changes made. 76 | """ 77 | _replace_met_se(pdb_structure, alterations_info) 78 | _remove_chains_of_length_one(pdb_structure, alterations_info) 79 | 80 | 81 | def _remove_heterogens(fixer, alterations_info, keep_water): 82 | """Removes the residues that Pdbfixer considers to be heterogens. 83 | 84 | Args: 85 | fixer: A Pdbfixer instance. 86 | alterations_info: A dict that will store details of changes made. 87 | keep_water: If True, water (HOH) is not considered to be a heterogen. 88 | """ 89 | initial_resnames = set() 90 | for chain in fixer.topology.chains(): 91 | for residue in chain.residues(): 92 | initial_resnames.add(residue.name) 93 | fixer.removeHeterogens(keepWater=keep_water) 94 | final_resnames = set() 95 | for chain in fixer.topology.chains(): 96 | for residue in chain.residues(): 97 | final_resnames.add(residue.name) 98 | alterations_info["removed_heterogens"] = initial_resnames.difference( 99 | final_resnames 100 | ) 101 | 102 | 103 | def _replace_met_se(pdb_structure, alterations_info): 104 | """Replace the Se in any MET residues that were not marked as modified.""" 105 | modified_met_residues = [] 106 | for res in pdb_structure.iter_residues(): 107 | name = res.get_name_with_spaces().strip() 108 | if name == "MET": 109 | s_atom = res.get_atom("SD") 110 | if s_atom.element_symbol == "Se": 111 | s_atom.element_symbol = "S" 112 | s_atom.element = element.get_by_symbol("S") 113 | modified_met_residues.append(s_atom.residue_number) 114 | alterations_info["Se_in_MET"] = modified_met_residues 115 | 116 | 117 | def _remove_chains_of_length_one(pdb_structure, alterations_info): 118 | """Removes chains that correspond to a single amino acid. 119 | 120 | A single amino acid in a chain is both N and C terminus. There is no force 121 | template for this case. 122 | 123 | Args: 124 | pdb_structure: An OpenMM pdb_structure to modify and fix. 125 | alterations_info: A dict that will store details of changes made. 126 | """ 127 | removed_chains = {} 128 | for model in pdb_structure.iter_models(): 129 | valid_chains = [c for c in model.iter_chains() if len(c) > 1] 130 | invalid_chain_ids = [ 131 | c.chain_id for c in model.iter_chains() if len(c) <= 1 132 | ] 133 | model.chains = valid_chains 134 | for chain_id in invalid_chain_ids: 135 | model.chains_by_id.pop(chain_id) 136 | removed_chains[model.number] = invalid_chain_ids 137 | alterations_info["removed_chains"] = removed_chains 138 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pickle 16 | import torch 17 | import torch.nn as nn 18 | import numpy as np 19 | import unittest 20 | from opencomplex.config.config import model_config 21 | from opencomplex.data import data_transforms 22 | from opencomplex.model.model import AlphaFold 23 | import opencomplex.utils.feats as feats 24 | from opencomplex.utils.tensor_utils import tree_map, tensor_tree_map 25 | import tests.compare_utils as compare_utils 26 | from tests.config import consts 27 | from tests.data_utils import ( 28 | random_template_feats, 29 | random_extra_msa_feats, 30 | ) 31 | 32 | if compare_utils.alphafold_is_installed(): 33 | alphafold = compare_utils.import_alphafold() 34 | import jax 35 | import haiku as hk 36 | 37 | 38 | class TestModel(unittest.TestCase): 39 | def test_dry_run(self): 40 | n_seq = consts.n_seq 41 | n_templ = consts.n_templ 42 | n_res = consts.n_res 43 | n_extra_seq = consts.n_extra 44 | 45 | c = model_config("model_1") 46 | c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here 47 | c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up 48 | # deepspeed for this test 49 | 50 | model = AlphaFold(c) 51 | 52 | batch = {} 53 | tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,)) 54 | batch["target_feat"] = nn.functional.one_hot( 55 | tf, c.model.input_embedder.tf_dim 56 | ).float() 57 | batch["butype"] = torch.argmax(batch["target_feat"], dim=-1) 58 | batch["residue_index"] = torch.arange(n_res) 59 | batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim)) 60 | t_feats = random_template_feats(n_templ, n_res) 61 | batch.update({k: torch.tensor(v) for k, v in t_feats.items()}) 62 | extra_feats = random_extra_msa_feats(n_extra_seq, n_res) 63 | batch.update({k: torch.tensor(v) for k, v in extra_feats.items()}) 64 | batch["msa_mask"] = torch.randint( 65 | low=0, high=2, size=(n_seq, n_res) 66 | ).float() 67 | batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float() 68 | batch.update(data_transforms.make_atom14_masks(batch)) 69 | batch["no_recycling_iters"] = torch.tensor(2.) 70 | 71 | add_recycling_dims = lambda t: ( 72 | t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters) 73 | ) 74 | batch = tensor_tree_map(add_recycling_dims, batch) 75 | 76 | with torch.no_grad(): 77 | out = model(batch) 78 | 79 | @compare_utils.skip_unless_alphafold_installed() 80 | def test_compare(self): 81 | def run_alphafold(batch): 82 | config = compare_utils.get_alphafold_config() 83 | model = alphafold.model.modules.AlphaFold(config.model) 84 | return model( 85 | batch=batch, 86 | is_training=False, 87 | return_representations=True, 88 | ) 89 | 90 | f = hk.transform(run_alphafold) 91 | 92 | params = compare_utils.fetch_alphafold_module_weights("") 93 | 94 | with open("tests/test_data/sample_feats.pickle", "rb") as fp: 95 | batch = pickle.load(fp) 96 | 97 | out_gt = f.apply(params, jax.random.PRNGKey(42), batch) 98 | 99 | out_gt = out_gt["structure_module"]["final_atom_positions"] 100 | # atom37_to_atom14 doesn't like batches 101 | batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0] 102 | batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0] 103 | out_gt = alphafold.model.all_atom.atom37_to_atom14(out_gt, batch) 104 | out_gt = torch.as_tensor(np.array(out_gt.block_until_ready())) 105 | 106 | batch["no_recycling_iters"] = np.array([3., 3., 3., 3.,]) 107 | batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()} 108 | 109 | batch["butype"] = batch["butype"].long() 110 | batch["template_butype"] = batch["template_butype"].long() 111 | batch["extra_msa"] = batch["extra_msa"].long() 112 | batch["residx_atom37_to_atom14"] = batch[ 113 | "residx_atom37_to_atom14" 114 | ].long() 115 | batch["template_all_atom_mask"] = batch["template_all_atom_masks"] 116 | batch.update( 117 | data_transforms.atom37_to_torsion_angles("template_")(batch) 118 | ) 119 | 120 | # Move the recycling dimension to the end 121 | move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0) 122 | batch = tensor_tree_map(move_dim, batch) 123 | 124 | with torch.no_grad(): 125 | model = compare_utils.get_global_pretrained_opencomplex() 126 | out_repro = model(batch) 127 | 128 | out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro) 129 | 130 | out_repro = out_repro["sm"]["positions"][-1] 131 | out_repro = out_repro.squeeze(0) 132 | 133 | print(torch.mean(torch.abs(out_gt - out_repro))) 134 | print(torch.max(torch.abs(out_gt - out_repro))) 135 | self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3) 136 | -------------------------------------------------------------------------------- /opencomplex/utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 BAAI 2 | # Copyright 2021 AlQuraishi Laboratory 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from functools import partial 18 | import logging 19 | from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | import numpy as np 25 | 26 | 27 | def add(m1, m2, inplace): 28 | # The first operation in a checkpoint can't be in-place, but it's 29 | # nice to have in-place addition during inference. Thus... 30 | if(not inplace): 31 | m1 = m1 + m2 32 | else: 33 | m1 += m2 34 | 35 | return m1 36 | 37 | 38 | def permute_final_dims(tensor: torch.Tensor, inds: List[int]): 39 | zero_index = -1 * len(inds) 40 | first_inds = list(range(len(tensor.shape[:zero_index]))) 41 | return tensor.permute(first_inds + [zero_index + i for i in inds]) 42 | 43 | 44 | def flatten_final_dims(t: torch.Tensor, no_dims: int): 45 | return t.reshape(t.shape[:-no_dims] + (-1,)) 46 | 47 | 48 | def masked_mean(mask, value, dim, eps=1e-4): 49 | mask = mask.expand(*value.shape) 50 | return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) 51 | 52 | 53 | def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): 54 | boundaries = torch.linspace( 55 | min_bin, max_bin, no_bins - 1, device=pts.device 56 | ) 57 | dists = torch.sqrt( 58 | torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) 59 | ) 60 | return torch.bucketize(dists, boundaries) 61 | 62 | 63 | def dict_multimap(fn, dicts): 64 | first = dicts[0] 65 | new_dict = {} 66 | for k, v in first.items(): 67 | all_v = [d[k] for d in dicts] 68 | if type(v) is dict: 69 | new_dict[k] = dict_multimap(fn, all_v) 70 | else: 71 | new_dict[k] = fn(all_v) 72 | 73 | return new_dict 74 | 75 | 76 | def one_hot(x, v_bins): 77 | reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) 78 | diffs = x[..., None] - reshaped_bins 79 | am = torch.argmin(torch.abs(diffs), dim=-1) 80 | return nn.functional.one_hot(am, num_classes=len(v_bins)).float() 81 | 82 | 83 | def batched_gather(data, inds, dim=0, no_batch_dims=0): 84 | ranges = [] 85 | for i, s in enumerate(data.shape[:no_batch_dims]): 86 | r = torch.arange(s) 87 | r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) 88 | ranges.append(r) 89 | 90 | remaining_dims = [ 91 | slice(None) for _ in range(len(data.shape) - no_batch_dims) 92 | ] 93 | remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds 94 | ranges.extend(remaining_dims) 95 | return data[ranges] 96 | 97 | 98 | # With tree_map, a poor man's JAX tree_map 99 | def dict_map(fn, dic, leaf_type): 100 | new_dict = {} 101 | for k, v in dic.items(): 102 | if type(v) is dict: 103 | new_dict[k] = dict_map(fn, v, leaf_type) 104 | else: 105 | new_dict[k] = tree_map(fn, v, leaf_type) 106 | 107 | return new_dict 108 | 109 | 110 | def tree_map(fn, tree, leaf_type): 111 | if isinstance(tree, dict): 112 | return dict_map(fn, tree, leaf_type) 113 | elif isinstance(tree, list): 114 | return [tree_map(fn, x, leaf_type) for x in tree] 115 | elif isinstance(tree, tuple): 116 | return tuple([tree_map(fn, x, leaf_type) for x in tree]) 117 | elif isinstance(tree, leaf_type): 118 | return fn(tree) 119 | else: 120 | print(type(tree)) 121 | raise ValueError("Not supported") 122 | 123 | 124 | tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) 125 | 126 | 127 | def padcat(tensors, axis=0): 128 | tensors = [t for t in tensors if t is not None and t.shape[axis] > 0] 129 | if len(tensors) == 1: 130 | return tensors[0] 131 | 132 | ndim = tensors[0].ndim 133 | if axis < 0: 134 | axis += ndim 135 | 136 | axis_max_len = [ 137 | max(t.shape[i] for t in tensors) 138 | for i in range(ndim) 139 | ] 140 | 141 | is_np = False 142 | if not isinstance(tensors[0], torch.Tensor): 143 | if tensors[0].dtype == np.object_: 144 | return np.concatenate(tensors, axis=axis) 145 | 146 | is_np = True 147 | tensors = [torch.tensor(t) for t in tensors] 148 | 149 | for i, t in enumerate(tensors): 150 | pads = [0 for _ in range(ndim * 2)] 151 | for j in range(0, ndim): 152 | if j != axis: 153 | pads[(ndim - j - 1) * 2 + 1] = axis_max_len[j] - t.shape[j] 154 | 155 | if any(pads): 156 | tensors[i] = torch.nn.functional.pad(t, tuple(pads)) 157 | 158 | ret = torch.cat(tensors, axis=axis) 159 | if is_np: 160 | return ret.numpy() 161 | 162 | return ret 163 | 164 | 165 | def map_padcat(a, b, axis=0): 166 | for i, j in zip(a, b): 167 | yield padcat([i, j], axis) 168 | 169 | 170 | def padto(t, shape): 171 | ndim = len(shape) 172 | pads = [0 for _ in range(ndim * 2)] 173 | for j in range(0, ndim): 174 | pads[(ndim - j - 1) * 2 + 1] = shape[j] - t.shape[j] 175 | if any(pads): 176 | t = torch.nn.functional.pad(t, tuple(pads)) 177 | return t 178 | -------------------------------------------------------------------------------- /opencomplex/utils/feats_rna.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 BAAI 2 | # Copyright 2021 AlQuraishi Laboratory 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import torch 18 | import torch.nn as nn 19 | from opencomplex.utils.rigid_utils import Rigid, Rotation 20 | from opencomplex.utils.tensor_utils import ( 21 | batched_gather, 22 | ) 23 | 24 | 25 | def torsion_angles_to_frames( 26 | r: Rigid, 27 | alpha: torch.Tensor, 28 | butype: torch.Tensor, 29 | rrgdf: torch.Tensor, 30 | ): 31 | 32 | # [*, N, 9, 4, 4] 33 | default_4x4 = rrgdf[butype, ...] 34 | 35 | # [*, N, 9] transformations, i.e. 36 | # One [*, N, 9, 3, 3] rotation matrix and 37 | # One [*, N, 9, 3] translation matrix 38 | default_r = r.from_tensor_4x4(default_4x4) 39 | 40 | bb_rot1 = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) 41 | bb_rot1[..., 1] = 1 42 | 43 | bb_rot2 = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2)) 44 | bb_rot2[..., 1] = 1 45 | 46 | alpha = torch.cat( 47 | [ 48 | bb_rot1.expand(*alpha.shape[:-2], -1, -1), 49 | bb_rot2.expand(*alpha.shape[:-2], -1, -1), 50 | alpha, 51 | ], dim=-2 52 | ) 53 | 54 | # [*, N, 9, 3, 3] 55 | # Produces rotation matrices of the form: 56 | # [ 57 | # [1, 0 , 0 ], 58 | # [0, a_2,-a_1], 59 | # [0, a_1, a_2] 60 | # ] 61 | # This follows the original code rather than the supplement, which uses 62 | # different indices. 63 | all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) 64 | all_rots[..., 0, 0] = 1 65 | all_rots[..., 1, 1] = alpha[..., 1] 66 | all_rots[..., 1, 2] = -alpha[..., 0] 67 | all_rots[..., 2, 1:] = alpha 68 | 69 | all_rots = Rigid(Rotation(rot_mats=all_rots), None) 70 | 71 | all_frames = default_r.compose(all_rots) 72 | 73 | delta_frame_to_frame = all_frames[..., 2] 74 | gamma_frame_to_frame = all_frames[..., 3] 75 | beta_frame_to_frame = all_frames[..., 4] 76 | alpha1_frame_to_frame = all_frames[..., 5] 77 | alpha2_frame_to_frame = all_frames[..., 6] 78 | tm_frame_to_frame = all_frames[..., 7] 79 | chi_frame_to_frame = all_frames[..., 8] 80 | 81 | delta_frame_to_bb = delta_frame_to_frame 82 | gamma_frame_to_bb = gamma_frame_to_frame 83 | beta_frame_to_bb = gamma_frame_to_bb.compose(beta_frame_to_frame) 84 | alpha1_frame_to_bb = beta_frame_to_bb.compose(alpha1_frame_to_frame) 85 | alpha2_frame_to_bb = beta_frame_to_bb.compose(alpha2_frame_to_frame) 86 | tm_frame_to_bb = tm_frame_to_frame 87 | chi_frame_to_bb = chi_frame_to_frame 88 | 89 | all_frames_to_bb = Rigid.cat( 90 | [ 91 | all_frames[..., :2], 92 | delta_frame_to_bb.unsqueeze(-1), 93 | gamma_frame_to_bb.unsqueeze(-1), 94 | beta_frame_to_bb.unsqueeze(-1), 95 | alpha1_frame_to_bb.unsqueeze(-1), 96 | alpha2_frame_to_bb.unsqueeze(-1), 97 | tm_frame_to_bb.unsqueeze(-1), 98 | chi_frame_to_bb.unsqueeze(-1), 99 | ], 100 | dim=-1, 101 | ) 102 | 103 | all_frames_to_global1 = r[..., 0, None].compose(all_frames_to_bb[..., 0:1]) 104 | all_frames_to_global2 = r[..., 1, None].compose(all_frames_to_bb[..., 1:2]) 105 | all_frames_to_global3 = r[..., 0, None].compose(all_frames_to_bb[..., 2:7]) 106 | all_frames_to_global4 = r[..., 1, None].compose(all_frames_to_bb[..., 7:9]) 107 | all_frames_to_global = Rigid.cat( 108 | [ 109 | all_frames_to_global1, 110 | all_frames_to_global2, 111 | all_frames_to_global3, 112 | all_frames_to_global4, 113 | ] 114 | , dim=-1) 115 | 116 | return all_frames_to_global 117 | 118 | 119 | def frames_and_literature_positions_to_atom23_pos( 120 | r: Rigid, 121 | butype: torch.Tensor, 122 | default_frames, 123 | group_idx, 124 | atom_mask, 125 | lit_positions, 126 | ): 127 | # [*, N, 23] 128 | group_mask = group_idx[butype, ...] 129 | 130 | # # [*, N, 23, 9] 131 | group_mask = nn.functional.one_hot( 132 | group_mask, 133 | num_classes=default_frames.shape[-3], 134 | ) 135 | 136 | # [*, N, 23, 9] 137 | t_atoms_to_global = r[..., None, :] * group_mask 138 | 139 | # # [*, N, 23] 140 | t_atoms_to_global = t_atoms_to_global.map_tensor_fn( 141 | lambda x: torch.sum(x, dim=-1) 142 | ) 143 | 144 | # [*, N, 23, 1] 145 | atom_mask = atom_mask[butype, ...].unsqueeze(-1) 146 | 147 | # [*, N, 23, 3] 148 | lit_positions = lit_positions[butype, ...] 149 | pred_positions = t_atoms_to_global.apply(lit_positions) 150 | pred_positions = pred_positions * atom_mask 151 | 152 | return pred_positions 153 | 154 | ########################################################### 155 | 156 | 157 | # get sparse representation of 23 atoms 158 | def atom23_to_atom27_train(atom23, batch): 159 | atom27_data = batched_gather( 160 | atom23, 161 | batch["residx_atom27_to_atom23"], 162 | dim=-2, 163 | no_batch_dims=len(atom23.shape[:-2]), 164 | ) 165 | 166 | atom27_data = atom27_data * batch["atom27_atom_exists"][..., None] 167 | 168 | return atom27_data 169 | 170 | 171 | def atom23_to_atom27_infer(sm, batch): 172 | 173 | atom23 = sm['positions'][-1] 174 | 175 | atom27_data = batched_gather( 176 | atom23, 177 | batch["residx_atom27_to_atom23"][..., 0], 178 | dim=-2, 179 | no_batch_dims=len(atom23.shape[:-2]), 180 | ) 181 | 182 | atom27_data = atom27_data * batch["atom27_atom_exists"][..., 0:1] 183 | 184 | return atom27_data -------------------------------------------------------------------------------- /scripts/rna_extract_pkl_from_fas.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import build_rna_features 4 | import pickle 5 | import random 6 | import sys 7 | import time 8 | import json 9 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 10 | 11 | 12 | def generate_pkl_from_fas(args): 13 | for fas in os.listdir(args.fasta_path): 14 | local_fasta_path = os.path.join(args.fasta_path, fas) 15 | feature_dir = os.path.join(args.output_dir, fas) 16 | 17 | timings = {} 18 | if args.use_precomputed_alignments is None: 19 | local_alignment_dir = os.path.join(feature_dir, "msa_hmm") 20 | if not os.path.exists(local_alignment_dir): 21 | os.makedirs(local_alignment_dir) 22 | # logging.info(f"Generating MSA for {fas} ...") 23 | print(f"Generating MSA for {fas} ...") 24 | assert "rMSA.pl" in os.listdir(args.rmsa_path), "rMSA.pl is not found. Please check that the rMSA package is installed and provide the correct path." 25 | cmd_perl = "perl " + os.path.join(args.rmsa_path, 'rMSA.pl ') + os.path.join(local_fasta_path, fas + '.fasta') + " -cpu={}".format(args.cpus) 26 | pt = time.time() 27 | os.system(cmd_perl) 28 | timings['msa'] = time.time() - pt 29 | if sum(1 for _ in open(os.path.join(local_fasta_path, fas + '.afa')))//2 > args.max_msa: 30 | cmd_head = "head -n {} {} > {}".format(args.max_msa*2, os.path.join(local_fasta_path, fas + '.afa'), os.path.join(local_alignment_dir, fas + '.afa')) 31 | os.system(cmd_head) 32 | else: 33 | cmd_mv = "mv " + os.path.join(local_fasta_path, fas + '.afa ') + os.path.join(local_alignment_dir, fas + '.afa') 34 | os.system(cmd_mv) 35 | cmd_mv = "mv " + os.path.join(local_fasta_path, fas + '.cm ') + os.path.join(local_alignment_dir, fas + '.cm') 36 | os.system(cmd_mv) 37 | # logging.info(f"MSA generation for {fas} is complete.") 38 | print(f"MSA generation for {fas} is complete.") 39 | else: 40 | local_alignment_dir = args.use_precomputed_alignments 41 | 42 | if args.use_precomputed_ss is None: 43 | local_ss_dir = os.path.join(feature_dir, "ss") 44 | if not os.path.exists(local_ss_dir): 45 | os.makedirs(local_ss_dir) 46 | print(f"Computing secondary structure for {fas} ...") 47 | cmd_pet = os.path.join(args.rmsa_path, 'bin/PETfold ')+ " -f " + os.path.join(local_alignment_dir, fas + '.afa ') + " -r " + os.path.join(local_ss_dir, '{}_ss.txt'.format(fas)) 48 | os.environ['PETFOLDBIN'] = os.path.join(args.rmsa_path, "data") 49 | pt = time.time() 50 | os.system(cmd_pet) 51 | timings['ss'] = time.time() - pt 52 | # logging.info(f"Secondary structure computation for {fas} is complete.") 53 | print(f"Secondary structure computation for {fas} is complete.") 54 | else: 55 | local_ss_dir = args.use_precomputed_ss 56 | 57 | seq_file = os.path.join(local_fasta_path, '{}.fasta'.format(fas)) 58 | msa_file = os.path.join(local_alignment_dir, '{}.afa'.format(fas)) 59 | hmm_file = os.path.join(local_alignment_dir, '{}.cm'.format(fas)) 60 | ss_file = os.path.join(local_ss_dir, '{}_ss.txt'.format(fas)) 61 | features_dict = build_rna_features.processing_fas_features().collect_features(seq_file=seq_file, msa_file=msa_file, hmm_file=hmm_file, ss_file=ss_file) 62 | 63 | if args.add_mmcif_features is not None: 64 | pdbID, chainID = fas.split('_') if '_' in fas else [fas, None] 65 | cif_path = os.path.join(args.add_mmcif_features, pdbID + '.cif') 66 | assert os.path.exists(cif_path), "Cannot find file for {}.cif, Please provide the correct file location.".join(pdbID) 67 | features_cif = build_rna_features.processing_cif_features().collect_features(cif_path=cif_path, pdbID=pdbID, chainID=chainID, butype=features_dict['butype']) 68 | features_dict.update(features_cif) 69 | 70 | features_output_path = os.path.join(feature_dir, 'features.pkl') 71 | with open(features_output_path, 'wb') as f: 72 | pickle.dump(features_dict, f, protocol=4) 73 | timings_output_path = os.path.join(feature_dir, 'timings.json') 74 | with open(timings_output_path, 'w') as fp: 75 | json.dump(timings, fp, indent=4) 76 | # logging.info(f"process file {fas} done.") 77 | print(f"process file {fas} done.") 78 | 79 | 80 | if __name__ == '__main__': 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument( 83 | "--fasta_path", type=str, 84 | help="Path to directory containing FASTA files, one sequence per file" 85 | ) 86 | parser.add_argument( 87 | "--output_dir", type=str, 88 | ) 89 | parser.add_argument( 90 | "--use_precomputed_alignments", type=str, default=None, 91 | help="""Path to alignment directory. If provided, alignment computation 92 | is skipped and database path arguments are ignored.""" 93 | ) 94 | parser.add_argument( 95 | "--use_precomputed_ss", type=str, default=None, 96 | help="""Path to secondary structure directory. If provided, secondary structure computation 97 | is skipped and database path arguments are ignored.""" 98 | ) 99 | parser.add_argument( 100 | "--add_mmcif_features", type=str, default=None, 101 | help="""Path to the mmCIF file. If provided, features of the structure will also be added to output.""" 102 | ) 103 | parser.add_argument( 104 | "--max_msa", type=int, default=1000, 105 | help="""Max number of msa to use.""" 106 | ) 107 | parser.add_argument( 108 | "--cpus", type=int, default=12, 109 | help="""Number of CPUs with which to run alignment tools""" 110 | ) 111 | parser.add_argument( 112 | "--rmsa_path", type=str, default="opencomplex/resources/RNA", 113 | help="""Path to the rMSA package. To install it to opencomplex/resources/RNA, run scripts/install_rmsa_petfold.sh""" 114 | ) 115 | args = parser.parse_args() 116 | 117 | generate_pkl_from_fas(args) -------------------------------------------------------------------------------- /opencomplex/config/references.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 BAAI 2 | # Copyright 2021 AlQuraishi Laboratory 3 | # Copyright 2021 DeepMind Technologies Limited 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import ml_collections as mlc 18 | 19 | c_z = mlc.FieldReference(128, field_type=int) 20 | c_m = mlc.FieldReference(256, field_type=int) 21 | c_t = mlc.FieldReference(64, field_type=int) 22 | c_e = mlc.FieldReference(64, field_type=int) 23 | c_s = mlc.FieldReference(384, field_type=int) 24 | blocks_per_ckpt = mlc.FieldReference(None, field_type=int) 25 | chunk_size = mlc.FieldReference(4, field_type=int) 26 | aux_distogram_bins = mlc.FieldReference(64, field_type=int) 27 | tm_enabled = mlc.FieldReference(False, field_type=bool) 28 | eps = mlc.FieldReference(1e-8, field_type=float) 29 | templates_enabled = mlc.FieldReference(True, field_type=bool) 30 | embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool) 31 | tune_chunk_size = mlc.FieldReference(True, field_type=bool) 32 | 33 | NUM_RES = "num residues placeholder" 34 | NUM_FULL_RES = "num residues before crop placeholder" 35 | NUM_MSA_SEQ = "msa placeholder" 36 | NUM_EXTRA_SEQ = "extra msa placeholder" 37 | NUM_TEMPLATES = "num templates placeholder" 38 | 39 | common_feats = mlc.ConfigDict({ 40 | "butype": [NUM_RES], 41 | "all_atom_mask": [NUM_RES, None], 42 | "all_atom_positions": [NUM_RES, None, None], 43 | "alt_chi_angles": [NUM_RES, None], 44 | "dense_atom_alt_gt_exists": [NUM_RES, None], 45 | "dense_atom_alt_gt_positions": [NUM_RES, None, None], 46 | "dense_atom_exists": [NUM_RES, None], 47 | "dense_atom_is_ambiguous": [NUM_RES, None], 48 | "dense_atom_gt_exists": [NUM_RES, None], 49 | "dense_atom_gt_positions": [NUM_RES, None, None], 50 | "all_atom_exists": [NUM_RES, None], 51 | "backbone_rigid_mask": [NUM_RES, None], 52 | "backbone_rigid_tensor": [NUM_RES, None, None, None], 53 | "bert_mask": [NUM_MSA_SEQ, NUM_RES], 54 | "chi_angles_sin_cos": [NUM_RES, None, None], 55 | "chi_mask": [NUM_RES, None], 56 | "extra_deletion_value": [NUM_EXTRA_SEQ, NUM_RES], 57 | "extra_has_deletion": [NUM_EXTRA_SEQ, NUM_RES], 58 | "extra_msa": [NUM_EXTRA_SEQ, NUM_RES], 59 | "extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES], 60 | "extra_msa_row_mask": [NUM_EXTRA_SEQ], 61 | "is_distillation": [], 62 | "msa_feat": [NUM_MSA_SEQ, NUM_RES, None], 63 | "msa_mask": [NUM_MSA_SEQ, NUM_RES], 64 | "msa_row_mask": [NUM_MSA_SEQ], 65 | "no_recycling_iters": [], 66 | "pseudo_beta": [NUM_RES, None], 67 | "pseudo_beta_mask": [NUM_RES], 68 | "residue_index": [NUM_RES], 69 | "residx_dense_to_all": [NUM_RES, None], 70 | "residx_all_to_dense": [NUM_RES, None], 71 | "resolution": [], 72 | "rigidgroups_alt_gt_frames": [NUM_RES, None, None, None], 73 | "rigidgroups_group_exists": [NUM_RES, None], 74 | "rigidgroups_group_is_ambiguous": [NUM_RES, None], 75 | "rigidgroups_gt_exists": [NUM_RES, None], 76 | "rigidgroups_gt_frames": [NUM_RES, None, None, None], 77 | "seq_length": [], 78 | "seq_mask": [NUM_RES], 79 | "target_feat": [NUM_RES, None], 80 | "template_butype": [NUM_TEMPLATES, NUM_RES], 81 | "template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None], 82 | "template_all_atom_positions": [ 83 | NUM_TEMPLATES, NUM_RES, None, None, 84 | ], 85 | "template_alt_torsion_angles_sin_cos": [ 86 | NUM_TEMPLATES, NUM_RES, None, None, 87 | ], 88 | "template_backbone_rigid_mask": [NUM_TEMPLATES, NUM_RES], 89 | "template_backbone_rigid_tensor": [ 90 | NUM_TEMPLATES, NUM_RES, None, None, 91 | ], 92 | "template_mask": [NUM_TEMPLATES], 93 | "template_pseudo_beta": [NUM_TEMPLATES, NUM_RES, None], 94 | "template_pseudo_beta_mask": [NUM_TEMPLATES, NUM_RES], 95 | "template_sum_probs": [NUM_TEMPLATES, None], 96 | "template_torsion_angles_mask": [ 97 | NUM_TEMPLATES, NUM_RES, None, 98 | ], 99 | "template_torsion_angles_sin_cos": [ 100 | NUM_TEMPLATES, NUM_RES, None, None, 101 | ], 102 | "true_msa": [NUM_MSA_SEQ, NUM_RES], 103 | "use_clamped_fape": [], 104 | }) 105 | 106 | multimer_feats = mlc.ConfigDict({ 107 | "sym_id": [NUM_RES], 108 | "asym_id": [NUM_RES], 109 | "entity_id": [NUM_RES], 110 | 111 | "origin_all_atom_mask": [NUM_FULL_RES, None], 112 | "origin_all_atom_positions": [NUM_FULL_RES, None, None], 113 | "origin_sym_id": [NUM_FULL_RES], 114 | "origin_asym_id": [NUM_FULL_RES], 115 | "origin_entity_id": [NUM_FULL_RES], 116 | 117 | "pseudo_beta_mask": [NUM_FULL_RES], 118 | "rigidgroups_gt_exists": [NUM_FULL_RES, None], 119 | "dense_atom_gt_exists": [NUM_FULL_RES, None], 120 | "rigidgroups_gt_frames": [NUM_FULL_RES, None, None, None], 121 | "rigidgroups_group_exists": [NUM_FULL_RES, None], 122 | "dense_atom_alt_gt_exists": [NUM_FULL_RES, None], 123 | "backbone_rigid_tensor": [NUM_FULL_RES, None, None, None], 124 | "pseudo_beta": [NUM_FULL_RES, None], 125 | "dense_atom_alt_gt_positions": [NUM_FULL_RES, None, None], 126 | "backbone_rigid_mask": [NUM_FULL_RES, None], 127 | "chi_mask": [NUM_FULL_RES, None], 128 | "rigidgroups_alt_gt_frames": [NUM_FULL_RES, None, None, None], 129 | "rigidgroups_group_is_ambiguous": [NUM_FULL_RES, None], 130 | "chi_angles_sin_cos": [NUM_FULL_RES, None, None], 131 | "dense_atom_gt_positions": [NUM_FULL_RES, None, None], 132 | "dense_atom_is_ambiguous": [NUM_FULL_RES, None], 133 | 134 | "resolution": [None], 135 | }) 136 | 137 | __all__ = [ 138 | "c_z", "c_m", "c_t", "c_e", "c_s", "blocks_per_ckpt", "chunk_size", "aux_distogram_bins", "tm_enabled", 139 | "eps", "templates_enabled", "embed_template_torsion_angles", "tune_chunk_size", 140 | "NUM_RES", "NUM_FULL_RES", "NUM_MSA_SEQ", "NUM_EXTRA_SEQ", "NUM_TEMPLATES", 141 | "common_feats", "multimer_feats" 142 | ] -------------------------------------------------------------------------------- /scripts/extract_pkl_from_fas.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import pickle 5 | import random 6 | import sys 7 | import time 8 | import json 9 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 10 | 11 | from opencomplex.data import templates, data_pipeline 12 | 13 | 14 | def generate_pkl_from_fas(args): 15 | template_featurizer = templates.TemplateHitFeaturizer( 16 | mmcif_dir=args.template_mmcif_dir, 17 | max_template_date=args.max_template_date, 18 | max_hits=args.max_templates, 19 | kalign_binary_path=args.kalign_binary_path, 20 | release_dates_path=args.release_dates_path, 21 | obsolete_pdbs_path=args.obsolete_pdbs_path 22 | ) 23 | 24 | data_processor = data_pipeline.DataPipeline( 25 | template_featurizer=template_featurizer, 26 | ) 27 | random_seed = args.data_random_seed 28 | if random_seed is None: 29 | random_seed = random.randrange(sys.maxsize) 30 | 31 | alignment_runner = data_pipeline.AlignmentRunner( 32 | jackhmmer_binary_path=args.jackhmmer_binary_path, 33 | hhblits_binary_path=args.hhblits_binary_path, 34 | hhsearch_binary_path=args.hhsearch_binary_path, 35 | uniref90_database_path=args.uniref90_database_path, 36 | mgnify_database_path=args.mgnify_database_path, 37 | bfd_database_path=args.bfd_database_path, 38 | uniclust30_database_path=args.uniclust30_database_path, 39 | pdb70_database_path=args.pdb70_database_path, 40 | use_small_bfd=args.use_small_bfd, 41 | no_cpus=args.cpus, 42 | ) 43 | for fas in os.listdir(args.fasta_path): 44 | local_fasta_path=os.path.join(args.fasta_path,fas) 45 | fas_name=fas.split('.')[0] 46 | feature_dir=os.path.join(args.output_dir, fas_name) 47 | 48 | if args.use_precomputed_alignments is None: 49 | local_alignment_dir = os.path.join(feature_dir, "msas") 50 | if not os.path.exists(local_alignment_dir): 51 | os.makedirs(local_alignment_dir) 52 | else: 53 | local_alignment_dir = args.use_precomputed_alignments 54 | 55 | 56 | # logging.info(f"Generating features for {fas_name} ...") 57 | print(f"Generating features for {fas_name} ...") 58 | # if timings is None: 59 | timings = {} 60 | pt = time.time() 61 | alignment_runner.run( 62 | local_fasta_path, local_alignment_dir 63 | ) 64 | feature_dict = data_processor.process_fasta( 65 | fasta_path=local_fasta_path, alignment_dir=local_alignment_dir 66 | ) 67 | timings['data_pipeline'] = time.time() - pt 68 | features_output_path = os.path.join(feature_dir, 'features.pkl') 69 | with open(features_output_path, 'wb') as f: 70 | pickle.dump(feature_dict, f, protocol=4) 71 | timings_output_path = os.path.join(feature_dir, 'timings.json') 72 | with open(timings_output_path, 'w') as fp: 73 | json.dump(timings, fp, indent=4) 74 | # logging.info(f"process file {fas_name} done.") 75 | print(f"process file {fas_name} done.") 76 | 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument( 81 | "--fasta_path", type=str, 82 | help="Path to directory containing FASTA files, one sequence per file" 83 | ) 84 | parser.add_argument( 85 | "--output_dir", type=str, 86 | ) 87 | parser.add_argument( 88 | "--use_precomputed_alignments", type=str, default=None, 89 | help="""Path to alignment directory. If provided, alignment computation 90 | is skipped and database path arguments are ignored.""" 91 | ) 92 | parser.add_argument( 93 | "--max_templates", type=int, default=20, 94 | help="""Max number of templates to use.""" 95 | ) 96 | parser.add_argument( 97 | "--cpus", type=int, default=12, 98 | help="""Number of CPUs with which to run alignment tools""" 99 | ) 100 | parser.add_argument( 101 | "--data_random_seed", type=str, default=None 102 | ) 103 | parser.add_argument( 104 | "--template_mmcif_dir", type=str, default="/mnt/database/pdb_mmcif/mmcif_files", 105 | help="""Path to mmcif directory.""" 106 | ) 107 | parser.add_argument( 108 | "--uniref90_database_path", type=str, default="/mnt/database/uniref90_latest/uniref90.fasta", 109 | help="""Path to uniref90 directory.""" 110 | ) 111 | parser.add_argument( 112 | "--mgnify_database_path", type=str, default="/mnt/database/mgnify/mgy_clusters.fa", 113 | help="""Path to mgnify directory.""" 114 | ) 115 | parser.add_argument( 116 | "--pdb70_database_path", type=str, default="/mnt/database/pdb70_latest/pdb70", 117 | help="""Path to pdb70 directory.""" 118 | ) 119 | parser.add_argument( 120 | "--uniclust30_database_path", type=str, default="/mnt/database/uniref30_latest/UniRef30_2021_03", 121 | help="""Path to uniclust30 directory.""" 122 | ) 123 | parser.add_argument( 124 | "--bfd_database_path", type=str, default="/mnt/database/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt", 125 | help="""Path to bfd directory.""" 126 | ) 127 | parser.add_argument( 128 | "--use_small_bfd", default=False, action="store_true", 129 | help="""If use small_bfd.""" 130 | ) 131 | parser.add_argument( 132 | "--obsolete_pdbs_path", type=str, default="/mnt/database/pdb_mmcif/obsolete.dat", 133 | help="""Path to obsolete_pdbs_path .""" 134 | ) 135 | parser.add_argument( 136 | "--jackhmmer_binary_path", type=str, default="/opt/conda/envs/opencomplex_venv/bin/jackhmmer", 137 | help="""Binary path of jackhmmer.""" 138 | ) 139 | parser.add_argument( 140 | "--hhblits_binary_path", type=str, default="/opt/cnda/envs/opencomplex_venv/bin/hhblits", 141 | help="""Binary path of hhblits.""" 142 | ) 143 | parser.add_argument( 144 | "--hhsearch_binary_path", type=str, default="/opt/conda/envs/opencomplex_venv/bin/hhsearch", 145 | help="""Binary path of hhsearch.""" 146 | ) 147 | parser.add_argument( 148 | "--kalign_binary_path", type=str, default="/opt/conda/envs/opencomplex_venv/bin/kalign", 149 | help="""Binary path of kalign.""" 150 | ) 151 | parser.add_argument( 152 | "--max_template_date", type=str, default="2022-04-24", 153 | help="""max_template_date.""" 154 | ) 155 | parser.add_argument( 156 | "--release_dates_path", type=str, default=None, 157 | help="""release_dates_path.""" 158 | ) 159 | args = parser.parse_args() 160 | 161 | generate_pkl_from_fas(args) -------------------------------------------------------------------------------- /opencomplex/data/tools/hhblits.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Library to run HHblits from Python.""" 17 | import glob 18 | import logging 19 | import os 20 | import subprocess 21 | from typing import Any, Mapping, Optional, Sequence 22 | 23 | from opencomplex.data.tools import utils 24 | 25 | 26 | _HHBLITS_DEFAULT_P = 20 27 | _HHBLITS_DEFAULT_Z = 500 28 | 29 | 30 | class HHBlits: 31 | """Python wrapper of the HHblits binary.""" 32 | 33 | def __init__( 34 | self, 35 | *, 36 | binary_path: str, 37 | databases: Sequence[str], 38 | n_cpu: int = 4, 39 | n_iter: int = 3, 40 | e_value: float = 0.001, 41 | maxseq: int = 1_000_000, 42 | realign_max: int = 100_000, 43 | maxfilt: int = 100_000, 44 | min_prefilter_hits: int = 1000, 45 | all_seqs: bool = False, 46 | alt: Optional[int] = None, 47 | p: int = _HHBLITS_DEFAULT_P, 48 | z: int = _HHBLITS_DEFAULT_Z, 49 | ): 50 | """Initializes the Python HHblits wrapper. 51 | 52 | Args: 53 | binary_path: The path to the HHblits executable. 54 | databases: A sequence of HHblits database paths. This should be the 55 | common prefix for the database files (i.e. up to but not including 56 | _hhm.ffindex etc.) 57 | n_cpu: The number of CPUs to give HHblits. 58 | n_iter: The number of HHblits iterations. 59 | e_value: The E-value, see HHblits docs for more details. 60 | maxseq: The maximum number of rows in an input alignment. Note that this 61 | parameter is only supported in HHBlits version 3.1 and higher. 62 | realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. 63 | maxfilt: Max number of hits allowed to pass the 2nd prefilter. 64 | HHblits default: 20000. 65 | min_prefilter_hits: Min number of hits to pass prefilter. 66 | HHblits default: 100. 67 | all_seqs: Return all sequences in the MSA / Do not filter the result MSA. 68 | HHblits default: False. 69 | alt: Show up to this many alternative alignments. 70 | p: Minimum Prob for a hit to be included in the output hhr file. 71 | HHblits default: 20. 72 | z: Hard cap on number of hits reported in the hhr file. 73 | HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. 74 | 75 | Raises: 76 | RuntimeError: If HHblits binary not found within the path. 77 | """ 78 | self.binary_path = binary_path 79 | self.databases = databases 80 | 81 | for database_path in self.databases: 82 | if not glob.glob(database_path + "_*"): 83 | logging.error( 84 | "Could not find HHBlits database %s", database_path 85 | ) 86 | raise ValueError( 87 | f"Could not find HHBlits database {database_path}" 88 | ) 89 | 90 | self.n_cpu = n_cpu 91 | self.n_iter = n_iter 92 | self.e_value = e_value 93 | self.maxseq = maxseq 94 | self.realign_max = realign_max 95 | self.maxfilt = maxfilt 96 | self.min_prefilter_hits = min_prefilter_hits 97 | self.all_seqs = all_seqs 98 | self.alt = alt 99 | self.p = p 100 | self.z = z 101 | 102 | def query(self, input_fasta_path: str) -> Mapping[str, Any]: 103 | """Queries the database using HHblits.""" 104 | with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir: 105 | a3m_path = os.path.join(query_tmp_dir, "output.a3m") 106 | 107 | db_cmd = [] 108 | for db_path in self.databases: 109 | db_cmd.append("-d") 110 | db_cmd.append(db_path) 111 | cmd = [ 112 | self.binary_path, 113 | "-i", 114 | input_fasta_path, 115 | "-cpu", 116 | str(self.n_cpu), 117 | "-oa3m", 118 | a3m_path, 119 | "-o", 120 | "/dev/null", 121 | "-n", 122 | str(self.n_iter), 123 | "-e", 124 | str(self.e_value), 125 | "-maxseq", 126 | str(self.maxseq), 127 | "-realign_max", 128 | str(self.realign_max), 129 | "-maxfilt", 130 | str(self.maxfilt), 131 | "-min_prefilter_hits", 132 | str(self.min_prefilter_hits), 133 | ] 134 | if self.all_seqs: 135 | cmd += ["-all"] 136 | if self.alt: 137 | cmd += ["-alt", str(self.alt)] 138 | if self.p != _HHBLITS_DEFAULT_P: 139 | cmd += ["-p", str(self.p)] 140 | if self.z != _HHBLITS_DEFAULT_Z: 141 | cmd += ["-Z", str(self.z)] 142 | cmd += db_cmd 143 | 144 | logging.info('Launching subprocess "%s"', " ".join(cmd)) 145 | process = subprocess.Popen( 146 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE 147 | ) 148 | 149 | with utils.timing("HHblits query"): 150 | stdout, stderr = process.communicate() 151 | retcode = process.wait() 152 | 153 | if retcode: 154 | # Logs have a 15k character limit, so log HHblits error line by line. 155 | logging.error("HHblits failed. HHblits stderr begin:") 156 | for error_line in stderr.decode("utf-8").splitlines(): 157 | if error_line.strip(): 158 | logging.error(error_line.strip()) 159 | logging.error("HHblits stderr end") 160 | raise RuntimeError( 161 | "HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n" 162 | % (stdout.decode("utf-8"), stderr[:500_000].decode("utf-8")) 163 | ) 164 | 165 | with open(a3m_path) as f: 166 | a3m = f.read() 167 | 168 | raw_output = dict( 169 | a3m=a3m, 170 | output=stdout, 171 | stderr=stderr, 172 | n_iter=self.n_iter, 173 | e_value=self.e_value, 174 | ) 175 | return raw_output 176 | -------------------------------------------------------------------------------- /opencomplex/model/torchscript.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional, Sequence, Tuple 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from opencomplex.model.dropout import ( 21 | DropoutRowwise, 22 | DropoutColumnwise, 23 | ) 24 | from opencomplex.model.evoformer import ( 25 | EvoformerBlock, 26 | EvoformerStack, 27 | ) 28 | from opencomplex.model.outer_product_mean import OuterProductMean 29 | from opencomplex.model.msa import ( 30 | MSARowAttentionWithPairBias, 31 | MSAColumnAttention, 32 | MSAColumnGlobalAttention, 33 | ) 34 | from opencomplex.model.pair_transition import PairTransition 35 | from opencomplex.model.primitives import Attention, GlobalAttention 36 | from opencomplex.model.sm.structure_module import ( 37 | InvariantPointAttention, 38 | BackboneUpdate, 39 | ) 40 | from opencomplex.model.template import TemplatePairStackBlock 41 | from opencomplex.model.triangular_attention import ( 42 | TriangleAttentionStartingNode, 43 | TriangleAttentionEndingNode, 44 | ) 45 | from opencomplex.model.triangular_multiplicative_update import ( 46 | TriangleMultiplicationOutgoing, 47 | TriangleMultiplicationIncoming, 48 | ) 49 | 50 | 51 | def script_preset_(model: torch.nn.Module): 52 | """ 53 | TorchScript a handful of low-level but frequently used submodule types 54 | that are known to be scriptable. 55 | 56 | Args: 57 | model: 58 | A torch.nn.Module. It should contain at least some modules from 59 | this repository, or this function won't do anything. 60 | """ 61 | script_submodules_( 62 | model, 63 | [ 64 | nn.Dropout, 65 | Attention, 66 | GlobalAttention, 67 | EvoformerBlock, 68 | #TemplatePairStackBlock, 69 | ], 70 | attempt_trace=False, 71 | batch_dims=None, 72 | ) 73 | 74 | 75 | def _get_module_device(module: torch.nn.Module) -> torch.device: 76 | """ 77 | Fetches the device of a module, assuming that all of the module's 78 | parameters reside on a single device 79 | 80 | Args: 81 | module: A torch.nn.Module 82 | Returns: 83 | The module's device 84 | """ 85 | return next(module.parameters()).device 86 | 87 | 88 | def _trace_module(module, batch_dims=None): 89 | if(batch_dims is None): 90 | batch_dims = () 91 | 92 | # Stand-in values 93 | n_seq = 10 94 | n_res = 10 95 | 96 | device = _get_module_device(module) 97 | 98 | def msa(channel_dim): 99 | return torch.rand( 100 | (*batch_dims, n_seq, n_res, channel_dim), 101 | device=device, 102 | ) 103 | 104 | def pair(channel_dim): 105 | return torch.rand( 106 | (*batch_dims, n_res, n_res, channel_dim), 107 | device=device, 108 | ) 109 | 110 | if(isinstance(module, MSARowAttentionWithPairBias)): 111 | inputs = { 112 | "forward": ( 113 | msa(module.c_in), # m 114 | pair(module.c_z), # z 115 | torch.randint( 116 | 0, 2, 117 | (*batch_dims, n_seq, n_res) 118 | ), # mask 119 | ), 120 | } 121 | elif(isinstance(module, MSAColumnAttention)): 122 | inputs = { 123 | "forward": ( 124 | msa(module.c_in), # m 125 | torch.randint( 126 | 0, 2, 127 | (*batch_dims, n_seq, n_res) 128 | ), # mask 129 | ), 130 | } 131 | elif(isinstance(module, OuterProductMean)): 132 | inputs = { 133 | "forward": ( 134 | msa(module.c_m), 135 | torch.randint( 136 | 0, 2, 137 | (*batch_dims, n_seq, n_res) 138 | ) 139 | ) 140 | } 141 | else: 142 | raise TypeError( 143 | f"tracing is not supported for modules of type {type(module)}" 144 | ) 145 | 146 | return torch.jit.trace_module(module, inputs) 147 | 148 | 149 | def _script_submodules_helper_( 150 | model, 151 | types, 152 | attempt_trace, 153 | to_trace, 154 | ): 155 | for name, child in model.named_children(): 156 | if(types is None or any(isinstance(child, t) for t in types)): 157 | try: 158 | scripted = torch.jit.script(child) 159 | setattr(model, name, scripted) 160 | continue 161 | except (RuntimeError, torch.jit.frontend.NotSupportedError) as e: 162 | if(attempt_trace): 163 | to_trace.add(type(child)) 164 | else: 165 | raise e 166 | 167 | _script_submodules_helper_(child, types, attempt_trace, to_trace) 168 | 169 | 170 | def _trace_submodules_( 171 | model, 172 | types, 173 | batch_dims=None, 174 | ): 175 | for name, child in model.named_children(): 176 | if(any(isinstance(child, t) for t in types)): 177 | traced = _trace_module(child, batch_dims=batch_dims) 178 | setattr(model, name, traced) 179 | else: 180 | _trace_submodules_(child, types, batch_dims=batch_dims) 181 | 182 | 183 | def script_submodules_( 184 | model: nn.Module, 185 | types: Optional[Sequence[type]] = None, 186 | attempt_trace: Optional[bool] = True, 187 | batch_dims: Optional[Tuple[int]] = None, 188 | ): 189 | """ 190 | Convert all submodules whose types match one of those in the input 191 | list to recursively scripted equivalents in place. To script the entire 192 | model, just call torch.jit.script on it directly. 193 | 194 | When types is None, all submodules are scripted. 195 | 196 | Args: 197 | model: 198 | A torch.nn.Module 199 | types: 200 | A list of types of submodules to script 201 | attempt_trace: 202 | Whether to attempt to trace specified modules if scripting 203 | fails. Recall that tracing eliminates all conditional 204 | logic---with great tracing comes the mild responsibility of 205 | having to remember to ensure that the modules in question 206 | perform the same computations no matter what. 207 | """ 208 | to_trace = set() 209 | 210 | # Aggressively script as much as possible first... 211 | _script_submodules_helper_(model, types, attempt_trace, to_trace) 212 | 213 | # ... and then trace stragglers. 214 | if(attempt_trace and len(to_trace) > 0): 215 | _trace_submodules_(model, to_trace, batch_dims=batch_dims) 216 | -------------------------------------------------------------------------------- /opencomplex/resources/stereo_chemical_props_RNA.txt: -------------------------------------------------------------------------------- 1 | Bond Residue Mean StdDev 2 | C3'-C4' A 1.523 0.000 3 | C4'-C5' A 1.512 0.000 4 | C5'-O5' A 1.428 0.000 5 | O5'-P A 1.593 0.000 6 | O3'-C3' A 1.422 0.000 7 | C4'-O4' A 1.450 0.000 8 | O4'-C1' A 1.415 0.000 9 | C1'-C2' A 1.526 0.000 10 | C2'-C3' A 1.524 0.000 11 | C2'-O2' A 1.416 0.000 12 | C1'-N9 A 1.469 0.000 13 | N9-C8 A 1.372 0.000 14 | N9-C4 A 1.374 0.000 15 | C8-N7 A 1.311 0.000 16 | C4-C5 A 1.382 0.000 17 | C4-N3 A 1.345 0.000 18 | N7-C5 A 1.387 0.000 19 | N3-C2 A 1.332 0.000 20 | C5-C6 A 1.406 0.000 21 | C2-N1 A 1.339 0.000 22 | N1-C6 A 1.351 0.000 23 | C6-N6 A 1.335 0.000 24 | C3'-C4' U 1.522 0.000 25 | C4'-C5' U 1.513 0.000 26 | C5'-O5' U 1.428 0.000 27 | O5'-P U 1.593 0.000 28 | O3'-C3' U 1.422 0.000 29 | C4'-O4' U 1.451 0.000 30 | O4'-C1' U 1.416 0.000 31 | C1'-C2' U 1.527 0.000 32 | C2'-C3' U 1.525 0.000 33 | C2'-O2' U 1.416 0.000 34 | C1'-N1 U 1.481 0.000 35 | N1-C6 U 1.376 0.000 36 | N1-C2 U 1.383 0.000 37 | C6-C5 U 1.338 0.000 38 | C2-N3 U 1.373 0.000 39 | C2-O2 U 1.220 0.000 40 | C5-C4 U 1.431 0.000 41 | N3-C4 U 1.378 0.000 42 | C4-O4 U 1.231 0.000 43 | C3'-C4' G 1.522 0.000 44 | C4'-C5' G 1.513 0.000 45 | C5'-O5' G 1.428 0.000 46 | O5'-P G 1.593 0.000 47 | O3'-C3' G 1.423 0.000 48 | C4'-O4' G 1.450 0.000 49 | O4'-C1' G 1.416 0.000 50 | C1'-C2' G 1.526 0.000 51 | C2'-C3' G 1.524 0.000 52 | C2'-O2' G 1.416 0.000 53 | C1'-N9 G 1.469 0.000 54 | N9-C8 G 1.374 0.000 55 | N9-C4 G 1.375 0.000 56 | C8-N7 G 1.306 0.000 57 | C4-C5 G 1.378 0.000 58 | C4-N3 G 1.350 0.000 59 | N7-C5 G 1.388 0.000 60 | N3-C2 G 1.325 0.000 61 | C5-C6 G 1.418 0.000 62 | C2-N1 G 1.372 0.000 63 | C2-N2 G 1.337 0.000 64 | N1-C6 G 1.390 0.000 65 | C6-O6 G 1.236 0.000 66 | C3'-C4' C 1.523 0.000 67 | C4'-C5' C 1.516 0.000 68 | C5'-O5' C 1.431 0.000 69 | O5'-P C 1.592 0.000 70 | O3'-C3' C 1.425 0.000 71 | C4'-O4' C 1.451 0.000 72 | O4'-C1' C 1.418 0.000 73 | C1'-C2' C 1.527 0.000 74 | C2'-C3' C 1.525 0.000 75 | C2'-O2' C 1.416 0.000 76 | C1'-N1 C 1.483 0.001 77 | N1-C6 C 1.369 0.000 78 | N1-C2 C 1.399 0.000 79 | C6-C5 C 1.340 0.000 80 | C2-N3 C 1.354 0.000 81 | C2-O2 C 1.240 0.000 82 | C5-C4 C 1.424 0.000 83 | N3-C4 C 1.336 0.000 84 | C4-N4 C 1.334 0.000 85 | - 86 | 87 | Angle Residue Mean StdDev 88 | O3'-C3'-C4' A 111.632 7.942 89 | O3'-C3'-C2' A 112.610 7.154 90 | C4'-C3'-C2' A 102.389 1.492 91 | C3'-C4'-C5' A 115.549 2.276 92 | C3'-C4'-O4' A 104.361 2.464 93 | C5'-C4'-O4' A 109.650 1.942 94 | C4'-C5'-O5' A 110.966 2.716 95 | C5'-O5'-P A 120.837 3.983 96 | C4'-O4'-C1' A 109.797 1.172 97 | O4'-C1'-C2' A 107.086 1.466 98 | O4'-C1'-N9 A 108.934 2.554 99 | C2'-C1'-N9 A 112.956 3.856 100 | C1'-C2'-C3' A 101.556 0.908 101 | C1'-C2'-O2' A 109.285 6.512 102 | C3'-C2'-O2' A 112.025 7.034 103 | C1'-N9-C8 A 127.676 2.447 104 | C1'-N9-C4 A 126.495 2.742 105 | C8-N9-C4 A 105.768 0.253 106 | N9-C8-N7 A 113.837 0.360 107 | C8-N7-C5 A 103.799 0.271 108 | N7-C5-C4 A 110.788 0.212 109 | N7-C5-C6 A 132.272 0.518 110 | C4-C5-C6 A 116.927 0.562 111 | C5-C4-N9 A 105.786 0.197 112 | C5-C4-N3 A 126.769 0.883 113 | N9-C4-N3 A 127.431 0.767 114 | C4-N3-C2 A 110.832 0.807 115 | N3-C2-N1 A 128.999 0.696 116 | C2-N1-C6 A 118.660 0.521 117 | N1-C6-C5 A 117.771 0.649 118 | N1-C6-N6 A 118.673 1.890 119 | C5-C6-N6 A 123.544 1.069 120 | O3'-C3'-C4' U 111.531 7.240 121 | O3'-C3'-C2' U 112.639 8.535 122 | C4'-C3'-C2' U 102.424 1.868 123 | C3'-C4'-C5' U 115.581 2.755 124 | C3'-C4'-O4' U 104.369 2.775 125 | C5'-C4'-O4' U 109.879 2.099 126 | C4'-C5'-O5' U 111.091 2.626 127 | C5'-O5'-P U 120.981 4.677 128 | C4'-O4'-C1' U 109.731 1.114 129 | O4'-C1'-C2' U 106.941 1.845 130 | O4'-C1'-N1 U 109.466 3.200 131 | C2'-C1'-N1 U 113.042 3.927 132 | C1'-C2'-C3' U 101.706 1.070 133 | C1'-C2'-O2' U 109.546 6.281 134 | C3'-C2'-O2' U 112.088 6.918 135 | C1'-N1-C6 U 121.242 2.116 136 | C1'-N1-C2 U 117.986 2.492 137 | C6-N1-C2 U 120.712 0.359 138 | N1-C6-C5 U 122.825 0.430 139 | C6-C5-C4 U 119.606 0.443 140 | C5-C4-N3 U 114.713 0.807 141 | C5-C4-O4 U 125.765 0.562 142 | O4-C4-N3 U 119.510 0.587 143 | N1-C2-N3 U 115.158 0.521 144 | N1-C2-O2 U 123.096 0.885 145 | N3-C2-O2 U 121.737 0.934 146 | C2-N3-C4 U 126.918 0.913 147 | O3'-C3'-C4' G 111.684 8.338 148 | O3'-C3'-C2' G 112.787 7.588 149 | C4'-C3'-C2' G 102.152 1.693 150 | C3'-C4'-C5' G 115.768 3.227 151 | C3'-C4'-O4' G 104.107 2.818 152 | C5'-C4'-O4' G 109.820 2.773 153 | C4'-C5'-O5' G 111.089 3.325 154 | C5'-O5'-P G 120.703 4.167 155 | C4'-O4'-C1' G 109.639 1.281 156 | O4'-C1'-C2' G 107.027 1.648 157 | O4'-C1'-N9 G 109.367 3.204 158 | C2'-C1'-N9 G 112.516 4.190 159 | C1'-C2'-C3' G 101.409 1.004 160 | C1'-C2'-O2' G 109.473 5.678 161 | C3'-C2'-O2' G 112.004 5.928 162 | C1'-N9-C8 G 127.214 2.239 163 | C1'-N9-C4 G 126.514 2.382 164 | C8-N9-C4 G 106.215 0.328 165 | N9-C8-N7 G 113.270 0.327 166 | C8-N7-C5 G 104.206 0.199 167 | N7-C5-C4 G 110.817 0.235 168 | N7-C5-C6 G 130.301 0.552 169 | C4-C5-C6 G 118.862 0.339 170 | C5-C4-N9 G 105.472 0.259 171 | C5-C4-N3 G 128.440 0.447 172 | N9-C4-N3 G 126.076 0.465 173 | C4-N3-C2 G 111.993 0.431 174 | N3-C2-N1 G 123.983 0.804 175 | N3-C2-N2 G 119.871 0.780 176 | N1-C2-N2 G 116.130 0.966 177 | C2-N1-C6 G 124.888 1.044 178 | N1-C6-C5 G 111.778 0.921 179 | N1-C6-O6 G 119.972 1.595 180 | C5-C6-O6 G 128.241 1.155 181 | O3'-C3'-C4' C 111.742 6.480 182 | O3'-C3'-C2' C 113.144 7.679 183 | C4'-C3'-C2' C 102.179 1.761 184 | C3'-C4'-C5' C 115.764 2.739 185 | C3'-C4'-O4' C 104.146 2.182 186 | C5'-C4'-O4' C 109.936 2.200 187 | C4'-C5'-O5' C 111.207 4.087 188 | C5'-O5'-P C 120.962 7.515 189 | C4'-O4'-C1' C 109.609 0.974 190 | O4'-C1'-C2' C 107.079 1.423 191 | O4'-C1'-N1 C 109.841 5.841 192 | C2'-C1'-N1 C 113.020 4.422 193 | C1'-C2'-C3' C 101.655 0.952 194 | C1'-C2'-O2' C 109.664 4.232 195 | C3'-C2'-O2' C 112.177 4.954 196 | C1'-N1-C6 C 120.635 3.297 197 | C1'-N1-C2 C 119.284 3.615 198 | C6-N1-C2 C 119.991 0.498 199 | N1-C6-C5 C 121.326 0.450 200 | C6-C5-C4 C 117.473 0.212 201 | C5-C4-N3 C 121.711 0.479 202 | C5-C4-N4 C 120.108 0.479 203 | N4-C4-N3 C 118.171 0.903 204 | N1-C2-N3 C 119.259 0.409 205 | N1-C2-O2 C 119.367 0.999 206 | N3-C2-O2 C 121.364 0.965 207 | C2-N3-C4 C 120.163 0.466 208 | - 209 | 210 | Non-bonded distance Minimum Dist Tolerance 211 | C-C 1.710 1.5 212 | C-N 2.096 1.5 213 | C-P 2.115 1.5 214 | C-O 1.737 1.5 215 | N-N 1.986 1.5 216 | N-P 2.954 1.5 217 | N-O 2.126 1.5 218 | O-P 2.333 1.5 219 | O-O 1.938 1.5 220 | P-P 3.657 1.5 221 | - 222 | --------------------------------------------------------------------------------