├── sam ├── __init__.py ├── nn │ ├── __init__.py │ ├── autoencoder │ │ ├── __init__.py │ │ ├── encoder │ │ │ └── __init__.py │ │ └── decoder │ │ │ └── __init__.py │ ├── noise_prediction │ │ ├── __init__.py │ │ └── eps │ │ │ └── __init__.py │ ├── ema.py │ └── generator.py ├── data │ ├── __init__.py │ ├── trajectory_reader.py │ └── topology.py ├── analysis │ ├── __init__.py │ └── melting_curves.py ├── diffusion │ ├── __init__.py │ └── common.py ├── evaluation │ └── __init__.py ├── openfold │ ├── np │ │ ├── __init__.py │ │ └── relax │ │ │ ├── __init__.py │ │ │ ├── utils.py │ │ │ ├── relax.py │ │ │ └── cleanup.py │ ├── data │ │ ├── __init__.py │ │ ├── tools │ │ │ ├── __init__.py │ │ │ ├── utils.py │ │ │ ├── parse_msa_files.py │ │ │ ├── kalign.py │ │ │ ├── hhsearch.py │ │ │ ├── hmmbuild.py │ │ │ ├── hmmsearch.py │ │ │ └── hhblits.py │ │ ├── errors.py │ │ ├── msa_identifiers.py │ │ ├── feature_pipeline.py │ │ ├── input_pipeline_multimer.py │ │ └── input_pipeline.py │ ├── model │ │ ├── __init__.py │ │ ├── dropout.py │ │ ├── pair_transition.py │ │ ├── outer_product_mean.py │ │ ├── triangular_attention.py │ │ └── torchscript.py │ ├── utils │ │ ├── __init__.py │ │ ├── kernel │ │ │ ├── __init__.py │ │ │ └── attention_core.py │ │ ├── callbacks.py │ │ ├── geometry │ │ │ ├── utils.py │ │ │ ├── __init__.py │ │ │ ├── quat_rigid.py │ │ │ ├── test_utils.py │ │ │ ├── rigid_matrix_vector.py │ │ │ └── rotation_matrix.py │ │ ├── precision_utils.py │ │ ├── argparse_utils.py │ │ ├── validation_metrics.py │ │ ├── exponential_moving_average.py │ │ ├── logger.py │ │ ├── lr_schedulers.py │ │ ├── checkpointing.py │ │ ├── superimposition.py │ │ └── tensor_utils.py │ ├── resources │ │ └── __init__.py │ └── __init__.py ├── minimizer │ ├── params │ │ ├── mizu_cfg.atlas.yaml │ │ ├── mizu_cfg.atlas_A.yaml │ │ ├── mizu_cfg.mdcath.yaml │ │ └── mizu_cfg.mdcath_A.yaml │ └── runner.py ├── dim_red.py ├── utils.py ├── trajectory.py └── coords.py ├── .gitignore ├── assets └── fig_1.png ├── pyproject.toml ├── setup.py ├── data ├── splits │ ├── atlas │ │ ├── val.max_len_500.txt │ │ ├── README.md │ │ └── test.txt │ └── mdcath │ │ ├── README.md │ │ ├── val.txt │ │ └── test.txt └── input │ └── README.md ├── scripts ├── ensemble_analysis.py ├── ensemble_comparison.py └── generate_ensemble.py └── config ├── mdcath_model.yaml └── atlas_model.yaml /sam/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam/nn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam/openfold/np/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam/nn/autoencoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam/openfold/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam/openfold/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam/openfold/np/relax/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam/openfold/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam/nn/noise_prediction/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam/openfold/data/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam/openfold/resources/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sam/openfold/utils/kernel/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | sam/training/ 3 | weights/ 4 | sam2.egg-info 5 | -------------------------------------------------------------------------------- /assets/fig_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giacomo-janson/sam2/HEAD/assets/fig_1.png -------------------------------------------------------------------------------- /sam/openfold/__init__.py: -------------------------------------------------------------------------------- 1 | from . import model 2 | from . import utils 3 | from . import data 4 | from . import np 5 | from . import resources 6 | 7 | __all__ = ["model", "utils", "np", "data", "resources"] 8 | -------------------------------------------------------------------------------- /sam/diffusion/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | class DiffusionCommon: 6 | 7 | def get_sample_model(self): 8 | # if self.ema is None: 9 | # return self.eps_model 10 | # else: 11 | # return self.ema.ema_model 12 | return self.eps_model -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "sam2" 7 | version = "1.0.0" 8 | dependencies = [ 9 | "biopython==1.85", 10 | "diffusers==0.32.2", 11 | "dm-tree", 12 | "mdtraj==1.10.3", 13 | "torch>=1.13.1" 14 | ] 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup, find_packages 3 | 4 | version = "1.0.0" 5 | 6 | if os.getenv("RUNNING_ON_COLAB") == "1": 7 | # setup( 8 | # name='sam', 9 | # version=version, 10 | # packages=find_packages(), 11 | # ) 12 | raise NotImplementedError() 13 | else: 14 | setup( 15 | name="sam", 16 | version=version, 17 | packages=["sam"], 18 | package_dir={"sam": "./sam"} 19 | ) 20 | -------------------------------------------------------------------------------- /sam/minimizer/params/mizu_cfg.atlas.yaml: -------------------------------------------------------------------------------- 1 | opt: 2 | energy_params: 3 | angle_const: 1000 4 | bond_const: 10000 5 | cabl_const: 100000 6 | chi_const: 1000 7 | improper_dihedral_const: 10 8 | nb_const: 100 9 | nb_form: l2 10 | phi_psi_const: 50 11 | proper_dihedral_const: 10 12 | history_size: 100 13 | max_iter: 10 14 | nb_centers_threshold: 1.0 15 | step_size: 1.0 16 | steps: 20 17 | top: 18 | nb_mode: os 19 | nb_os_tol: null 20 | -------------------------------------------------------------------------------- /data/splits/atlas/val.max_len_500.txt: -------------------------------------------------------------------------------- 1 | 6irx_A 2 | 6cka_B 3 | 6hj6_A 4 | 6dgk_B 5 | 6fc0_B 6 | 6dlm_A 7 | 6mdw_A 8 | 6bwq_A 9 | 5ydn_A 10 | 5w82_E 11 | 6cb7_A 12 | 5yrv_I 13 | 5z51_A 14 | 6a9a_A 15 | 6atg_C 16 | 5w4a_B 17 | 5zmo_A 18 | 6bk4_A 19 | 6eu8_A 20 | 6mbg_A 21 | 6bm5_A 22 | 6f45_D 23 | 6fub_B 24 | 6gfx_C 25 | 6a02_A 26 | 6c0h_A 27 | 6dnm_A 28 | 6as3_A 29 | 6bn0_A 30 | 5wfy_A 31 | 6hem_A 32 | 5z1n_A 33 | 5zlq_A 34 | 5naz_A 35 | 6e33_A 36 | 5ok6_A 37 | 6e5y_A 38 | 6crk_G 39 | -------------------------------------------------------------------------------- /sam/minimizer/params/mizu_cfg.atlas_A.yaml: -------------------------------------------------------------------------------- 1 | opt: 2 | energy_params: 3 | angle_const: 1 4 | bond_const: 0.1 5 | cabl_const: 100000 6 | chi_const: 1000 7 | improper_dihedral_const: 1 8 | nb_const: 100 9 | nb_form: l2 10 | phi_psi_const: 50 11 | proper_dihedral_const: 1 12 | history_size: 100 13 | max_iter: 10 14 | nb_centers_threshold: 1.0 15 | step_size: 1.0 16 | steps: 20 17 | top: 18 | use_ff_consts: true 19 | nb_mode: os 20 | nb_os_tol: null 21 | -------------------------------------------------------------------------------- /sam/nn/ema.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | try: 3 | from torch_ema import ExponentialMovingAverage as EMA 4 | has_ema = True 5 | except ImportError: 6 | has_ema = False 7 | 8 | 9 | def get_ema(network: Callable, model_cfg: dict, network_key: str = "generator"): 10 | if model_cfg[network_key].get("ema"): 11 | if not has_ema: 12 | raise ImportError("torch_ema is not installed") 13 | ema = EMA( 14 | network.parameters(), decay=model_cfg[network_key]["ema"]["beta"] 15 | ) 16 | else: 17 | ema = None 18 | return ema -------------------------------------------------------------------------------- /data/splits/atlas/README.md: -------------------------------------------------------------------------------- 1 | # Notes 2 | Here are the splits that we used for training/evaluating aSAMc models on ATLAS data. We used the same splits of [AlphaFlow](https://github.com/bjing2016/alphaflow). 3 | ## Autoencoder training 4 | * First stage 5 | * training: `training.max_len_320.txt` 6 | * validation: `val.max_len_500.txt` 7 | * Second stage: 8 | * training: `training.max_len_500.txt` 9 | * validation: `val.max_len_500.txt` 10 | * test: `test.txt` 11 | ## Diffusion model training 12 | * First stage: 13 | * training: `training.max_len_500.txt` 14 | * validation: `val.max_len_500.txt` 15 | * test: `test.txt` -------------------------------------------------------------------------------- /sam/openfold/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.utilities import rank_zero_info 2 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 3 | 4 | class EarlyStoppingVerbose(EarlyStopping): 5 | """ 6 | The default EarlyStopping callback's verbose mode is too verbose. 7 | This class outputs a message only when it's getting ready to stop. 8 | """ 9 | def _evalute_stopping_criteria(self, *args, **kwargs): 10 | should_stop, reason = super()._evalute_stopping_criteria(*args, **kwargs) 11 | if(should_stop): 12 | rank_zero_info(f"{reason}\n") 13 | 14 | return should_stop, reason 15 | -------------------------------------------------------------------------------- /sam/dim_red.py: -------------------------------------------------------------------------------- 1 | import sklearn.decomposition 2 | from sam.coords import calc_dmap_triu 3 | 4 | 5 | def run_pca(traj_ref, traj_hat, featurize_for_pca=None, get_x=False): 6 | # Function to featurize the trajectory and to get input features of PCA. 7 | if featurize_for_pca is None: 8 | featurize_for_pca = lambda x: calc_dmap_triu(x, backend="numpy") 9 | # PCA features. 10 | x_ref = featurize_for_pca(traj_ref.xyz) 11 | x_hat = featurize_for_pca(traj_hat.xyz) 12 | # Perform PCA. 13 | pca = sklearn.decomposition.PCA(n_components=10) 14 | pca.fit(x_ref) 15 | y_ref = pca.transform(x_ref) 16 | y_hat = pca.transform(x_hat) 17 | # Return the results. 18 | results = {"y_ref": y_ref, "y_hat": y_hat, "obj": pca} 19 | if get_x: 20 | results.update({"x_ref": x_ref, "x_hat": x_hat}) 21 | return results -------------------------------------------------------------------------------- /sam/openfold/utils/geometry/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utils for geometry library.""" 15 | 16 | import dataclasses 17 | 18 | 19 | def get_field_names(cls): 20 | fields = dataclasses.fields(cls) 21 | field_names = [f.name for f in fields] 22 | return field_names 23 | -------------------------------------------------------------------------------- /sam/openfold/utils/precision_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | 16 | import torch 17 | 18 | def is_fp16_enabled(): 19 | # Autocast world 20 | fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16 21 | fp16_enabled = fp16_enabled and torch.is_autocast_enabled() 22 | 23 | return fp16_enabled 24 | -------------------------------------------------------------------------------- /data/splits/atlas/test.txt: -------------------------------------------------------------------------------- 1 | 6o2v_A 2 | 7ead_A 3 | 6uof_A 4 | 6lus_A 5 | 6qj0_A 6 | 6j56_A 7 | 7ec1_A 8 | 6xds_A 9 | 6xrx_A 10 | 6q9c_B 11 | 6rrv_A 12 | 7lao_A 13 | 6l4l_A 14 | 7asg_A 15 | 6kty_A 16 | 6vjg_A 17 | 6sms_A 18 | 6l3r_E 19 | 7qsu_A 20 | 7p46_A 21 | 7e2s_A 22 | 6pxz_B 23 | 6ovk_R 24 | 6ndw_B 25 | 6pce_B 26 | 7p41_D 27 | 6h86_A 28 | 7jfl_C 29 | 6iah_A 30 | 6y2x_A 31 | 7nmq_A 32 | 6xb3_H 33 | 6jwh_A 34 | 6l4p_B 35 | 6jpt_A 36 | 7a66_B 37 | 6okd_C 38 | 6in7_A 39 | 7onn_A 40 | 6ono_C 41 | 6d7y_A 42 | 6odd_B 43 | 6p5x_B 44 | 6tgk_C 45 | 7dmn_A 46 | 7lp1_A 47 | 6l34_A 48 | 7ned_A 49 | 7s86_A 50 | 6l8s_A 51 | 7bwf_B 52 | 7aex_A 53 | 6d7y_B 54 | 6e7e_A 55 | 7k7p_B 56 | 7buy_A 57 | 6yhu_B 58 | 6h49_A 59 | 7aqx_A 60 | 7c45_A 61 | 6gus_A 62 | 6q9c_A 63 | 7n0j_E 64 | 6o6y_A 65 | 6zsl_B 66 | 7rm7_A 67 | 6ypi_A 68 | 6ro6_A 69 | 7mf4_A 70 | 7jrq_A 71 | 7wab_A 72 | 5znj_A 73 | 6pnv_A 74 | 6rwt_A 75 | 6oz1_A 76 | 6nl2_A 77 | 6p5h_A 78 | 6q10_A 79 | 6jv8_A 80 | 6lrd_A 81 | 6tly_A 82 | 7la6_A 83 | -------------------------------------------------------------------------------- /sam/openfold/data/errors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """General-purpose errors used throughout the data pipeline""" 17 | class Error(Exception): 18 | """Base class for exceptions.""" 19 | 20 | 21 | class MultipleChainsError(Error): 22 | """An error indicating that multiple chains were found for a given ID.""" 23 | -------------------------------------------------------------------------------- /data/splits/mdcath/README.md: -------------------------------------------------------------------------------- 1 | # Notes 2 | Here are the splits that we used for training/evaluating aSAMt models on mdCATH data. See the aSAM publication for details on how the splits were created. The names of the system follow this pattern: 3 | 4 | `${CATH_DOMAIN_ID}.${TEMPERATURE}` 5 | 6 | where `$TEMPERATURE` is the temperature of the MD simulation for that system. 7 | 8 | ## Autoencoder training 9 | * First stage 10 | * training: `train.max_len_320.txt` 11 | * note: this split contains only training systems containing a protein domain with length <= 320. See `train.txt` file for a list of all training systems. 12 | * validation: `val.txt` 13 | * test: `test.txt` 14 | ## Diffusion model training 15 | * First stage: 16 | * training: `train.max_len_320.fix.txt` 17 | * note: this list is derived from `train.max_len_320.txt`, but here we excluded all systems for 4j2nC00, 1b5fC00, 3hshE00, 1ow4A00, 3h33A00, 3vhoA00 and 5e99H01, because of errors when encoding their trajectories with our encoder network. 18 | * validation: `val.txt` 19 | * test: `test.txt` -------------------------------------------------------------------------------- /sam/nn/generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable 3 | import torch 4 | from sam.nn.noise_prediction.eps import get_eps_network 5 | from sam.diffusion.diffusers_dm import Diffusers 6 | 7 | 8 | def get_generator_net(model_cfg: dict): 9 | generator = get_eps_network(model_cfg) 10 | return generator 11 | 12 | 13 | def get_generative_model(model_cfg: dict, 14 | network: Callable, 15 | ema: Callable = None): 16 | 17 | # Diffusion modeling using the Diffusers library. 18 | if model_cfg["generative_model"]["type"] == "diffusers_dm": 19 | """ 20 | model = get_diffusion_model(model_cfg=model_cfg, network=network, ema=ema) 21 | """ 22 | model = Diffusers( 23 | eps_model=network, 24 | sched_params=model_cfg["generative_model"]["sched_params"], 25 | loss=model_cfg["generative_model"].get("loss", "l2"), 26 | extra_loss=model_cfg["generative_model"].get("extra_loss", {}), 27 | ema=ema, 28 | sc_params=None 29 | ) 30 | else: 31 | raise NotImplementedError() 32 | return model -------------------------------------------------------------------------------- /sam/openfold/utils/argparse_utils.py: -------------------------------------------------------------------------------- 1 | from argparse import HelpFormatter 2 | from operator import attrgetter 3 | 4 | class ArgparseAlphabetizer(HelpFormatter): 5 | """ 6 | Sorts the optional arguments of an argparse parser alphabetically 7 | """ 8 | 9 | @staticmethod 10 | def sort_actions(actions): 11 | return sorted(actions, key=attrgetter("option_strings")) 12 | 13 | # Formats the help message 14 | def add_arguments(self, actions): 15 | actions = ArgparseAlphabetizer.sort_actions(actions) 16 | super(ArgparseAlphabetizer, self).add_arguments(actions) 17 | 18 | # Formats the usage message 19 | def add_usage(self, usage, actions, groups, prefix=None): 20 | actions = ArgparseAlphabetizer.sort_actions(actions) 21 | args = usage, actions, groups, prefix 22 | super(ArgparseAlphabetizer, self).add_usage(*args) 23 | 24 | 25 | def remove_arguments(parser, args): 26 | for arg in args: 27 | for action in parser._actions: 28 | opts = vars(action)["option_strings"] 29 | if(arg in opts): 30 | parser._handle_conflict_resolve(None, [(arg, action)]) 31 | -------------------------------------------------------------------------------- /sam/minimizer/params/mizu_cfg.mdcath.yaml: -------------------------------------------------------------------------------- 1 | opt: 2 | energy_params: 3 | angle_const: 1000 4 | bond_const: 10000 5 | chi_const: 1000 6 | improper_dihedral_const: 10 7 | nb_const: 250 8 | nb_form: l2 9 | early_stopping_hc_score: 0.7 10 | early_stopping_hc_thresh: 0.175 11 | phi_psi_const: 50 12 | proper_dihedral_const: 10 13 | cabl_const: 100000 14 | history_size: 100 15 | max_iter: 10 16 | nb_centers_threshold: 1.0 17 | step_size: 1.0 18 | steps: 30 19 | cabl_init_range: [0.357, 0.411] 20 | # gradient_clip_mode: value 21 | # gradient_clip: 100000.0 # 80000.0 22 | opt: lbfgs 23 | 24 | opt_ini: 25 | energy_params: 26 | angle_const: 1000 27 | bond_const: 10000 28 | chi_const: 1000 29 | improper_dihedral_const: 10 30 | nb_const: 100 31 | nb_form: l2 32 | early_stopping_hc_score: 0.7 33 | early_stopping_hc_thresh: 0.175 34 | phi_psi_const: 50 35 | proper_dihedral_const: 10 36 | cabl_const: 100000 37 | nb_centers_threshold: 1.0 38 | cabl_init_range: [0.357, 0.411] 39 | step_size: 0.001 40 | steps: 50 41 | beta1: 0.5 42 | beta2: 0.9 43 | opt: adam 44 | 45 | top: 46 | nb_mode: os 47 | nb_os_tol: null 48 | -------------------------------------------------------------------------------- /sam/openfold/utils/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Geometry Module.""" 15 | 16 | from sam.openfold.utils.geometry import rigid_matrix_vector 17 | from sam.openfold.utils.geometry import rotation_matrix 18 | from sam.openfold.utils.geometry import vector 19 | 20 | Rot3Array = rotation_matrix.Rot3Array 21 | Rigid3Array = rigid_matrix_vector.Rigid3Array 22 | 23 | Vec3Array = vector.Vec3Array 24 | square_euclidean_distance = vector.square_euclidean_distance 25 | euclidean_distance = vector.euclidean_distance 26 | dihedral_angle = vector.dihedral_angle 27 | dot = vector.dot 28 | cross = vector.cross 29 | -------------------------------------------------------------------------------- /sam/minimizer/params/mizu_cfg.mdcath_A.yaml: -------------------------------------------------------------------------------- 1 | opt: 2 | energy_params: 3 | angle_const: 1 4 | bond_const: 0.1 5 | chi_const: 1000 6 | improper_dihedral_const: 1 7 | nb_const: 250 8 | nb_form: l2 9 | early_stopping_hc_score: 0.7 10 | early_stopping_hc_thresh: 0.175 11 | phi_psi_const: 50 12 | proper_dihedral_const: 1 13 | cabl_const: 100000 14 | history_size: 100 15 | max_iter: 10 16 | nb_centers_threshold: 1.0 17 | step_size: 1.0 18 | steps: 30 19 | cabl_init_range: [0.357, 0.411] 20 | # gradient_clip_mode: value 21 | # gradient_clip: 100000.0 # 80000.0 22 | opt: lbfgs 23 | 24 | opt_ini: 25 | energy_params: 26 | angle_const: 1 27 | bond_const: 0.1 28 | chi_const: 1000 29 | improper_dihedral_const: 1 30 | nb_const: 100 31 | nb_form: l2 32 | early_stopping_hc_score: 0.7 33 | early_stopping_hc_thresh: 0.175 34 | phi_psi_const: 50 35 | proper_dihedral_const: 1 36 | cabl_const: 100000 37 | nb_centers_threshold: 1.0 38 | cabl_init_range: [0.357, 0.411] 39 | step_size: 0.001 40 | steps: 50 41 | beta1: 0.5 42 | beta2: 0.9 43 | opt: adam 44 | 45 | top: 46 | nb_mode: os 47 | nb_os_tol: null 48 | use_ff_consts: true 49 | -------------------------------------------------------------------------------- /sam/nn/noise_prediction/eps/__init__.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from sam.data.sequences import get_num_beads 3 | from sam.nn.noise_prediction.eps.eps_simple import ( 4 | LatentEpsNetwork_v02, SAM_LatentEpsNetwork_v02 5 | ) 6 | 7 | def get_eps_network(model_cfg): 8 | # Get the class for the noise prediction network. 9 | if model_cfg["generator"]["arch"] == "eps_v02": 10 | model_cls = LatentEpsNetwork_v02 11 | wrapper_cls = SAM_LatentEpsNetwork_v02 12 | else: 13 | raise KeyError(model["generator"]["arch"]) 14 | # Get the arguments of the eps network class. 15 | eps_args = list(inspect.signature(model_cls.__init__).parameters.keys()) 16 | eps_args.remove("input_dim") 17 | # Get from 'model_cfg' the corresponding arguments. 18 | eps_params = {} 19 | for eps_arg in eps_args: 20 | if eps_arg in model_cfg["generator"]: 21 | eps_params[eps_arg] = model_cfg["generator"][eps_arg] 22 | # Initialize the network. 23 | return wrapper_cls( 24 | input_dim=model_cfg["generative_stack"]["encoding_dim"], 25 | use_res_ids=model_cfg.get("data", {}).get("res_ids_mode") is not None, 26 | num_beads=get_num_beads(model_cfg.get("data", {}).get("alphabet")), 27 | **eps_params 28 | ) 29 | -------------------------------------------------------------------------------- /sam/nn/autoencoder/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | from sam.nn.autoencoder.encoder.aa import AllAtomEncoder_v01 4 | from sam.data.sequences import get_num_beads 5 | 6 | 7 | def get_encoder(model_cfg, output_dim=None): 8 | """ 9 | Returns an object of the encoder class specified in `model_cfg`. 10 | """ 11 | 12 | # Use a coarse-grained representation. 13 | if model_cfg["encoder"]["arch"] == "enc_aa_v01": 14 | enc_class = AllAtomEncoder_v01 15 | else: 16 | raise KeyError(model_cfg["encoder"]["arch"]) 17 | 18 | # Get the arguments of the encoder network class. 19 | args = list( 20 | inspect.signature(enc_class.__init__).parameters.keys()) 21 | args.remove("encoding_dim") 22 | # Get from 'model_cfg' the corresponding arguments. 23 | params = {} 24 | for arg in args: 25 | if arg in model_cfg["encoder"]: 26 | params[arg] = model_cfg["encoder"][arg] 27 | # Initialize the network. 28 | return enc_class( 29 | encoding_dim=output_dim if output_dim is not None \ 30 | else model_cfg["generative_stack"]["encoding_dim"], 31 | use_res_ids=model_cfg.get("data", {}).get("res_ids_mode") is not None, 32 | num_beads=get_num_beads(model_cfg.get("data", {}).get("alphabet")), 33 | **params) -------------------------------------------------------------------------------- /sam/nn/autoencoder/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | from sam.nn.autoencoder.decoder.aa import AllAtomDecoder_v01 4 | from sam.data.sequences import get_num_beads 5 | 6 | 7 | def get_decoder(model_cfg, input_dim=None, output_dim=None): 8 | """ 9 | Returns an object of the decoder class specified in `model_cfg`. 10 | """ 11 | 12 | # Use a coarse-grained representation. 13 | if model_cfg["decoder"]["arch"] == "dec_aa_v01": 14 | dec_class = AllAtomDecoder_v01 15 | else: 16 | raise KeyError(model_cfg["decoder"]["arch"]) 17 | 18 | # Get the arguments of the decoder network class. 19 | args = list( 20 | inspect.signature(dec_class.__init__).parameters.keys()) 21 | args.remove("encoding_dim") 22 | # Get from 'model_cfg' the corresponding arguments. 23 | params = {} 24 | for arg in args: 25 | if arg in model_cfg["decoder"]: 26 | params[arg] = model_cfg["decoder"][arg] 27 | # Initialize the network. 28 | return dec_class( 29 | encoding_dim=input_dim if input_dim is not None \ 30 | else model_cfg["generative_stack"]["encoding_dim"], 31 | output_dim=output_dim if output_dim is not None else 3, 32 | use_res_ids=model_cfg.get("data", {}).get("res_ids_mode") is not None, 33 | num_beads=get_num_beads(model_cfg.get("data", {}).get("alphabet")), 34 | **params 35 | ) -------------------------------------------------------------------------------- /sam/openfold/utils/geometry/quat_rigid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from sam.openfold.model.primitives import Linear 5 | from sam.openfold.utils.geometry.rigid_matrix_vector import Rigid3Array 6 | from sam.openfold.utils.geometry.rotation_matrix import Rot3Array 7 | from sam.openfold.utils.geometry.vector import Vec3Array 8 | 9 | 10 | class QuatRigid(nn.Module): 11 | def __init__(self, c_hidden, full_quat): 12 | super().__init__() 13 | self.full_quat = full_quat 14 | if self.full_quat: 15 | rigid_dim = 7 16 | else: 17 | rigid_dim = 6 18 | 19 | self.linear = Linear(c_hidden, rigid_dim, init="final", precision=torch.float32) 20 | 21 | def forward(self, activations: torch.Tensor) -> Rigid3Array: 22 | # NOTE: During training, this needs to be run in higher precision 23 | rigid_flat = self.linear(activations) 24 | 25 | rigid_flat = torch.unbind(rigid_flat, dim=-1) 26 | if(self.full_quat): 27 | qw, qx, qy, qz = rigid_flat[:4] 28 | translation = rigid_flat[4:] 29 | else: 30 | qx, qy, qz = rigid_flat[:3] 31 | qw = torch.ones_like(qx) 32 | translation = rigid_flat[3:] 33 | 34 | rotation = Rot3Array.from_quaternion( 35 | qw, qx, qy, qz, normalize=True, 36 | ) 37 | translation = Vec3Array(*translation) 38 | return Rigid3Array(rotation, translation) 39 | -------------------------------------------------------------------------------- /sam/openfold/data/tools/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Common utilities for data pipeline tools.""" 17 | import contextlib 18 | import datetime 19 | import logging 20 | import shutil 21 | import tempfile 22 | import time 23 | from typing import Optional 24 | 25 | 26 | @contextlib.contextmanager 27 | def tmpdir_manager(base_dir: Optional[str] = None): 28 | """Context manager that deletes a temporary directory on exit.""" 29 | tmpdir = tempfile.mkdtemp(dir=base_dir) 30 | try: 31 | yield tmpdir 32 | finally: 33 | shutil.rmtree(tmpdir, ignore_errors=True) 34 | 35 | 36 | @contextlib.contextmanager 37 | def timing(msg: str): 38 | logging.info("Started %s", msg) 39 | tic = time.perf_counter() 40 | yield 41 | toc = time.perf_counter() 42 | logging.info("Finished %s in %.3f seconds", msg, toc - tic) 43 | 44 | 45 | def to_date(s: str): 46 | return datetime.datetime( 47 | year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10]) 48 | ) 49 | -------------------------------------------------------------------------------- /sam/data/trajectory_reader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Load trajectory data. 3 | """ 4 | import numpy as np 5 | import mdtraj 6 | from sam.data.topology import slice_ca_traj 7 | from sam.coords import sample_data 8 | 9 | def get_ca_traj( 10 | input: dict, 11 | slice_traj: bool = True, 12 | n_frames: int = 1000, 13 | n_trajs: int = None, 14 | frames_mode: str = "ensemble", 15 | get_xyx: bool = False): 16 | 17 | # Read xyz data from a trajectory file. 18 | xyz = [] 19 | if slice_traj: 20 | slice_func = slice_ca_traj 21 | else: 22 | slice_func = lambda t: t 23 | top_traj = slice_func(mdtraj.load(input["topology"])) 24 | 25 | if n_trajs is None: 26 | sel_trajectories = input["trajectories"] 27 | else: 28 | sel_trajectories = np.random.choice(input["trajectories"], n_trajs, 29 | replace=False) 30 | 31 | # Actually parse each trajectory file. 32 | for traj_fp_i in sel_trajectories: 33 | # Load the trajectory. 34 | traj_i = slice_func( 35 | mdtraj.load(traj_fp_i, top=input["topology"])) 36 | xyz_i = traj_i.xyz 37 | # Sample frames with mode "trajectory". 38 | if frames_mode == "trajectory": 39 | xyz_i = sample_data(data=xyz_i, n_samples=n_frames) 40 | if xyz_i.shape[0] == 0: 41 | raise ValueError() 42 | # Store the frames. 43 | xyz.append(xyz_i) 44 | 45 | # Begin preparing the results to return. 46 | if not xyz: 47 | raise ValueError("No data found in: {}".format( 48 | repr(input["trajectories"]))) 49 | xyz = np.concatenate(xyz, axis=0) 50 | 51 | # Sample frames with mode "ensemble". 52 | if frames_mode == "ensemble": 53 | xyz = sample_data(data=xyz, n_samples=n_frames) 54 | if get_xyx: 55 | return xyz 56 | else: 57 | traj = mdtraj.Trajectory(xyz=xyz, topology=top_traj.topology) 58 | return traj -------------------------------------------------------------------------------- /sam/analysis/melting_curves.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import curve_fit 3 | 4 | 5 | # Define the sigmoid function 6 | def sigmoid(T, T_m, k): 7 | return 1 / (1 + np.exp(-(T - T_m) / k)) 8 | 9 | def get_Tm(T, f_T, p0=[500, -10], get_err=False): 10 | # Fit the data 11 | popt, pcov = curve_fit(sigmoid, T, f_T, p0=p0) # Initial guesses for T_m and k 12 | # Extract fitted parameters 13 | T_m, k = popt 14 | if not get_err: 15 | return T_m, k 16 | else: 17 | return (T_m, k), pcov 18 | 19 | def b_sigmoid(T, Tm, slope, top, bottom): 20 | return bottom + (top - bottom)/(1 + np.exp(-(T - Tm)/slope)) 21 | 22 | def fit_hTm(T, f_T, p0=[500, -10]): 23 | top = np.max(f_T) 24 | bottom = np.min(f_T) 25 | data_sigmoid = lambda T, Tm, slope: b_sigmoid(T, Tm, slope, top=top, bottom=bottom) 26 | popt, pcov = curve_fit(data_sigmoid, T, f_T, p0=p0) # Initial guesses for T_m and k 27 | # Extract fitted parameters 28 | Tm, slope = popt 29 | fitted_sigmoid = lambda T: b_sigmoid(T, Tm=Tm, slope=slope, top=top, bottom=bottom) 30 | return Tm, slope, fitted_sigmoid 31 | 32 | def plot_sigmoid( 33 | ax, T, f_T, color, use_b_fit=True, p0=[500, -10], n_plot_points=250, 34 | label=None, use_label=True 35 | ): 36 | if not use_b_fit: 37 | Tm, k = get_Tm(T, f_T, p0=p0) 38 | # f_T = (f_T - f_T.min())/(f_T.max() - f_T.min()) 39 | T_fit = np.linspace(min(T), max(T), n_plot_points) 40 | f_T_fit = sigmoid(T_fit, Tm, k) 41 | # f_fit = (f_T.max() - f_T.min())*f_fit + f_T.min() 42 | else: 43 | Tm, _, fitted_sigmoid = fit_hTm(T, f_T) 44 | T_fit = np.linspace(min(T), max(T), n_plot_points) 45 | f_T_fit = fitted_sigmoid(T_fit) 46 | if use_label: 47 | if label is None: 48 | label = rf"$\hat{{T}}_m$={Tm:.0f} K" 49 | else: 50 | label = None 51 | plot_s = ax.plot( 52 | T_fit, f_T_fit, color=color, alpha=0.5, ls="--", 53 | label=label 54 | ) 55 | return Tm -------------------------------------------------------------------------------- /data/input/README.md: -------------------------------------------------------------------------------- 1 | # Input PDB files used in the aSAM article 2 | Use the links below to download the input PDB files used in the aSAM article. Please refer to the article for details on how the PDB files were obtained. 3 | 4 | 5 | ## ATLAS 6 | Input files for all training/validation/test domains. 7 | 8 | Link: [input_atlas.zip](https://github.com/giacomo-janson/sam2/releases/download/data-1.0/input_atlas.zip) 9 | 10 | ## mdCATH 11 | Input files for all training/validation/test domains. **NOTE**: To generate ensembles of test set domains at all temperatures, we used the input file of the domain at 320 K. 12 | 13 | Link: [input_mdcath.zip](https://github.com/giacomo-janson/sam2/releases/download/data-1.0/input_mdcath.zip) 14 | 15 | ## Fast-folding proteins 16 | Input files for the 12 fast-folding proteins analyzed in the aSAM article (Fig. 6). Their sequence is the same one found in the [How fast-folding proteins fold](https://pubmed.ncbi.nlm.nih.gov/22034434/) article. 17 | 18 | Link: [input_fast_folding.zip](https://github.com/giacomo-janson/sam2/releases/download/data-1.0/input_fast_folding.zip) 19 | 20 | ## Monomeric proteins with experimentally-determined melting tmeperature 21 | Input files for the 62 monomeric proteins (from [Pucci et al., 2017](https://pubmed.ncbi.nlm.nih.gov/29036273/)) analyzed in the aSAM article (Fig. 7a). 22 | 23 | Link: [input_experimental_tm.zip](https://github.com/giacomo-janson/sam2/releases/download/data-1.0/input_experimental_tm.zip) 24 | 25 | ## Homologous pairs of meso- and thermophilic proteins 26 | Input files for the 5 pairs of proteins (from [Razvi and Scholtz 2006](https://pubmed.ncbi.nlm.nih.gov/16815912/)) analyzed in the aSAM article (Fig. 7b) 27 | 28 | Link: [input_meso_thermo_pairs.zip](https://github.com/giacomo-janson/sam2/releases/download/data-1.0/input_meso_thermo_pairs.zip) 29 | 30 | ## RFdiffusion proteins, modeled by AF2 31 | Input files for the 9 [RFdiffusion](https://github.com/RosettaCommons/RFdiffusion) and control AlphaFold2 models from [AlphaFold Database](https://alphafold.ebi.ac.uk) analyzed in the aSAM article (Fig. 7c). RFdiffusion models were copied from [here](https://figshare.com/s/439fdd59488215753bc3). 32 | 33 | Link: [input_rfdiffusion.zip](https://github.com/giacomo-janson/sam2/releases/download/data-1.0/input_rfdiffusion.zip) -------------------------------------------------------------------------------- /sam/openfold/utils/validation_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | 16 | 17 | def drmsd(structure_1, structure_2, mask=None): 18 | def prep_d(structure): 19 | d = structure[..., :, None, :] - structure[..., None, :, :] 20 | d = d ** 2 21 | d = torch.sqrt(torch.sum(d, dim=-1)) 22 | return d 23 | 24 | d1 = prep_d(structure_1) 25 | d2 = prep_d(structure_2) 26 | 27 | drmsd = d1 - d2 28 | drmsd = drmsd ** 2 29 | if(mask is not None): 30 | drmsd = drmsd * (mask[..., None] * mask[..., None, :]) 31 | drmsd = torch.sum(drmsd, dim=(-1, -2)) 32 | n = d1.shape[-1] if mask is None else torch.min(torch.sum(mask, dim=-1)) 33 | drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.) 34 | drmsd = torch.sqrt(drmsd) 35 | 36 | return drmsd 37 | 38 | 39 | def drmsd_np(structure_1, structure_2, mask=None): 40 | structure_1 = torch.tensor(structure_1) 41 | structure_2 = torch.tensor(structure_2) 42 | if(mask is not None): 43 | mask = torch.tensor(mask) 44 | 45 | return drmsd(structure_1, structure_2, mask) 46 | 47 | 48 | def gdt(p1, p2, mask, cutoffs): 49 | n = torch.sum(mask, dim=-1) 50 | 51 | p1 = p1.float() 52 | p2 = p2.float() 53 | distances = torch.sqrt(torch.sum((p1 - p2)**2, dim=-1)) 54 | scores = [] 55 | for c in cutoffs: 56 | score = torch.sum((distances <= c) * mask, dim=-1) / n 57 | score = torch.mean(score) 58 | scores.append(score) 59 | 60 | return sum(scores) / len(scores) 61 | 62 | 63 | def gdt_ts(p1, p2, mask): 64 | return gdt(p1, p2, mask, [1., 2., 4., 8.]) 65 | 66 | 67 | def gdt_ha(p1, p2, mask): 68 | return gdt(p1, p2, mask, [0.5, 1., 2., 4.]) 69 | 70 | -------------------------------------------------------------------------------- /scripts/ensemble_analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Analyze an ensemble againts a reference structure and compute scores used in the 3 | aSAM article. Specifically: 4 | (i) Folded state fraction (FSF), 5 | (ii) Secondary structure elements preservation (SSEP), 6 | (iii) average initRMSD (RMSD with respect to the reference structure) 7 | """ 8 | 9 | import os 10 | import json 11 | import pathlib 12 | import argparse 13 | import json 14 | import mdtraj 15 | from sam.trajectory import calc_initrmsd, calc_ssep, calc_q_values 16 | 17 | 18 | if __name__ == "__main__": 19 | 20 | parser = argparse.ArgumentParser(description=__doc__) 21 | parser.add_argument('-n', '--native_pdb', type=str, required=True, 22 | help='a reference PDB file, which may represent a native structure or' 23 | ' an initial structure for MD' 24 | ) 25 | parser.add_argument('-p', '--ensemble_top', type=str, required=True, 26 | help='topology file of the input ensemble' 27 | ) 28 | parser.add_argument('-t', '--ensemble_traj', type=str, required=True, 29 | help='trajectory file of the input ensemble' 30 | ) 31 | parser.add_argument('--q_thresh', type=float, default=0.6, 32 | help='Q value threshold for defining the folded state' 33 | ) 34 | args = parser.parse_args() 35 | 36 | # Will store the results here. 37 | json_data = {} 38 | 39 | # Load the ensemble and the native structure 40 | traj = mdtraj.load(args.ensemble_traj, top=args.ensemble_top) 41 | native_traj = mdtraj.load(args.native_pdb) 42 | 43 | # Compute FSF. 44 | # First compute Q values (we also use them in the paper and show their 45 | # histograms). 46 | q_values = calc_q_values( 47 | traj=traj, 48 | native_traj=native_traj, 49 | beta=50.0, 50 | lambda_=1.2, 51 | delta=0.0, 52 | threshold=1.0 53 | ) 54 | # Assign folded/unfolded states. 55 | folded_state = q_values > args.q_thresh 56 | # Calculate the fraction of snapshots in the folded state. 57 | json_data["fsf"] = float(folded_state.mean()) 58 | 59 | # Compute SSEP. 60 | seep = calc_ssep(traj=traj, native_traj=native_traj) 61 | json_data["seep"] = float(seep.mean()) 62 | 63 | # Compute initRMSD. 64 | initrmsd = calc_initrmsd( 65 | traj=traj, init_traj=native_traj, is_ca=False, get_tm=False 66 | ) 67 | json_data["avg_initrmsd"] = float(initrmsd.mean()) 68 | 69 | # Print results. 70 | print(json_data) 71 | -------------------------------------------------------------------------------- /sam/openfold/model/dropout.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | import torch.nn as nn 18 | from functools import partialmethod 19 | from typing import Union, List 20 | 21 | 22 | class Dropout(nn.Module): 23 | """ 24 | Implementation of dropout with the ability to share the dropout mask 25 | along a particular dimension. 26 | 27 | If not in training mode, this module computes the identity function. 28 | """ 29 | 30 | def __init__(self, r: float, batch_dim: Union[int, List[int]]): 31 | """ 32 | Args: 33 | r: 34 | Dropout rate 35 | batch_dim: 36 | Dimension(s) along which the dropout mask is shared 37 | """ 38 | super(Dropout, self).__init__() 39 | 40 | self.r = r 41 | if type(batch_dim) == int: 42 | batch_dim = [batch_dim] 43 | self.batch_dim = batch_dim 44 | self.dropout = nn.Dropout(self.r) 45 | 46 | def forward(self, x: torch.Tensor) -> torch.Tensor: 47 | """ 48 | Args: 49 | x: 50 | Tensor to which dropout is applied. Can have any shape 51 | compatible with self.batch_dim 52 | """ 53 | shape = list(x.shape) 54 | if self.batch_dim is not None: 55 | for bd in self.batch_dim: 56 | shape[bd] = 1 57 | mask = x.new_ones(shape) 58 | mask = self.dropout(mask) 59 | x *= mask 60 | return x 61 | 62 | 63 | class DropoutRowwise(Dropout): 64 | """ 65 | Convenience class for rowwise dropout as described in subsection 66 | 1.11.6. 67 | """ 68 | 69 | __init__ = partialmethod(Dropout.__init__, batch_dim=-3) 70 | 71 | 72 | class DropoutColumnwise(Dropout): 73 | """ 74 | Convenience class for columnwise dropout as described in subsection 75 | 1.11.6. 76 | """ 77 | 78 | __init__ = partialmethod(Dropout.__init__, batch_dim=-2) 79 | -------------------------------------------------------------------------------- /sam/openfold/data/tools/parse_msa_files.py: -------------------------------------------------------------------------------- 1 | import os, argparse, pickle, tempfile, concurrent 2 | from sam.openfold.data import parsers 3 | from concurrent.futures import ProcessPoolExecutor 4 | 5 | 6 | def parse_stockholm_file(alignment_dir: str, stockholm_file: str): 7 | path = os.path.join(alignment_dir, stockholm_file) 8 | file_name,_ = os.path.splitext(stockholm_file) 9 | with open(path, "r") as infile: 10 | msa = parsers.parse_stockholm(infile.read()) 11 | infile.close() 12 | return {file_name: msa} 13 | 14 | 15 | def parse_a3m_file(alignment_dir: str, a3m_file: str): 16 | path = os.path.join(alignment_dir, a3m_file) 17 | file_name,_ = os.path.splitext(a3m_file) 18 | with open(path, "r") as infile: 19 | msa = parsers.parse_a3m(infile.read()) 20 | infile.close() 21 | return {file_name: msa} 22 | 23 | 24 | def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: list, alignment_dir:str): 25 | # Number of workers based on the tasks 26 | msa_results={} 27 | a3m_tasks = [(alignment_dir, f) for f in a3m_files] 28 | sto_tasks = [(alignment_dir, f) for f in stockholm_files] 29 | with ProcessPoolExecutor(max_workers = len(a3m_tasks) + len(sto_tasks)) as executor: 30 | a3m_futures = {executor.submit(parse_a3m_file, *task): task for task in a3m_tasks} 31 | sto_futures = {executor.submit(parse_stockholm_file, *task): task for task in sto_tasks} 32 | 33 | for future in concurrent.futures.as_completed(a3m_futures | sto_futures): 34 | try: 35 | result = future.result() 36 | msa_results.update(result) 37 | except Exception as exc: 38 | print(f'Task generated an exception: {exc}') 39 | return msa_results 40 | 41 | 42 | def main(): 43 | parser = argparse.ArgumentParser(description='Process msa files in parallel') 44 | parser.add_argument('--alignment_dir', type=str, help='path to alignment dir') 45 | args = parser.parse_args() 46 | alignment_dir = args.alignment_dir 47 | stockholm_files = [i for i in os.listdir(alignment_dir) 48 | if all([i.endswith('.sto'), "hmm_output" not in i, "uniprot_hits" not in i])] 49 | a3m_files = [i for i in os.listdir(alignment_dir) if i.endswith('.a3m')] 50 | msa_data = run_parse_all_msa_files_multiprocessing(stockholm_files, a3m_files, alignment_dir) 51 | with tempfile.NamedTemporaryFile('wb', suffix='.pkl', delete=False) as outfile: 52 | pickle.dump(msa_data, outfile) 53 | print(outfile.name) 54 | 55 | 56 | if __name__ == "__main__": 57 | main() -------------------------------------------------------------------------------- /sam/openfold/utils/exponential_moving_average.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | 6 | from sam.openfold.utils.tensor_utils import tensor_tree_map 7 | 8 | 9 | class ExponentialMovingAverage: 10 | """ 11 | Maintains moving averages of parameters with exponential decay 12 | 13 | At each step, the stored copy `copy` of each parameter `param` is 14 | updated as follows: 15 | 16 | `copy = decay * copy + (1 - decay) * param` 17 | 18 | where `decay` is an attribute of the ExponentialMovingAverage object. 19 | """ 20 | 21 | def __init__(self, model: nn.Module, decay: float): 22 | """ 23 | Args: 24 | model: 25 | A torch.nn.Module whose parameters are to be tracked 26 | decay: 27 | A value (usually close to 1.) by which updates are 28 | weighted as part of the above formula 29 | """ 30 | super(ExponentialMovingAverage, self).__init__() 31 | 32 | clone_param = lambda t: t.clone().detach() 33 | self.params = tensor_tree_map(clone_param, model.state_dict()) 34 | self.decay = decay 35 | self.device = next(model.parameters()).device 36 | 37 | def to(self, device): 38 | self.params = tensor_tree_map(lambda t: t.to(device), self.params) 39 | self.device = device 40 | 41 | def _update_state_dict_(self, update, state_dict): 42 | with torch.no_grad(): 43 | for k, v in update.items(): 44 | stored = state_dict[k] 45 | if not isinstance(v, torch.Tensor): 46 | self._update_state_dict_(v, stored) 47 | else: 48 | diff = stored - v 49 | diff *= 1 - self.decay 50 | stored -= diff 51 | 52 | def update(self, model: torch.nn.Module) -> None: 53 | """ 54 | Updates the stored parameters using the state dict of the provided 55 | module. The module should have the same structure as that used to 56 | initialize the ExponentialMovingAverage object. 57 | """ 58 | self._update_state_dict_(model.state_dict(), self.params) 59 | 60 | def load_state_dict(self, state_dict: OrderedDict) -> None: 61 | for k in state_dict["params"].keys(): 62 | self.params[k] = state_dict["params"][k].clone() 63 | self.decay = state_dict["decay"] 64 | 65 | def state_dict(self) -> OrderedDict: 66 | return OrderedDict( 67 | { 68 | "params": self.params, 69 | "decay": self.decay, 70 | } 71 | ) 72 | -------------------------------------------------------------------------------- /scripts/ensemble_comparison.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compare two ensembles of the same proteins using the ensemble comparison 3 | and analysis scores used in the aSAM article (see Table 1). Specifically: 4 | (i) PCC Ca RMSF, 5 | (ii) chiJSD, 6 | (iii) heavy clashes, 7 | (iv) peptide bond length violations. 8 | """ 9 | 10 | import os 11 | import json 12 | import pathlib 13 | import argparse 14 | import numpy as np 15 | import mdtraj 16 | from sam.data.topology import slice_ca_traj 17 | from sam.evaluation.scores import score_pcc_ca_rmsf, score_chiJSD, mstats_stereo 18 | 19 | 20 | if __name__ == "__main__": 21 | 22 | parser = argparse.ArgumentParser(description=__doc__) 23 | parser.add_argument('-P', '--ref_top', type=str, required=True, 24 | help='topology file of the reference ensemble' 25 | ) 26 | parser.add_argument('-T', '--ref_traj', type=str, required=True, 27 | help='trajectory file of the reference ensemble' 28 | ) 29 | parser.add_argument('-p', '--hat_top', type=str, required=True, 30 | help='topology file of the proposed ensemble' 31 | ) 32 | parser.add_argument('-t', '--hat_traj', type=str, required=True, 33 | help='trajectory file of the proposed ensemble' 34 | ) 35 | parser.add_argument('-i', '--init_pdb', type=str, required=True, 36 | help='PDB file of some initial (or reference) structure' 37 | ) 38 | args = parser.parse_args() 39 | 40 | # Will store the results here. 41 | json_data = {} 42 | 43 | # Load the heavy atom (ha) ensembles. 44 | ref_ha_traj = mdtraj.load(args.ref_traj, top=args.ref_top) 45 | hat_ha_traj = mdtraj.load(args.hat_traj, top=args.hat_top) 46 | ini_ha_traj = mdtraj.load(args.init_pdb) 47 | # Create the Ca ensembles. 48 | ref_ca_traj = slice_ca_traj(ref_ha_traj) 49 | hat_ca_traj = slice_ca_traj(hat_ha_traj) 50 | ini_ca_traj = slice_ca_traj(ini_ha_traj) 51 | 52 | # Score PCC Ca RMSF. 53 | pcc_ca_rmsf = score_pcc_ca_rmsf( 54 | ref_ca_traj=ref_ca_traj, 55 | hat_ca_traj=hat_ca_traj, 56 | ini_ca_traj=ini_ca_traj 57 | ) 58 | json_data["pcc_ca_rmsf"] = float(pcc_ca_rmsf) 59 | 60 | # Score chiJSD. 61 | chi_jsd = score_chiJSD( 62 | ref_traj=ref_ha_traj, hat_traj=hat_ha_traj 63 | ) 64 | json_data["chi_jsd"] = float(chi_jsd) 65 | 66 | # Heavy clashes and peptide bond length violations. 67 | stats = mstats_stereo(ref_ha_traj) 68 | json_data[f"ref_heavy_clash"] = float(np.mean(stats["heavy_clash_ha"])) 69 | json_data[f"ref_viol_c_n"] = float(np.mean(stats["viol_c_n"])) 70 | stats = mstats_stereo(hat_ha_traj) 71 | json_data[f"hat_heavy_clash"] = float(np.mean(stats["heavy_clash_ha"])) 72 | json_data[f"hat_viol_c_n"] = float(np.mean(stats["viol_c_n"])) 73 | 74 | # Print results. 75 | print(json_data) -------------------------------------------------------------------------------- /sam/openfold/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import operator 16 | import time 17 | 18 | import dllogger as logger 19 | from dllogger import JSONStreamBackend, StdOutBackend, Verbosity 20 | import numpy as np 21 | from pytorch_lightning import Callback 22 | import torch.cuda.profiler as profiler 23 | 24 | 25 | def is_main_process(): 26 | return int(os.getenv("LOCAL_RANK", "0")) == 0 27 | 28 | 29 | class PerformanceLoggingCallback(Callback): 30 | def __init__(self, log_file, global_batch_size, warmup_steps: int = 0, profile: bool = False): 31 | logger.init(backends=[JSONStreamBackend(Verbosity.VERBOSE, log_file), StdOutBackend(Verbosity.VERBOSE)]) 32 | self.warmup_steps = warmup_steps 33 | self.global_batch_size = global_batch_size 34 | self.step = 0 35 | self.profile = profile 36 | self.timestamps = [] 37 | 38 | def do_step(self): 39 | self.step += 1 40 | if self.profile and self.step == self.warmup_steps: 41 | profiler.start() 42 | if self.step > self.warmup_steps: 43 | self.timestamps.append(time.time()) 44 | 45 | def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): 46 | self.do_step() 47 | 48 | def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): 49 | self.do_step() 50 | 51 | def process_performance_stats(self, deltas): 52 | def _round3(val): 53 | return round(val, 3) 54 | 55 | throughput_imgps = _round3(self.global_batch_size / np.mean(deltas)) 56 | timestamps_ms = 1000 * deltas 57 | stats = { 58 | f"throughput": throughput_imgps, 59 | f"latency_mean": _round3(timestamps_ms.mean()), 60 | } 61 | for level in [90, 95, 99]: 62 | stats.update({f"latency_{level}": _round3(np.percentile(timestamps_ms, level))}) 63 | 64 | return stats 65 | 66 | def _log(self): 67 | if is_main_process(): 68 | diffs = list(map(operator.sub, self.timestamps[1:], self.timestamps[:-1])) 69 | deltas = np.array(diffs) 70 | stats = self.process_performance_stats(deltas) 71 | logger.log(step=(), data=stats) 72 | logger.flush() 73 | 74 | def on_train_end(self, trainer, pl_module): 75 | if self.profile: 76 | profiler.stop() 77 | self._log() 78 | 79 | def on_epoch_end(self, trainer, pl_module): 80 | self._log() 81 | -------------------------------------------------------------------------------- /sam/openfold/utils/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler): 5 | """ Implements the learning rate schedule defined in the AlphaFold 2 6 | supplement. A linear warmup is followed by a plateau at the maximum 7 | learning rate and then exponential decay. 8 | 9 | Note that the initial learning rate of the optimizer in question is 10 | ignored; use this class' base_lr parameter to specify the starting 11 | point of the warmup. 12 | """ 13 | def __init__(self, 14 | optimizer, 15 | last_epoch: int = -1, 16 | verbose: bool = False, 17 | base_lr: float = 0., 18 | max_lr: float = 0.001, 19 | warmup_no_steps: int = 1000, 20 | start_decay_after_n_steps: int = 50000, 21 | decay_every_n_steps: int = 50000, 22 | decay_factor: float = 0.95, 23 | ): 24 | step_counts = { 25 | "warmup_no_steps": warmup_no_steps, 26 | "start_decay_after_n_steps": start_decay_after_n_steps, 27 | } 28 | 29 | for k,v in step_counts.items(): 30 | if(v < 0): 31 | raise ValueError(f"{k} must be nonnegative") 32 | 33 | if(warmup_no_steps > start_decay_after_n_steps): 34 | raise ValueError( 35 | "warmup_no_steps must not exceed start_decay_after_n_steps" 36 | ) 37 | 38 | self.optimizer = optimizer 39 | self.last_epoch = last_epoch 40 | self.verbose = verbose 41 | self.base_lr = base_lr 42 | self.max_lr = max_lr 43 | self.warmup_no_steps = warmup_no_steps 44 | self.start_decay_after_n_steps = start_decay_after_n_steps 45 | self.decay_every_n_steps = decay_every_n_steps 46 | self.decay_factor = decay_factor 47 | 48 | super(AlphaFoldLRScheduler, self).__init__( 49 | optimizer, 50 | last_epoch=last_epoch, 51 | verbose=verbose, 52 | ) 53 | 54 | def state_dict(self): 55 | state_dict = { 56 | k:v for k,v in self.__dict__.items() if k not in ["optimizer"] 57 | } 58 | 59 | return state_dict 60 | 61 | def load_state_dict(self, state_dict): 62 | self.__dict__.update(state_dict) 63 | 64 | def get_lr(self): 65 | if(not self._get_lr_called_within_step): 66 | raise RuntimeError( 67 | "To get the last learning rate computed by the scheduler, use " 68 | "get_last_lr()" 69 | ) 70 | 71 | step_no = self.last_epoch 72 | 73 | if(step_no <= self.warmup_no_steps): 74 | lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr 75 | elif(step_no > self.start_decay_after_n_steps): 76 | steps_since_decay = step_no - self.start_decay_after_n_steps 77 | exp = (steps_since_decay // self.decay_every_n_steps) + 1 78 | lr = self.max_lr * (self.decay_factor ** exp) 79 | else: # plateau 80 | lr = self.max_lr 81 | 82 | return [lr for group in self.optimizer.param_groups] 83 | -------------------------------------------------------------------------------- /sam/openfold/utils/checkpointing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | from typing import Any, Tuple, List, Callable, Optional 16 | 17 | deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None 18 | if(deepspeed_is_installed): 19 | import deepspeed 20 | 21 | import torch 22 | import torch.utils.checkpoint 23 | 24 | 25 | BLOCK_ARG = Any 26 | BLOCK_ARGS = List[BLOCK_ARG] 27 | 28 | 29 | def get_checkpoint_fn(): 30 | deepspeed_is_configured = ( 31 | deepspeed_is_installed and 32 | deepspeed.checkpointing.is_configured() 33 | ) 34 | if(deepspeed_is_configured): 35 | checkpoint = deepspeed.checkpointing.checkpoint 36 | else: 37 | checkpoint = torch.utils.checkpoint.checkpoint 38 | 39 | return checkpoint 40 | 41 | 42 | @torch.jit.ignore 43 | def checkpoint_blocks( 44 | blocks: List[Callable], 45 | args: BLOCK_ARGS, 46 | blocks_per_ckpt: Optional[int], 47 | ) -> BLOCK_ARGS: 48 | """ 49 | Chunk a list of blocks and run each chunk with activation 50 | checkpointing. We define a "block" as a callable whose only inputs are 51 | the outputs of the previous block. 52 | 53 | Implements Subsection 1.11.8 54 | 55 | Args: 56 | blocks: 57 | List of blocks 58 | args: 59 | Tuple of arguments for the first block. 60 | blocks_per_ckpt: 61 | Size of each chunk. A higher value corresponds to fewer 62 | checkpoints, and trades memory for speed. If None, no checkpointing 63 | is performed. 64 | Returns: 65 | The output of the final block 66 | """ 67 | def wrap(a): 68 | return (a,) if type(a) is not tuple else a 69 | 70 | def exec(b, a): 71 | for block in b: 72 | a = wrap(block(*a)) 73 | return a 74 | 75 | def chunker(s, e): 76 | def exec_sliced(*a): 77 | return exec(blocks[s:e], a) 78 | 79 | return exec_sliced 80 | 81 | # Avoids mishaps when the blocks take just one argument 82 | args = wrap(args) 83 | 84 | if blocks_per_ckpt is None or not torch.is_grad_enabled(): 85 | return exec(blocks, args) 86 | elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks): 87 | raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)") 88 | 89 | checkpoint = get_checkpoint_fn() 90 | 91 | for s in range(0, len(blocks), blocks_per_ckpt): 92 | e = s + blocks_per_ckpt 93 | args = checkpoint(chunker(s, e), *args) 94 | args = wrap(args) 95 | 96 | return args 97 | -------------------------------------------------------------------------------- /sam/openfold/model/pair_transition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from typing import Optional 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from sam.openfold.model.primitives import Linear, LayerNorm 21 | from sam.openfold.utils.chunk_utils import chunk_layer 22 | 23 | 24 | class PairTransition(nn.Module): 25 | """ 26 | Implements Algorithm 15. 27 | """ 28 | 29 | def __init__(self, c_z, n): 30 | """ 31 | Args: 32 | c_z: 33 | Pair transition channel dimension 34 | n: 35 | Factor by which c_z is multiplied to obtain hidden channel 36 | dimension 37 | """ 38 | super(PairTransition, self).__init__() 39 | 40 | self.c_z = c_z 41 | self.n = n 42 | 43 | self.layer_norm = LayerNorm(self.c_z) 44 | self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu") 45 | self.relu = nn.ReLU() 46 | self.linear_2 = Linear(self.n * self.c_z, c_z, init="final") 47 | 48 | def _transition(self, z, mask): 49 | # [*, N_res, N_res, C_z] 50 | z = self.layer_norm(z) 51 | 52 | # [*, N_res, N_res, C_hidden] 53 | z = self.linear_1(z) 54 | z = self.relu(z) 55 | 56 | # [*, N_res, N_res, C_z] 57 | z = self.linear_2(z) 58 | z = z * mask 59 | 60 | return z 61 | 62 | @torch.jit.ignore 63 | def _chunk(self, 64 | z: torch.Tensor, 65 | mask: torch.Tensor, 66 | chunk_size: int, 67 | ) -> torch.Tensor: 68 | return chunk_layer( 69 | self._transition, 70 | {"z": z, "mask": mask}, 71 | chunk_size=chunk_size, 72 | no_batch_dims=len(z.shape[:-2]), 73 | ) 74 | 75 | def forward(self, 76 | z: torch.Tensor, 77 | mask: Optional[torch.Tensor] = None, 78 | chunk_size: Optional[int] = None, 79 | ) -> torch.Tensor: 80 | """ 81 | Args: 82 | z: 83 | [*, N_res, N_res, C_z] pair embedding 84 | Returns: 85 | [*, N_res, N_res, C_z] pair embedding update 86 | """ 87 | # DISCREPANCY: DeepMind forgets to apply the mask in this module. 88 | if mask is None: 89 | mask = z.new_ones(z.shape[:-1]) 90 | 91 | # [*, N_res, N_res, 1] 92 | mask = mask.unsqueeze(-1) 93 | 94 | if chunk_size is not None: 95 | z = self._chunk(z, mask, chunk_size) 96 | else: 97 | z = self._transition(z=z, mask=mask) 98 | 99 | return z 100 | -------------------------------------------------------------------------------- /sam/openfold/data/msa_identifiers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for extracting identifiers from MSA sequence descriptions.""" 16 | 17 | import dataclasses 18 | import re 19 | from typing import Optional 20 | 21 | 22 | # Sequences coming from UniProtKB database come in the 23 | # `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE` 24 | # or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively). 25 | _UNIPROT_PATTERN = re.compile( 26 | r""" 27 | ^ 28 | # UniProtKB/TrEMBL or UniProtKB/Swiss-Prot 29 | (?:tr|sp) 30 | \| 31 | # A primary accession number of the UniProtKB entry. 32 | (?P[A-Za-z0-9]{6,10}) 33 | # Occasionally there is a _0 or _1 isoform suffix, which we ignore. 34 | (?:_\d)? 35 | \| 36 | # TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic 37 | # protein ID code. 38 | (?:[A-Za-z0-9]+) 39 | _ 40 | # A mnemonic species identification code. 41 | (?P([A-Za-z0-9]){1,5}) 42 | # Small BFD uses a final value after an underscore, which we ignore. 43 | (?:_\d+)? 44 | $ 45 | """, 46 | re.VERBOSE) 47 | 48 | 49 | @dataclasses.dataclass(frozen=True) 50 | class Identifiers: 51 | species_id: str = '' 52 | 53 | 54 | def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: 55 | """Gets accession id and species from an msa sequence identifier. 56 | 57 | The sequence identifier has the format specified by 58 | _UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN. 59 | An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE` 60 | 61 | Args: 62 | msa_sequence_identifier: a sequence identifier. 63 | 64 | Returns: 65 | An `Identifiers` instance with a uniprot_accession_id and species_id. These 66 | can be empty in the case where no identifier was found. 67 | """ 68 | matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip()) 69 | if matches: 70 | return Identifiers( 71 | species_id=matches.group('SpeciesIdentifier') 72 | ) 73 | return Identifiers() 74 | 75 | 76 | def _extract_sequence_identifier(description: str) -> Optional[str]: 77 | """Extracts sequence identifier from description. Returns None if no match.""" 78 | split_description = description.split() 79 | if split_description: 80 | return split_description[0].partition('/')[0] 81 | else: 82 | return None 83 | 84 | 85 | def get_identifiers(description: str) -> Identifiers: 86 | """Computes extra MSA features from the description.""" 87 | sequence_identifier = _extract_sequence_identifier(description) 88 | if sequence_identifier is None: 89 | return Identifiers() 90 | else: 91 | return _parse_sequence_identifier(sequence_identifier) 92 | -------------------------------------------------------------------------------- /sam/openfold/np/relax/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utils for minimization.""" 17 | import io 18 | from sam.openfold.np import residue_constants 19 | from Bio import PDB 20 | import numpy as np 21 | from openmm import app as openmm_app 22 | from openmm.app.internal.pdbstructure import PdbStructure 23 | 24 | 25 | def overwrite_pdb_coordinates(pdb_str: str, pos) -> str: 26 | pdb_file = io.StringIO(pdb_str) 27 | structure = PdbStructure(pdb_file) 28 | topology = openmm_app.PDBFile(structure).getTopology() 29 | with io.StringIO() as f: 30 | openmm_app.PDBFile.writeFile(topology, pos, f) 31 | return f.getvalue() 32 | 33 | 34 | def overwrite_b_factors(pdb_str: str, bfactors: np.ndarray) -> str: 35 | """Overwrites the B-factors in pdb_str with contents of bfactors array. 36 | 37 | Args: 38 | pdb_str: An input PDB string. 39 | bfactors: A numpy array with shape [1, n_residues, 37]. We assume that the 40 | B-factors are per residue; i.e. that the nonzero entries are identical in 41 | [0, i, :]. 42 | 43 | Returns: 44 | A new PDB string with the B-factors replaced. 45 | """ 46 | if bfactors.shape[-1] != residue_constants.atom_type_num: 47 | raise ValueError( 48 | f"Invalid final dimension size for bfactors: {bfactors.shape[-1]}." 49 | ) 50 | 51 | parser = PDB.PDBParser(QUIET=True) 52 | handle = io.StringIO(pdb_str) 53 | structure = parser.get_structure("", handle) 54 | 55 | curr_resid = ("", "", "") 56 | idx = -1 57 | for atom in structure.get_atoms(): 58 | atom_resid = atom.parent.get_id() 59 | if atom_resid != curr_resid: 60 | idx += 1 61 | if idx >= bfactors.shape[0]: 62 | raise ValueError( 63 | "Index into bfactors exceeds number of residues. " 64 | "B-factors shape: {shape}, idx: {idx}." 65 | ) 66 | curr_resid = atom_resid 67 | atom.bfactor = bfactors[idx, residue_constants.atom_order["CA"]] 68 | 69 | new_pdb = io.StringIO() 70 | pdb_io = PDB.PDBIO() 71 | pdb_io.set_structure(structure) 72 | pdb_io.save(new_pdb) 73 | return new_pdb.getvalue() 74 | 75 | 76 | def assert_equal_nonterminal_atom_types( 77 | atom_mask: np.ndarray, ref_atom_mask: np.ndarray 78 | ): 79 | """Checks that pre- and post-minimized proteins have same atom set.""" 80 | # Ignore any terminal OXT atoms which may have been added by minimization. 81 | oxt = residue_constants.atom_order["OXT"] 82 | no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=bool) 83 | no_oxt_mask[..., oxt] = False 84 | np.testing.assert_almost_equal( 85 | ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask] 86 | ) 87 | -------------------------------------------------------------------------------- /sam/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import requests 4 | import yaml 5 | import zipfile 6 | 7 | 8 | def read_cfg_file(cfg_fp): 9 | if cfg_fp.endswith(".json"): 10 | with open(cfg_fp, "r") as i_fh: 11 | model_cfg = json.load(i_fh) 12 | elif cfg_fp.endswith(".yaml"): 13 | with open(cfg_fp, 'r') as i_fh: 14 | model_cfg = yaml.safe_load(i_fh) 15 | else: 16 | raise TypeError( 17 | f"Invalid extension for configuration file: {cfg_fp}. Must be a" 18 | " json or yaml file.") 19 | return model_cfg 20 | 21 | def print_msg(msg, verbose=True, tag="verbose"): 22 | if verbose: 23 | print(f"[{tag}]:", msg) 24 | 25 | def to_json_compatible(report): 26 | _report = {} 27 | for k in report: 28 | if isinstance(report[k], (str, int, float)): 29 | _report[k] = report[k] 30 | else: 31 | try: 32 | _report[k] = float(report[k]) 33 | except TypeError: 34 | _report[k] = "ERROR" 35 | return _report 36 | 37 | 38 | github_releases_url = "https://github.com/giacomo-janson/sam2/releases/download/data-1.0" 39 | 40 | def check_sam_weights(cfg_path: str, verbose: bool = True): 41 | model_cfg = read_cfg_file(cfg_path) 42 | if model_cfg["weights"]["path"] is None: 43 | download_path = download_sam_weights(model_cfg["weights"]["version"]) 44 | model_cfg["weights"]["path"] = download_path 45 | with open(cfg_path, "w") as o_fh: 46 | yaml.dump(model_cfg, o_fh) 47 | else: 48 | if not os.path.isdir(model_cfg["weights"]["path"]): 49 | raise FileNotFoundError( 50 | "Weights directory not found at: {}".format( 51 | model_cfg["weights"]["path"] 52 | ) 53 | ) 54 | 55 | def download_sam_weights(version: str, verbose: bool = True): 56 | filename = f"{version}.zip" 57 | url = github_releases_url + "/" + filename 58 | 59 | if os.getenv("SAM_WEIGHTS_PATH") is None: 60 | download_path = os.path.expanduser("~/.sam2/weights") 61 | else: 62 | download_path = os.getenv("SAM_WEIGHTS_PATH") 63 | os.makedirs(download_path, exist_ok=True) 64 | 65 | print_msg( 66 | f"# No aSAM weights were detected, beginning download now.", 67 | verbose=verbose, tag="download" 68 | ) 69 | print_msg( 70 | f"- Downloading aSAM weights from: {url}", 71 | verbose=verbose, tag="download" 72 | ) 73 | print_msg( 74 | f"- Weights will be saved at: {download_path}", 75 | verbose=verbose, tag="download" 76 | ) 77 | res = requests.get(url) 78 | if res.status_code != 200: 79 | raise OSError( 80 | f"unable to download file (status code {res.status_code})." 81 | ) 82 | save_path = os.path.join(download_path, filename) 83 | with open(save_path, "wb") as f: 84 | f.write(res.content) 85 | print_msg("- Download completed.", verbose=verbose, tag="download") 86 | print_msg("- Unzipping weight files.", verbose=verbose, tag="download") 87 | with zipfile.ZipFile(save_path, 'r') as zip_ref: 88 | zip_ref.extractall(download_path) 89 | os.remove(save_path) 90 | print_msg(f"- Weights are now ready.", verbose=verbose, tag="download") 91 | print_msg( 92 | "- Will update the input .yaml configuration file.", 93 | verbose=verbose, tag="download" 94 | ) 95 | return os.path.join(download_path, version) 96 | -------------------------------------------------------------------------------- /data/splits/mdcath/val.txt: -------------------------------------------------------------------------------- 1 | 1k12A00.320 2 | 1k12A00.348 3 | 1k12A00.379 4 | 1k12A00.413 5 | 1k12A00.450 6 | 4c0sB03.320 7 | 4c0sB03.348 8 | 4c0sB03.379 9 | 4c0sB03.413 10 | 4c0sB03.450 11 | 3rfyA02.320 12 | 3rfyA02.348 13 | 3rfyA02.379 14 | 3rfyA02.413 15 | 3rfyA02.450 16 | 4kcaA03.320 17 | 4kcaA03.348 18 | 4kcaA03.379 19 | 4kcaA03.413 20 | 4kcaA03.450 21 | 4zy7A00.320 22 | 4zy7A00.348 23 | 4zy7A00.379 24 | 4zy7A00.413 25 | 4zy7A00.450 26 | 1zvuA02.320 27 | 1zvuA02.348 28 | 1zvuA02.379 29 | 1zvuA02.413 30 | 1zvuA02.450 31 | 2yinA03.320 32 | 2yinA03.348 33 | 2yinA03.379 34 | 2yinA03.413 35 | 2yinA03.450 36 | 2ac1A02.320 37 | 2ac1A02.348 38 | 2ac1A02.379 39 | 2ac1A02.413 40 | 2ac1A02.450 41 | 1jsuC00.320 42 | 1jsuC00.348 43 | 1jsuC00.379 44 | 1jsuC00.413 45 | 1jsuC00.450 46 | 1x9mA01.320 47 | 1x9mA01.348 48 | 1x9mA01.379 49 | 1x9mA01.413 50 | 1x9mA01.450 51 | 2qz5A00.320 52 | 2qz5A00.348 53 | 2qz5A00.379 54 | 2qz5A00.413 55 | 2qz5A00.450 56 | 3fa9A00.320 57 | 3fa9A00.348 58 | 3fa9A00.379 59 | 3fa9A00.413 60 | 3fa9A00.450 61 | 2bvbA00.320 62 | 2bvbA00.348 63 | 2bvbA00.379 64 | 2bvbA00.413 65 | 2bvbA00.450 66 | 1vs5T00.320 67 | 1vs5T00.348 68 | 1vs5T00.379 69 | 1vs5T00.413 70 | 1vs5T00.450 71 | 1cpyA02.320 72 | 1cpyA02.348 73 | 1cpyA02.379 74 | 1cpyA02.413 75 | 1cpyA02.450 76 | 4g1iA02.320 77 | 4g1iA02.348 78 | 4g1iA02.379 79 | 4g1iA02.413 80 | 4g1iA02.450 81 | 1k8bA00.320 82 | 1k8bA00.348 83 | 1k8bA00.379 84 | 1k8bA00.413 85 | 1k8bA00.450 86 | 2vztA03.320 87 | 2vztA03.348 88 | 2vztA03.379 89 | 2vztA03.413 90 | 2vztA03.450 91 | 1q7lA00.320 92 | 1q7lA00.348 93 | 1q7lA00.379 94 | 1q7lA00.413 95 | 1q7lA00.450 96 | 3mq2A00.320 97 | 3mq2A00.348 98 | 3mq2A00.379 99 | 3mq2A00.413 100 | 3mq2A00.450 101 | 2h41A00.320 102 | 2h41A00.348 103 | 2h41A00.379 104 | 2h41A00.413 105 | 2h41A00.450 106 | 2nmsA00.320 107 | 2nmsA00.348 108 | 2nmsA00.379 109 | 2nmsA00.413 110 | 2nmsA00.450 111 | 2zyrA02.320 112 | 2zyrA02.348 113 | 2zyrA02.379 114 | 2zyrA02.413 115 | 2zyrA02.450 116 | 3pmgA02.320 117 | 3pmgA02.348 118 | 3pmgA02.379 119 | 3pmgA02.413 120 | 3pmgA02.450 121 | 2bbrA01.320 122 | 2bbrA01.348 123 | 2bbrA01.379 124 | 2bbrA01.413 125 | 2bbrA01.450 126 | 4wfoA04.320 127 | 4wfoA04.348 128 | 4wfoA04.379 129 | 4wfoA04.413 130 | 4wfoA04.450 131 | 1iq8A03.320 132 | 1iq8A03.348 133 | 1iq8A03.379 134 | 1iq8A03.413 135 | 1iq8A03.450 136 | 1z2zA02.320 137 | 1z2zA02.348 138 | 1z2zA02.379 139 | 1z2zA02.413 140 | 1z2zA02.450 141 | 5f64A02.320 142 | 5f64A02.348 143 | 5f64A02.379 144 | 5f64A02.413 145 | 5f64A02.450 146 | 1b3qA04.320 147 | 1b3qA04.348 148 | 1b3qA04.379 149 | 1b3qA04.413 150 | 1b3qA04.450 151 | 2c81A01.320 152 | 2c81A01.348 153 | 2c81A01.379 154 | 2c81A01.413 155 | 2c81A01.450 156 | 2wcyA01.320 157 | 2wcyA01.348 158 | 2wcyA01.379 159 | 2wcyA01.413 160 | 2wcyA01.450 161 | 4gipD02.320 162 | 4gipD02.348 163 | 4gipD02.379 164 | 4gipD02.413 165 | 4gipD02.450 166 | 2nc9A00.320 167 | 2nc9A00.348 168 | 2nc9A00.379 169 | 2nc9A00.413 170 | 2nc9A00.450 171 | 2pmzN00.320 172 | 2pmzN00.348 173 | 2pmzN00.379 174 | 2pmzN00.413 175 | 2pmzN00.450 176 | 4impA02.320 177 | 4impA02.348 178 | 4impA02.379 179 | 4impA02.413 180 | 4impA02.450 181 | 4c12A01.320 182 | 4c12A01.348 183 | 4c12A01.379 184 | 4c12A01.413 185 | 4c12A01.450 186 | 2mh9A00.320 187 | 2mh9A00.348 188 | 2mh9A00.379 189 | 2mh9A00.413 190 | 2mh9A00.450 191 | 2yo2A01.320 192 | 2yo2A01.348 193 | 2yo2A01.379 194 | 2yo2A01.413 195 | 2yo2A01.450 196 | 1bihA02.320 197 | 1bihA02.348 198 | 1bihA02.379 199 | 1bihA02.413 200 | 1bihA02.450 201 | -------------------------------------------------------------------------------- /sam/openfold/utils/kernel/attention_core.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import importlib 15 | from functools import reduce 16 | from operator import mul 17 | 18 | import torch 19 | 20 | attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda") 21 | 22 | 23 | SUPPORTED_DTYPES = [torch.float32, torch.bfloat16] 24 | 25 | 26 | class AttentionCoreFunction(torch.autograd.Function): 27 | @staticmethod 28 | def forward(ctx, q, k, v, bias_1=None, bias_2=None): 29 | if(bias_1 is None and bias_2 is not None): 30 | raise ValueError("bias_1 must be specified before bias_2") 31 | if(q.dtype not in SUPPORTED_DTYPES): 32 | raise ValueError("Unsupported datatype") 33 | 34 | q = q.contiguous() 35 | k = k.contiguous() 36 | 37 | # [*, H, Q, K] 38 | attention_logits = torch.matmul( 39 | q, k.transpose(-1, -2), 40 | ) 41 | 42 | if(bias_1 is not None): 43 | attention_logits += bias_1 44 | if(bias_2 is not None): 45 | attention_logits += bias_2 46 | 47 | attn_core_inplace_cuda.forward_( 48 | attention_logits, 49 | reduce(mul, attention_logits.shape[:-1]), 50 | attention_logits.shape[-1], 51 | ) 52 | 53 | o = torch.matmul(attention_logits, v) 54 | 55 | ctx.bias_1_shape = bias_1.shape if bias_1 is not None else None 56 | ctx.bias_2_shape = bias_2.shape if bias_2 is not None else None 57 | ctx.save_for_backward(q, k, v, attention_logits) 58 | 59 | return o 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | q, k, v, attention_logits = ctx.saved_tensors 64 | grad_q = grad_k = grad_v = grad_bias_1 = grad_bias_2 = None 65 | 66 | grad_v = torch.matmul( 67 | attention_logits.transpose(-1, -2), 68 | grad_output 69 | ) 70 | 71 | attn_core_inplace_cuda.backward_( 72 | attention_logits, 73 | grad_output.contiguous(), 74 | v.contiguous(), # v is implicitly transposed in the kernel 75 | reduce(mul, attention_logits.shape[:-1]), 76 | attention_logits.shape[-1], 77 | grad_output.shape[-1], 78 | ) 79 | 80 | if(ctx.bias_1_shape is not None): 81 | grad_bias_1 = torch.sum( 82 | attention_logits, 83 | dim=tuple(i for i,d in enumerate(ctx.bias_1_shape) if d == 1), 84 | keepdim=True, 85 | ) 86 | 87 | if(ctx.bias_2_shape is not None): 88 | grad_bias_2 = torch.sum( 89 | attention_logits, 90 | dim=tuple(i for i,d in enumerate(ctx.bias_2_shape) if d == 1), 91 | keepdim=True, 92 | ) 93 | 94 | grad_q = torch.matmul( 95 | attention_logits, k 96 | ) 97 | grad_k = torch.matmul( 98 | q.transpose(-1, -2), attention_logits, 99 | ).transpose(-1, -2) 100 | 101 | return grad_q, grad_k, grad_v, grad_bias_1, grad_bias_2 102 | 103 | attention_core = AttentionCoreFunction.apply 104 | -------------------------------------------------------------------------------- /sam/openfold/utils/superimposition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from Bio.SVDSuperimposer import SVDSuperimposer 15 | import numpy as np 16 | import torch 17 | 18 | 19 | def _superimpose_np(reference, coords): 20 | """ 21 | Superimposes coordinates onto a reference by minimizing RMSD using SVD. 22 | 23 | Args: 24 | reference: 25 | [N, 3] reference array 26 | coords: 27 | [N, 3] array 28 | Returns: 29 | A tuple of [N, 3] superimposed coords and the final RMSD. 30 | """ 31 | sup = SVDSuperimposer() 32 | sup.set(reference, coords) 33 | sup.run() 34 | return sup.get_transformed(), sup.get_rms() 35 | 36 | 37 | def _superimpose_single(reference, coords): 38 | reference_np = reference.detach().to(torch.float).cpu().numpy() 39 | coords_np = coords.detach().to(torch.float).cpu().numpy() 40 | superimposed, rmsd = _superimpose_np(reference_np, coords_np) 41 | return coords.new_tensor(superimposed), coords.new_tensor(rmsd) 42 | 43 | 44 | def superimpose(reference, coords, mask): 45 | """ 46 | Superimposes coordinates onto a reference by minimizing RMSD using SVD. 47 | 48 | Args: 49 | reference: 50 | [*, N, 3] reference tensor 51 | coords: 52 | [*, N, 3] tensor 53 | mask: 54 | [*, N] tensor 55 | Returns: 56 | A tuple of [*, N, 3] superimposed coords and [*] final RMSDs. 57 | """ 58 | def select_unmasked_coords(coords, mask): 59 | return torch.masked_select( 60 | coords, 61 | (mask > 0.)[..., None], 62 | ).reshape(-1, 3) 63 | 64 | batch_dims = reference.shape[:-2] 65 | flat_reference = reference.reshape((-1,) + reference.shape[-2:]) 66 | flat_coords = coords.reshape((-1,) + reference.shape[-2:]) 67 | flat_mask = mask.reshape((-1,) + mask.shape[-1:]) 68 | superimposed_list = [] 69 | rmsds = [] 70 | for r, c, m in zip(flat_reference, flat_coords, flat_mask): 71 | r_unmasked_coords = select_unmasked_coords(r, m) 72 | c_unmasked_coords = select_unmasked_coords(c, m) 73 | superimposed, rmsd = _superimpose_single( 74 | r_unmasked_coords, 75 | c_unmasked_coords 76 | ) 77 | 78 | # This is very inelegant, but idk how else to invert the masking 79 | # procedure. 80 | count = 0 81 | superimposed_full_size = torch.zeros_like(r) 82 | for i, unmasked in enumerate(m): 83 | if(unmasked): 84 | superimposed_full_size[i] = superimposed[count] 85 | count += 1 86 | 87 | superimposed_list.append(superimposed_full_size) 88 | rmsds.append(rmsd) 89 | 90 | superimposed_stacked = torch.stack(superimposed_list, dim=0) 91 | rmsds_stacked = torch.stack(rmsds, dim=0) 92 | 93 | superimposed_reshaped = superimposed_stacked.reshape( 94 | batch_dims + coords.shape[-2:] 95 | ) 96 | rmsds_reshaped = rmsds_stacked.reshape( 97 | batch_dims 98 | ) 99 | 100 | return superimposed_reshaped, rmsds_reshaped 101 | -------------------------------------------------------------------------------- /config/mdcath_model.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | res_ids_mode: index 3 | tbm: 4 | mode: single 5 | perturb: null 6 | type: xyz 7 | use_temperature: true 8 | decoder: 9 | accessory_activation: relu 10 | activation: relu 11 | add_bias_2d: true 12 | arch: dec_aa_v01 13 | bead_embed_dim: 32 14 | block_transition: null 15 | edge_dim: 128 16 | embed_inject_mode: outer_sum 17 | linear_bias: true 18 | mlp_dim: null 19 | node_dim: 256 20 | node_init_mode: mlp 21 | noise_sigma: null 22 | num_blocks: 0 23 | num_heads: 16 24 | pos_embed_r: 32 25 | sm_c_ipa: 16 26 | sm_c_resnet: 128 27 | sm_dropout_rate: 0.0 28 | sm_no_angles: 7 29 | sm_no_blocks: 5 30 | sm_no_heads_ipa: 12 31 | sm_no_qk_points: 4 32 | sm_no_resnet_blocks: 2 33 | sm_no_transition_layers: 1 34 | sm_no_v_points: 8 35 | sm_share: true 36 | sm_swiglu_transition: false 37 | sm_swiglu_transition_hr: 2 38 | sm_trans_scale_factor: 10 39 | encoder: 40 | accessory_activation: silu 41 | activation: silu 42 | add_bias_2d: true 43 | arch: enc_aa_v01 44 | bead_embed_dim: 32 45 | com_dmap_embed_params: 46 | cutoff_lower: 0.0 47 | cutoff_upper: 7.0 48 | num_rbf: 64 49 | trainable: true 50 | type: expnorm 51 | dmap_embed_params: 52 | cutoff_lower: 0.0 53 | cutoff_upper: 10.0 54 | num_rbf: 128 55 | trainable: true 56 | type: expnorm 57 | dmap_inject_mode: shallow 58 | dmap_merge_dim: 192 59 | dmap_merge_mode: cat_shallow 60 | edge_dim: 128 61 | edge_residual: false 62 | input_embed_params: 63 | label_terminus: true 64 | local_pos: true 65 | local_pos_span: 3 66 | linear_bias: true 67 | mlp_dim: 512 68 | no_dmap_embed_params: 69 | cutoff_lower: 0.0 70 | cutoff_upper: 3.0 71 | num_rbf: 32 72 | trainable: true 73 | type: expnorm 74 | node_dim: 256 75 | node_init_mode: mlp 76 | node_residual: true 77 | node_update_addition: false 78 | num_blocks: 4 79 | num_heads: 16 80 | out_ln: false 81 | out_mode: simple 82 | pos_embed_r: 32 83 | generative_model: 84 | loss: l2 85 | sched_params: 86 | beta_end: 0.02 87 | beta_schedule: sigmoid 88 | beta_start: 0.0001 89 | name: ddpm 90 | num_train_timesteps: 1000 91 | prediction_type: epsilon 92 | variance_type: fixed_small 93 | type: diffusers_dm 94 | generative_stack: 95 | bead_type: ca 96 | data_type: aa_protein 97 | encoding_dim: 32 98 | use_enc_scaler: true 99 | generator: 100 | activation: silu 101 | arch: eps_v02 102 | attention_mode: adanorm 103 | bead_embed_dim: 32 104 | conditioned_transition: false 105 | edge_dim: 144 106 | edge_embed_mode: idpsam 107 | edge_residual: false 108 | edge_update_addition: false 109 | edge_update_freq: null 110 | edge_update_mode: null 111 | edge_update_params: null 112 | generator: null 113 | input_embed_mode: mlp 114 | linear_bias: true 115 | node_dim: 512 116 | node_embed_mode: cat_mlp 117 | num_blocks: 22 118 | num_heads: 32 119 | out_mode: mlp 120 | pos_embed_r: 32 121 | tem_inject_mode: xyz 122 | tem_inject_params: 123 | dmap_embed_params: 124 | cutoff_lower: 0.0 125 | cutoff_upper: 14.0 126 | num_rbf: 144 127 | trainable: true 128 | type: expnorm 129 | inject_edge_mode: null 130 | inject_node_mode: add 131 | node_angle_bins: 16 132 | node_angle_mask: extra 133 | node_dim: 512 134 | node_embed_resolution: aa 135 | node_mlp_depth: 2 136 | node_mlp_mult: 1 137 | temperature_embed_dim: 256 138 | temperature_embed_mode: scaler 139 | temperature_embed_params: 140 | scaler: 141 | max: 470 142 | min: 280 143 | time_embed_dim: 512 144 | time_embed_mode: sinusoidal 145 | time_embed_params: 146 | time_freq_dim: 256 147 | token_dim: 1024 148 | token_residual: false 149 | token_update_addition: false 150 | minimization: 151 | protocol: mdcath 152 | platform: 153 | device: cuda 154 | weights: 155 | path: null 156 | version: mdcath_1.0 157 | -------------------------------------------------------------------------------- /config/atlas_model.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | res_ids_mode: index 3 | tbm: 4 | mode: single 5 | perturb: null 6 | type: xyz 7 | decoder: 8 | accessory_activation: relu 9 | activation: relu 10 | add_bias_2d: true 11 | arch: dec_aa_v01 12 | bead_embed_dim: 32 13 | block_transition: null 14 | edge_dim: 128 15 | embed_inject_mode: outer_sum 16 | linear_bias: true 17 | mlp_dim: null 18 | node_dim: 256 19 | node_init_mode: mlp 20 | noise_sigma: null 21 | num_blocks: 0 22 | num_heads: 16 23 | pos_embed_r: 32 24 | sm_c_ipa: 16 25 | sm_c_resnet: 128 26 | sm_dropout_rate: 0.0 27 | sm_no_angles: 7 28 | sm_no_blocks: 5 29 | sm_no_heads_ipa: 12 30 | sm_no_qk_points: 4 31 | sm_no_resnet_blocks: 2 32 | sm_no_transition_layers: 1 33 | sm_no_v_points: 8 34 | sm_share: true 35 | sm_swiglu_transition: false 36 | sm_swiglu_transition_hr: 2 37 | sm_trans_scale_factor: 10 38 | encoder: 39 | accessory_activation: silu 40 | activation: silu 41 | add_bias_2d: true 42 | arch: enc_aa_v01 43 | bead_embed_dim: 32 44 | com_dmap_embed_params: 45 | cutoff_lower: 0.0 46 | cutoff_upper: 7.0 47 | num_rbf: 64 48 | trainable: true 49 | type: expnorm 50 | dmap_embed_params: 51 | cutoff_lower: 0.0 52 | cutoff_upper: 10.0 53 | num_rbf: 128 54 | trainable: true 55 | type: expnorm 56 | dmap_inject_mode: shallow 57 | dmap_merge_dim: 192 58 | dmap_merge_mode: cat_shallow 59 | edge_dim: 128 60 | edge_residual: false 61 | input_embed_params: 62 | label_terminus: true 63 | local_pos: true 64 | local_pos_span: 3 65 | linear_bias: true 66 | mlp_dim: 512 67 | no_dmap_embed_params: 68 | cutoff_lower: 0.0 69 | cutoff_upper: 3.0 70 | num_rbf: 32 71 | trainable: true 72 | type: expnorm 73 | node_dim: 256 74 | node_init_mode: mlp 75 | node_residual: true 76 | node_update_addition: false 77 | num_blocks: 4 78 | num_heads: 16 79 | out_ln: false 80 | out_mode: simple 81 | pos_embed_r: 32 82 | generative_model: 83 | loss: l2 84 | sched_params: 85 | beta_end: 0.02 86 | beta_schedule: sigmoid 87 | beta_start: 0.0001 88 | name: ddpm 89 | num_train_timesteps: 1000 90 | prediction_type: epsilon 91 | variance_type: fixed_small 92 | type: diffusers_dm 93 | generative_stack: 94 | bead_type: ca 95 | data_type: aa_protein 96 | encoding_dim: 32 97 | use_enc_scaler: true 98 | generator: 99 | activation: silu 100 | arch: eps_v02 101 | attention_mode: adanorm 102 | bead_embed_dim: 32 103 | conditioned_transition: false 104 | edge_dim: 128 105 | edge_embed_mode: idpsam 106 | edge_residual: false 107 | edge_update_addition: false 108 | edge_update_freq: null 109 | edge_update_mode: null 110 | edge_update_params: null 111 | input_embed_mode: mlp 112 | linear_bias: true 113 | node_dim: 512 114 | node_embed_mode: cat_mlp 115 | num_blocks: 22 116 | num_heads: 32 117 | out_mode: mlp 118 | pos_embed_r: 32 119 | tem_inject_mode: xyz 120 | tem_inject_params: 121 | com_dmap_embed_params: 122 | cutoff_lower: 0.0 123 | cutoff_upper: 7.0 124 | num_rbf: 64 125 | trainable: true 126 | type: expnorm 127 | dmap_embed_params: 128 | cutoff_lower: 0.0 129 | cutoff_upper: 14.0 130 | num_rbf: 144 131 | trainable: true 132 | type: expnorm 133 | inject_edge_mode: null 134 | inject_node_mode: add 135 | no_dmap_embed_params: 136 | cutoff_lower: 0.0 137 | cutoff_upper: 3.0 138 | num_rbf: 32 139 | trainable: true 140 | type: expnorm 141 | node_angle_bins: 16 142 | node_angle_mask: extra 143 | node_dim: 512 144 | node_embed_resolution: aa 145 | node_mlp_depth: 2 146 | node_mlp_mult: 1 147 | time_embed_dim: 512 148 | time_embed_mode: sinusoidal 149 | time_embed_params: 150 | time_freq_dim: 256 151 | token_dim: 1024 152 | token_residual: false 153 | token_update_addition: false 154 | minimization: 155 | protocol: atlas 156 | platform: 157 | device: cuda 158 | weights: 159 | path: null 160 | version: atlas_1.0 161 | -------------------------------------------------------------------------------- /sam/openfold/utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from functools import partial 17 | import logging 18 | from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | 24 | def add(m1, m2, inplace): 25 | # The first operation in a checkpoint can't be in-place, but it's 26 | # nice to have in-place addition during inference. Thus... 27 | if(not inplace): 28 | m1 = m1 + m2 29 | else: 30 | m1 += m2 31 | 32 | return m1 33 | 34 | 35 | def permute_final_dims(tensor: torch.Tensor, inds: List[int]): 36 | zero_index = -1 * len(inds) 37 | first_inds = list(range(len(tensor.shape[:zero_index]))) 38 | return tensor.permute(first_inds + [zero_index + i for i in inds]) 39 | 40 | 41 | def flatten_final_dims(t: torch.Tensor, no_dims: int): 42 | return t.reshape(t.shape[:-no_dims] + (-1,)) 43 | 44 | 45 | def masked_mean(mask, value, dim, eps=1e-4): 46 | mask = mask.expand(*value.shape) 47 | return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) 48 | 49 | 50 | def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64): 51 | boundaries = torch.linspace( 52 | min_bin, max_bin, no_bins - 1, device=pts.device 53 | ) 54 | dists = torch.sqrt( 55 | torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1) 56 | ) 57 | return torch.bucketize(dists, boundaries) 58 | 59 | 60 | def dict_multimap(fn, dicts): 61 | first = dicts[0] 62 | new_dict = {} 63 | for k, v in first.items(): 64 | all_v = [d[k] for d in dicts] 65 | if type(v) is dict: 66 | new_dict[k] = dict_multimap(fn, all_v) 67 | else: 68 | new_dict[k] = fn(all_v) 69 | 70 | return new_dict 71 | 72 | 73 | def one_hot(x, v_bins): 74 | reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),)) 75 | diffs = x[..., None] - reshaped_bins 76 | am = torch.argmin(torch.abs(diffs), dim=-1) 77 | return nn.functional.one_hot(am, num_classes=len(v_bins)).float() 78 | 79 | 80 | def batched_gather(data, inds, dim=0, no_batch_dims=0): 81 | ranges = [] 82 | for i, s in enumerate(data.shape[:no_batch_dims]): 83 | r = torch.arange(s) 84 | r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) 85 | ranges.append(r) 86 | 87 | remaining_dims = [ 88 | slice(None) for _ in range(len(data.shape) - no_batch_dims) 89 | ] 90 | remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds 91 | ranges.extend(remaining_dims) 92 | return data[ranges] 93 | 94 | 95 | # With tree_map, a poor man's JAX tree_map 96 | def dict_map(fn, dic, leaf_type): 97 | new_dict = {} 98 | for k, v in dic.items(): 99 | if type(v) is dict: 100 | new_dict[k] = dict_map(fn, v, leaf_type) 101 | else: 102 | new_dict[k] = tree_map(fn, v, leaf_type) 103 | 104 | return new_dict 105 | 106 | 107 | def tree_map(fn, tree, leaf_type): 108 | if isinstance(tree, dict): 109 | return dict_map(fn, tree, leaf_type) 110 | elif isinstance(tree, list): 111 | return [tree_map(fn, x, leaf_type) for x in tree] 112 | elif isinstance(tree, tuple): 113 | return tuple([tree_map(fn, x, leaf_type) for x in tree]) 114 | elif isinstance(tree, leaf_type): 115 | return fn(tree) 116 | else: 117 | raise ValueError(f"Tree of type {type(tree)} not supported") 118 | 119 | 120 | tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) 121 | -------------------------------------------------------------------------------- /sam/openfold/data/tools/kalign.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A Python wrapper for Kalign.""" 17 | import os 18 | import subprocess 19 | from typing import Sequence 20 | 21 | from absl import logging 22 | 23 | from sam.openfold.data.tools import utils 24 | 25 | 26 | def _to_a3m(sequences: Sequence[str]) -> str: 27 | """Converts sequences to an a3m file.""" 28 | names = ["sequence %d" % i for i in range(1, len(sequences) + 1)] 29 | a3m = [] 30 | for sequence, name in zip(sequences, names): 31 | a3m.append(u">" + name + u"\n") 32 | a3m.append(sequence + u"\n") 33 | return "".join(a3m) 34 | 35 | 36 | class Kalign: 37 | """Python wrapper of the Kalign binary.""" 38 | 39 | def __init__(self, *, binary_path: str): 40 | """Initializes the Python Kalign wrapper. 41 | 42 | Args: 43 | binary_path: The path to the Kalign binary. 44 | 45 | Raises: 46 | RuntimeError: If Kalign binary not found within the path. 47 | """ 48 | self.binary_path = binary_path 49 | 50 | def align(self, sequences: Sequence[str]) -> str: 51 | """Aligns the sequences and returns the alignment in A3M string. 52 | 53 | Args: 54 | sequences: A list of query sequence strings. The sequences have to be at 55 | least 6 residues long (Kalign requires this). Note that the order in 56 | which you give the sequences might alter the output slightly as 57 | different alignment tree might get constructed. 58 | 59 | Returns: 60 | A string with the alignment in a3m format. 61 | 62 | Raises: 63 | RuntimeError: If Kalign fails. 64 | ValueError: If any of the sequences is less than 6 residues long. 65 | """ 66 | logging.info("Aligning %d sequences", len(sequences)) 67 | 68 | for s in sequences: 69 | if len(s) < 6: 70 | raise ValueError( 71 | "Kalign requires all sequences to be at least 6 " 72 | "residues long. Got %s (%d residues)." % (s, len(s)) 73 | ) 74 | 75 | with utils.tmpdir_manager() as query_tmp_dir: 76 | input_fasta_path = os.path.join(query_tmp_dir, "input.fasta") 77 | output_a3m_path = os.path.join(query_tmp_dir, "output.a3m") 78 | 79 | with open(input_fasta_path, "w") as f: 80 | f.write(_to_a3m(sequences)) 81 | 82 | cmd = [ 83 | self.binary_path, 84 | "-i", 85 | input_fasta_path, 86 | "-o", 87 | output_a3m_path, 88 | "-format", 89 | "fasta", 90 | ] 91 | 92 | logging.info('Launching subprocess "%s"', " ".join(cmd)) 93 | process = subprocess.Popen( 94 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE 95 | ) 96 | 97 | with utils.timing("Kalign query"): 98 | stdout, stderr = process.communicate() 99 | retcode = process.wait() 100 | logging.info( 101 | "Kalign stdout:\n%s\n\nstderr:\n%s\n", 102 | stdout.decode("utf-8"), 103 | stderr.decode("utf-8"), 104 | ) 105 | 106 | if retcode: 107 | raise RuntimeError( 108 | "Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n" 109 | % (stdout.decode("utf-8"), stderr.decode("utf-8")) 110 | ) 111 | 112 | with open(output_a3m_path) as f: 113 | a3m = f.read() 114 | 115 | return a3m 116 | -------------------------------------------------------------------------------- /sam/openfold/utils/geometry/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Shared utils for tests.""" 15 | 16 | import dataclasses 17 | import torch 18 | 19 | from sam.openfold.utils.geometry import rigid_matrix_vector 20 | from sam.openfold.utils.geometry import rotation_matrix 21 | from sam.openfold.utils.geometry import vector 22 | 23 | 24 | def assert_rotation_matrix_equal(matrix1: rotation_matrix.Rot3Array, 25 | matrix2: rotation_matrix.Rot3Array): 26 | for field in dataclasses.fields(rotation_matrix.Rot3Array): 27 | field = field.name 28 | assert torch.equal( 29 | getattr(matrix1, field), getattr(matrix2, field)) 30 | 31 | 32 | def assert_rotation_matrix_close(mat1: rotation_matrix.Rot3Array, 33 | mat2: rotation_matrix.Rot3Array): 34 | assert torch.allclose(mat1.to_tensor(), mat2.to_tensor(), atol=1e-6) 35 | 36 | 37 | def assert_array_equal_to_rotation_matrix(array: torch.Tensor, 38 | matrix: rotation_matrix.Rot3Array): 39 | """Check that array and Matrix match.""" 40 | assert torch.equal(matrix.xx, array[..., 0, 0]) 41 | assert torch.equal(matrix.xy, array[..., 0, 1]) 42 | assert torch.equal(matrix.xz, array[..., 0, 2]) 43 | assert torch.equal(matrix.yx, array[..., 1, 0]) 44 | assert torch.equal(matrix.yy, array[..., 1, 1]) 45 | assert torch.equal(matrix.yz, array[..., 1, 2]) 46 | assert torch.equal(matrix.zx, array[..., 2, 0]) 47 | assert torch.equal(matrix.zy, array[..., 2, 1]) 48 | assert torch.equal(matrix.zz, array[..., 2, 2]) 49 | 50 | 51 | def assert_array_close_to_rotation_matrix(array: torch.Tensor, 52 | matrix: rotation_matrix.Rot3Array): 53 | assert torch.allclose(matrix.to_tensor(), array, atol=1e-6) 54 | 55 | 56 | def assert_vectors_equal(vec1: vector.Vec3Array, vec2: vector.Vec3Array): 57 | assert torch.equal(vec1.x, vec2.x) 58 | assert torch.equal(vec1.y, vec2.y) 59 | assert torch.equal(vec1.z, vec2.z) 60 | 61 | 62 | def assert_vectors_close(vec1: vector.Vec3Array, vec2: vector.Vec3Array): 63 | assert torch.allclose(vec1.x, vec2.x, atol=1e-6, rtol=0.) 64 | assert torch.allclose(vec1.y, vec2.y, atol=1e-6, rtol=0.) 65 | assert torch.allclose(vec1.z, vec2.z, atol=1e-6, rtol=0.) 66 | 67 | 68 | def assert_array_close_to_vector(array: torch.Tensor, vec: vector.Vec3Array): 69 | assert torch.allclose(vec.to_tensor(), array, atol=1e-6, rtol=0.) 70 | 71 | 72 | def assert_array_equal_to_vector(array: torch.Tensor, vec: vector.Vec3Array): 73 | assert torch.equal(vec.to_tensor(), array) 74 | 75 | 76 | def assert_rigid_equal_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, 77 | rigid2: rigid_matrix_vector.Rigid3Array): 78 | assert_rot_trans_equal_to_rigid(rigid1.rotation, rigid1.translation, rigid2) 79 | 80 | 81 | def assert_rigid_close_to_rigid(rigid1: rigid_matrix_vector.Rigid3Array, 82 | rigid2: rigid_matrix_vector.Rigid3Array): 83 | assert_rot_trans_close_to_rigid(rigid1.rotation, rigid1.translation, rigid2) 84 | 85 | 86 | def assert_rot_trans_equal_to_rigid(rot: rotation_matrix.Rot3Array, 87 | trans: vector.Vec3Array, 88 | rigid: rigid_matrix_vector.Rigid3Array): 89 | assert_rotation_matrix_equal(rot, rigid.rotation) 90 | assert_vectors_equal(trans, rigid.translation) 91 | 92 | 93 | def assert_rot_trans_close_to_rigid(rot: rotation_matrix.Rot3Array, 94 | trans: vector.Vec3Array, 95 | rigid: rigid_matrix_vector.Rigid3Array): 96 | assert_rotation_matrix_close(rot, rigid.rotation) 97 | assert_vectors_close(trans, rigid.translation) 98 | -------------------------------------------------------------------------------- /sam/openfold/np/relax/relax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Amber relaxation.""" 17 | from typing import Any, Dict, Sequence, Tuple 18 | from sam.openfold.np import protein 19 | from sam.openfold.np.relax import amber_minimize, utils 20 | import numpy as np 21 | 22 | 23 | class AmberRelaxation(object): 24 | """Amber relaxation.""" 25 | def __init__( 26 | self, 27 | *, 28 | max_iterations: int, 29 | tolerance: float, 30 | stiffness: float, 31 | exclude_residues: Sequence[int], 32 | max_outer_iterations: int, 33 | use_gpu: bool, 34 | ): 35 | """Initialize Amber Relaxer. 36 | 37 | Args: 38 | max_iterations: Maximum number of L-BFGS iterations. 0 means no max. 39 | tolerance: kcal/mol, the energy tolerance of L-BFGS. 40 | stiffness: kcal/mol A**2, spring constant of heavy atom restraining 41 | potential. 42 | exclude_residues: Residues to exclude from per-atom restraining. 43 | Zero-indexed. 44 | max_outer_iterations: Maximum number of violation-informed relax 45 | iterations. A value of 1 will run the non-iterative procedure used in 46 | CASP14. Use 20 so that >95% of the bad cases are relaxed. Relax finishes 47 | as soon as there are no violations, hence in most cases this causes no 48 | slowdown. In the worst case we do 20 outer iterations. 49 | use_gpu: Whether to run on GPU 50 | """ 51 | 52 | self._max_iterations = max_iterations 53 | self._tolerance = tolerance 54 | self._stiffness = stiffness 55 | self._exclude_residues = exclude_residues 56 | self._max_outer_iterations = max_outer_iterations 57 | self._use_gpu = use_gpu 58 | 59 | def process( 60 | self, *, prot: protein.Protein, cif_output: bool = False 61 | ) -> Tuple[str, Dict[str, Any], np.ndarray]: 62 | """Runs Amber relax on a prediction, adds hydrogens, returns PDB string.""" 63 | out = amber_minimize.run_pipeline( 64 | prot=prot, 65 | max_iterations=self._max_iterations, 66 | tolerance=self._tolerance, 67 | stiffness=self._stiffness, 68 | exclude_residues=self._exclude_residues, 69 | max_outer_iterations=self._max_outer_iterations, 70 | use_gpu=self._use_gpu, 71 | ) 72 | min_pos = out["pos"] 73 | start_pos = out["posinit"] 74 | rmsd = np.sqrt(np.sum((start_pos - min_pos) ** 2) / start_pos.shape[0]) 75 | debug_data = { 76 | "initial_energy": out["einit"], 77 | "final_energy": out["efinal"], 78 | "attempts": out["min_attempts"], 79 | "rmsd": rmsd, 80 | } 81 | pdb_str = amber_minimize.clean_protein(prot) 82 | min_pdb = utils.overwrite_pdb_coordinates(pdb_str, min_pos) 83 | min_pdb = utils.overwrite_b_factors(min_pdb, prot.b_factors) 84 | utils.assert_equal_nonterminal_atom_types( 85 | protein.from_pdb_string(min_pdb).atom_mask, prot.atom_mask 86 | ) 87 | violations = out["structural_violations"][ 88 | "total_per_residue_violations_mask" 89 | ] 90 | 91 | min_pdb = protein.add_pdb_headers(prot, min_pdb) 92 | output_str = min_pdb 93 | if cif_output: 94 | # TODO the model cif will be missing some metadata like headers (PARENTs and 95 | # REMARK with some details of the run, like num of recycles) 96 | final_prot = protein.from_pdb_string(min_pdb) 97 | output_str = protein.to_modelcif(final_prot) 98 | 99 | return output_str, debug_data, violations 100 | -------------------------------------------------------------------------------- /sam/openfold/data/tools/hhsearch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Library to run HHsearch from Python.""" 17 | import glob 18 | import logging 19 | import os 20 | import subprocess 21 | from typing import Sequence, Optional 22 | 23 | from sam.openfold.data import parsers 24 | from sam.openfold.data.tools import utils 25 | 26 | 27 | class HHSearch: 28 | """Python wrapper of the HHsearch binary.""" 29 | 30 | def __init__( 31 | self, 32 | *, 33 | binary_path: str, 34 | databases: Sequence[str], 35 | n_cpu: int = 2, 36 | maxseq: int = 1_000_000, 37 | ): 38 | """Initializes the Python HHsearch wrapper. 39 | 40 | Args: 41 | binary_path: The path to the HHsearch executable. 42 | databases: A sequence of HHsearch database paths. This should be the 43 | common prefix for the database files (i.e. up to but not including 44 | _hhm.ffindex etc.) 45 | n_cpu: The number of CPUs to use 46 | maxseq: The maximum number of rows in an input alignment. Note that this 47 | parameter is only supported in HHBlits version 3.1 and higher. 48 | 49 | Raises: 50 | RuntimeError: If HHsearch binary not found within the path. 51 | """ 52 | self.binary_path = binary_path 53 | self.databases = databases 54 | self.n_cpu = n_cpu 55 | self.maxseq = maxseq 56 | 57 | for database_path in self.databases: 58 | if not glob.glob(database_path + "_*"): 59 | logging.error( 60 | "Could not find HHsearch database %s", database_path 61 | ) 62 | raise ValueError( 63 | f"Could not find HHsearch database {database_path}" 64 | ) 65 | 66 | @property 67 | def output_format(self) -> str: 68 | return 'hhr' 69 | 70 | @property 71 | def input_format(self) -> str: 72 | return 'a3m' 73 | 74 | def query(self, a3m: str, output_dir: Optional[str] = None) -> str: 75 | """Queries the database using HHsearch using a given a3m.""" 76 | with utils.tmpdir_manager() as query_tmp_dir: 77 | input_path = os.path.join(query_tmp_dir, "query.a3m") 78 | output_dir = query_tmp_dir if output_dir is None else output_dir 79 | hhr_path = os.path.join(output_dir, "hhsearch_output.hhr") 80 | with open(input_path, "w") as f: 81 | f.write(a3m) 82 | 83 | db_cmd = [] 84 | for db_path in self.databases: 85 | db_cmd.append("-d") 86 | db_cmd.append(db_path) 87 | cmd = [ 88 | self.binary_path, 89 | "-i", 90 | input_path, 91 | "-o", 92 | hhr_path, 93 | "-maxseq", 94 | str(self.maxseq), 95 | "-cpu", 96 | str(self.n_cpu), 97 | ] + db_cmd 98 | 99 | logging.info('Launching subprocess "%s"', " ".join(cmd)) 100 | process = subprocess.Popen( 101 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE 102 | ) 103 | with utils.timing("HHsearch query"): 104 | stdout, stderr = process.communicate() 105 | retcode = process.wait() 106 | 107 | if retcode: 108 | # Stderr is truncated to prevent proto size errors in Beam. 109 | raise RuntimeError( 110 | "HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n" 111 | % (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8")) 112 | ) 113 | 114 | with open(hhr_path) as f: 115 | hhr = f.read() 116 | return hhr 117 | 118 | @staticmethod 119 | def get_template_hits( 120 | output_string: str, 121 | input_sequence: str 122 | ) -> Sequence[parsers.TemplateHit]: 123 | """Gets parsed template hits from the raw string output by the tool""" 124 | del input_sequence # Used by hmmsearch but not needed for hhsearch 125 | return parsers.parse_hhr(output_string) 126 | -------------------------------------------------------------------------------- /sam/minimizer/runner.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class for streamline usage of the SAM minimizer. 3 | """ 4 | 5 | import os 6 | import yaml 7 | import numpy as np 8 | import mdtraj 9 | import torch 10 | 11 | from sam.data.aa_topology import sam_openfold_aa_map, get_traj_list 12 | from sam.minimizer import get_topology, initialize, minimize, reconstruct_atom14 13 | from sam.data.aa_protein import AllAtomProteinDataset 14 | 15 | 16 | class Minimizer: 17 | 18 | def __init__(self, 19 | name: str, 20 | top_fp: str, 21 | ens_fp: str, 22 | protocol: str, 23 | params_fp: str = None 24 | ): 25 | 26 | # Load some data. 27 | self.dataset = AllAtomProteinDataset( 28 | input=[{ 29 | "name": name, "topology": top_fp, "trajectories": [ens_fp] 30 | }], 31 | n_trajs=None, 32 | n_frames=None, 33 | frames_mode="ensemble", 34 | proteins=None, 35 | per_protein_frames=None, 36 | re_filter=None, 37 | res_ids_mode=None, 38 | bead_type="ca", 39 | alphabet="standard", 40 | xyz_sigma=None, 41 | xyz_perturb=None, 42 | verbose=False, 43 | random_seed=None 44 | ) 45 | 46 | # Load the minimization parameters. 47 | module_dp = os.path.dirname(__file__) 48 | if protocol == "atlas": 49 | params_fp = os.path.join(module_dp, "params", "mizu_cfg.atlas.yaml") 50 | elif protocol == "mdcath": 51 | params_fp = os.path.join(module_dp, "params", "mizu_cfg.mdcath.yaml") 52 | elif protocol == "custom": 53 | if params_fp is None: 54 | raise ValueError() 55 | else: 56 | raise KeyError(protocol) 57 | with open(params_fp, 'r') as i_fh: 58 | params = yaml.safe_load(i_fh) 59 | self.opt_params = params["opt"] 60 | if "opt_ini" in params: 61 | self.opt_ini_params = params["opt_ini"] 62 | else: 63 | self.opt_ini_params = None 64 | self.top_params = params["top"] 65 | self.data_params = params.get("data", {"batch_size": 50}) 66 | 67 | 68 | def run(self, 69 | batch_size: int = None, 70 | device: str = "cpu", 71 | verbose: bool = True 72 | ): 73 | # Setup the batch size. 74 | if batch_size is None: 75 | batch_size = self.data_params["batch_size"] 76 | 77 | # Setup the dataloader to serve the batches to minimize. 78 | dataloader = torch.utils.data.dataloader.DataLoader( 79 | dataset=self.dataset, batch_size=batch_size, shuffle=False 80 | ) 81 | 82 | # Iterate over the whole dataset in batches. 83 | minimized_traj = [] 84 | for i, batch in enumerate(dataloader): 85 | 86 | # Get the xyz coordinates and the amino acid sequence. 87 | positions = batch.atom14_gt_positions 88 | a = torch.tensor(sam_openfold_aa_map[batch.a], dtype=torch.long) 89 | positions = positions*0.1 90 | positions = positions.to(device) 91 | a = a.to(device) 92 | 93 | # Initialize. 94 | if i == 0: 95 | topology = get_topology(a, **self.top_params) 96 | positions = initialize(positions, a) 97 | 98 | # Brief initial minimization, typically with simple GD or Adam. 99 | if self.opt_ini_params is not None: 100 | positions, es = minimize( 101 | positions=positions, 102 | topology=topology, 103 | return_early_stopping=True, 104 | verbose=verbose, 105 | **self.opt_ini_params 106 | ) 107 | if not es: 108 | positions = torch.autograd.Variable(positions) 109 | positions.requires_grad = True 110 | else: 111 | es = False 112 | 113 | # Main minimization, typically with L-BFGS. 114 | if not es: 115 | # Run only if did not early-stop at the initial minimization. 116 | positions = minimize( 117 | positions=positions, 118 | topology=topology, 119 | verbose=verbose, 120 | **self.opt_params 121 | ) 122 | 123 | # Reconstruct atom14 trajectory. 124 | atom14_rec = reconstruct_atom14(positions, topology) 125 | traj = get_traj_list({"positions": atom14_rec, "a": a}, join=True) 126 | minimized_traj.append(traj) 127 | 128 | # Join the mdtraj trajectories and return. 129 | minimized_traj = mdtraj.join(minimized_traj) 130 | 131 | return minimized_traj -------------------------------------------------------------------------------- /sam/trajectory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Analyze different types of features of mdtraj trajectories. 3 | """ 4 | 5 | import numpy as np 6 | import mdtraj 7 | from sam.data.topology import slice_ca_traj 8 | from sam.data.sequences import ofo_restype_name_to_atom14_names 9 | 10 | 11 | def calc_mdtraj_rmsf( 12 | traj: mdtraj.Trajectory, 13 | ref_traj: mdtraj.Trajectory, 14 | ref_index: int = 0 15 | ) -> np.ndarray: 16 | traj_c = mdtraj.Trajectory(traj.xyz, topology=traj.topology) 17 | ref_traj = mdtraj.Trajectory(ref_traj.xyz, topology=traj.topology) 18 | rmsf = mdtraj.rmsf(traj_c, ref_traj, ref_index) 19 | return rmsf 20 | 21 | 22 | std_atoms = set() 23 | for k in ofo_restype_name_to_atom14_names: 24 | for a in ofo_restype_name_to_atom14_names[k]: 25 | if a: 26 | std_atoms.add(a) 27 | 28 | 29 | def calc_q_values( 30 | traj, 31 | native_traj, 32 | beta=50.0, 33 | lambda_=1.2, 34 | delta=0.0, 35 | threshold=1.0 # in nanometers. 36 | ): 37 | 38 | if len(native_traj) != 1: 39 | raise NotImplementedError() 40 | 41 | dist = [] 42 | top_atoms_dict = [[a for a in r.atoms if a.name in std_atoms] for r in native_traj.topology.residues] 43 | 44 | top_ids = [] 45 | traj_ids = [] 46 | top_residues = list(native_traj.topology.residues) # 47 | top_atoms = list(native_traj.topology.atoms) # 48 | traj_residues = list(traj.topology.residues) 49 | traj_atoms = list(traj.topology.atoms) # 50 | for i, r_i in enumerate(native_traj.topology.residues): 51 | for j, r_j in enumerate(native_traj.topology.residues): 52 | if r_j.index - r_i.index > 3: 53 | a_k_ids = [a_k.index for a_k in top_atoms_dict[i]] 54 | a_k_atoms = top_atoms_dict[i] 55 | a_l_ids = [a_l.index for a_l in top_atoms_dict[j]] 56 | a_l_atoms = top_atoms_dict[j] 57 | dist_ij = np.sqrt( 58 | np.sum( 59 | np.square( 60 | native_traj.xyz[:,a_k_ids,None,:] - native_traj.xyz[:,None,a_l_ids,:] 61 | ), 62 | axis=-1 63 | ) 64 | )[0] 65 | a_k_pos, a_l_pos = np.unravel_index(dist_ij.argmax(), dist_ij.shape) 66 | if threshold is not None: 67 | if dist_ij[a_k_pos, a_l_pos] > threshold: 68 | continue 69 | top_ids.append((a_k_ids[a_k_pos], a_l_ids[a_l_pos])) 70 | traj_ids.append( 71 | (traj_residues[i].atom(a_k_atoms[a_k_pos].name).index, 72 | traj_residues[j].atom(a_l_atoms[a_l_pos].name).index) 73 | ) 74 | 75 | if not traj_ids: 76 | raise ValueError() 77 | ref_dist = mdtraj.compute_distances(native_traj, top_ids) 78 | traj_dist = mdtraj.compute_distances(traj, traj_ids) 79 | n_contacts = ref_dist.shape[1] 80 | q_x = np.sum(1/(1+np.exp(beta*(traj_dist - lambda_*(ref_dist + delta)))), axis=1)/n_contacts 81 | return q_x 82 | 83 | 84 | def calc_rmsd(hat_traj, ref_traj, ref_idx=0, prealigned=True, get_tm=False): 85 | ref_traj = ref_traj[ref_idx:ref_idx+1] 86 | if not prealigned: 87 | raise NotImplementedError() 88 | ref_xyz = ref_traj.xyz * 10.0 89 | hat_xyz = hat_traj.xyz * 10.0 90 | sq_dev = np.sum(np.square(ref_xyz - hat_xyz), axis=-1) 91 | rsq_dev = np.sqrt(sq_dev) 92 | n_res = ref_xyz.shape[1] 93 | rmsd = np.sqrt(sq_dev.sum(axis=1)/n_res) 94 | tm_score = (1 / n_res) * np.sum(1 / (1 + (rsq_dev/d_0(n_res))**2), axis=1) 95 | if not get_tm: 96 | return rmsd*0.1 97 | else: 98 | return tm_score, rmsd*0.1 99 | 100 | def d_0(n_res): 101 | return 1.24*(n_res - 15)**(1/3) - 1.8 102 | 103 | def calc_initrmsd(traj, init_traj, is_ca=False, get_tm=False): 104 | if not is_ca: 105 | init_traj = slice_ca_traj(init_traj) 106 | traj = slice_ca_traj(traj) 107 | traj.superpose(init_traj) 108 | scores = calc_rmsd( 109 | traj, init_traj, ref_idx=0, prealigned=True, get_tm=get_tm 110 | ) 111 | return scores 112 | 113 | def calc_ssep(traj, native_traj): 114 | if len(native_traj) > 1: 115 | raise ValueError() 116 | ref_dssp = mdtraj.compute_dssp(native_traj) 117 | traj_dssp = mdtraj.compute_dssp(traj) 118 | match = ref_dssp == traj_dssp 119 | match = match.astype(int) 120 | score = match.mean(axis=1) 121 | he_mask = np.isin(ref_dssp, ['H', 'E']).astype(int) 122 | match_mask = match * he_mask 123 | score_mask = match_mask.sum(axis=1)/he_mask.sum() 124 | return score_mask 125 | 126 | def _calc_distances(xyz, eps=1e-9): 127 | d = np.sqrt( 128 | np.sum(np.square(xyz[:,:,None,:] - xyz[:,None,:,:]), axis=3)+eps 129 | ) 130 | return d -------------------------------------------------------------------------------- /sam/openfold/data/tools/hmmbuild.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A Python wrapper for hmmbuild - construct HMM profiles from MSA.""" 16 | 17 | import os 18 | import re 19 | import subprocess 20 | 21 | from absl import logging 22 | from sam.openfold.data.tools import utils 23 | 24 | 25 | class Hmmbuild(object): 26 | """Python wrapper of the hmmbuild binary.""" 27 | 28 | def __init__(self, 29 | *, 30 | binary_path: str, 31 | singlemx: bool = False): 32 | """Initializes the Python hmmbuild wrapper. 33 | 34 | Args: 35 | binary_path: The path to the hmmbuild executable. 36 | singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to 37 | just use a common substitution score matrix. 38 | 39 | Raises: 40 | RuntimeError: If hmmbuild binary not found within the path. 41 | """ 42 | self.binary_path = binary_path 43 | self.singlemx = singlemx 44 | 45 | def build_profile_from_sto(self, sto: str, model_construction='fast') -> str: 46 | """Builds a HHM for the aligned sequences given as an A3M string. 47 | 48 | Args: 49 | sto: A string with the aligned sequences in the Stockholm format. 50 | model_construction: Whether to use reference annotation in the msa to 51 | determine consensus columns ('hand') or default ('fast'). 52 | 53 | Returns: 54 | A string with the profile in the HMM format. 55 | 56 | Raises: 57 | RuntimeError: If hmmbuild fails. 58 | """ 59 | return self._build_profile(sto, model_construction=model_construction) 60 | 61 | def build_profile_from_a3m(self, a3m: str) -> str: 62 | """Builds a HHM for the aligned sequences given as an A3M string. 63 | 64 | Args: 65 | a3m: A string with the aligned sequences in the A3M format. 66 | 67 | Returns: 68 | A string with the profile in the HMM format. 69 | 70 | Raises: 71 | RuntimeError: If hmmbuild fails. 72 | """ 73 | lines = [] 74 | for line in a3m.splitlines(): 75 | if not line.startswith('>'): 76 | line = re.sub('[a-z]+', '', line) # Remove inserted residues. 77 | lines.append(line + '\n') 78 | msa = ''.join(lines) 79 | return self._build_profile(msa, model_construction='fast') 80 | 81 | def _build_profile(self, msa: str, model_construction: str = 'fast') -> str: 82 | """Builds a HMM for the aligned sequences given as an MSA string. 83 | 84 | Args: 85 | msa: A string with the aligned sequences, in A3M or STO format. 86 | model_construction: Whether to use reference annotation in the msa to 87 | determine consensus columns ('hand') or default ('fast'). 88 | 89 | Returns: 90 | A string with the profile in the HMM format. 91 | 92 | Raises: 93 | RuntimeError: If hmmbuild fails. 94 | ValueError: If unspecified arguments are provided. 95 | """ 96 | if model_construction not in {'hand', 'fast'}: 97 | raise ValueError(f'Invalid model_construction {model_construction} - only' 98 | 'hand and fast supported.') 99 | 100 | with utils.tmpdir_manager() as query_tmp_dir: 101 | input_query = os.path.join(query_tmp_dir, 'query.msa') 102 | output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm') 103 | 104 | with open(input_query, 'w') as f: 105 | f.write(msa) 106 | 107 | cmd = [self.binary_path] 108 | # If adding flags, we have to do so before the output and input: 109 | 110 | if model_construction == 'hand': 111 | cmd.append(f'--{model_construction}') 112 | if self.singlemx: 113 | cmd.append('--singlemx') 114 | cmd.extend([ 115 | '--amino', 116 | output_hmm_path, 117 | input_query, 118 | ]) 119 | 120 | logging.info('Launching subprocess %s', cmd) 121 | process = subprocess.Popen(cmd, stdout=subprocess.PIPE, 122 | stderr=subprocess.PIPE) 123 | 124 | with utils.timing('hmmbuild query'): 125 | stdout, stderr = process.communicate() 126 | retcode = process.wait() 127 | logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n', 128 | stdout.decode('utf-8'), stderr.decode('utf-8')) 129 | 130 | if retcode: 131 | raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n' 132 | % (stdout.decode('utf-8'), stderr.decode('utf-8'))) 133 | 134 | with open(output_hmm_path, encoding='utf-8') as f: 135 | hmm = f.read() 136 | 137 | return hmm 138 | -------------------------------------------------------------------------------- /sam/openfold/data/feature_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import copy 17 | from typing import Mapping, Tuple, List, Optional, Dict, Sequence 18 | 19 | import ml_collections 20 | import numpy as np 21 | import torch 22 | 23 | from sam.openfold.data import input_pipeline, input_pipeline_multimer 24 | 25 | 26 | FeatureDict = Mapping[str, np.ndarray] 27 | TensorDict = Dict[str, torch.Tensor] 28 | 29 | 30 | def np_to_tensor_dict( 31 | np_example: Mapping[str, np.ndarray], 32 | features: Sequence[str], 33 | ) -> TensorDict: 34 | """Creates dict of tensors from a dict of NumPy arrays. 35 | 36 | Args: 37 | np_example: A dict of NumPy feature arrays. 38 | features: A list of strings of feature names to be returned in the dataset. 39 | 40 | Returns: 41 | A dictionary of features mapping feature names to features. Only the given 42 | features are returned, all other ones are filtered out. 43 | """ 44 | # torch generates warnings if feature is already a torch Tensor 45 | to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach() 46 | tensor_dict = { 47 | k: to_tensor(v) for k, v in np_example.items() if k in features 48 | } 49 | 50 | return tensor_dict 51 | 52 | 53 | def make_data_config( 54 | config: ml_collections.ConfigDict, 55 | mode: str, 56 | num_res: int, 57 | ) -> Tuple[ml_collections.ConfigDict, List[str]]: 58 | cfg = copy.deepcopy(config) 59 | mode_cfg = cfg[mode] 60 | with cfg.unlocked(): 61 | if mode_cfg.crop_size is None: 62 | mode_cfg.crop_size = num_res 63 | 64 | feature_names = cfg.common.unsupervised_features 65 | 66 | # Add seqemb related features if using seqemb mode. 67 | if cfg.seqemb_mode.enabled: 68 | feature_names += cfg.common.seqemb_features 69 | 70 | if cfg.common.use_templates: 71 | feature_names += cfg.common.template_features 72 | 73 | if cfg[mode].supervised: 74 | feature_names += cfg.supervised.supervised_features 75 | 76 | return cfg, feature_names 77 | 78 | 79 | def np_example_to_features( 80 | np_example: FeatureDict, 81 | config: ml_collections.ConfigDict, 82 | mode: str, 83 | is_multimer: bool = False 84 | ): 85 | np_example = dict(np_example) 86 | 87 | seq_length = np_example["seq_length"] 88 | num_res = int(seq_length[0]) if seq_length.ndim != 0 else int(seq_length) 89 | cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) 90 | 91 | if "deletion_matrix_int" in np_example: 92 | np_example["deletion_matrix"] = np_example.pop( 93 | "deletion_matrix_int" 94 | ).astype(np.float32) 95 | 96 | tensor_dict = np_to_tensor_dict( 97 | np_example=np_example, features=feature_names 98 | ) 99 | 100 | with torch.no_grad(): 101 | if is_multimer: 102 | features = input_pipeline_multimer.process_tensors_from_config( 103 | tensor_dict, 104 | cfg.common, 105 | cfg[mode], 106 | ) 107 | else: 108 | features = input_pipeline.process_tensors_from_config( 109 | tensor_dict, 110 | cfg.common, 111 | cfg[mode], 112 | ) 113 | 114 | if mode == "train": 115 | p = torch.rand(1).item() 116 | use_clamped_fape_value = float(p < cfg.supervised.clamp_prob) 117 | features["use_clamped_fape"] = torch.full( 118 | size=[cfg.common.max_recycling_iters + 1], 119 | fill_value=use_clamped_fape_value, 120 | dtype=torch.float32, 121 | ) 122 | else: 123 | features["use_clamped_fape"] = torch.full( 124 | size=[cfg.common.max_recycling_iters + 1], 125 | fill_value=0.0, 126 | dtype=torch.float32, 127 | ) 128 | 129 | return {k: v for k, v in features.items()} 130 | 131 | 132 | class FeaturePipeline: 133 | def __init__( 134 | self, 135 | config: ml_collections.ConfigDict, 136 | ): 137 | self.config = config 138 | 139 | def process_features( 140 | self, 141 | raw_features: FeatureDict, 142 | mode: str = "train", 143 | is_multimer: bool = False, 144 | ) -> FeatureDict: 145 | # if(is_multimer and mode != "predict"): 146 | # raise ValueError("Multimer mode is not currently trainable") 147 | 148 | return np_example_to_features( 149 | np_example=raw_features, 150 | config=self.config, 151 | mode=mode, 152 | is_multimer=is_multimer, 153 | ) 154 | -------------------------------------------------------------------------------- /sam/openfold/model/outer_product_mean.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from functools import partial 17 | from typing import Optional 18 | 19 | import torch 20 | import torch.nn as nn 21 | 22 | from sam.openfold.model.primitives import Linear 23 | from sam.openfold.utils.chunk_utils import chunk_layer 24 | from sam.openfold.utils.precision_utils import is_fp16_enabled 25 | 26 | 27 | class OuterProductMean(nn.Module): 28 | """ 29 | Implements Algorithm 10. 30 | """ 31 | 32 | def __init__(self, c_m, c_z, c_hidden, eps=1e-3): 33 | """ 34 | Args: 35 | c_m: 36 | MSA embedding channel dimension 37 | c_z: 38 | Pair embedding channel dimension 39 | c_hidden: 40 | Hidden channel dimension 41 | """ 42 | super(OuterProductMean, self).__init__() 43 | 44 | self.c_m = c_m 45 | self.c_z = c_z 46 | self.c_hidden = c_hidden 47 | self.eps = eps 48 | 49 | self.layer_norm = nn.LayerNorm(c_m) 50 | self.linear_1 = Linear(c_m, c_hidden) 51 | self.linear_2 = Linear(c_m, c_hidden) 52 | self.linear_out = Linear(c_hidden ** 2, c_z, init="final") 53 | 54 | def _opm(self, a, b): 55 | # [*, N_res, N_res, C, C] 56 | outer = torch.einsum("...bac,...dae->...bdce", a, b) 57 | 58 | # [*, N_res, N_res, C * C] 59 | outer = outer.reshape(outer.shape[:-2] + (-1,)) 60 | 61 | # [*, N_res, N_res, C_z] 62 | outer = self.linear_out(outer) 63 | 64 | return outer 65 | 66 | @torch.jit.ignore 67 | def _chunk(self, 68 | a: torch.Tensor, 69 | b: torch.Tensor, 70 | chunk_size: int 71 | ) -> torch.Tensor: 72 | # Since the "batch dim" in this case is not a true batch dimension 73 | # (in that the shape of the output depends on it), we need to 74 | # iterate over it ourselves 75 | a_reshape = a.reshape((-1,) + a.shape[-3:]) 76 | b_reshape = b.reshape((-1,) + b.shape[-3:]) 77 | out = [] 78 | for a_prime, b_prime in zip(a_reshape, b_reshape): 79 | outer = chunk_layer( 80 | partial(self._opm, b=b_prime), 81 | {"a": a_prime}, 82 | chunk_size=chunk_size, 83 | no_batch_dims=1, 84 | ) 85 | out.append(outer) 86 | 87 | # For some cursed reason making this distinction saves memory 88 | if(len(out) == 1): 89 | outer = out[0].unsqueeze(0) 90 | else: 91 | outer = torch.stack(out, dim=0) 92 | 93 | outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) 94 | 95 | return outer 96 | 97 | def _forward(self, 98 | m: torch.Tensor, 99 | mask: Optional[torch.Tensor] = None, 100 | chunk_size: Optional[int] = None, 101 | inplace_safe: bool = False, 102 | ) -> torch.Tensor: 103 | """ 104 | Args: 105 | m: 106 | [*, N_seq, N_res, C_m] MSA embedding 107 | mask: 108 | [*, N_seq, N_res] MSA mask 109 | Returns: 110 | [*, N_res, N_res, C_z] pair embedding update 111 | """ 112 | if mask is None: 113 | mask = m.new_ones(m.shape[:-1]) 114 | 115 | # [*, N_seq, N_res, C_m] 116 | ln = self.layer_norm(m) 117 | 118 | # [*, N_seq, N_res, C] 119 | mask = mask.unsqueeze(-1) 120 | a = self.linear_1(ln) 121 | a = a * mask 122 | 123 | b = self.linear_2(ln) 124 | b = b * mask 125 | 126 | del ln 127 | 128 | a = a.transpose(-2, -3) 129 | b = b.transpose(-2, -3) 130 | 131 | if chunk_size is not None: 132 | outer = self._chunk(a, b, chunk_size) 133 | else: 134 | outer = self._opm(a, b) 135 | 136 | # [*, N_res, N_res, 1] 137 | norm = torch.einsum("...abc,...adc->...bdc", mask, mask) 138 | norm = norm + self.eps 139 | 140 | # [*, N_res, N_res, C_z] 141 | if(inplace_safe): 142 | outer /= norm 143 | else: 144 | outer = outer / norm 145 | 146 | return outer 147 | 148 | def forward(self, 149 | m: torch.Tensor, 150 | mask: Optional[torch.Tensor] = None, 151 | chunk_size: Optional[int] = None, 152 | inplace_safe: bool = False, 153 | ) -> torch.Tensor: 154 | if(is_fp16_enabled()): 155 | with torch.cuda.amp.autocast(enabled=False): 156 | return self._forward(m.float(), mask, chunk_size, inplace_safe) 157 | else: 158 | return self._forward(m, mask, chunk_size, inplace_safe) 159 | 160 | -------------------------------------------------------------------------------- /sam/openfold/data/tools/hmmsearch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """A Python wrapper for hmmsearch - search profile against a sequence db.""" 16 | 17 | import os 18 | import subprocess 19 | from typing import Optional, Sequence 20 | 21 | from absl import logging 22 | from sam.openfold.data import parsers 23 | from sam.openfold.data.tools import hmmbuild 24 | from sam.openfold.data.tools import utils 25 | 26 | 27 | class Hmmsearch(object): 28 | """Python wrapper of the hmmsearch binary.""" 29 | 30 | def __init__(self, 31 | *, 32 | binary_path: str, 33 | hmmbuild_binary_path: str, 34 | database_path: str, 35 | flags: Optional[Sequence[str]] = None 36 | ): 37 | """Initializes the Python hmmsearch wrapper. 38 | 39 | Args: 40 | binary_path: The path to the hmmsearch executable. 41 | hmmbuild_binary_path: The path to the hmmbuild executable. Used to build 42 | an hmm from an input a3m. 43 | database_path: The path to the hmmsearch database (FASTA format). 44 | flags: List of flags to be used by hmmsearch. 45 | 46 | Raises: 47 | RuntimeError: If hmmsearch binary not found within the path. 48 | """ 49 | self.binary_path = binary_path 50 | self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path) 51 | self.database_path = database_path 52 | if flags is None: 53 | # Default hmmsearch run settings. 54 | flags = ['--F1', '0.1', 55 | '--F2', '0.1', 56 | '--F3', '0.1', 57 | '--incE', '100', 58 | '-E', '100', 59 | '--domE', '100', 60 | '--incdomE', '100'] 61 | self.flags = flags 62 | 63 | if not os.path.exists(self.database_path): 64 | logging.error('Could not find hmmsearch database %s', database_path) 65 | raise ValueError(f'Could not find hmmsearch database {database_path}') 66 | 67 | @property 68 | def output_format(self) -> str: 69 | return 'sto' 70 | 71 | @property 72 | def input_format(self) -> str: 73 | return 'sto' 74 | 75 | def query(self, msa_sto: str, output_dir: Optional[str] = None) -> str: 76 | """Queries the database using hmmsearch using a given stockholm msa.""" 77 | hmm = self.hmmbuild_runner.build_profile_from_sto( 78 | msa_sto, 79 | model_construction='hand' 80 | ) 81 | return self.query_with_hmm(hmm, output_dir) 82 | 83 | def query_with_hmm(self, 84 | hmm: str, 85 | output_dir: Optional[str] = None 86 | ) -> str: 87 | """Queries the database using hmmsearch using a given hmm.""" 88 | with utils.tmpdir_manager() as query_tmp_dir: 89 | hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') 90 | output_dir = query_tmp_dir if output_dir is None else output_dir 91 | out_path = os.path.join(output_dir, 'hmm_output.sto') 92 | with open(hmm_input_path, 'w') as f: 93 | f.write(hmm) 94 | 95 | cmd = [ 96 | self.binary_path, 97 | '--noali', # Don't include the alignment in stdout. 98 | '--cpu', '8' 99 | ] 100 | # If adding flags, we have to do so before the output and input: 101 | if self.flags: 102 | cmd.extend(self.flags) 103 | cmd.extend([ 104 | '-A', out_path, 105 | hmm_input_path, 106 | self.database_path, 107 | ]) 108 | 109 | logging.info('Launching sub-process %s', cmd) 110 | process = subprocess.Popen( 111 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 112 | with utils.timing( 113 | f'hmmsearch ({os.path.basename(self.database_path)}) query'): 114 | stdout, stderr = process.communicate() 115 | retcode = process.wait() 116 | 117 | if retcode: 118 | raise RuntimeError( 119 | 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( 120 | stdout.decode('utf-8'), stderr.decode('utf-8'))) 121 | 122 | with open(out_path) as f: 123 | out_msa = f.read() 124 | 125 | return out_msa 126 | 127 | @staticmethod 128 | def get_template_hits( 129 | output_string: str, 130 | input_sequence: str 131 | ) -> Sequence[parsers.TemplateHit]: 132 | """Gets parsed template hits from the raw string output by the tool.""" 133 | template_hits = parsers.parse_hmmsearch_sto( 134 | output_string, 135 | input_sequence, 136 | ) 137 | return template_hits 138 | -------------------------------------------------------------------------------- /sam/openfold/model/triangular_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from functools import partialmethod, partial 17 | import math 18 | from typing import Optional, List 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | from sam.openfold.model.primitives import Linear, LayerNorm, Attention 24 | from sam.openfold.utils.chunk_utils import chunk_layer 25 | from sam.openfold.utils.tensor_utils import ( 26 | permute_final_dims, 27 | flatten_final_dims, 28 | ) 29 | 30 | 31 | class TriangleAttention(nn.Module): 32 | def __init__( 33 | self, c_in, c_hidden, no_heads, starting=True, inf=1e9 34 | ): 35 | """ 36 | Args: 37 | c_in: 38 | Input channel dimension 39 | c_hidden: 40 | Overall hidden channel dimension (not per-head) 41 | no_heads: 42 | Number of attention heads 43 | """ 44 | super(TriangleAttention, self).__init__() 45 | 46 | self.c_in = c_in 47 | self.c_hidden = c_hidden 48 | self.no_heads = no_heads 49 | self.starting = starting 50 | self.inf = inf 51 | 52 | self.layer_norm = LayerNorm(self.c_in) 53 | 54 | self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") 55 | 56 | self.mha = Attention( 57 | self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads 58 | ) 59 | 60 | @torch.jit.ignore 61 | def _chunk(self, 62 | x: torch.Tensor, 63 | biases: List[torch.Tensor], 64 | chunk_size: int, 65 | use_memory_efficient_kernel: bool = False, 66 | use_deepspeed_evo_attention: bool = False, 67 | use_lma: bool = False, 68 | inplace_safe: bool = False, 69 | ) -> torch.Tensor: 70 | "triangle! triangle!" 71 | mha_inputs = { 72 | "q_x": x, 73 | "kv_x": x, 74 | "biases": biases, 75 | } 76 | 77 | return chunk_layer( 78 | partial( 79 | self.mha, 80 | use_memory_efficient_kernel=use_memory_efficient_kernel, 81 | use_deepspeed_evo_attention=use_deepspeed_evo_attention, 82 | use_lma=use_lma 83 | ), 84 | mha_inputs, 85 | chunk_size=chunk_size, 86 | no_batch_dims=len(x.shape[:-2]), 87 | _out=x if inplace_safe else None, 88 | ) 89 | 90 | def forward(self, 91 | x: torch.Tensor, 92 | mask: Optional[torch.Tensor] = None, 93 | chunk_size: Optional[int] = None, 94 | use_memory_efficient_kernel: bool = False, 95 | use_deepspeed_evo_attention: bool = False, 96 | use_lma: bool = False, 97 | inplace_safe: bool = False, 98 | ) -> torch.Tensor: 99 | """ 100 | Args: 101 | x: 102 | [*, I, J, C_in] input tensor (e.g. the pair representation) 103 | Returns: 104 | [*, I, J, C_in] output tensor 105 | """ 106 | if mask is None: 107 | # [*, I, J] 108 | mask = x.new_ones( 109 | x.shape[:-1], 110 | ) 111 | 112 | if(not self.starting): 113 | x = x.transpose(-2, -3) 114 | mask = mask.transpose(-1, -2) 115 | 116 | # [*, I, J, C_in] 117 | x = self.layer_norm(x) 118 | 119 | # [*, I, 1, 1, J] 120 | mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] 121 | 122 | # [*, H, I, J] 123 | triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) 124 | 125 | # [*, 1, H, I, J] 126 | triangle_bias = triangle_bias.unsqueeze(-4) 127 | 128 | biases = [mask_bias, triangle_bias] 129 | 130 | if chunk_size is not None: 131 | x = self._chunk( 132 | x, 133 | biases, 134 | chunk_size, 135 | use_memory_efficient_kernel=use_memory_efficient_kernel, 136 | use_deepspeed_evo_attention=use_deepspeed_evo_attention, 137 | use_lma=use_lma, 138 | inplace_safe=inplace_safe, 139 | ) 140 | else: 141 | x = self.mha( 142 | q_x=x, 143 | kv_x=x, 144 | biases=biases, 145 | use_memory_efficient_kernel=use_memory_efficient_kernel, 146 | use_deepspeed_evo_attention=use_deepspeed_evo_attention, 147 | use_lma=use_lma 148 | ) 149 | 150 | if(not self.starting): 151 | x = x.transpose(-2, -3) 152 | 153 | return x 154 | 155 | 156 | # Implements Algorithm 13 157 | TriangleAttentionStartingNode = TriangleAttention 158 | 159 | 160 | class TriangleAttentionEndingNode(TriangleAttention): 161 | """ 162 | Implements Algorithm 14. 163 | """ 164 | __init__ = partialmethod(TriangleAttention.__init__, starting=False) 165 | -------------------------------------------------------------------------------- /sam/openfold/np/relax/cleanup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Cleans up a PDB file using pdbfixer in preparation for OpenMM simulations. 16 | 17 | fix_pdb uses a third-party tool. We also support fixing some additional edge 18 | cases like removing chains of length one (see clean_structure). 19 | """ 20 | import io 21 | 22 | import pdbfixer 23 | from openmm import app 24 | from openmm.app import element 25 | 26 | 27 | def fix_pdb(pdbfile, alterations_info): 28 | """Apply pdbfixer to the contents of a PDB file; return a PDB string result. 29 | 30 | 1) Replaces nonstandard residues. 31 | 2) Removes heterogens (non protein residues) including water. 32 | 3) Adds missing residues and missing atoms within existing residues. 33 | 4) Adds hydrogens assuming pH=7.0. 34 | 5) KeepIds is currently true, so the fixer must keep the existing chain and 35 | residue identifiers. This will fail for some files in wider PDB that have 36 | invalid IDs. 37 | 38 | Args: 39 | pdbfile: Input PDB file handle. 40 | alterations_info: A dict that will store details of changes made. 41 | 42 | Returns: 43 | A PDB string representing the fixed structure. 44 | """ 45 | fixer = pdbfixer.PDBFixer(pdbfile=pdbfile) 46 | fixer.findNonstandardResidues() 47 | alterations_info["nonstandard_residues"] = fixer.nonstandardResidues 48 | fixer.replaceNonstandardResidues() 49 | _remove_heterogens(fixer, alterations_info, keep_water=False) 50 | fixer.findMissingResidues() 51 | alterations_info["missing_residues"] = fixer.missingResidues 52 | fixer.findMissingAtoms() 53 | alterations_info["missing_heavy_atoms"] = fixer.missingAtoms 54 | alterations_info["missing_terminals"] = fixer.missingTerminals 55 | fixer.addMissingAtoms(seed=0) 56 | fixer.addMissingHydrogens() 57 | out_handle = io.StringIO() 58 | app.PDBFile.writeFile( 59 | fixer.topology, fixer.positions, out_handle, keepIds=True 60 | ) 61 | return out_handle.getvalue() 62 | 63 | 64 | def clean_structure(pdb_structure, alterations_info): 65 | """Applies additional fixes to an OpenMM structure, to handle edge cases. 66 | 67 | Args: 68 | pdb_structure: An OpenMM structure to modify and fix. 69 | alterations_info: A dict that will store details of changes made. 70 | """ 71 | _replace_met_se(pdb_structure, alterations_info) 72 | _remove_chains_of_length_one(pdb_structure, alterations_info) 73 | 74 | 75 | def _remove_heterogens(fixer, alterations_info, keep_water): 76 | """Removes the residues that Pdbfixer considers to be heterogens. 77 | 78 | Args: 79 | fixer: A Pdbfixer instance. 80 | alterations_info: A dict that will store details of changes made. 81 | keep_water: If True, water (HOH) is not considered to be a heterogen. 82 | """ 83 | initial_resnames = set() 84 | for chain in fixer.topology.chains(): 85 | for residue in chain.residues(): 86 | initial_resnames.add(residue.name) 87 | fixer.removeHeterogens(keepWater=keep_water) 88 | final_resnames = set() 89 | for chain in fixer.topology.chains(): 90 | for residue in chain.residues(): 91 | final_resnames.add(residue.name) 92 | alterations_info["removed_heterogens"] = initial_resnames.difference( 93 | final_resnames 94 | ) 95 | 96 | 97 | def _replace_met_se(pdb_structure, alterations_info): 98 | """Replace the Se in any MET residues that were not marked as modified.""" 99 | modified_met_residues = [] 100 | for res in pdb_structure.iter_residues(): 101 | name = res.get_name_with_spaces().strip() 102 | if name == "MET": 103 | s_atom = res.get_atom("SD") 104 | if s_atom.element_symbol == "Se": 105 | s_atom.element_symbol = "S" 106 | s_atom.element = element.get_by_symbol("S") 107 | modified_met_residues.append(s_atom.residue_number) 108 | alterations_info["Se_in_MET"] = modified_met_residues 109 | 110 | 111 | def _remove_chains_of_length_one(pdb_structure, alterations_info): 112 | """Removes chains that correspond to a single amino acid. 113 | 114 | A single amino acid in a chain is both N and C terminus. There is no force 115 | template for this case. 116 | 117 | Args: 118 | pdb_structure: An OpenMM pdb_structure to modify and fix. 119 | alterations_info: A dict that will store details of changes made. 120 | """ 121 | removed_chains = {} 122 | for model in pdb_structure.iter_models(): 123 | valid_chains = [c for c in model.iter_chains() if len(c) > 1] 124 | invalid_chain_ids = [ 125 | c.chain_id for c in model.iter_chains() if len(c) <= 1 126 | ] 127 | model.chains = valid_chains 128 | for chain_id in invalid_chain_ids: 129 | model.chains_by_id.pop(chain_id) 130 | removed_chains[model.number] = invalid_chain_ids 131 | alterations_info["removed_chains"] = removed_chains 132 | -------------------------------------------------------------------------------- /sam/data/topology.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import mdtraj 4 | from sam.data.sequences import (aa_one_letter, 5 | aa_three_letters, 6 | aa_one_to_three_dict) 7 | 8 | 9 | def get_ca_topology(sequence: str): 10 | topology = mdtraj.Topology() 11 | chain = topology.add_chain() 12 | for res in sequence: 13 | res_obj = topology.add_residue(aa_one_to_three_dict[res], chain) 14 | topology.add_atom("CA", mdtraj.core.topology.elem.carbon, res_obj) 15 | return topology 16 | 17 | 18 | def get_seq_from_top(top: str) -> str: 19 | if isinstance(top, str): 20 | top_traj = mdtraj.load(top) 21 | topology = top_traj.topology 22 | else: 23 | topology = top 24 | seq = [r.code for r in topology.residues if \ 25 | r.name in aa_three_letters] 26 | return "".join(seq) 27 | 28 | 29 | def _check_ca_atom(a, standard=True) -> bool: 30 | if standard: 31 | return a.name == "CA" and a.residue.name in aa_three_letters # a.residue.code in aa_one_letter 32 | else: 33 | return a.name == "CA" 34 | 35 | # def _check_cg_atom(a): 36 | # return a.name in ("CG", "CG2") and a.residue.code in aa_one_letter 37 | 38 | def slice_ca_traj(traj, standard=True): 39 | ca_ids = [a.index for a in traj.topology.atoms \ 40 | if _check_ca_atom(a, standard)] 41 | traj = traj.atom_slice(ca_ids) 42 | return traj 43 | 44 | def slice_traj_to_com(traj, get_xyz=True): 45 | ha_ids = [a.index for a in traj.topology.atoms if \ 46 | a.residue.name in aa_three_letters and \ 47 | a.element.symbol != "H"] 48 | ha_traj = traj.atom_slice(ha_ids) 49 | residues = list(ha_traj.topology.residues) 50 | com_xyz = np.zeros((ha_traj.xyz.shape[0], len(residues), 3)) 51 | for i, residue_i in enumerate(residues): 52 | ha_ids_i = [a.index for a in residue_i.atoms] 53 | masses_i = np.array([a.element.mass for a in residue_i.atoms]) 54 | masses_i = masses_i[None,:,None] 55 | tot_mass_i = masses_i.sum() 56 | com_xyz_i = np.sum(ha_traj.xyz[:,ha_ids_i,:]*masses_i, axis=1)/tot_mass_i 57 | com_xyz[:,i,:] = com_xyz_i 58 | if get_xyz: 59 | return com_xyz 60 | else: 61 | return mdtraj.Trajectory( 62 | xyz=com_xyz, 63 | topology=get_ca_topology( 64 | sequence="".join([r.code for r in ha_traj.topology.residues]) 65 | )) 66 | 67 | 68 | residue_atoms = { 69 | "ALA": ["C", "CA", "CB", "N", "O"], 70 | "ARG": ["C", "CA", "CB", "CG", "CD", "CZ", "N", "NE", "O", "NH1", "NH2"], 71 | "ASP": ["C", "CA", "CB", "CG", "N", "O", "OD1", "OD2"], 72 | "ASN": ["C", "CA", "CB", "CG", "N", "ND2", "O", "OD1"], 73 | "CYS": ["C", "CA", "CB", "N", "O", "SG"], 74 | "GLU": ["C", "CA", "CB", "CG", "CD", "N", "O", "OE1", "OE2"], 75 | "GLN": ["C", "CA", "CB", "CG", "CD", "N", "NE2", "O", "OE1"], 76 | "GLY": ["C", "CA", "N", "O"], 77 | "HIS": ["C", "CA", "CB", "CG", "CD2", "CE1", "N", "ND1", "NE2", "O"], 78 | "ILE": ["C", "CA", "CB", "CG1", "CG2", "CD1", "N", "O"], 79 | "LEU": ["C", "CA", "CB", "CG", "CD1", "CD2", "N", "O"], 80 | "LYS": ["C", "CA", "CB", "CG", "CD", "CE", "N", "NZ", "O"], 81 | "MET": ["C", "CA", "CB", "CG", "CE", "N", "O", "SD"], 82 | "PHE": ["C", "CA", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "N", "O"], 83 | "PRO": ["C", "CA", "CB", "CG", "CD", "N", "O"], 84 | "SER": ["C", "CA", "CB", "N", "O", "OG"], 85 | "THR": ["C", "CA", "CB", "CG2", "N", "O", "OG1"], 86 | "TRP": [ 87 | "C", 88 | "CA", 89 | "CB", 90 | "CG", 91 | "CD1", 92 | "CD2", 93 | "CE2", 94 | "CE3", 95 | "CZ2", 96 | "CZ3", 97 | "CH2", 98 | "N", 99 | "NE1", 100 | "O", 101 | ], 102 | "TYR": [ 103 | "C", 104 | "CA", 105 | "CB", 106 | "CG", 107 | "CD1", 108 | "CD2", 109 | "CE1", 110 | "CE2", 111 | "CZ", 112 | "N", 113 | "O", 114 | "OH", 115 | ], 116 | "VAL": ["C", "CA", "CB", "CG1", "CG2", "N", "O"], 117 | } 118 | 119 | atom_types = [ 120 | "N", 121 | "CA", 122 | "C", 123 | "CB", 124 | "O", 125 | "CG", 126 | "CG1", 127 | "CG2", 128 | "OG", 129 | "OG1", 130 | "SG", 131 | "CD", 132 | "CD1", 133 | "CD2", 134 | "ND1", 135 | "ND2", 136 | "OD1", 137 | "OD2", 138 | "SD", 139 | "CE", 140 | "CE1", 141 | "CE2", 142 | "CE3", 143 | "NE", 144 | "NE1", 145 | "NE2", 146 | "OE1", 147 | "OE2", 148 | "CH2", 149 | "NH1", 150 | "NH2", 151 | "OH", 152 | "CZ", 153 | "CZ2", 154 | "CZ3", 155 | "NZ", 156 | "OXT", 157 | ] 158 | 159 | def get_atom14_sam_data(traj): 160 | topology = traj.topology 161 | n_frames = len(traj) 162 | n_residues = topology.n_residues 163 | 164 | atom14_xyz = np.zeros((n_frames, n_residues, 14, 3)) 165 | atom14_mask = np.zeros((n_frames, n_residues, 14)) 166 | for res_idx, res in enumerate(topology.residues): 167 | if res.name not in residue_atoms: 168 | raise KeyError(res.name) 169 | for atom in res.atoms: 170 | if atom.name in residue_atoms[res.name]: 171 | atom14_mask[:,res_idx,residue_atoms[res.name].index(atom.name)] = 1 172 | atom14_xyz[:, res_idx,residue_atoms[res.name].index(atom.name)] = traj.xyz[:,atom.index] 173 | else: 174 | pass 175 | for res_idx in range(n_residues): 176 | if atom14_mask[:,res_idx,1].sum() != n_frames: 177 | raise ValueError(f"No Ca atoms for residue id {res_idx}") 178 | return {"xyz": atom14_xyz, "top": atom14_mask} -------------------------------------------------------------------------------- /scripts/generate_ensemble.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate with SAM conformational ensemble for an input PDB file. 3 | Notes: 4 | On the first time you use this script, weights for the aSAM models will be 5 | automatically downloaded to ~/.sam2/weights. Change the $SAM_WEIGHTS_PATH 6 | environmemtal variable to change the download path. 7 | """ 8 | 9 | import os 10 | import sys 11 | import argparse 12 | import time 13 | import numpy as np 14 | import mdtraj 15 | from sam.model import AllAtomSAM 16 | from sam.utils import read_cfg_file, print_msg, check_sam_weights 17 | from sam.data.topology import get_seq_from_top 18 | from sam.minimizer.runner import Minimizer 19 | 20 | 21 | if __name__ == "__main__": 22 | 23 | parser = argparse.ArgumentParser( 24 | description=__doc__) 25 | parser.add_argument('-c', '--config_fp', type=str, required=True, 26 | help='YAML or JSON configuration file for a SAM generative model.') 27 | parser.add_argument('-i', '--init', type=str, required=True, 28 | help='Input PDB file with the initial structure.') 29 | parser.add_argument('-o', '--out_path', type=str, required=True, 30 | help='Output path. File extensions for different file types will be' 31 | ' automatically added.') 32 | parser.add_argument('-u', '--out_fmt', type=str, default='dcd', 33 | choices=['dcd', 'xtc'], 34 | help='Output format for the file storing xyz coordinates.' 35 | ' (default: dcd)') 36 | parser.add_argument('-n', '--n_samples', type=int, default=250, 37 | help='Number of samples to generate. (default: 250)') 38 | parser.add_argument('-t', '--n_steps', type=int, default=100, 39 | help='Number of diffusion steps. (min=1, max=1000) (default: 100)') 40 | parser.add_argument('-b', '--batch_size', type=int, default=8, 41 | help='Batch size for sampling. (default: 8)') 42 | parser.add_argument('-d', '--device', type=str, default='cuda', 43 | choices=['cuda', 'cpu'], help='PyTorch device. (default: cuda)') 44 | parser.add_argument('-T', '--temperature', type=float, 45 | help='temperature (optional, only for temperature-based models)') 46 | parser.add_argument('-q', '--quiet', action='store_true', 47 | help='Quiet mode, will not print any output.') 48 | parser.add_argument('--no_minimize', action='store_true', 49 | help='Do not perform energy minimization.') 50 | parser.add_argument('--keep_no_min', action='store_true', 51 | help='If performing energy minimization, save also a trajectory file' 52 | ' for the non-minimized ensemble.') 53 | parser.add_argument('--ca', action='store_true', 54 | help='Save an additional Ca-only trajectory.') 55 | parser.add_argument('--time', action='store_true', 56 | help='Save an output file with the wall clock time of sampling.') 57 | args = parser.parse_args() 58 | 59 | 60 | #--------------- 61 | # Check input. - 62 | #--------------- 63 | 64 | timing = {"all": time.time(), "sample": None} 65 | 66 | if not os.path.isfile(args.init): 67 | raise FileNotFoundError(args.init) 68 | 69 | check_sam_weights(args.config_fp) 70 | model_cfg = read_cfg_file(args.config_fp) 71 | # check_env(model_cfg) 72 | 73 | tem_traj = mdtraj.load(args.init, top=args.init) 74 | tbm_data = {"xyz": tem_traj.xyz} 75 | if model_cfg["generative_stack"]["data_type"] == "aa_protein": 76 | tbm_data["topology"] = tem_traj.topology 77 | seq = get_seq_from_top(tbm_data["topology"]) 78 | 79 | #----------- 80 | # Run SAM. - 81 | #----------- 82 | 83 | # Initialize the SAM model. 84 | if model_cfg["generative_stack"]["data_type"] == "cg_protein": 85 | raise NotImplementedError() 86 | elif model_cfg["generative_stack"]["data_type"] == "aa_protein": 87 | model_cls = AllAtomSAM 88 | else: 89 | raise KeyError(model_cfg["generative_stack"]["data_type"]) 90 | 91 | model = model_cls( 92 | config_fp=args.config_fp, 93 | device=args.device, 94 | verbose=not args.quiet 95 | ) 96 | 97 | conditions = {} 98 | if args.temperature is not None: 99 | conditions["temperature"] = args.temperature 100 | sample_args = {} 101 | 102 | # Generate ensemble. 103 | timing["sample"] = time.time() 104 | out = model.sample( 105 | seq=seq, 106 | n_samples=args.n_samples, 107 | n_steps=args.n_steps, 108 | batch_size_eps=args.batch_size, 109 | batch_size_dec=args.batch_size, 110 | tbm_data=tbm_data, 111 | return_enc=False, 112 | sample_args=sample_args, 113 | conditions=conditions, 114 | use_cache=True 115 | ) 116 | timing["sample"] = time.time() - timing["sample"] 117 | 118 | # Save the output data. 119 | save = model.save( 120 | out=out, 121 | out_path=args.out_path, 122 | out_fmt=args.out_fmt, 123 | save_ca=args.ca 124 | ) 125 | # tem_traj.save(f"{args.out_path}.template.pdb") 126 | 127 | #------------------------------------- 128 | # Energy minimize the conformations. - 129 | #------------------------------------- 130 | 131 | if args.no_minimize or model_cfg["minimization"]["protocol"] is None: 132 | pass 133 | else: 134 | timing["min"] = time.time() 135 | min_obj = Minimizer( 136 | name="sam_ensemble", 137 | top_fp=save["aa_top"], 138 | ens_fp=save["aa_traj"], 139 | protocol=model_cfg["minimization"]["protocol"] 140 | ) 141 | min_traj = min_obj.run(device=args.device, verbose=not args.quiet) 142 | if args.keep_no_min: 143 | min_out_str = ".min" 144 | else: 145 | min_out_str = "" 146 | min_traj_path = f"{args.out_path}{min_out_str}.traj.{args.out_fmt}" 147 | print_msg( 148 | f"- Saving a trajectory file to: {min_traj_path}", 149 | verbose=not args.quiet, 150 | tag="minimization" 151 | ) 152 | min_traj.save(min_traj_path) 153 | min_top_path = f"{args.out_path}{min_out_str}.top.pdb" 154 | print_msg( 155 | f"- Saving a topology PDB file to: {min_top_path}", 156 | verbose=not args.quiet, 157 | tag="minimization" 158 | ) 159 | min_traj[0].save(min_top_path) 160 | timing["min"] = time.time() - timing["min"] 161 | 162 | #------------ 163 | # Complete. - 164 | #------------ 165 | 166 | timing["all"] = time.time() - timing["all"] 167 | if args.time: 168 | with open(f"{args.out_path}.time.txt", "w") as o_fh: 169 | for stage in timing: 170 | o_fh.write(f"{stage}: {timing[stage]}\n") -------------------------------------------------------------------------------- /sam/openfold/data/input_pipeline_multimer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import random 17 | import torch 18 | 19 | from sam.openfold.data import ( 20 | data_transforms, 21 | data_transforms_multimer, 22 | ) 23 | 24 | 25 | def groundtruth_transforms_fns(): 26 | transforms = [data_transforms.make_atom14_masks, 27 | data_transforms.make_atom14_positions, 28 | data_transforms.atom37_to_frames, 29 | data_transforms.atom37_to_torsion_angles(""), 30 | data_transforms.make_pseudo_beta(""), 31 | data_transforms.get_backbone_frames, 32 | data_transforms.get_chi_angles] 33 | return transforms 34 | 35 | 36 | def nonensembled_transform_fns(): 37 | """Input pipeline data transformers that are not ensembled.""" 38 | transforms = [ 39 | data_transforms.cast_to_64bit_ints, 40 | data_transforms_multimer.make_msa_profile, 41 | data_transforms_multimer.create_target_feat, 42 | data_transforms.make_atom14_masks 43 | ] 44 | 45 | return transforms 46 | 47 | 48 | def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): 49 | """Input pipeline data transformers that can be ensembled and averaged.""" 50 | transforms = [] 51 | 52 | pad_msa_clusters = mode_cfg.max_msa_clusters 53 | max_msa_clusters = pad_msa_clusters 54 | max_extra_msa = mode_cfg.max_extra_msa 55 | 56 | msa_seed = None 57 | if(not common_cfg.resample_msa_in_recycling): 58 | msa_seed = ensemble_seed 59 | 60 | transforms.append( 61 | data_transforms_multimer.sample_msa( 62 | max_msa_clusters, 63 | max_extra_msa, 64 | seed=msa_seed, 65 | ) 66 | ) 67 | 68 | if "masked_msa" in common_cfg: 69 | # Masked MSA should come *before* MSA clustering so that 70 | # the clustering and full MSA profile do not leak information about 71 | # the masked locations and secret corrupted locations. 72 | transforms.append( 73 | data_transforms_multimer.make_masked_msa( 74 | common_cfg.masked_msa, 75 | mode_cfg.masked_msa_replace_fraction, 76 | seed=(msa_seed + 1) if msa_seed else None, 77 | ) 78 | ) 79 | 80 | transforms.append(data_transforms_multimer.nearest_neighbor_clusters()) 81 | transforms.append(data_transforms_multimer.create_msa_feat) 82 | 83 | crop_feats = dict(common_cfg.feat) 84 | 85 | if mode_cfg.fixed_size: 86 | transforms.append(data_transforms.select_feat(list(crop_feats))) 87 | 88 | if mode_cfg.crop: 89 | transforms.append( 90 | data_transforms_multimer.random_crop_to_size( 91 | crop_size=mode_cfg.crop_size, 92 | max_templates=mode_cfg.max_templates, 93 | shape_schema=crop_feats, 94 | spatial_crop_prob=mode_cfg.spatial_crop_prob, 95 | interface_threshold=mode_cfg.interface_threshold, 96 | subsample_templates=mode_cfg.subsample_templates, 97 | seed=ensemble_seed + 1, 98 | ) 99 | ) 100 | transforms.append( 101 | data_transforms.make_fixed_size( 102 | shape_schema=crop_feats, 103 | msa_cluster_size=pad_msa_clusters, 104 | extra_msa_size=mode_cfg.max_extra_msa, 105 | num_res=mode_cfg.crop_size, 106 | num_templates=mode_cfg.max_templates, 107 | ) 108 | ) 109 | else: 110 | transforms.append( 111 | data_transforms.crop_templates(mode_cfg.max_templates) 112 | ) 113 | 114 | return transforms 115 | 116 | 117 | def prepare_ground_truth_features(tensors): 118 | """Prepare ground truth features that are only needed for loss calculation during training""" 119 | 120 | gt_features = ['all_atom_mask', 'all_atom_positions', 'asym_id', 'sym_id', 'entity_id'] 121 | gt_tensors = {k: v for k, v in tensors.items() if k in gt_features} 122 | gt_tensors['aatype'] = tensors['aatype'].to(torch.long) 123 | gt_tensors = compose(groundtruth_transforms_fns())(gt_tensors) 124 | return gt_tensors 125 | 126 | 127 | def process_tensors_from_config(tensors, common_cfg, mode_cfg): 128 | """Based on the config, apply filters and transformations to the data.""" 129 | 130 | process_gt_feats = mode_cfg.supervised 131 | gt_tensors = {} 132 | if process_gt_feats: 133 | gt_tensors = prepare_ground_truth_features(tensors) 134 | 135 | ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max) 136 | tensors['aatype'] = tensors['aatype'].to(torch.long) 137 | nonensembled = nonensembled_transform_fns() 138 | tensors = compose(nonensembled)(tensors) 139 | if("no_recycling_iters" in tensors): 140 | num_recycling = int(tensors["no_recycling_iters"]) 141 | else: 142 | num_recycling = common_cfg.max_recycling_iters 143 | 144 | def wrap_ensemble_fn(data, i): 145 | """Function to be mapped over the ensemble dimension.""" 146 | d = data.copy() 147 | fns = ensembled_transform_fns( 148 | common_cfg, 149 | mode_cfg, 150 | ensemble_seed, 151 | ) 152 | fn = compose(fns) 153 | d["ensemble_index"] = i 154 | return fn(d) 155 | 156 | tensors = map_fn( 157 | lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1) 158 | ) 159 | 160 | if process_gt_feats: 161 | tensors['gt_features'] = gt_tensors 162 | 163 | return tensors 164 | 165 | @data_transforms.curry1 166 | def compose(x, fs): 167 | for f in fs: 168 | x = f(x) 169 | return x 170 | 171 | 172 | def map_fn(fun, x): 173 | ensembles = [fun(elem) for elem in x] 174 | features = ensembles[0].keys() 175 | ensembled_dict = {} 176 | for feat in features: 177 | ensembled_dict[feat] = torch.stack( 178 | [dict_i[feat] for dict_i in ensembles], dim=-1 179 | ) 180 | return ensembled_dict 181 | -------------------------------------------------------------------------------- /sam/openfold/utils/geometry/rigid_matrix_vector.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Rigid3Array Transformations represented by a Matrix and a Vector.""" 15 | 16 | from __future__ import annotations 17 | import dataclasses 18 | from typing import Union, List 19 | 20 | import torch 21 | 22 | from sam.openfold.utils.geometry import rotation_matrix 23 | from sam.openfold.utils.geometry import vector 24 | 25 | 26 | Float = Union[float, torch.Tensor] 27 | 28 | 29 | @dataclasses.dataclass(frozen=True) 30 | class Rigid3Array: 31 | """Rigid Transformation, i.e. element of special euclidean group.""" 32 | 33 | rotation: rotation_matrix.Rot3Array 34 | translation: vector.Vec3Array 35 | 36 | def __matmul__(self, other: Rigid3Array) -> Rigid3Array: 37 | new_rotation = self.rotation @ other.rotation # __matmul__ 38 | new_translation = self.apply_to_point(other.translation) 39 | return Rigid3Array(new_rotation, new_translation) 40 | 41 | def __getitem__(self, index) -> Rigid3Array: 42 | return Rigid3Array( 43 | self.rotation[index], 44 | self.translation[index], 45 | ) 46 | 47 | def __mul__(self, other: torch.Tensor) -> Rigid3Array: 48 | return Rigid3Array( 49 | self.rotation * other, 50 | self.translation * other, 51 | ) 52 | 53 | def map_tensor_fn(self, fn) -> Rigid3Array: 54 | return Rigid3Array( 55 | self.rotation.map_tensor_fn(fn), 56 | self.translation.map_tensor_fn(fn), 57 | ) 58 | 59 | def inverse(self) -> Rigid3Array: 60 | """Return Rigid3Array corresponding to inverse transform.""" 61 | inv_rotation = self.rotation.inverse() 62 | inv_translation = inv_rotation.apply_to_point(-self.translation) 63 | return Rigid3Array(inv_rotation, inv_translation) 64 | 65 | def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: 66 | """Apply Rigid3Array transform to point.""" 67 | return self.rotation.apply_to_point(point) + self.translation 68 | 69 | def apply(self, point: torch.Tensor) -> torch.Tensor: 70 | return self.apply_to_point(vector.Vec3Array.from_array(point)).to_tensor() 71 | 72 | def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: 73 | """Apply inverse Rigid3Array transform to point.""" 74 | new_point = point - self.translation 75 | return self.rotation.apply_inverse_to_point(new_point) 76 | 77 | def invert_apply(self, point: torch.Tensor) -> torch.Tensor: 78 | return self.apply_inverse_to_point(vector.Vec3Array.from_array(point)).to_tensor() 79 | 80 | def compose_rotation(self, other_rotation): 81 | rot = self.rotation @ other_rotation 82 | return Rigid3Array(rot, self.translation.clone()) 83 | 84 | def compose(self, other_rigid): 85 | return self @ other_rigid 86 | 87 | def unsqueeze(self, dim: int): 88 | return Rigid3Array( 89 | self.rotation.unsqueeze(dim), 90 | self.translation.unsqueeze(dim), 91 | ) 92 | 93 | @property 94 | def shape(self) -> torch.Size: 95 | return self.rotation.xx.shape 96 | 97 | @property 98 | def dtype(self) -> torch.dtype: 99 | return self.rotation.xx.dtype 100 | 101 | @property 102 | def device(self) -> torch.device: 103 | return self.rotation.xx.device 104 | 105 | @classmethod 106 | def identity(cls, shape, device) -> Rigid3Array: 107 | """Return identity Rigid3Array of given shape.""" 108 | return cls( 109 | rotation_matrix.Rot3Array.identity(shape, device), 110 | vector.Vec3Array.zeros(shape, device) 111 | ) 112 | 113 | @classmethod 114 | def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array: 115 | return cls( 116 | rotation_matrix.Rot3Array.cat( 117 | [r.rotation for r in rigids], dim=dim 118 | ), 119 | vector.Vec3Array.cat( 120 | [r.translation for r in rigids], dim=dim 121 | ), 122 | ) 123 | 124 | def scale_translation(self, factor: Float) -> Rigid3Array: 125 | """Scale translation in Rigid3Array by 'factor'.""" 126 | return Rigid3Array(self.rotation, self.translation * factor) 127 | 128 | def to_tensor(self) -> torch.Tensor: 129 | rot_array = self.rotation.to_tensor() 130 | vec_array = self.translation.to_tensor() 131 | array = torch.zeros( 132 | rot_array.shape[:-2] + (4, 4), 133 | device=rot_array.device, 134 | dtype=rot_array.dtype 135 | ) 136 | array[..., :3, :3] = rot_array 137 | array[..., :3, 3] = vec_array 138 | array[..., 3, 3] = 1. 139 | return array 140 | 141 | def to_tensor_4x4(self) -> torch.Tensor: 142 | return self.to_tensor() 143 | 144 | def reshape(self, new_shape) -> Rigid3Array: 145 | rots = self.rotation.reshape(new_shape) 146 | trans = self.translation.reshape(new_shape) 147 | return Rigid3Array(rots, trans) 148 | 149 | def stop_rot_gradient(self) -> Rigid3Array: 150 | return Rigid3Array( 151 | self.rotation.stop_gradient(), 152 | self.translation, 153 | ) 154 | 155 | @classmethod 156 | def from_array(cls, array): 157 | rot = rotation_matrix.Rot3Array.from_array( 158 | array[..., :3, :3], 159 | ) 160 | vec = vector.Vec3Array.from_array(array[..., :3, 3]) 161 | return cls(rot, vec) 162 | 163 | @classmethod 164 | def from_tensor_4x4(cls, array): 165 | return cls.from_array(array) 166 | 167 | @classmethod 168 | def from_array4x4(cls, array: torch.tensor) -> Rigid3Array: 169 | """Construct Rigid3Array from homogeneous 4x4 array.""" 170 | rotation = rotation_matrix.Rot3Array( 171 | array[..., 0, 0], array[..., 0, 1], array[..., 0, 2], 172 | array[..., 1, 0], array[..., 1, 1], array[..., 1, 2], 173 | array[..., 2, 0], array[..., 2, 1], array[..., 2, 2] 174 | ) 175 | translation = vector.Vec3Array( 176 | array[..., 0, 3], array[..., 1, 3], array[..., 2, 3] 177 | ) 178 | return cls(rotation, translation) 179 | 180 | def cuda(self) -> Rigid3Array: 181 | return Rigid3Array.from_tensor_4x4(self.to_tensor_4x4().cuda()) 182 | -------------------------------------------------------------------------------- /sam/openfold/data/tools/hhblits.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Library to run HHblits from Python.""" 17 | import glob 18 | import logging 19 | import os 20 | import subprocess 21 | from typing import Any, List, Mapping, Optional, Sequence 22 | 23 | from sam.openfold.data.tools import utils 24 | 25 | 26 | _HHBLITS_DEFAULT_P = 20 27 | _HHBLITS_DEFAULT_Z = 500 28 | 29 | 30 | class HHBlits: 31 | """Python wrapper of the HHblits binary.""" 32 | 33 | def __init__( 34 | self, 35 | *, 36 | binary_path: str, 37 | databases: Sequence[str], 38 | n_cpu: int = 4, 39 | n_iter: int = 3, 40 | e_value: float = 0.001, 41 | maxseq: int = 1_000_000, 42 | realign_max: int = 100_000, 43 | maxfilt: int = 100_000, 44 | min_prefilter_hits: int = 1000, 45 | all_seqs: bool = False, 46 | alt: Optional[int] = None, 47 | p: int = _HHBLITS_DEFAULT_P, 48 | z: int = _HHBLITS_DEFAULT_Z, 49 | ): 50 | """Initializes the Python HHblits wrapper. 51 | 52 | Args: 53 | binary_path: The path to the HHblits executable. 54 | databases: A sequence of HHblits database paths. This should be the 55 | common prefix for the database files (i.e. up to but not including 56 | _hhm.ffindex etc.) 57 | n_cpu: The number of CPUs to give HHblits. 58 | n_iter: The number of HHblits iterations. 59 | e_value: The E-value, see HHblits docs for more details. 60 | maxseq: The maximum number of rows in an input alignment. Note that this 61 | parameter is only supported in HHBlits version 3.1 and higher. 62 | realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. 63 | maxfilt: Max number of hits allowed to pass the 2nd prefilter. 64 | HHblits default: 20000. 65 | min_prefilter_hits: Min number of hits to pass prefilter. 66 | HHblits default: 100. 67 | all_seqs: Return all sequences in the MSA / Do not filter the result MSA. 68 | HHblits default: False. 69 | alt: Show up to this many alternative alignments. 70 | p: Minimum Prob for a hit to be included in the output hhr file. 71 | HHblits default: 20. 72 | z: Hard cap on number of hits reported in the hhr file. 73 | HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. 74 | 75 | Raises: 76 | RuntimeError: If HHblits binary not found within the path. 77 | """ 78 | self.binary_path = binary_path 79 | self.databases = databases 80 | 81 | for database_path in self.databases: 82 | if not glob.glob(database_path + "_*"): 83 | logging.error( 84 | "Could not find HHBlits database %s", database_path 85 | ) 86 | raise ValueError( 87 | f"Could not find HHBlits database {database_path}" 88 | ) 89 | 90 | self.n_cpu = n_cpu 91 | self.n_iter = n_iter 92 | self.e_value = e_value 93 | self.maxseq = maxseq 94 | self.realign_max = realign_max 95 | self.maxfilt = maxfilt 96 | self.min_prefilter_hits = min_prefilter_hits 97 | self.all_seqs = all_seqs 98 | self.alt = alt 99 | self.p = p 100 | self.z = z 101 | 102 | def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]: 103 | """Queries the database using HHblits.""" 104 | with utils.tmpdir_manager() as query_tmp_dir: 105 | a3m_path = os.path.join(query_tmp_dir, "output.a3m") 106 | 107 | db_cmd = [] 108 | for db_path in self.databases: 109 | db_cmd.append("-d") 110 | db_cmd.append(db_path) 111 | cmd = [ 112 | self.binary_path, 113 | "-i", 114 | input_fasta_path, 115 | "-cpu", 116 | str(self.n_cpu), 117 | "-oa3m", 118 | a3m_path, 119 | "-o", 120 | "/dev/null", 121 | "-n", 122 | str(self.n_iter), 123 | "-e", 124 | str(self.e_value), 125 | "-maxseq", 126 | str(self.maxseq), 127 | "-realign_max", 128 | str(self.realign_max), 129 | "-maxfilt", 130 | str(self.maxfilt), 131 | "-min_prefilter_hits", 132 | str(self.min_prefilter_hits), 133 | ] 134 | if self.all_seqs: 135 | cmd += ["-all"] 136 | if self.alt: 137 | cmd += ["-alt", str(self.alt)] 138 | if self.p != _HHBLITS_DEFAULT_P: 139 | cmd += ["-p", str(self.p)] 140 | if self.z != _HHBLITS_DEFAULT_Z: 141 | cmd += ["-Z", str(self.z)] 142 | cmd += db_cmd 143 | 144 | logging.info('Launching subprocess "%s"', " ".join(cmd)) 145 | process = subprocess.Popen( 146 | cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE 147 | ) 148 | 149 | with utils.timing("HHblits query"): 150 | stdout, stderr = process.communicate() 151 | retcode = process.wait() 152 | 153 | if retcode: 154 | # Logs have a 15k character limit, so log HHblits error line by line. 155 | logging.error("HHblits failed. HHblits stderr begin:") 156 | for error_line in stderr.decode("utf-8").splitlines(): 157 | if error_line.strip(): 158 | logging.error(error_line.strip()) 159 | logging.error("HHblits stderr end") 160 | raise RuntimeError( 161 | "HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n" 162 | % (stdout.decode("utf-8"), stderr[:500_000].decode("utf-8")) 163 | ) 164 | 165 | with open(a3m_path) as f: 166 | a3m = f.read() 167 | 168 | raw_output = dict( 169 | a3m=a3m, 170 | output=stdout, 171 | stderr=stderr, 172 | n_iter=self.n_iter, 173 | e_value=self.e_value, 174 | ) 175 | return [raw_output] 176 | -------------------------------------------------------------------------------- /sam/openfold/model/torchscript.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional, Sequence, Tuple 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from sam.openfold.model.dropout import ( 21 | DropoutRowwise, 22 | DropoutColumnwise, 23 | ) 24 | from sam.openfold.model.evoformer import ( 25 | EvoformerBlock, 26 | EvoformerStack, 27 | ) 28 | from sam.openfold.model.outer_product_mean import OuterProductMean 29 | from sam.openfold.model.msa import ( 30 | MSARowAttentionWithPairBias, 31 | MSAColumnAttention, 32 | MSAColumnGlobalAttention, 33 | ) 34 | from sam.openfold.model.pair_transition import PairTransition 35 | from sam.openfold.model.primitives import Attention, GlobalAttention 36 | from sam.openfold.model.structure_module import ( 37 | InvariantPointAttention, 38 | BackboneUpdate, 39 | ) 40 | from sam.openfold.model.template import TemplatePairStackBlock 41 | from sam.openfold.model.triangular_attention import ( 42 | TriangleAttentionStartingNode, 43 | TriangleAttentionEndingNode, 44 | ) 45 | from sam.openfold.model.triangular_multiplicative_update import ( 46 | TriangleMultiplicationOutgoing, 47 | TriangleMultiplicationIncoming, 48 | ) 49 | 50 | 51 | def script_preset_(model: torch.nn.Module): 52 | """ 53 | TorchScript a handful of low-level but frequently used submodule types 54 | that are known to be scriptable. 55 | 56 | Args: 57 | model: 58 | A torch.nn.Module. It should contain at least some modules from 59 | this repository, or this function won't do anything. 60 | """ 61 | script_submodules_( 62 | model, 63 | [ 64 | nn.Dropout, 65 | Attention, 66 | GlobalAttention, 67 | EvoformerBlock, 68 | #TemplatePairStackBlock, 69 | ], 70 | attempt_trace=False, 71 | batch_dims=None, 72 | ) 73 | 74 | 75 | def _get_module_device(module: torch.nn.Module) -> torch.device: 76 | """ 77 | Fetches the device of a module, assuming that all of the module's 78 | parameters reside on a single device 79 | 80 | Args: 81 | module: A torch.nn.Module 82 | Returns: 83 | The module's device 84 | """ 85 | return next(module.parameters()).device 86 | 87 | 88 | def _trace_module(module, batch_dims=None): 89 | if(batch_dims is None): 90 | batch_dims = () 91 | 92 | # Stand-in values 93 | n_seq = 10 94 | n_res = 10 95 | 96 | device = _get_module_device(module) 97 | 98 | def msa(channel_dim): 99 | return torch.rand( 100 | (*batch_dims, n_seq, n_res, channel_dim), 101 | device=device, 102 | ) 103 | 104 | def pair(channel_dim): 105 | return torch.rand( 106 | (*batch_dims, n_res, n_res, channel_dim), 107 | device=device, 108 | ) 109 | 110 | if(isinstance(module, MSARowAttentionWithPairBias)): 111 | inputs = { 112 | "forward": ( 113 | msa(module.c_in), # m 114 | pair(module.c_z), # z 115 | torch.randint( 116 | 0, 2, 117 | (*batch_dims, n_seq, n_res) 118 | ), # mask 119 | ), 120 | } 121 | elif(isinstance(module, MSAColumnAttention)): 122 | inputs = { 123 | "forward": ( 124 | msa(module.c_in), # m 125 | torch.randint( 126 | 0, 2, 127 | (*batch_dims, n_seq, n_res) 128 | ), # mask 129 | ), 130 | } 131 | elif(isinstance(module, OuterProductMean)): 132 | inputs = { 133 | "forward": ( 134 | msa(module.c_m), 135 | torch.randint( 136 | 0, 2, 137 | (*batch_dims, n_seq, n_res) 138 | ) 139 | ) 140 | } 141 | else: 142 | raise TypeError( 143 | f"tracing is not supported for modules of type {type(module)}" 144 | ) 145 | 146 | return torch.jit.trace_module(module, inputs) 147 | 148 | 149 | def _script_submodules_helper_( 150 | model, 151 | types, 152 | attempt_trace, 153 | to_trace, 154 | ): 155 | for name, child in model.named_children(): 156 | if(types is None or any(isinstance(child, t) for t in types)): 157 | try: 158 | scripted = torch.jit.script(child) 159 | setattr(model, name, scripted) 160 | continue 161 | except (RuntimeError, torch.jit.frontend.NotSupportedError) as e: 162 | if(attempt_trace): 163 | to_trace.add(type(child)) 164 | else: 165 | raise e 166 | 167 | _script_submodules_helper_(child, types, attempt_trace, to_trace) 168 | 169 | 170 | def _trace_submodules_( 171 | model, 172 | types, 173 | batch_dims=None, 174 | ): 175 | for name, child in model.named_children(): 176 | if(any(isinstance(child, t) for t in types)): 177 | traced = _trace_module(child, batch_dims=batch_dims) 178 | setattr(model, name, traced) 179 | else: 180 | _trace_submodules_(child, types, batch_dims=batch_dims) 181 | 182 | 183 | def script_submodules_( 184 | model: nn.Module, 185 | types: Optional[Sequence[type]] = None, 186 | attempt_trace: Optional[bool] = True, 187 | batch_dims: Optional[Tuple[int]] = None, 188 | ): 189 | """ 190 | Convert all submodules whose types match one of those in the input 191 | list to recursively scripted equivalents in place. To script the entire 192 | model, just call torch.jit.script on it directly. 193 | 194 | When types is None, all submodules are scripted. 195 | 196 | Args: 197 | model: 198 | A torch.nn.Module 199 | types: 200 | A list of types of submodules to script 201 | attempt_trace: 202 | Whether to attempt to trace specified modules if scripting 203 | fails. Recall that tracing eliminates all conditional 204 | logic---with great tracing comes the mild responsibility of 205 | having to remember to ensure that the modules in question 206 | perform the same computations no matter what. 207 | """ 208 | to_trace = set() 209 | 210 | # Aggressively script as much as possible first... 211 | _script_submodules_helper_(model, types, attempt_trace, to_trace) 212 | 213 | # ... and then trace stragglers. 214 | if(attempt_trace and len(to_trace) > 0): 215 | _trace_submodules_(model, to_trace, batch_dims=batch_dims) 216 | -------------------------------------------------------------------------------- /sam/openfold/data/input_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 AlQuraishi Laboratory 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import random 17 | 18 | import torch 19 | 20 | from sam.openfold.data import data_transforms 21 | 22 | 23 | def nonensembled_transform_fns(common_cfg, mode_cfg): 24 | """Input pipeline data transformers that are not ensembled.""" 25 | transforms = [ 26 | data_transforms.cast_to_64bit_ints, 27 | data_transforms.correct_msa_restypes, 28 | data_transforms.squeeze_features, 29 | data_transforms.randomly_replace_msa_with_unknown(0.0), 30 | data_transforms.make_seq_mask, 31 | data_transforms.make_msa_mask, 32 | data_transforms.make_hhblits_profile, 33 | ] 34 | if common_cfg.use_templates: 35 | transforms.extend( 36 | [ 37 | data_transforms.fix_templates_aatype, 38 | data_transforms.make_template_mask, 39 | data_transforms.make_pseudo_beta("template_"), 40 | ] 41 | ) 42 | if common_cfg.use_template_torsion_angles: 43 | transforms.extend( 44 | [ 45 | data_transforms.atom37_to_torsion_angles("template_"), 46 | ] 47 | ) 48 | 49 | transforms.extend( 50 | [ 51 | data_transforms.make_atom14_masks, 52 | ] 53 | ) 54 | 55 | if mode_cfg.supervised: 56 | transforms.extend( 57 | [ 58 | data_transforms.make_atom14_positions, 59 | data_transforms.atom37_to_frames, 60 | data_transforms.atom37_to_torsion_angles(""), 61 | data_transforms.make_pseudo_beta(""), 62 | data_transforms.get_backbone_frames, 63 | data_transforms.get_chi_angles, 64 | ] 65 | ) 66 | 67 | return transforms 68 | 69 | 70 | def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): 71 | """Input pipeline data transformers that can be ensembled and averaged.""" 72 | transforms = [] 73 | 74 | if mode_cfg.block_delete_msa: 75 | transforms.append(data_transforms.block_delete_msa(common_cfg.block_delete_msa)) 76 | 77 | if "max_distillation_msa_clusters" in mode_cfg: 78 | transforms.append( 79 | data_transforms.sample_msa_distillation( 80 | mode_cfg.max_distillation_msa_clusters 81 | ) 82 | ) 83 | 84 | if common_cfg.reduce_msa_clusters_by_max_templates: 85 | pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates 86 | else: 87 | pad_msa_clusters = mode_cfg.max_msa_clusters 88 | 89 | max_msa_clusters = pad_msa_clusters 90 | max_extra_msa = mode_cfg.max_extra_msa 91 | 92 | msa_seed = None 93 | if(not common_cfg.resample_msa_in_recycling): 94 | msa_seed = ensemble_seed 95 | 96 | transforms.append( 97 | data_transforms.sample_msa( 98 | max_msa_clusters, 99 | keep_extra=True, 100 | seed=msa_seed, 101 | ) 102 | ) 103 | 104 | if "masked_msa" in common_cfg: 105 | # Masked MSA should come *before* MSA clustering so that 106 | # the clustering and full MSA profile do not leak information about 107 | # the masked locations and secret corrupted locations. 108 | transforms.append( 109 | data_transforms.make_masked_msa( 110 | common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction, 111 | seed=(msa_seed + 1) if msa_seed else None, 112 | ) 113 | ) 114 | 115 | if common_cfg.msa_cluster_features: 116 | transforms.append(data_transforms.nearest_neighbor_clusters()) 117 | transforms.append(data_transforms.summarize_clusters()) 118 | 119 | # Crop after creating the cluster profiles. 120 | if max_extra_msa: 121 | transforms.append(data_transforms.crop_extra_msa(max_extra_msa)) 122 | else: 123 | transforms.append(data_transforms.delete_extra_msa) 124 | 125 | transforms.append(data_transforms.make_msa_feat()) 126 | 127 | crop_feats = dict(common_cfg.feat) 128 | 129 | if mode_cfg.fixed_size: 130 | transforms.append(data_transforms.select_feat(list(crop_feats))) 131 | transforms.append( 132 | data_transforms.random_crop_to_size( 133 | mode_cfg.crop_size, 134 | mode_cfg.max_templates, 135 | crop_feats, 136 | mode_cfg.subsample_templates, 137 | seed=ensemble_seed + 1, 138 | ) 139 | ) 140 | transforms.append( 141 | data_transforms.make_fixed_size( 142 | crop_feats, 143 | pad_msa_clusters, 144 | mode_cfg.max_extra_msa, 145 | mode_cfg.crop_size, 146 | mode_cfg.max_templates, 147 | ) 148 | ) 149 | else: 150 | transforms.append( 151 | data_transforms.crop_templates(mode_cfg.max_templates) 152 | ) 153 | 154 | return transforms 155 | 156 | 157 | def process_tensors_from_config(tensors, common_cfg, mode_cfg): 158 | """Based on the config, apply filters and transformations to the data.""" 159 | 160 | ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max) 161 | 162 | def wrap_ensemble_fn(data, i): 163 | """Function to be mapped over the ensemble dimension.""" 164 | d = data.copy() 165 | fns = ensembled_transform_fns( 166 | common_cfg, 167 | mode_cfg, 168 | ensemble_seed, 169 | ) 170 | fn = compose(fns) 171 | d["ensemble_index"] = i 172 | return fn(d) 173 | 174 | no_templates = True 175 | if("template_aatype" in tensors): 176 | no_templates = tensors["template_aatype"].shape[0] == 0 177 | 178 | nonensembled = nonensembled_transform_fns( 179 | common_cfg, 180 | mode_cfg, 181 | ) 182 | 183 | tensors = compose(nonensembled)(tensors) 184 | 185 | if("no_recycling_iters" in tensors): 186 | num_recycling = int(tensors["no_recycling_iters"]) 187 | else: 188 | num_recycling = common_cfg.max_recycling_iters 189 | 190 | tensors = map_fn( 191 | lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1) 192 | ) 193 | 194 | return tensors 195 | 196 | 197 | @data_transforms.curry1 198 | def compose(x, fs): 199 | for f in fs: 200 | x = f(x) 201 | return x 202 | 203 | 204 | def map_fn(fun, x): 205 | ensembles = [fun(elem) for elem in x] 206 | features = ensembles[0].keys() 207 | ensembled_dict = {} 208 | for feat in features: 209 | ensembled_dict[feat] = torch.stack( 210 | [dict_i[feat] for dict_i in ensembles], dim=-1 211 | ) 212 | return ensembled_dict 213 | -------------------------------------------------------------------------------- /data/splits/mdcath/test.txt: -------------------------------------------------------------------------------- 1 | 1cb8A03.320 2 | 1cb8A03.348 3 | 1cb8A03.379 4 | 1cb8A03.413 5 | 1cb8A03.450 6 | 1cjyA01.320 7 | 1cjyA01.348 8 | 1cjyA01.379 9 | 1cjyA01.413 10 | 1cjyA01.450 11 | 1cp2A00.320 12 | 1cp2A00.348 13 | 1cp2A00.379 14 | 1cp2A00.413 15 | 1cp2A00.450 16 | 1djqA01.320 17 | 1djqA01.348 18 | 1djqA01.379 19 | 1djqA01.413 20 | 1djqA01.450 21 | 1e20A00.320 22 | 1e20A00.348 23 | 1e20A00.379 24 | 1e20A00.413 25 | 1e20A00.450 26 | 1ei6A02.320 27 | 1ei6A02.348 28 | 1ei6A02.379 29 | 1ei6A02.413 30 | 1ei6A02.450 31 | 1fb1A01.320 32 | 1fb1A01.348 33 | 1fb1A01.379 34 | 1fb1A01.413 35 | 1fb1A01.450 36 | 1ffsB00.320 37 | 1ffsB00.348 38 | 1ffsB00.379 39 | 1ffsB00.413 40 | 1ffsB00.450 41 | 1gaxA04.320 42 | 1gaxA04.348 43 | 1gaxA04.379 44 | 1gaxA04.413 45 | 1gaxA04.450 46 | 1gyxA01.320 47 | 1gyxA01.348 48 | 1gyxA01.379 49 | 1gyxA01.413 50 | 1gyxA01.450 51 | 1hkvA02.320 52 | 1hkvA02.348 53 | 1hkvA02.379 54 | 1hkvA02.413 55 | 1hkvA02.450 56 | 1i2tA00.320 57 | 1i2tA00.348 58 | 1i2tA00.379 59 | 1i2tA00.413 60 | 1i2tA00.450 61 | 1ileA03.320 62 | 1ileA03.348 63 | 1ileA03.379 64 | 1ileA03.413 65 | 1ileA03.450 66 | 1n7oA03.320 67 | 1n7oA03.348 68 | 1n7oA03.379 69 | 1n7oA03.413 70 | 1n7oA03.450 71 | 1ngmJ00.320 72 | 1ngmJ00.348 73 | 1ngmJ00.379 74 | 1ngmJ00.413 75 | 1ngmJ00.450 76 | 1nhwA00.320 77 | 1nhwA00.348 78 | 1nhwA00.379 79 | 1nhwA00.413 80 | 1nhwA00.450 81 | 1nijA01.320 82 | 1nijA01.348 83 | 1nijA01.379 84 | 1nijA01.413 85 | 1nijA01.450 86 | 1np7B02.320 87 | 1np7B02.348 88 | 1np7B02.379 89 | 1np7B02.413 90 | 1np7B02.450 91 | 1or7C00.320 92 | 1or7C00.348 93 | 1or7C00.379 94 | 1or7C00.413 95 | 1or7C00.450 96 | 1p90A00.320 97 | 1p90A00.348 98 | 1p90A00.379 99 | 1p90A00.413 100 | 1p90A00.450 101 | 1pseA00.320 102 | 1pseA00.348 103 | 1pseA00.379 104 | 1pseA00.413 105 | 1pseA00.450 106 | 1pwuA04.320 107 | 1pwuA04.348 108 | 1pwuA04.379 109 | 1pwuA04.413 110 | 1pwuA04.450 111 | 1pyaA00.320 112 | 1pyaA00.348 113 | 1pyaA00.379 114 | 1pyaA00.413 115 | 1pyaA00.450 116 | 1qsaA01.320 117 | 1qsaA01.348 118 | 1qsaA01.379 119 | 1qsaA01.413 120 | 1qsaA01.450 121 | 1qwjB00.320 122 | 1qwjB00.348 123 | 1qwjB00.379 124 | 1qwjB00.413 125 | 1qwjB00.450 126 | 1qzgA00.320 127 | 1qzgA00.348 128 | 1qzgA00.379 129 | 1qzgA00.413 130 | 1qzgA00.450 131 | 1r17A02.320 132 | 1r17A02.348 133 | 1r17A02.379 134 | 1r17A02.413 135 | 1r17A02.450 136 | 1rocA00.320 137 | 1rocA00.348 138 | 1rocA00.379 139 | 1rocA00.413 140 | 1rocA00.450 141 | 1sznA02.320 142 | 1sznA02.348 143 | 1sznA02.379 144 | 1sznA02.413 145 | 1sznA02.450 146 | 1tjoB00.320 147 | 1tjoB00.348 148 | 1tjoB00.379 149 | 1tjoB00.413 150 | 1tjoB00.450 151 | 1ugoA00.320 152 | 1ugoA00.348 153 | 1ugoA00.379 154 | 1ugoA00.413 155 | 1ugoA00.450 156 | 1vb3A01.320 157 | 1vb3A01.348 158 | 1vb3A01.379 159 | 1vb3A01.413 160 | 1vb3A01.450 161 | 1vf5B02.320 162 | 1vf5B02.348 163 | 1vf5B02.379 164 | 1vf5B02.413 165 | 1vf5B02.450 166 | 1wexA01.320 167 | 1wexA01.348 168 | 1wexA01.379 169 | 1wexA01.413 170 | 1wexA01.450 171 | 1zxfA00.320 172 | 1zxfA00.348 173 | 1zxfA00.379 174 | 1zxfA00.413 175 | 1zxfA00.450 176 | 2a06B02.320 177 | 2a06B02.348 178 | 2a06B02.379 179 | 2a06B02.413 180 | 2a06B02.450 181 | 2de3A02.320 182 | 2de3A02.348 183 | 2de3A02.379 184 | 2de3A02.413 185 | 2de3A02.450 186 | 2dgmA02.320 187 | 2dgmA02.348 188 | 2dgmA02.379 189 | 2dgmA02.413 190 | 2dgmA02.450 191 | 2dixA01.320 192 | 2dixA01.348 193 | 2dixA01.379 194 | 2dixA01.413 195 | 2dixA01.450 196 | 2e0nB02.320 197 | 2e0nB02.348 198 | 2e0nB02.379 199 | 2e0nB02.413 200 | 2e0nB02.450 201 | 2e3tB03.320 202 | 2e3tB03.348 203 | 2e3tB03.379 204 | 2e3tB03.413 205 | 2e3tB03.450 206 | 2e63A00.320 207 | 2e63A00.348 208 | 2e63A00.379 209 | 2e63A00.413 210 | 2e63A00.450 211 | 2fm7A00.320 212 | 2fm7A00.348 213 | 2fm7A00.379 214 | 2fm7A00.413 215 | 2fm7A00.450 216 | 2ga1A02.320 217 | 2ga1A02.348 218 | 2ga1A02.379 219 | 2ga1A02.413 220 | 2ga1A02.450 221 | 2kjrA01.320 222 | 2kjrA01.348 223 | 2kjrA01.379 224 | 2kjrA01.413 225 | 2kjrA01.450 226 | 2l23A00.320 227 | 2l23A00.348 228 | 2l23A00.379 229 | 2l23A00.413 230 | 2l23A00.450 231 | 2l3lA01.320 232 | 2l3lA01.348 233 | 2l3lA01.379 234 | 2l3lA01.413 235 | 2l3lA01.450 236 | 2l7kA00.320 237 | 2l7kA00.348 238 | 2l7kA00.379 239 | 2l7kA00.413 240 | 2l7kA00.450 241 | 2lcqA01.320 242 | 2lcqA01.348 243 | 2lcqA01.379 244 | 2lcqA01.413 245 | 2lcqA01.450 246 | 2okmA00.320 247 | 2okmA00.348 248 | 2okmA00.379 249 | 2okmA00.413 250 | 2okmA00.450 251 | 2vy2A00.320 252 | 2vy2A00.348 253 | 2vy2A00.379 254 | 2vy2A00.413 255 | 2vy2A00.450 256 | 2wdqC00.320 257 | 2wdqC00.348 258 | 2wdqC00.379 259 | 2wdqC00.413 260 | 2wdqC00.450 261 | 2xtqA01.320 262 | 2xtqA01.348 263 | 2xtqA01.379 264 | 2xtqA01.413 265 | 2xtqA01.450 266 | 2xwpA02.320 267 | 2xwpA02.348 268 | 2xwpA02.379 269 | 2xwpA02.413 270 | 2xwpA02.450 271 | 2z84A00.320 272 | 2z84A00.348 273 | 2z84A00.379 274 | 2z84A00.413 275 | 2z84A00.450 276 | 3af5A02.320 277 | 3af5A02.348 278 | 3af5A02.379 279 | 3af5A02.413 280 | 3af5A02.450 281 | 3ahcA01.320 282 | 3ahcA01.348 283 | 3ahcA01.379 284 | 3ahcA01.413 285 | 3ahcA01.450 286 | 3ajdA01.320 287 | 3ajdA01.348 288 | 3ajdA01.379 289 | 3ajdA01.413 290 | 3ajdA01.450 291 | 3an1B07.320 292 | 3an1B07.348 293 | 3an1B07.379 294 | 3an1B07.413 295 | 3an1B07.450 296 | 3e6jA00.320 297 | 3e6jA00.348 298 | 3e6jA00.379 299 | 3e6jA00.413 300 | 3e6jA00.450 301 | 3er0A02.320 302 | 3er0A02.348 303 | 3er0A02.379 304 | 3er0A02.413 305 | 3er0A02.450 306 | 3ik4A02.320 307 | 3ik4A02.348 308 | 3ik4A02.379 309 | 3ik4A02.413 310 | 3ik4A02.450 311 | 3mudB02.320 312 | 3mudB02.348 313 | 3mudB02.379 314 | 3mudB02.413 315 | 3mudB02.450 316 | 3n6rA03.320 317 | 3n6rA03.348 318 | 3n6rA03.379 319 | 3n6rA03.413 320 | 3n6rA03.450 321 | 3nb0A03.320 322 | 3nb0A03.348 323 | 3nb0A03.379 324 | 3nb0A03.413 325 | 3nb0A03.450 326 | 3nb2A04.320 327 | 3nb2A04.348 328 | 3nb2A04.379 329 | 3nb2A04.413 330 | 3nb2A04.450 331 | 3p7lA03.320 332 | 3p7lA03.348 333 | 3p7lA03.379 334 | 3p7lA03.413 335 | 3p7lA03.450 336 | 3r1kA02.320 337 | 3r1kA02.348 338 | 3r1kA02.379 339 | 3r1kA02.413 340 | 3r1kA02.450 341 | 3slpA00.320 342 | 3slpA00.348 343 | 3slpA00.379 344 | 3slpA00.413 345 | 3slpA00.450 346 | 3tk8C02.320 347 | 3tk8C02.348 348 | 3tk8C02.379 349 | 3tk8C02.413 350 | 3tk8C02.450 351 | 3u28C00.320 352 | 3u28C00.348 353 | 3u28C00.379 354 | 3u28C00.413 355 | 3u28C00.450 356 | 3v7oA00.320 357 | 3v7oA00.348 358 | 3v7oA00.379 359 | 3v7oA00.413 360 | 3v7oA00.450 361 | 3vk8A02.320 362 | 3vk8A02.348 363 | 3vk8A02.379 364 | 3vk8A02.413 365 | 3vk8A02.450 366 | 3zrhA01.320 367 | 3zrhA01.348 368 | 3zrhA01.379 369 | 3zrhA01.413 370 | 3zrhA01.450 371 | 4a57A02.320 372 | 4a57A02.348 373 | 4a57A02.379 374 | 4a57A02.413 375 | 4a57A02.450 376 | 4b96A00.320 377 | 4b96A00.348 378 | 4b96A00.379 379 | 4b96A00.413 380 | 4b96A00.450 381 | 4c23B01.320 382 | 4c23B01.348 383 | 4c23B01.379 384 | 4c23B01.413 385 | 4c23B01.450 386 | 4feiA00.320 387 | 4feiA00.348 388 | 4feiA00.379 389 | 4feiA00.413 390 | 4feiA00.450 391 | 4gvbA00.320 392 | 4gvbA00.348 393 | 4gvbA00.379 394 | 4gvbA00.413 395 | 4gvbA00.450 396 | 4kbuK02.320 397 | 4kbuK02.348 398 | 4kbuK02.379 399 | 4kbuK02.413 400 | 4kbuK02.450 401 | 4kmcA00.320 402 | 4kmcA00.348 403 | 4kmcA00.379 404 | 4kmcA00.413 405 | 4kmcA00.450 406 | 4labA00.320 407 | 4labA00.348 408 | 4labA00.379 409 | 4labA00.413 410 | 4labA00.450 411 | 4qbuA03.320 412 | 4qbuA03.348 413 | 4qbuA03.379 414 | 4qbuA03.413 415 | 4qbuA03.450 416 | 4r9iA02.320 417 | 4r9iA02.348 418 | 4r9iA02.379 419 | 4r9iA02.413 420 | 4r9iA02.450 421 | 4u8uM02.320 422 | 4u8uM02.348 423 | 4u8uM02.379 424 | 4u8uM02.413 425 | 4u8uM02.450 426 | 4uuwB02.320 427 | 4uuwB02.348 428 | 4uuwB02.379 429 | 4uuwB02.413 430 | 4uuwB02.450 431 | 4wj4A03.320 432 | 4wj4A03.348 433 | 4wj4A03.379 434 | 4wj4A03.413 435 | 4wj4A03.450 436 | 4zn3A02.320 437 | 4zn3A02.348 438 | 4zn3A02.379 439 | 4zn3A02.413 440 | 4zn3A02.450 441 | 5a1mA00.320 442 | 5a1mA00.348 443 | 5a1mA00.379 444 | 5a1mA00.413 445 | 5a1mA00.450 446 | 5f33A03.320 447 | 5f33A03.348 448 | 5f33A03.379 449 | 5f33A03.413 450 | 5f33A03.450 451 | -------------------------------------------------------------------------------- /sam/openfold/utils/geometry/rotation_matrix.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Rot3Array Matrix Class.""" 15 | 16 | from __future__ import annotations 17 | import dataclasses 18 | from typing import List 19 | 20 | import torch 21 | 22 | from sam.openfold.utils.geometry import utils 23 | from sam.openfold.utils.geometry import vector 24 | from sam.openfold.utils.tensor_utils import tensor_tree_map 25 | 26 | 27 | COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz'] 28 | 29 | @dataclasses.dataclass(frozen=True) 30 | class Rot3Array: 31 | """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays.""" 32 | xx: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32}) 33 | xy: torch.Tensor 34 | xz: torch.Tensor 35 | yx: torch.Tensor 36 | yy: torch.Tensor 37 | yz: torch.Tensor 38 | zx: torch.Tensor 39 | zy: torch.Tensor 40 | zz: torch.Tensor 41 | 42 | __array_ufunc__ = None 43 | 44 | def __getitem__(self, index): 45 | field_names = utils.get_field_names(Rot3Array) 46 | return Rot3Array( 47 | **{ 48 | name: getattr(self, name)[index] 49 | for name in field_names 50 | } 51 | ) 52 | 53 | def __mul__(self, other: torch.Tensor): 54 | field_names = utils.get_field_names(Rot3Array) 55 | return Rot3Array( 56 | **{ 57 | name: getattr(self, name) * other 58 | for name in field_names 59 | } 60 | ) 61 | 62 | def __matmul__(self, other: Rot3Array) -> Rot3Array: 63 | """Composes two Rot3Arrays.""" 64 | c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx)) 65 | c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy)) 66 | c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz)) 67 | return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) 68 | 69 | def map_tensor_fn(self, fn) -> Rot3Array: 70 | field_names = utils.get_field_names(Rot3Array) 71 | return Rot3Array( 72 | **{ 73 | name: fn(getattr(self, name)) 74 | for name in field_names 75 | } 76 | ) 77 | 78 | def inverse(self) -> Rot3Array: 79 | """Returns inverse of Rot3Array.""" 80 | return Rot3Array( 81 | self.xx, self.yx, self.zx, 82 | self.xy, self.yy, self.zy, 83 | self.xz, self.yz, self.zz 84 | ) 85 | 86 | def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: 87 | """Applies Rot3Array to point.""" 88 | return vector.Vec3Array( 89 | self.xx * point.x + self.xy * point.y + self.xz * point.z, 90 | self.yx * point.x + self.yy * point.y + self.yz * point.z, 91 | self.zx * point.x + self.zy * point.y + self.zz * point.z 92 | ) 93 | 94 | def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: 95 | """Applies inverse Rot3Array to point.""" 96 | return self.inverse().apply_to_point(point) 97 | 98 | 99 | def unsqueeze(self, dim: int): 100 | return Rot3Array( 101 | *tensor_tree_map( 102 | lambda t: t.unsqueeze(dim), 103 | [getattr(self, c) for c in COMPONENTS] 104 | ) 105 | ) 106 | 107 | def stop_gradient(self) -> Rot3Array: 108 | return Rot3Array( 109 | *[getattr(self, c).detach() for c in COMPONENTS] 110 | ) 111 | 112 | @classmethod 113 | def identity(cls, shape, device) -> Rot3Array: 114 | """Returns identity of given shape.""" 115 | ones = torch.ones(shape, dtype=torch.float32, device=device) 116 | zeros = torch.zeros(shape, dtype=torch.float32, device=device) 117 | return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) 118 | 119 | @classmethod 120 | def from_two_vectors( 121 | cls, e0: vector.Vec3Array, 122 | e1: vector.Vec3Array 123 | ) -> Rot3Array: 124 | """Construct Rot3Array from two Vectors. 125 | 126 | Rot3Array is constructed such that in the corresponding frame 'e0' lies on 127 | the positive x-Axis and 'e1' lies in the xy plane with positive sign of y. 128 | 129 | Args: 130 | e0: Vector 131 | e1: Vector 132 | Returns: 133 | Rot3Array 134 | """ 135 | # Normalize the unit vector for the x-axis, e0. 136 | e0 = e0.normalized() 137 | # make e1 perpendicular to e0. 138 | c = e1.dot(e0) 139 | e1 = (e1 - c * e0).normalized() 140 | # Compute e2 as cross product of e0 and e1. 141 | e2 = e0.cross(e1) 142 | return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) 143 | 144 | @classmethod 145 | def from_array(cls, array: torch.Tensor) -> Rot3Array: 146 | """Construct Rot3Array Matrix from array of shape. [..., 3, 3].""" 147 | rows = torch.unbind(array, dim=-2) 148 | rc = [torch.unbind(e, dim=-1) for e in rows] 149 | return cls(*[e for row in rc for e in row]) 150 | 151 | def to_tensor(self) -> torch.Tensor: 152 | """Convert Rot3Array to array of shape [..., 3, 3].""" 153 | return torch.stack( 154 | [ 155 | torch.stack([self.xx, self.xy, self.xz], dim=-1), 156 | torch.stack([self.yx, self.yy, self.yz], dim=-1), 157 | torch.stack([self.zx, self.zy, self.zz], dim=-1) 158 | ], 159 | dim=-2 160 | ) 161 | 162 | @classmethod 163 | def from_quaternion(cls, 164 | w: torch.Tensor, 165 | x: torch.Tensor, 166 | y: torch.Tensor, 167 | z: torch.Tensor, 168 | normalize: bool = True, 169 | eps: float = 1e-6 170 | ) -> Rot3Array: 171 | """Construct Rot3Array from components of quaternion.""" 172 | if normalize: 173 | inv_norm = torch.rsqrt(torch.clamp(w**2 + x**2 + y**2 + z**2, min=eps)) 174 | w = w * inv_norm 175 | x = x * inv_norm 176 | y = y * inv_norm 177 | z = z * inv_norm 178 | xx = 1.0 - 2.0 * (y ** 2 + z ** 2) 179 | xy = 2.0 * (x * y - w * z) 180 | xz = 2.0 * (x * z + w * y) 181 | yx = 2.0 * (x * y + w * z) 182 | yy = 1.0 - 2.0 * (x ** 2 + z ** 2) 183 | yz = 2.0 * (y * z - w * x) 184 | zx = 2.0 * (x * z - w * y) 185 | zy = 2.0 * (y * z + w * x) 186 | zz = 1.0 - 2.0 * (x ** 2 + y ** 2) 187 | return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) 188 | 189 | def reshape(self, new_shape): 190 | field_names = utils.get_field_names(Rot3Array) 191 | reshape_fn = lambda t: t.reshape(new_shape) 192 | return Rot3Array( 193 | **{ 194 | name: reshape_fn(getattr(self, name)) 195 | for name in field_names 196 | } 197 | ) 198 | 199 | @classmethod 200 | def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array: 201 | field_names = utils.get_field_names(Rot3Array) 202 | cat_fn = lambda l: torch.cat(l, dim=dim) 203 | return cls( 204 | **{ 205 | name: cat_fn([getattr(r, name) for r in rots]) 206 | for name in field_names 207 | } 208 | ) 209 | -------------------------------------------------------------------------------- /sam/coords.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def calc_dmap(xyz, epsilon=1e-12, backend="torch"): 6 | if backend == "torch": 7 | B = torch 8 | elif backend == "numpy": 9 | B = np 10 | else: 11 | raise KeyError(backend) 12 | if len(xyz.shape) == 2: 13 | if xyz.shape[1] != 3: 14 | raise ValueError(xyz.shape) 15 | elif len(xyz.shape) == 3: 16 | if xyz.shape[2] != 3: 17 | raise ValueError(xyz.shape) 18 | else: 19 | raise ValueError(xyz.shape) 20 | if len(xyz.shape) == 3: 21 | dmap = B.sqrt( 22 | B.sum( 23 | B.square(xyz[:,None,:,:] - xyz[:,:,None,:]), 24 | axis=3) + epsilon) 25 | exp_dim = 1 26 | else: 27 | dmap = B.sqrt( 28 | B.sum( 29 | B.square(xyz[None,:,:] - xyz[:,None,:]), 30 | axis=2) + epsilon) 31 | exp_dim = 0 32 | if backend == "torch": 33 | return dmap.unsqueeze(exp_dim) 34 | elif backend == "numpy": 35 | return np.expand_dims(dmap, exp_dim) 36 | else: 37 | raise KeyError(backend) 38 | 39 | 40 | def calc_dmap_triu(input_data, offset=1, epsilon=1e-12, backend="torch"): 41 | # Check the shape. 42 | if len(input_data.shape) == 2: 43 | if input_data.shape[1] != 3: 44 | raise ValueError(input_data.shape) 45 | dmap = calc_dmap(input_data, epsilon, backend) 46 | elif len(input_data.shape) == 3: 47 | if input_data.shape[2] != 3: 48 | raise ValueError(input_data.shape) 49 | dmap = calc_dmap(input_data, epsilon, backend) 50 | elif len(input_data.shape) == 4: 51 | if input_data.shape[1] != 1: 52 | raise ValueError(input_data.shape) 53 | if input_data.shape[2] != input_data.shape[3]: 54 | raise ValueError(input_data.shape) 55 | dmap = input_data 56 | else: 57 | raise ValueError(input_data.shape) 58 | # Get the triu ids. 59 | l = dmap.shape[2] 60 | if backend == "torch": 61 | triu_ids = torch.triu_indices(l, l, offset=offset) 62 | elif backend == "numpy": 63 | triu_ids = np.triu_indices(l, k=offset) 64 | else: 65 | raise KeyError(backend) 66 | # Returns the values. 67 | if len(input_data.shape) != 2: 68 | return dmap[:,0,triu_ids[0],triu_ids[1]] 69 | else: 70 | return dmap[0,triu_ids[0],triu_ids[1]] 71 | 72 | 73 | def torch_chain_dihedrals(xyz, backend="torch"): 74 | if backend == "torch": 75 | r_sel = xyz 76 | elif backend == "numpy": 77 | r_sel = torch.tensor(xyz) 78 | else: 79 | raise KeyError(backend) 80 | b0 = -(r_sel[:,1:-2,:] - r_sel[:,0:-3,:]) 81 | b1 = r_sel[:,2:-1,:] - r_sel[:,1:-2,:] 82 | b2 = r_sel[:,3:,:] - r_sel[:,2:-1,:] 83 | b0xb1 = torch.cross(b0, b1, dim=2) ### 84 | b1xb2 = torch.cross(b2, b1, dim=2) ### 85 | b0xb1_x_b1xb2 = torch.cross(b0xb1, b1xb2, dim=2) ### 86 | y = torch.sum(b0xb1_x_b1xb2*b1, axis=2)*(1.0/torch.linalg.norm(b1, dim=2)) 87 | x = torch.sum(b0xb1*b1xb2, axis=2) 88 | dh_vals = torch.atan2(y, x) 89 | return dh_vals 90 | 91 | 92 | def calc_chain_bond_angles(xyz, backend="numpy"): 93 | ids = np.array([[i, i+1, i+2] for i in range(xyz.shape[1]-2)]) 94 | return calc_angles(xyz, ids, backend=backend) 95 | 96 | 97 | def calc_angles(xyz, angle_indices, backend="numpy"): 98 | if backend == "numpy": 99 | B = np 100 | elif backend == "torch": 101 | B = torch 102 | else: 103 | raise KeyError(backend) 104 | 105 | ix01 = angle_indices[:, [1, 0]] 106 | ix21 = angle_indices[:, [1, 2]] 107 | 108 | u_prime = xyz[:,ix01[:,1]]-xyz[:,ix01[:,0]] 109 | v_prime = xyz[:,ix21[:,1]]-xyz[:,ix01[:,0]] 110 | u_norm = B.sqrt((u_prime**2).sum(-1)) 111 | v_norm = B.sqrt((v_prime**2).sum(-1)) 112 | 113 | # adding a new axis makes sure that broasting rules kick in on the third 114 | # dimension 115 | u = u_prime / (u_norm[..., None]) 116 | v = v_prime / (v_norm[..., None]) 117 | 118 | return B.arccos((u * v).sum(-1)) 119 | 120 | 121 | def calc_torsion(A, B, C, D, dim=2): 122 | b0 = -(B - A) 123 | b1 = C - B 124 | b2 = D - C 125 | b0xb1 = torch.cross(b0, b1, dim=dim) ## 126 | b1xb2 = torch.cross(b2, b1, dim=dim) ## 127 | b0xb1_x_b1xb2 = torch.cross(b0xb1, b1xb2, dim=dim) ## 128 | y = torch.sum(b0xb1_x_b1xb2*b1, axis=dim)*(1.0/torch.linalg.norm(b1, dim=dim)) 129 | x = torch.sum(b0xb1*b1xb2, axis=dim) 130 | angle = torch.atan2(y, x) 131 | return angle 132 | 133 | 134 | def compute_rg(xyz): 135 | """ 136 | Adapted from the mdtraj library: https://github.com/mdtraj/mdtraj. 137 | """ 138 | num_atoms = xyz.shape[1] 139 | masses = np.ones(num_atoms) 140 | weights = masses / masses.sum() 141 | mu = xyz.mean(1) 142 | centered = (xyz.transpose((1, 0, 2)) - mu).transpose((1, 0, 2)) 143 | squared_dists = (centered ** 2).sum(2) 144 | Rg = (squared_dists * weights).sum(1) ** 0.5 145 | return Rg 146 | 147 | 148 | def sample_data(data, n_samples, backend="numpy"): 149 | if backend in ("numpy", "torch"): 150 | if n_samples is not None: 151 | ids = np.random.choice(data.shape[0], 152 | n_samples, 153 | replace=data.shape[0] < n_samples) 154 | return data[ids] 155 | else: 156 | return data 157 | else: 158 | raise KeyError(backend) 159 | 160 | 161 | def get_edge_dmap(xyz, batch, epsilon=1e-12): 162 | row, col = batch.nr_edge_index 163 | dmap = torch.sqrt( 164 | torch.sum( 165 | torch.square(xyz[row] - xyz[col]), 166 | axis=1) + epsilon) 167 | return dmap 168 | 169 | 170 | def calc_bond_len(x, eps=1e-6): 171 | return torch.sqrt(torch.sum(torch.square(x[:,:-1,:] - x[:,1:,:]), dim=2) + eps) 172 | 173 | 174 | def calc_com_traj(positions, atom14_gt_exists, mult=0.1): 175 | """ 176 | TODO: legacy function, its name is wrong. 177 | """ 178 | return calc_scen_pos(positions, atom14_gt_exists, mult) 179 | 180 | def calc_scen_pos(positions, atom14_gt_exists, mult=0.1): 181 | """ 182 | Calculate side chain centroid positions from OpenFold data. 183 | """ 184 | # Get positions. 185 | pos = positions[:,:,4:,:]*mult # Side-chain atoms positions. 186 | ca_pos = positions[:,:,1,:]*mult # CA positions. 187 | mask = atom14_gt_exists[:,:,4:].unsqueeze(-1) # Check if side-chain atoms exists. 188 | gly_mask = 1-atom14_gt_exists[:,:,4].unsqueeze(-1) # Check if CB atoms exist. 189 | mask = mask.to(dtype=pos.dtype) 190 | gly_mask = gly_mask.to(dtype=pos.dtype) 191 | # Compute centroid. 192 | com_pos = (pos.sum(dim=2) + ca_pos*gly_mask)/(mask.sum(dim=2) + gly_mask) 193 | return com_pos 194 | 195 | 196 | def calc_aa_cen_traj(positions, atom14_gt_exists, mult=0.1): 197 | # Get every atom positions. 198 | pos = positions*mult 199 | mask = atom14_gt_exists.unsqueeze(-1) 200 | # Calculate centroid. 201 | cen_pos = pos.sum(dim=2) / mask.sum(dim=2) 202 | return cen_pos 203 | 204 | 205 | def calc_bb_cen_traj(positions, atom14_gt_exists, mult=0.1): 206 | # Get backbone positions. 207 | pos = positions[:,:,0:4,:]*mult 208 | mask = atom14_gt_exists[:,:,0:4].unsqueeze(-1) 209 | # Calculate centroid. 210 | cen_pos = pos.sum(dim=2) / mask.sum(dim=2) 211 | return cen_pos 212 | 213 | def calc_nb_dist(xyz, get_triu_ids=False, offset=3, eps=1e-9): 214 | triu_ids = torch.triu_indices(xyz.shape[1], xyz.shape[1], offset=offset) 215 | # dmap_triu = torch.sqrt( 216 | # torch.square( 217 | # xyz[:,triu_ids[0]] - xyz[:,triu_ids[1]] 218 | # ).sum(axis=-1) + eps 219 | # ) 220 | dmap_triu = torch.cdist(xyz, xyz, p=2.0) 221 | dmap_triu = dmap_triu[:,triu_ids[0],triu_ids[1]] 222 | if get_triu_ids: 223 | return dmap_triu, triu_ids 224 | else: 225 | return dmap_triu --------------------------------------------------------------------------------