├── utils ├── __init__.py ├── common.py └── metrics.py ├── assets ├── model_overview.jpg └── results_on_PRA.png ├── data ├── protein │ └── __init__.py ├── rna │ ├── __init__.py │ ├── base_constants.py │ ├── rnas.py │ └── sec_struct_utils.py ├── __init__.py ├── transforms │ ├── __init__.py │ ├── geometric.py │ ├── select_chain.py │ ├── noise.py │ ├── _base.py │ ├── corrupt_chi.py │ ├── select_atom.py │ ├── mask.py │ └── patch.py ├── register.py ├── complex.py ├── sequence_dataset.py └── pri30k_dataset.py ├── pl_modules ├── __init__.py ├── data_module.py ├── model_module.py └── pretune_module.py ├── models ├── __init__.py ├── components │ ├── loss.py │ ├── rope.py │ ├── coformer.py │ ├── valina_transformer.py │ ├── valina_attn.py │ └── attention.py ├── register.py ├── lora_tune.py ├── encoders │ ├── single.py │ ├── pair.py │ ├── layers.py │ └── attn.py ├── ipa.py └── esm_rinalmo_seq.py ├── config ├── runs │ ├── finetune_struct.yml │ ├── finetune_sequence.yml │ ├── zero_shot_blindtest.yml │ ├── pretune_struct.yml │ └── test_basic.yml ├── datasets │ ├── PRA201.yml │ ├── PRA310.yml │ ├── PRI30k.yml │ ├── mCSM.yml │ └── blindtest.yml └── models │ ├── esm2_rinalmo_seq.yml │ └── copra.yml ├── environment.yml ├── LICENSE ├── .gitignore ├── README.md └── run.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | from .metrics import * 3 | from .geometry import * -------------------------------------------------------------------------------- /assets/model_overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hanrthu/CoPRA/HEAD/assets/model_overview.jpg -------------------------------------------------------------------------------- /assets/results_on_PRA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hanrthu/CoPRA/HEAD/assets/results_on_PRA.png -------------------------------------------------------------------------------- /data/protein/__init__.py: -------------------------------------------------------------------------------- 1 | from .atom_convert import * 2 | from .proteins import * 3 | from .residue_constants import * -------------------------------------------------------------------------------- /pl_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_module import * 2 | from .model_module import * 3 | from .ddg_module import * 4 | from .pretune_module import * -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .esm_rinalmo_seq import * 2 | from .model import * 3 | from .components.coformer import * 4 | from .register import * 5 | from .ipa import * -------------------------------------------------------------------------------- /data/rna/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_constants import * 2 | from .data_utils import * 3 | from .featurizer import * 4 | from .rnas import * 5 | # from .sec_struct_utils import * -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .sequence_dataset import * 2 | from .structure_dataset import * 3 | from .register import * 4 | from data.protein.proteins import * 5 | from data.protein.residue_constants import * 6 | from .pri30k_dataset import * -------------------------------------------------------------------------------- /config/runs/finetune_struct.yml: -------------------------------------------------------------------------------- 1 | epochs: 140 2 | patience: 150 3 | output_dir: './outputs/esm2rinalmo_650M/scratch/PRA310' 4 | gpus: 5 | - 7 6 | 7 | ckpt: null 8 | 9 | run_name: 'esm2rinalmo_scratch_patch256' 10 | wandb: False 11 | num_folds: 5 -------------------------------------------------------------------------------- /config/runs/finetune_sequence.yml: -------------------------------------------------------------------------------- 1 | epochs: 30 2 | patience: 50 3 | output_dir: './outputs/esm2rinalmo_650M/finetune_sequence/with_transformer_pdbbind_dG_restrict_remove_cls_token' 4 | gpus: 5 | - 0 6 | ckpt: null 7 | run_name: 'esm2_rinalmo_650M_clstoken' 8 | wandb: true 9 | num_folds: 5 -------------------------------------------------------------------------------- /config/runs/zero_shot_blindtest.yml: -------------------------------------------------------------------------------- 1 | iters: 10500 2 | epochs: 110 3 | patience: 50 4 | output_dir: './outputs/esm2rinalmo_650M/mCSM/blind_test' 5 | gpus: 6 | - 0 7 | ckpts: 8 | - './weights/CoPRA_fold3.ckpt' 9 | 10 | run_name: 'esm2rinalmo_mmCSM_patch256' 11 | wandb: False 12 | num_folds: 1 -------------------------------------------------------------------------------- /config/runs/pretune_struct.yml: -------------------------------------------------------------------------------- 1 | # iters: 30000 2 | epochs: 100 3 | patience: 50 4 | output_dir: './outputs/esm2rinalmo_650M/pretune_struct/esm2rinalmo_650M_cformer_PRI30k_fixed_interface_dist_random_mask0.15_tokenpooling' 5 | gpus: 6 | - 0 7 | - 1 8 | - 2 9 | - 3 10 | ckpt: null 11 | 12 | run_name: 'esm2rinalmo_650M_cformer_PRI30k_fixed_interface_dist_random_mask0.15_tokenpooling' 13 | wandb: false 14 | num_folds: 1 -------------------------------------------------------------------------------- /config/runs/test_basic.yml: -------------------------------------------------------------------------------- 1 | output_dir: './outputs/esm2rinalmo_650M/test_struct/best_model' 2 | gpus: 3 | - 0 4 | 5 | ckpts: 6 | - ./weights/CoPRA_5fold/CoPRA_fold0.ckpt 7 | - ./weights/CoPRA_5fold/CoPRA_fold1.ckpt 8 | - ./weights/CoPRA_5fold/CoPRA_fold2.ckpt 9 | - ./weights/CoPRA_5fold/CoPRA_fold3.ckpt 10 | - ./weights/CoPRA_5fold/CoPRA_fold4.ckpt 11 | 12 | run_name: 'esm2rinalmo_test_best' 13 | wandb: False 14 | num_folds: 5 -------------------------------------------------------------------------------- /data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Transforms 2 | from .patch import FocusedRandomPatch, RandomPatch, SelectedRegionWithPaddingPatch, SelectedRegionFixedSizePatch 3 | from .select_chain import SelectFocused 4 | from .select_atom import SelectAtom 5 | from .mask import RandomMaskAminoAcids, MaskSelectedAminoAcids 6 | from .noise import AddAtomNoise, AddChiAngleNoise 7 | from .corrupt_chi import CorruptChiAngle 8 | from .geometric import SubtractCOM 9 | # Factory 10 | from ._base import get_transform, Compose, _get_CB_positions 11 | -------------------------------------------------------------------------------- /config/datasets/PRA201.yml: -------------------------------------------------------------------------------- 1 | dataset_type: 'structure_dataset' 2 | df_path: './datasets/PRA310/splits/PRA201.csv' 3 | batch_size: 16 4 | data_root: './datasets/PRA310/PDBs' 5 | num_workers: 2 6 | col_prot_chain: 'Protein chains' 7 | col_na_chain: 'RNA chains' 8 | col_prot: 'Protein sequences' 9 | col_na: 'RNA sequences' 10 | col_label: '△G(kcal/mol)' 11 | pin_memory: True 12 | cache_dir: './cache/pra201' 13 | loss_type: regression 14 | strategy: separate 15 | 16 | transform: 17 | - type: select_atom 18 | resolution: backbone 19 | - type: selected_region_with_distmap 20 | patch_size: 256 21 | - type: subtract_center_of_mass -------------------------------------------------------------------------------- /config/datasets/PRA310.yml: -------------------------------------------------------------------------------- 1 | dataset_type: 'structure_dataset' 2 | df_path: './datasets/PRA310/splits/PRA310.csv' 3 | batch_size: 16 4 | data_root: './datasets/PRA310/PDBs' 5 | num_workers: 2 6 | col_prot_chain: 'Protein chains' 7 | col_na_chain: 'RNA chains' 8 | col_prot: 'Protein sequences' 9 | col_na: 'RNA sequences' 10 | col_label: '△G(kcal/mol)' 11 | pin_memory: True 12 | cache_dir: './cache/pra310' 13 | loss_type: regression 14 | strategy: separate 15 | 16 | transform: 17 | - type: select_atom 18 | resolution: backbone 19 | - type: selected_region_with_distmap 20 | patch_size: 256 21 | - type: subtract_center_of_mass -------------------------------------------------------------------------------- /config/datasets/PRI30k.yml: -------------------------------------------------------------------------------- 1 | dataset_type: 'pri30k_dataset' 2 | df_path: './datasets/PRI30k/splits/pretrain_length_750_clean.csv' 3 | batch_size: 20 4 | data_root: './datasets/PRI30k/PDBs' 5 | num_workers: 0 6 | col_prot_name: PDB 7 | col_prot_chain: Protein chains 8 | col_na_chain: RNA chains 9 | col_binding_site: Binding site renumbered merged 10 | col_ligand: Binding ligands 11 | pin_memory: True 12 | cache_dir: './cache/pri30k' 13 | loss_type: regression 14 | strategy: separate 15 | 16 | transform: 17 | - type: select_atom 18 | resolution: backbone 19 | - type: selected_region_with_distmap 20 | patch_size: 256 21 | - type: subtract_center_of_mass -------------------------------------------------------------------------------- /data/transforms/geometric.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | from ._base import register_transform 5 | 6 | @register_transform('subtract_center_of_mass') 7 | class SubtractCOM(object): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def __call__(self, data): 12 | pos = data['pos_atoms'] 13 | mask = data['mask_atoms'] 14 | if mask is None: 15 | center = np.zeros(3) 16 | elif mask.sum() == 0: 17 | center = np.zeros(3) 18 | else: 19 | center = pos[mask].mean(axis=0) 20 | data['pos_atoms'] = pos - center[None, None, :] 21 | return data 22 | 23 | -------------------------------------------------------------------------------- /config/datasets/mCSM.yml: -------------------------------------------------------------------------------- 1 | dataset_type: 'structure_dataset' 2 | df_path: './datasets/mCSM_RNA/splits/crossvalidation.csv' 3 | batch_size: 8 4 | data_root: './datasets/mCSM_RNA/PDBs' 5 | num_workers: 2 6 | col_prot_chain: 'Protein chains' 7 | col_na_chain: 'RNA chains' 8 | col_prot: 'Protein sequences' 9 | col_mut: 'Mutation sequences' 10 | mut: True 11 | col_na: 'RNA sequences' 12 | col_label: 'DDG' 13 | pin_memory: True 14 | cache_dir: './cache/mmCSM' 15 | loss_type: regression 16 | strategy: separate 17 | 18 | transform: 19 | - type: select_atom 20 | resolution: backbone 21 | - type: selected_region_with_distmap 22 | patch_size: 256 23 | - type: subtract_center_of_mass -------------------------------------------------------------------------------- /config/datasets/blindtest.yml: -------------------------------------------------------------------------------- 1 | dataset_type: 'structure_dataset' 2 | df_path: './datasets/mCSM_RNA/splits/blindtest.csv' 3 | batch_size: 8 4 | data_root: './datasets/mCSM_RNA/PDBs' 5 | num_workers: 2 6 | col_prot_chain: 'Protein chains' 7 | col_na_chain: 'RNA chains' 8 | col_prot: 'Protein sequences' 9 | col_mut: 'Mutation sequences' 10 | mut: True 11 | col_na: 'RNA sequences' 12 | col_label: 'DDG' 13 | pin_memory: True 14 | cache_dir: './cache/blindtest' 15 | loss_type: regression 16 | strategy: separate 17 | 18 | transform: 19 | - type: select_atom 20 | resolution: backbone 21 | - type: selected_region_with_distmap 22 | patch_size: 256 23 | - type: subtract_center_of_mass -------------------------------------------------------------------------------- /models/components/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class PearsonCorrLoss(nn.Module): 6 | def __init__(self): 7 | super(PearsonCorrLoss, self).__init__() 8 | 9 | def forward(self, pred, target): 10 | pred_mean = torch.mean(pred) 11 | target_mean = torch.mean(target) 12 | 13 | pred_centered = pred - pred_mean 14 | target_centered = target - target_mean 15 | 16 | numerator = torch.sum(pred_centered * target_centered) 17 | denominator = torch.sqrt(torch.sum(pred_centered ** 2) * torch.sum(target_centered ** 2)) 18 | 19 | pearson_corr = numerator / (denominator + 1e-5) 20 | loss = 1 - pearson_corr 21 | 22 | return loss -------------------------------------------------------------------------------- /config/models/esm2_rinalmo_seq.yml: -------------------------------------------------------------------------------- 1 | resume: null 2 | model: 3 | model_type: esm2_rinalmo_seq 4 | esm_type: 650M 5 | pooling: mean 6 | output_dim: 1 7 | fix_lms: True 8 | lora_tune: false 9 | lora_rank: 16 10 | lora_alpha: 32 11 | vallina: True 12 | representation_layer: 33 13 | transformer: 14 | embed_dim: 320 15 | num_blocks: 6 16 | num_heads: 20 17 | use_rot_emb: true 18 | attn_qkv_bias: false 19 | attention_dropout: 0.1 20 | transition_dropout: 0.0 21 | residual_dropout: 0.1 22 | transition_factor: 4 23 | use_flash_attn: false 24 | train: 25 | max_iters: 30_000 26 | val_freq: 1000 27 | seed: 2024 28 | # max_grad_norm: 100.0 29 | optimizer: 30 | type: adam 31 | lr: 3.e-5 32 | weight_decay: 0.0 33 | beta1: 0.9 34 | beta2: 0.999 35 | scheduler: 36 | type: plateau 37 | factor: 0.8 38 | patience: 5 39 | min_lr: 1.e-6 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: copra_h 2 | channels: 3 | - pytorch3d 4 | - pytorch 5 | - conda-forge 6 | - nvidia 7 | - pyg 8 | dependencies: 9 | - python=3.10.14 10 | - pytorch:pytorch=2.1.2 11 | - pytorch:pytorch-cuda=11.8 12 | # - llvm-openmp<16 13 | - einops=0.6.1 14 | - pytorch-lightning 15 | - pandas 16 | - dm-tree=0.1.7 17 | - diskcache=5.4.0 18 | - fire=0.5.0 19 | - numpy=1.26.4 20 | - scikit-learn=1.3.0 21 | - pip 22 | - gpustat 23 | - wandb 24 | - easydict=1.9 25 | - matplotlib=3.8.0 26 | - pytorch_geometric 27 | - pytorch-scatter 28 | - pytorch-sparse 29 | - pytorch-cluster 30 | - pip: 31 | - biopython==1.84 32 | - fair-esm==2.0.0 33 | - tree==0.2.4 34 | - seaborn==0.13.2 35 | - triton==2.1.0 36 | - torchsummary==1.5.1 37 | - peft==0.14.0 38 | - biotite==0.41.2 39 | - cpdb-protein==0.2.0 40 | - -i https://pypi.tuna.tsinghua.edu.cn/simple -------------------------------------------------------------------------------- /config/models/copra.yml: -------------------------------------------------------------------------------- 1 | resume: null 2 | model: 3 | model_type: copra 4 | esm_type: 650M 5 | pooling: token 6 | output_dim: 1 7 | fix_lms: true 8 | lora_tune: false 9 | lora_rank: 16 10 | lora_alpha: 32 11 | pair_dim: 40 12 | dist_dim: 40 13 | representation_layer: 33 14 | coformer: 15 | embed_dim: 320 16 | pair_dim: 40 17 | num_blocks: 6 18 | num_heads: 20 19 | use_rot_emb: true 20 | attn_qkv_bias: false 21 | attention_dropout: 0.1 22 | transition_dropout: 0.0 23 | residual_dropout: 0.1 24 | transition_factor: 4 25 | use_flash_attn: false 26 | 27 | train: 28 | temperature: 0.2 29 | max_iters: 30_000 30 | val_freq: 1000 31 | seed: 2024 32 | max_grad_norm: 100.0 33 | optimizer: 34 | type: adam 35 | lr: 3.e-5 36 | weight_decay: 0.0 37 | beta1: 0.9 38 | beta2: 0.999 39 | scheduler: 40 | type: plateau 41 | factor: 0.8 42 | patience: 5 43 | min_lr: 1.e-6 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Rong Han 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/transforms/select_chain.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | from ._base import _mask_select_data, register_transform 5 | 6 | 7 | @register_transform('random_interacting_chain') 8 | class RandomInteractingChain(object): 9 | 10 | def __init__(self, interaction_attr): 11 | super().__init__() 12 | self.interaction_attr = interaction_attr 13 | 14 | def __call__(self, data): 15 | # Randomly choose a chain to mask 16 | interact_flag = (data[self.interaction_attr] > 0) # (L, ) 17 | if interact_flag.sum() == 0: 18 | # If there is no active residues, randomly pick one. 19 | interact_flag[random.randint(0, interact_flag.size(0)-1)] = True 20 | seed_idx = torch.multinomial(interact_flag.float(), num_samples=1).item() 21 | 22 | chain_nb_selected = data['chain_nb'][seed_idx].item() 23 | mask_chain = (data['chain_nb'] == chain_nb_selected) 24 | return _mask_select_data(data, mask_chain) 25 | 26 | 27 | @register_transform('select_focused') 28 | class SelectFocused(object): 29 | 30 | def __init__(self, focus_attr): 31 | super().__init__() 32 | self.focus_attr = focus_attr 33 | 34 | def __call__(self, data): 35 | mask_focus = (data[self.focus_attr] > 0) 36 | return _mask_select_data(data, mask_focus) 37 | 38 | -------------------------------------------------------------------------------- /data/transforms/noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from ._base import register_transform 5 | 6 | 7 | @register_transform('add_atom_noise') 8 | class AddAtomNoise(object): 9 | 10 | def __init__(self, noise_std=0.02): 11 | super().__init__() 12 | self.noise_std = noise_std 13 | 14 | def __call__(self, data): 15 | pos_atoms = data['pos_atoms'] # (L, A, 3) 16 | mask_atoms = data['mask_atoms'] # (L, A) 17 | noise = (torch.randn_like(pos_atoms) * self.noise_std) * mask_atoms[:, :, None] 18 | pos_noisy = pos_atoms + noise 19 | data['pos_atoms'] = pos_noisy 20 | return data 21 | 22 | 23 | @register_transform('add_chi_angle_noise') 24 | class AddChiAngleNoise(object): 25 | 26 | def __init__(self, noise_std=0.02): 27 | super().__init__() 28 | self.noise_std = noise_std 29 | 30 | def _normalize_angles(self, angles): 31 | angles = angles % (2*np.pi) 32 | return torch.where(angles > np.pi, angles - 2*np.pi, angles) 33 | 34 | def __call__(self, data): 35 | chi, chi_alt = data['chi'], data['chi_alt'] # (L, 4) 36 | chi_mask = data['chi_mask'] # (L, 4) 37 | 38 | _get_noise = lambda: ((torch.randn_like(chi) * self.noise_std) * chi_mask) 39 | data['chi'] = self._normalize_angles( chi + _get_noise() ) 40 | data['chi_alt'] = self._normalize_angles( chi_alt + _get_noise() ) 41 | return data 42 | -------------------------------------------------------------------------------- /data/rna/base_constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | 4 | 5 | PROJECT_PATH = os.environ.get("PROJECT_PATH") 6 | 7 | DATA_PATH = os.environ.get("DATA_PATH") 8 | 9 | X3DNA_PATH = os.environ.get("X3DNA") 10 | 11 | ETERNAFOLD_PATH = os.environ.get("ETERNAFOLD") 12 | 13 | 14 | # Value to fill missing coordinate entries when reading PDB files 15 | FILL_VALUE = 1e-5 16 | 17 | 18 | # Small epsilon value added to distances to avoid division by zero 19 | DISTANCE_EPS = 0.001 20 | 21 | 22 | # List of possible atoms in RNA nucleotides 23 | RNA_ATOMS = [ 24 | 'P', "C5'", "O5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", 25 | 'N1', 26 | 'C2', 27 | 'O2', 'N2', 28 | 'N3', 29 | 'C4', 'O4', 'N4', 30 | 'C5', 31 | 'C6', 32 | 'O6', 'N6', 33 | 'N7', 34 | 'C8', 35 | 'N9', 36 | 'OP1', 'OP2', 37 | ] 38 | 39 | 40 | # List of possible RNA nucleotides 41 | RNA_NUCLEOTIDES = [ 42 | 'A', 43 | 'G', 44 | 'C', 45 | 'U', 46 | # '_' # placeholder for missing/unknown nucleotides 47 | ] 48 | 49 | 50 | # List of purine nucleotides 51 | PURINES = ["A", "G"] 52 | 53 | 54 | # List of pyrimidine nucleotides 55 | PYRIMIDINES = ["C", "U"] 56 | 57 | 58 | # 59 | LETTER_TO_NUM = dict(zip( 60 | RNA_NUCLEOTIDES, 61 | list(range(len(RNA_NUCLEOTIDES))) 62 | )) 63 | 64 | 65 | # 66 | NUM_TO_LETTER = {v:k for k, v in LETTER_TO_NUM.items()} 67 | 68 | 69 | # 70 | DOTBRACKET_TO_NUM = { 71 | '.': 0, 72 | '(': 1, 73 | ')': 2 74 | } 75 | -------------------------------------------------------------------------------- /models/register.py: -------------------------------------------------------------------------------- 1 | from utils import singleton 2 | 3 | REGISTERED_MODELS = [] 4 | 5 | @singleton 6 | class ModelRegister(dict): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self._dict = {} 10 | def register(self, target): 11 | def add_register_item(key, value): 12 | if not callable(value): 13 | raise Exception(f"register object must be callable! But receice:{value} is not callable!") 14 | if key in self._dict: 15 | print(f"warning: \033[33m{value.__name__} has been registered before, so we will overriden it\033[0m") 16 | self[key] = value 17 | REGISTERED_MODELS.append(key) 18 | return value 19 | 20 | if callable(target): 21 | return add_register_item(target.__name__, target) 22 | else: 23 | return lambda x: add_register_item(target, x) 24 | 25 | def __call__(self, target): 26 | return self.register(target) 27 | 28 | def __setitem__(self, key, value): 29 | self._dict[key] = value 30 | 31 | def __getitem__(self, key): 32 | return self._dict[key] 33 | 34 | def __contains__(self, key): 35 | return key in self._dict 36 | 37 | def __str__(self): 38 | return str(self._dict) 39 | 40 | def keys(self): 41 | return self._dict.keys() 42 | 43 | def values(self): 44 | return self._dict.values() 45 | 46 | def items(self): 47 | return self._dict.items() 48 | 49 | 50 | -------------------------------------------------------------------------------- /data/register.py: -------------------------------------------------------------------------------- 1 | from utils import singleton 2 | 3 | REGISTERED_DATASETS = [] 4 | 5 | @singleton 6 | class DataRegister(dict): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self._dict = {} 10 | def register(self, target): 11 | def add_register_item(key, value): 12 | if not callable(value): 13 | raise Exception(f"register object must be callable! But receice:{value} is not callable!") 14 | if key in self._dict: 15 | print(f"warning: \033[33m{value.__name__} has been registered before, so we will overriden it\033[0m") 16 | self[key] = value 17 | REGISTERED_DATASETS.append(key) 18 | return value 19 | 20 | if callable(target): 21 | return add_register_item(target.__name__, target) 22 | else: 23 | return lambda x: add_register_item(target, x) 24 | 25 | def __call__(self, target): 26 | return self.register(target) 27 | 28 | def __setitem__(self, key, value): 29 | self._dict[key] = value 30 | 31 | def __getitem__(self, key): 32 | # print(self._dict) 33 | return self._dict[key] 34 | 35 | def __contains__(self, key): 36 | return key in self._dict 37 | 38 | def __str__(self): 39 | return str(self._dict) 40 | 41 | def keys(self): 42 | return self._dict.keys() 43 | 44 | def values(self): 45 | return self._dict.values() 46 | 47 | def items(self): 48 | return self._dict.items() 49 | 50 | 51 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pandas as pd 4 | import torch 5 | from torch_geometric.data import Data 6 | from torch_geometric.typing import OptTensor, SparseTensor 7 | 8 | def singleton(cls): 9 | _instance = {} 10 | 11 | def inner(*args, **kwargs): 12 | if cls not in _instance: 13 | _instance[cls] = cls(*args, **kwargs) 14 | return _instance[cls] 15 | 16 | return inner 17 | 18 | 19 | class MyData(Data): 20 | chain: pd.Series 21 | channel_weights: torch.Tensor 22 | 23 | def __init__( 24 | self, x: OptTensor = None, edge_index: OptTensor = None, 25 | edge_attr: OptTensor = None, y: OptTensor = None, 26 | pos: OptTensor = None, **kwargs 27 | ): 28 | super(MyData, self).__init__(x, edge_index, edge_attr, y, pos, **kwargs) 29 | 30 | def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any: 31 | if 'batch' in key: 32 | return int(value.max()) + 1 33 | elif 'index' in key or 'face' in key: 34 | return self.num_nodes 35 | elif 'interaction' in key: 36 | return self.num_edges 37 | elif 'chains' in key: 38 | return self.num_chains 39 | else: 40 | return 0 41 | def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: 42 | if isinstance(value, SparseTensor) and 'adj' in key: 43 | return (0, 1) 44 | elif 'index' in key or 'face' in key or 'interaction' in key: 45 | return -1 46 | else: 47 | return 0 48 | @property 49 | def num_chains(self): 50 | return len(torch.unique(self.chains)) -------------------------------------------------------------------------------- /models/components/rope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Code heavily inspired by https://blog.eleuther.ai/rotary-embeddings/, GPT-NeoX (Pytorch) implementation 5 | # and ESM2 implementation https://github.com/facebookresearch/esm/blob/main/esm/rotary_embedding.py 6 | 7 | def rotate_half(x): 8 | x1, x2 = x.chunk(2, dim=-1) 9 | return torch.cat((-x2, x1), dim=-1) 10 | 11 | 12 | @torch.jit.script 13 | def apply_rotary_pos_emb(q, k, cos, sin): 14 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 15 | 16 | class RotaryPositionEmbedding(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | base: int = 10000 21 | ): 22 | super().__init__() 23 | 24 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 25 | self.register_buffer("inv_freq", inv_freq) 26 | 27 | self.seq_len_cached = None 28 | self.cos_cached = None 29 | self.sin_cached = None 30 | 31 | def _update_cached(self, x, seq_dim): 32 | seq_len = x.shape[seq_dim] 33 | 34 | if seq_len != self.seq_len_cached: 35 | self.seq_len_cached = seq_len 36 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) 37 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 38 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 39 | 40 | self.cos_cached = emb.cos()[None, None, :, :] 41 | self.sin_cached = emb.sin()[None, None, :, :] 42 | 43 | def forward(self, q, k): 44 | self._update_cached(k, seq_dim=-2) 45 | return apply_rotary_pos_emb(q, k, self.cos_cached, self.sin_cached) 46 | -------------------------------------------------------------------------------- /models/lora_tune.py: -------------------------------------------------------------------------------- 1 | from transformers import PreTrainedModel, BertConfig, PretrainedConfig, BertForMaskedLM 2 | from typing import List 3 | class ESMConfig(PretrainedConfig): 4 | model_type = "esm2" 5 | def __init__( 6 | self, 7 | pooling='mean', 8 | output_dim=1, 9 | fix_lms=True, 10 | lora_tune=True, 11 | **kwargs, 12 | ): 13 | self.pooling = pooling 14 | self.output_dim = output_dim 15 | self.fix_lms = fix_lms 16 | self.lora_tune = lora_tune 17 | super().__init__(**kwargs) 18 | 19 | class RiNALMoConfig(PretrainedConfig): 20 | model_type = "rinalmo_flash" 21 | def __init__( 22 | self, 23 | pooling='mean', 24 | output_dim=1, 25 | fix_lms=True, 26 | lora_tune=True, 27 | **kwargs, 28 | ): 29 | self.pooling = pooling 30 | self.output_dim = output_dim 31 | self.fix_lms = fix_lms 32 | self.lora_tune = lora_tune 33 | super().__init__(**kwargs) 34 | 35 | 36 | class LoRAESM(PreTrainedModel): 37 | config_class = ESMConfig 38 | def __init__(self, model, config): 39 | super().__init__(config) 40 | self.model = model 41 | def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False): 42 | return self.model(tokens, repr_layers, need_head_weights, return_contacts) 43 | 44 | class LoRARiNALMo(PreTrainedModel): 45 | config_class = RiNALMoConfig 46 | def __init__(self, model, config): 47 | super().__init__(config) 48 | self.model = model 49 | def forward(self, tokens, need_attn_weights=False): 50 | return self.model(tokens, need_attn_weights) -------------------------------------------------------------------------------- /models/encoders/single.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.encoders.layers import AngularEncoding 5 | 6 | 7 | class PerResidueEncoder(nn.Module): 8 | 9 | def __init__(self, feat_dim, max_num_atoms, max_aa_types=27): 10 | super().__init__() 11 | self.max_num_atoms = max_num_atoms 12 | self.max_aa_types = max_aa_types 13 | self.aatype_embed = nn.Embedding(self.max_aa_types, feat_dim) 14 | # self.dihed_embed = AngularEncoding() 15 | # infeat_dim = feat_dim + self.dihed_embed.get_out_dim(6) # Phi, Psi, Chi1-4 16 | infeat_dim = feat_dim 17 | self.mlp = nn.Sequential( 18 | nn.Linear(infeat_dim, feat_dim * 2), nn.ReLU(), 19 | nn.Linear(feat_dim * 2, feat_dim), nn.ReLU(), 20 | nn.Linear(feat_dim, feat_dim), nn.ReLU(), 21 | nn.Linear(feat_dim, feat_dim) 22 | ) 23 | 24 | # def forward(self, aa, phi, phi_mask, psi, psi_mask, chi, chi_mask, mask_residue): 25 | def forward(self, aa, mask_residue): 26 | """ 27 | Args: 28 | aa: (N, L) 29 | # phi, phi_mask: (N, L) 30 | # psi, psi_mask: (N, L) 31 | # chi, chi_mask: (N, L, 4) 32 | mask_residue: (N, L) 33 | """ 34 | N, L = aa.size() 35 | 36 | # Amino acid identity features 37 | aa_feat = self.aatype_embed(aa) # (N, L, feat) 38 | 39 | # Dihedral features 40 | # dihedral = torch.cat( 41 | # [phi[..., None], psi[..., None], chi], 42 | # dim=-1 43 | # ) # (N, L, 6) 44 | # dihedral_mask = torch.cat([ 45 | # phi_mask[..., None], psi_mask[..., None], chi_mask], 46 | # dim=-1 47 | # ) # (N, L, 6) 48 | # dihedral_feat = self.dihed_embed(dihedral[..., None]) * dihedral_mask[..., None] # (N, L, 6, feat) 49 | # dihedral_feat = dihedral_feat.reshape(N, L, -1) 50 | 51 | # Mix 52 | # out_feat = self.mlp(torch.cat([aa_feat, dihedral_feat], dim=-1)) # (N, L, F) 53 | out_feat = self.mlp(aa_feat) 54 | out_feat = out_feat * mask_residue[:, :, None] 55 | return out_feat 56 | -------------------------------------------------------------------------------- /data/rna/rnas.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | from typing import Dict, Optional, List 4 | import numpy as np 5 | from data.rna.data_utils import pdb_to_array, cif_to_array 6 | 7 | 8 | @dataclass 9 | class RNAInput: 10 | seq: str # L 11 | mask: np.ndarray # (L, ) 12 | basetype: np.ndarray # (L, ) 13 | atom_mask: np.ndarray # (L, 27) 14 | 15 | atom_positions: Optional[np.ndarray] = field(default=None) # (L, 27, 3) 16 | residue_index: Optional[np.ndarray] = field(default=None) # (L, ) 17 | 18 | res_nb: Optional[np.ndarray] = field(default=None) # (L, ) 19 | resseq: Optional[np.ndarray] = field(default=None) # (L, ) 20 | chain_nb: Optional[np.ndarray] = field(default=None) # (L, ) 21 | chain_id: Optional[List] = field(default=None) # len(chain_id) = L 22 | 23 | chainid: Optional[np.ndarray] = field(default=None) # (L) 24 | 25 | 26 | @classmethod 27 | def from_path(self, pdb_filepath, valid_chains): 28 | """ 29 | Load RNA backbone from PDB file. 30 | 31 | Args: 32 | pdb_filepath (str): Path to PDB file. 33 | """ 34 | pdb_filepath = str(pdb_filepath) 35 | if '.pdb' in pdb_filepath: 36 | rnas = pdb_to_array( 37 | pdb_filepath, valid_chains, return_sec_struct=True, return_sasa=False) 38 | else: 39 | rnas = cif_to_array(pdb_filepath, valid_chains) 40 | # print(rnas) 41 | rna_dict = {} 42 | for key in rnas: 43 | rna = self(**rnas[key]) 44 | rna_dict[key] = rna 45 | return rna_dict 46 | # coords = get_backbone_coords(coords, sequence) 47 | # rna = { 48 | # 'sequence': sequence, 49 | # 'coords_list': coords, 50 | # } 51 | # return rna 52 | 53 | @property 54 | def length(self): 55 | return len(self.seq) 56 | 57 | def __repr__(self): 58 | return self.__str__() 59 | 60 | def __str__(self): 61 | texts = [] 62 | texts += [f'seq: {self.seq}'] 63 | texts += [f'length: {len(self.seq)}'] 64 | texts += [f"mask: {''.join(self.mask.astype('int').astype('str'))}"] 65 | if self.chainid is not None: 66 | texts += [f"chainid: {''.join(self.chainid.astype('int').astype('str'))}"] 67 | 68 | names = [ 69 | 'basetype', 70 | 'atom_mask', 71 | 'atom_positions', 72 | ] 73 | for name in names: 74 | value = getattr(self, name) 75 | if value is None: 76 | text = f'{name}: None' 77 | else: 78 | text = f'{name}: {value.shape}' 79 | texts += [text] 80 | text = ', \n '.join(texts) 81 | text = f'RNA(\n {text}\n)' 82 | return text -------------------------------------------------------------------------------- /data/transforms/_base.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | 4 | 5 | class Compose: 6 | 7 | def __init__(self, transforms): 8 | self.transforms = transforms 9 | 10 | def __call__(self, data): 11 | for t in self.transforms: 12 | data = t(data) 13 | return data 14 | 15 | 16 | _TRANSFORM_DICT = {} 17 | 18 | 19 | def register_transform(name): 20 | def decorator(cls): 21 | _TRANSFORM_DICT[name] = cls 22 | return cls 23 | return decorator 24 | 25 | 26 | def get_transform(cfg): 27 | if cfg is None or len(cfg) == 0: 28 | return None 29 | tfms = [] 30 | for t_dict in cfg: 31 | t_dict = copy.deepcopy(t_dict) 32 | cls = _TRANSFORM_DICT[t_dict.pop('type')] 33 | tfms.append(cls(**t_dict)) 34 | return Compose(tfms) 35 | 36 | 37 | def _index_select(v, index, n): 38 | if isinstance(v, torch.Tensor) and v.size(0) == n: 39 | return v[index] 40 | elif isinstance(v, list) and len(v) == n: 41 | return [v[i] for i in index] 42 | else: 43 | return v 44 | 45 | 46 | def _index_select_with_dist(v, index, n): 47 | if isinstance(v, torch.Tensor) and v.size(0) == n and (len(v.shape) == 1 or v.size(0) != v.size(1)): 48 | return v[index] 49 | elif isinstance(v, torch.Tensor) and v.size(0) == v.size(1) and v.size(0) == n: 50 | tmp = v[index] 51 | new_v = tmp[:, index] 52 | return new_v 53 | elif isinstance(v, list) and len(v) == n: 54 | return [v[i] for i in index] 55 | else: 56 | return v 57 | 58 | def _index_select_data(data, index): 59 | return { 60 | k: _index_select(v, index, data['aa'].size(0)) 61 | for k, v in data.items() 62 | } 63 | 64 | def _index_select_complex(data, index): 65 | return { 66 | k: _index_select_with_dist(v, index, data['restype'].size(0)) 67 | for k, v in data.items() 68 | } 69 | 70 | def _mask_select(v, mask): 71 | if isinstance(v, torch.Tensor) and v.size(0) == mask.size(0): 72 | return v[mask] 73 | elif isinstance(v, list) and len(v) == mask.size(0): 74 | return [v[i] for i, b in enumerate(mask) if b] 75 | else: 76 | return v 77 | 78 | 79 | def _mask_select_data(data, mask): 80 | return { 81 | k: _mask_select(v, mask) 82 | for k, v in data.items() 83 | } 84 | 85 | 86 | def _get_CB_positions(pos_atoms, mask_atoms): 87 | """ 88 | Args: 89 | pos_atoms: (L, A, 3) 90 | mask_atoms: (L, A) 91 | """ 92 | L = pos_atoms.size(0) 93 | pos_CA = pos_atoms[:, 1] # (L, 3) # CA 94 | if pos_atoms.size(1) < 5: 95 | return pos_CA 96 | pos_CB = pos_atoms[:, 4] #CB 97 | mask_CB = mask_atoms[:, 4, None].expand(L, 3) #CB 98 | return torch.where(mask_CB, pos_CB, pos_CA) 99 | 100 | -------------------------------------------------------------------------------- /data/transforms/corrupt_chi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from ._base import register_transform, _get_CB_positions 5 | 6 | 7 | @register_transform('corrupt_chi_angle') 8 | class CorruptChiAngle(object): 9 | 10 | def __init__(self, ratio_mask=0.1, add_noise=True, maskable_flag_attr=None): 11 | super().__init__() 12 | self.ratio_mask = ratio_mask 13 | self.add_noise = add_noise 14 | self.maskable_flag_attr = maskable_flag_attr 15 | 16 | def _normalize_angles(self, angles): 17 | angles = angles % (2*np.pi) 18 | return torch.where(angles > np.pi, angles - 2*np.pi, angles) 19 | 20 | def _get_min_dist(self, data, center_idx): 21 | pos_beta_all = _get_CB_positions(data['pos_atoms'], data['mask_atoms']) 22 | pos_beta_center = pos_beta_all[center_idx] 23 | cdist = torch.cdist(pos_beta_all, pos_beta_center) # (L, K) 24 | min_dist = cdist.min(dim=1)[0] # (L, ) 25 | return min_dist 26 | 27 | def _get_noise_std(self, min_dist): 28 | return torch.clamp_min((-1/16) * min_dist + 1, 0) 29 | 30 | def _get_flip_prob(self, min_dist): 31 | return torch.where( 32 | min_dist <= 8.0, 33 | torch.full_like(min_dist, 0.25), 34 | torch.zeros_like(min_dist,), 35 | ) 36 | 37 | def _add_chi_gaussian_noise(self, chi, noise_std, chi_mask): 38 | """ 39 | Args: 40 | chi: (L, 4) 41 | noise_std: (L, ) 42 | chi_mask: (L, 4) 43 | """ 44 | noise = torch.randn_like(chi) * noise_std[:, None] * chi_mask 45 | return self._normalize_angles(chi + noise) 46 | 47 | def _random_flip_chi(self, chi, flip_prob, chi_mask): 48 | """ 49 | Args: 50 | chi: (L, 4) 51 | flip_prob: (L, ) 52 | chi_mask: (L, 4) 53 | """ 54 | delta = torch.where( 55 | torch.rand_like(chi) <= flip_prob[:, None], 56 | torch.full_like(chi, np.pi), 57 | torch.zeros_like(chi), 58 | ) * chi_mask 59 | return self._normalize_angles(chi + delta) 60 | 61 | def __call__(self, data): 62 | L = data['aa'].size(0) # Calculate aa num 63 | idx = torch.arange(0, L) 64 | num_mask = max(int(self.ratio_mask * L), 1) # Calculate mask ratio 65 | if self.maskable_flag_attr is not None: 66 | flag = data[self.maskable_flag_attr] 67 | idx = idx[flag] 68 | 69 | idx = idx.tolist() 70 | np.random.shuffle(idx) 71 | idx_mask = idx[:num_mask] 72 | min_dist = self._get_min_dist(data, idx_mask) 73 | noise_std = self._get_noise_std(min_dist) 74 | flip_prob = self._get_flip_prob(min_dist) 75 | 76 | chi_native = torch.where( 77 | torch.randn_like(data['chi']) > 0, 78 | data['chi'], 79 | data['chi_alt'], 80 | ) # (L, 4), randomly pick from chi and chi_alt 81 | chi = chi_native.clone() 82 | chi_mask = data['chi_mask'] 83 | 84 | if self.add_noise: 85 | chi = self._add_chi_gaussian_noise(chi, noise_std, chi_mask) 86 | chi = self._random_flip_chi(chi, flip_prob, chi_mask) 87 | chi[idx_mask] = 0.0 # Mask chi angles 88 | 89 | corrupt_flag = torch.zeros(L, dtype=torch.bool) 90 | corrupt_flag[idx_mask] = True 91 | corrupt_flag[min_dist <= 8] = True 92 | 93 | masked_flag = torch.zeros(L, dtype=torch.bool) 94 | masked_flag[idx_mask] = True 95 | 96 | data['chi_native'] = chi_native 97 | data['chi_corrupt'] = chi 98 | data['chi_corrupt_flag'] = corrupt_flag 99 | data['chi_masked_flag'] = masked_flag 100 | return data 101 | -------------------------------------------------------------------------------- /models/encoders/pair.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from utils.geometry import angstrom_to_nm, pairwise_dihedrals 6 | from models.encoders.layers import AngularEncoding 7 | 8 | 9 | class ResiduePairEncoder(nn.Module): 10 | 11 | def __init__(self, feat_dim, max_num_atoms, max_aa_types=30, max_relpos=32): 12 | super().__init__() 13 | self.max_num_atoms = max_num_atoms 14 | self.max_aa_types = max_aa_types 15 | self.max_relpos = max_relpos 16 | self.aa_pair_embed = nn.Embedding(self.max_aa_types * self.max_aa_types, feat_dim) 17 | self.relpos_embed = nn.Embedding(2 * max_relpos + 1, feat_dim) 18 | 19 | self.aapair_to_distcoef = nn.Embedding(self.max_aa_types * self.max_aa_types, max_num_atoms * max_num_atoms) 20 | # nn.init.zeros_(self.aapair_to_distcoef.weight) 21 | self.distance_embed = nn.Sequential( 22 | nn.Linear(max_num_atoms * max_num_atoms, feat_dim), nn.ReLU(), 23 | nn.Linear(feat_dim, feat_dim), nn.ReLU(), 24 | ) 25 | 26 | self.dihedral_embed = AngularEncoding() 27 | feat_dihed_dim = self.dihedral_embed.get_out_dim(2) # Phi and Psi 28 | 29 | infeat_dim = feat_dim + feat_dim + feat_dim + feat_dihed_dim 30 | self.out_mlp = nn.Sequential( 31 | nn.Linear(infeat_dim, feat_dim), nn.ReLU(), 32 | nn.Linear(feat_dim, feat_dim), nn.ReLU(), 33 | nn.Linear(feat_dim, feat_dim), 34 | ) 35 | 36 | def forward(self, aa, res_nb, chain_nb, pos_atoms, mask_atoms): 37 | """ 38 | Args: 39 | aa: (N, L). 40 | res_nb: (N, L). 41 | chain_nb: (N, L). 42 | pos_atoms: (N, L, A, 3) 43 | mask_atoms: (N, L, A) 44 | Returns: 45 | (N, L, L, feat_dim) 46 | """ 47 | N, L = aa.size() 48 | mask_residue = mask_atoms[:, :, 1] # (N, L) 49 | mask_pair = mask_residue[:, :, None] * mask_residue[:, None, :] 50 | 51 | # Pair identities 52 | aa_pair = aa[:, :, None] * self.max_aa_types + aa[:, None, :] # (N, L, L) 53 | feat_aapair = self.aa_pair_embed(aa_pair) 54 | 55 | # Relative positions 56 | same_chain = (chain_nb[:, :, None] == chain_nb[:, None, :]) 57 | relpos = torch.clamp( 58 | res_nb[:, :, None] - res_nb[:, None, :], 59 | min=-self.max_relpos, max=self.max_relpos, 60 | ) # (N, L, L) 61 | 62 | feat_relpos = self.relpos_embed(relpos + self.max_relpos) * same_chain[:, :, :, None] 63 | 64 | # Distances 65 | d = angstrom_to_nm(torch.linalg.norm( 66 | pos_atoms[:, :, None, :, None] - pos_atoms[:, None, :, None, :], 67 | dim=-1, ord=2, 68 | )).reshape(N, L, L, -1) # (N, L, L, A*A) 69 | c = F.softplus(self.aapair_to_distcoef(aa_pair)) # (N, L, L, A*A) 70 | d_gauss = torch.exp(-1 * c * d ** 2) 71 | mask_atom_pair = (mask_atoms[:, :, None, :, None] * mask_atoms[:, None, :, None, :]).reshape(N, L, L, -1) 72 | feat_dist = self.distance_embed(d_gauss * mask_atom_pair) 73 | 74 | # Orientations 75 | dihed = pairwise_dihedrals(pos_atoms) # (N, L, L, 2) 76 | feat_dihed = self.dihedral_embed(dihed) 77 | # if torch.isnan(feat_dihed).any(): 78 | # print("Found nan in dihed!") 79 | # All 80 | feat_all = torch.cat([feat_aapair, feat_relpos, feat_dist, feat_dihed], dim=-1) 81 | feat_all = self.out_mlp(feat_all) # (N, L, L, F) 82 | feat_all = feat_all * mask_pair[:, :, :, None] 83 | if torch.isnan(feat_all).any(): 84 | # print("??:", torch.isnan(aa_pair).any()) 85 | print("Let's check:", torch.isnan(feat_aapair).any(), torch.isnan(feat_relpos).any(), torch.isnan(feat_dist).any(), torch.isnan(feat_dihed).any()) 86 | print("Let's check 2:", torch.isnan(aa).any(), torch.isnan(res_nb).any(), torch.isnan(chain_nb).any(), torch.isnan(pos_atoms).any(), torch.isnan(mask_atoms).any()) 87 | return feat_all 88 | 89 | -------------------------------------------------------------------------------- /data/transforms/select_atom.py: -------------------------------------------------------------------------------- 1 | 2 | from ._base import register_transform 3 | import torch 4 | from data.rna import get_backbone_coords, RNA_ATOMS 5 | 6 | 7 | @register_transform('select_atom') 8 | class SelectAtom(object): 9 | 10 | def __init__(self, resolution): 11 | super().__init__() 12 | assert resolution in ('full', 'backbone', 'backbone+R1', 'C_Only') 13 | self.resolution = resolution 14 | 15 | def __call__(self, data): 16 | if self.resolution == 'full': 17 | data['pos_atoms'] = data['pos_heavyatom'][:, :] 18 | data['mask_atoms'] = data['mask_heavyatom'][:, :] 19 | 20 | elif self.resolution == 'backbone': 21 | pos_atoms = torch.zeros([data['pos_heavyatom'].shape[0]] + [4, 3], device=data['pos_heavyatom'].device) 22 | mask_atoms = torch.zeros([data['mask_heavyatom'].shape[0]] + [4], device=data['mask_heavyatom'].device).bool() 23 | # print(pos_atoms.shape, mask_atoms.shape) 24 | pos_atoms[data['identifier']==0] = data['pos_heavyatom'][data['identifier']==0, :4] 25 | mask_atoms[data['identifier']==0] = data['mask_heavyatom'][data['identifier']==0, :4] 26 | na_atoms = data['pos_heavyatom'][data['identifier']==1, 14:] 27 | na_seqs = data['seq'] 28 | indices = torch.nonzero(data['identifier']) 29 | first_index = indices[0].item() if len(indices) > 0 else -1 30 | na_seqs = na_seqs[first_index:] 31 | fill_value = 1e-5 32 | pyrimidine_bb_indices = [RNA_ATOMS.index("P"), RNA_ATOMS.index("C4'"), RNA_ATOMS.index("C1'"), RNA_ATOMS.index("N1")] 33 | purine_bb_indices = [RNA_ATOMS.index("P"), RNA_ATOMS.index("C4'"), RNA_ATOMS.index("C1'"), RNA_ATOMS.index("N9")] 34 | backbone_coords = get_backbone_coords(na_atoms, na_seqs, pyrimidine_bb_indices, purine_bb_indices, fill_value) 35 | pos_atoms[data['identifier']==1, :4] = backbone_coords 36 | mask_atoms[data['identifier']==1, :4] = (backbone_coords != fill_value).sum(-1).bool() 37 | data['pos_atoms'] = pos_atoms 38 | data['mask_atoms'] = mask_atoms 39 | 40 | elif self.resolution == 'backbone+R1': 41 | # For Protein, it's N,CA,C,O,CB; For RNA, it's P, C4', N1/N9, C1' 42 | pos_atoms = torch.zeros([data['pos_heavyatom'].shape[0]] + [5, 3], device=data['pos_heavyatom'].device) 43 | mask_atoms = torch.zeros(data['mask_heavyatom'].shape[0] + [5], device=data['mask_heavyatom'].device).bool() 44 | pos_atoms[data['identifier']==0] = data['pos_heavyatom'][data['identifier']==0, :5] 45 | mask_atoms[data['identifier']==0] = data['mask_heavyatom'][data['identifier']==0, :5] 46 | 47 | na_atoms = data['pos_heavyatom'][data['identifier']==1, 14:] 48 | na_seqs = data['seq'] 49 | indices = torch.nonzero(data['identifier']) 50 | first_index = indices[0].item() if len(indices) > 0 else -1 51 | na_seqs = na_seqs[first_index:] 52 | fill_value = 1e-5 53 | pyrimidine_bb_indices = [RNA_ATOMS.index("P"), RNA_ATOMS.index("C4'"), RNA_ATOMS.index("C1'"), RNA_ATOMS.index("C5'"), RNA_ATOMS.index("N1")] 54 | purine_bb_indices = [RNA_ATOMS.index("P"), RNA_ATOMS.index("C4'"), RNA_ATOMS.index("C1'"), RNA_ATOMS.index("C5'"), RNA_ATOMS.index("N9")] 55 | backbone_coords = get_backbone_coords(na_atoms, na_seqs, pyrimidine_bb_indices, purine_bb_indices, fill_value) 56 | pos_atoms[data['identifier']==1, :5] = backbone_coords 57 | mask_atoms[data['identifier']==1, :5] = backbone_coords != fill_value 58 | 59 | data['pos_atoms'] = data['pos_heavyatom'][:, :5] 60 | data['mask_atoms'] = data['mask_heavyatom'][:, :5] 61 | 62 | elif self.resolution == 'C_Only': 63 | data['pos_atoms'] = data['pos_heavyatom'][:, :1] 64 | data['mask_atoms'] = data['mask_heavyatom'][:, :1] 65 | 66 | return data 67 | -------------------------------------------------------------------------------- /data/transforms/mask.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from ._base import register_transform 7 | 8 | 9 | def _extend_mask(mask, chain_nb): 10 | """ 11 | Args: 12 | mask, chain_nb: (L, ). 13 | """ 14 | # Shift right 15 | mask_sr = torch.logical_and( 16 | F.pad(mask[:-1], pad=(1, 0), value=0), 17 | (F.pad(chain_nb[:-1], pad=(1, 0), value=-1) == chain_nb) 18 | ) 19 | # Shift left 20 | mask_sl = torch.logical_and( 21 | F.pad(mask[1:], pad=(0, 1), value=0), 22 | (F.pad(chain_nb[1:], pad=(0, 1), value=-1) == chain_nb) 23 | ) 24 | return torch.logical_or(mask, torch.logical_or(mask_sr, mask_sl)) 25 | 26 | 27 | def _mask_sidechains(pos_atoms, mask_atoms, mask_idx): 28 | """ 29 | Args: 30 | pos_atoms: (L, A, 3) 31 | mask_atoms: (L, A) 32 | """ 33 | pos_atoms = pos_atoms.clone() 34 | pos_atoms[mask_idx, 4:] = 0.0 35 | 36 | mask_atoms = mask_atoms.clone() 37 | mask_atoms[mask_idx, 4:] = False 38 | return pos_atoms, mask_atoms 39 | 40 | 41 | @register_transform('random_mask_amino_acids') 42 | class RandomMaskAminoAcids(object): 43 | 44 | def __init__( 45 | self, 46 | mask_ratio_in_all=0.05, 47 | ratio_in_maskable_limit=0.5, 48 | mask_token=20, 49 | maskable_flag_attr='core_flag', 50 | extend_maskable_flag=False, 51 | mask_ratio_mode='constant', 52 | ): 53 | super().__init__() 54 | self.mask_ratio_in_all = mask_ratio_in_all 55 | self.ratio_in_maskable_limit = ratio_in_maskable_limit 56 | self.mask_token = mask_token 57 | self.maskable_flag_attr = maskable_flag_attr 58 | self.extend_maskable_flag = extend_maskable_flag 59 | assert mask_ratio_mode in ('constant', 'random') 60 | self.mask_ratio_mode = mask_ratio_mode 61 | 62 | def __call__(self, data): 63 | if self.maskable_flag_attr is None: 64 | maskable_flag = torch.ones([data['aa'].size(0), ], dtype=torch.bool) 65 | else: 66 | maskable_flag = data[self.maskable_flag_attr] 67 | if self.extend_maskable_flag: 68 | maskable_flag = _extend_mask(maskable_flag, data['chain_nb']) 69 | 70 | num_masked_max = math.ceil(self.mask_ratio_in_all * data['aa'].size(0)) 71 | if self.mask_ratio_mode == 'random': 72 | num_masked = random.randint(1, num_masked_max) 73 | else: 74 | num_masked = num_masked_max 75 | mask_idx = torch.multinomial( 76 | maskable_flag.float() / maskable_flag.sum(), 77 | num_samples=num_masked, 78 | ) 79 | mask_idx = mask_idx[:math.ceil(self.ratio_in_maskable_limit * maskable_flag.sum().item())] 80 | 81 | aa_masked = data['aa'].clone() 82 | aa_masked[mask_idx] = self.mask_token 83 | data['aa_true'] = data['aa'] 84 | data['aa_masked'] = aa_masked 85 | 86 | data['pos_atoms'], data['mask_atoms'] = _mask_sidechains( 87 | data['pos_atoms'], data['mask_atoms'], mask_idx 88 | ) 89 | 90 | return data 91 | 92 | 93 | @register_transform('mask_selected_amino_acids') 94 | class MaskSelectedAminoAcids(object): 95 | 96 | def __init__(self, select_attr, mask_token=20): 97 | super().__init__() 98 | self.select_attr = select_attr 99 | self.mask_token = mask_token 100 | 101 | def __call__(self, data): 102 | mask_flag = (data[self.select_attr] > 0) 103 | 104 | aa_masked = data['aa'].clone() 105 | aa_masked[mask_flag] = self.mask_token 106 | data['aa_true'] = data['aa'] 107 | data['aa_masked'] = aa_masked 108 | 109 | data['pos_atoms'], data['mask_atoms'] = _mask_sidechains( 110 | data['pos_atoms'], data['mask_atoms'], mask_flag 111 | ) 112 | 113 | return data 114 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | datasets/* 165 | .vscode 166 | *.whl 167 | weights 168 | RiNALMo 169 | outputs 170 | cache 171 | wandb 172 | *.pdb 173 | # code_play.ipynb 174 | *.ipynb 175 | hugging_face -------------------------------------------------------------------------------- /pl_modules/data_module.py: -------------------------------------------------------------------------------- 1 | from data.sequence_dataset import CustomSeqCollate 2 | from data.structure_dataset import CustomStructCollate 3 | from data.pri30k_dataset import PRI30kStructCollate 4 | from data import DataRegister 5 | import pytorch_lightning as pl 6 | import diskcache 7 | import pandas as pd 8 | from torch.utils.data import DataLoader 9 | from torch_geometric.loader import DataLoader as GraphLoader 10 | 11 | def get_dataset(data_args:dict=None): 12 | register = DataRegister() 13 | dataset_cls = register[data_args.dataset_type] 14 | return dataset_cls 15 | 16 | def get_collate(dataset_type): 17 | collate_dict = {'sequence_dataset': CustomSeqCollate, 18 | 'structure_dataset': CustomStructCollate, 19 | 'pri30k_dataset': PRI30kStructCollate, 20 | } 21 | return collate_dict[dataset_type] 22 | 23 | class DataModule(pl.LightningDataModule): 24 | def __init__(self, 25 | df_path='', 26 | col_group='fold_0', 27 | batch_size=32, 28 | num_workers=0, 29 | pin_memory=True, 30 | cache_dir=None, 31 | strategy='separate', 32 | dataset_args=None, 33 | **kwargs): 34 | super().__init__() 35 | self.df_path = df_path 36 | self.col_group=col_group 37 | self.batch_size=batch_size 38 | self.num_workers=num_workers 39 | self.pin_memory=pin_memory 40 | self.cache_dir=cache_dir 41 | self.strategy=strategy 42 | self.dataset_args=dataset_args 43 | # print("Dataset Args:", dataset_args) 44 | 45 | def setup(self, stage=None): 46 | if self.cache_dir is None: 47 | cache = None 48 | else: 49 | print("Using diskcache at {}.".format(self.cache_dir)) 50 | cache = diskcache.Cache(directory=self.cache_dir, eviction_policy='none') 51 | 52 | df = pd.read_csv(self.df_path) 53 | df_train = df[df[self.col_group].isin(['train'])] 54 | df_val = df[df[self.col_group].isin(['val'])] 55 | df_test = df[df[self.col_group].isin(['test'])] 56 | dataset_cls = get_dataset(self.dataset_args) 57 | self.train_dataset = dataset_cls(df_train, **self.dataset_args, diskcache=cache) 58 | self.val_dataset = dataset_cls(df_val, **self.dataset_args, diskcache=cache) 59 | 60 | 61 | if len(df_test) > 0 : 62 | print(f"Using Test Fold to test the model!") 63 | self.test_dataset = dataset_cls(df_test, **self.dataset_args, diskcache=cache) 64 | else: 65 | print(f"Using Validation Fold {self.col_group} to test the model!") 66 | self.test_dataset = dataset_cls(df_val, **self.dataset_args, diskcache=cache) 67 | 68 | def train_dataloader(self): 69 | if self.dataset_args.dataset_type != 'graph_dataset': 70 | collate = get_collate(self.dataset_args.dataset_type) 71 | return DataLoader( 72 | self.train_dataset, 73 | batch_size=self.batch_size, 74 | shuffle=True, 75 | num_workers=self.num_workers, 76 | pin_memory=self.pin_memory, 77 | persistent_workers=self.num_workers > 0, 78 | collate_fn=collate(strategy=self.strategy), 79 | ) 80 | else: 81 | return GraphLoader( 82 | self.train_dataset, 83 | batch_size=self.batch_size, 84 | shuffle=True, 85 | num_workers=self.num_workers, 86 | pin_memory=self.pin_memory, 87 | persistent_workers=self.num_workers > 0, 88 | ) 89 | 90 | def val_dataloader(self): 91 | if self.dataset_args.dataset_type != 'graph_dataset': 92 | collate = get_collate(self.dataset_args.dataset_type) 93 | return DataLoader( 94 | self.val_dataset, 95 | batch_size=self.batch_size, 96 | shuffle=False, 97 | num_workers=self.num_workers, 98 | pin_memory=self.pin_memory, 99 | persistent_workers=self.num_workers > 0, 100 | collate_fn=collate(strategy=self.strategy), 101 | ) 102 | else: 103 | return GraphLoader( 104 | self.val_dataset, 105 | batch_size=self.batch_size, 106 | shuffle=True, 107 | num_workers=self.num_workers, 108 | pin_memory=self.pin_memory, 109 | persistent_workers=self.num_workers > 0, 110 | ) 111 | 112 | 113 | def test_dataloader(self): 114 | if self.dataset_args.dataset_type != 'graph_dataset': 115 | collate = get_collate(self.dataset_args.dataset_type) 116 | return DataLoader( 117 | self.test_dataset, 118 | batch_size=self.batch_size, 119 | shuffle=False, 120 | num_workers=self.num_workers, 121 | pin_memory=self.pin_memory, 122 | persistent_workers=self.num_workers > 0, 123 | collate_fn=collate(strategy=self.strategy), 124 | ) 125 | else: 126 | return GraphLoader( 127 | self.test_dataset, 128 | batch_size=self.batch_size, 129 | shuffle=True, 130 | num_workers=self.num_workers, 131 | pin_memory=self.pin_memory, 132 | persistent_workers=self.num_workers > 0, 133 | ) 134 | 135 | -------------------------------------------------------------------------------- /models/encoders/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def mask_zero(mask, value): 6 | return torch.where(mask, value, torch.zeros_like(value)) 7 | 8 | 9 | class DistanceToBins(nn.Module): 10 | 11 | def __init__(self, dist_min=0.0, dist_max=20.0, num_bins=64, use_onehot=False): 12 | super().__init__() 13 | self.dist_min = dist_min 14 | self.dist_max = dist_max 15 | self.num_bins = num_bins 16 | self.use_onehot = use_onehot 17 | 18 | if use_onehot: 19 | offset = torch.linspace(dist_min, dist_max, self.num_bins) 20 | else: 21 | offset = torch.linspace(dist_min, dist_max, self.num_bins - 1) # 1 overflow flag 22 | self.coeff = -0.5 / ((offset[1] - offset[0]) * 0.2).item() ** 2 # `*0.2`: makes it not too blurred 23 | self.register_buffer('offset', offset) 24 | 25 | @property 26 | def out_channels(self): 27 | return self.num_bins 28 | 29 | def forward(self, dist, dim, normalize=True): 30 | """ 31 | Args: 32 | dist: (N, *, 1, *) 33 | Returns: 34 | (N, *, num_bins, *) 35 | """ 36 | assert dist.size()[dim] == 1 37 | offset_shape = [1] * len(dist.size()) 38 | offset_shape[dim] = -1 39 | 40 | if self.use_onehot: 41 | diff = torch.abs(dist - self.offset.view(*offset_shape)) # (N, *, num_bins, *) 42 | bin_idx = torch.argmin(diff, dim=dim, keepdim=True) # (N, *, 1, *) 43 | y = torch.zeros_like(diff).scatter_(dim=dim, index=bin_idx, value=1.0) 44 | else: 45 | overflow_symb = (dist >= self.dist_max).float() # (N, *, 1, *) 46 | y = dist - self.offset.view(*offset_shape) # (N, *, num_bins-1, *) 47 | y = torch.exp(self.coeff * torch.pow(y, 2)) # (N, *, num_bins-1, *) 48 | y = torch.cat([y, overflow_symb], dim=dim) # (N, *, num_bins, *) 49 | if normalize: 50 | y = y / y.sum(dim=dim, keepdim=True) 51 | 52 | return y 53 | 54 | 55 | class PositionalEncoding(nn.Module): 56 | 57 | def __init__(self, num_funcs=6): 58 | super().__init__() 59 | self.num_funcs = num_funcs 60 | self.register_buffer('freq_bands', 2.0 ** torch.linspace(0.0, num_funcs - 1, num_funcs)) 61 | 62 | def get_out_dim(self, in_dim): 63 | return in_dim * (2 * self.num_funcs + 1) 64 | 65 | def forward(self, x): 66 | """ 67 | Args: 68 | x: (..., d). 69 | """ 70 | shape = list(x.shape[:-1]) + [-1] 71 | x = x.unsqueeze(-1) # (..., d, 1) 72 | code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1) 73 | code = code.reshape(shape) 74 | return code 75 | 76 | 77 | class AngularEncoding(nn.Module): 78 | 79 | def __init__(self, num_funcs=3): 80 | super().__init__() 81 | self.num_funcs = num_funcs 82 | self.register_buffer('freq_bands', torch.FloatTensor( 83 | [i + 1 for i in range(num_funcs)] + [1. / (i + 1) for i in range(num_funcs)] 84 | )) 85 | 86 | def get_out_dim(self, in_dim): 87 | return in_dim * (1 + 2 * 2 * self.num_funcs) 88 | 89 | def forward(self, x): 90 | """ 91 | Args: 92 | x: (..., d). 93 | """ 94 | shape = list(x.shape[:-1]) + [-1] 95 | x = x.unsqueeze(-1) # (..., d, 1) 96 | code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1) 97 | code = code.reshape(shape) 98 | return code 99 | 100 | 101 | class LayerNorm(nn.Module): 102 | 103 | def __init__(self, 104 | normal_shape, 105 | gamma=True, 106 | beta=True, 107 | epsilon=1e-10): 108 | """Layer normalization layer 109 | See: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf) 110 | :param normal_shape: The shape of the input tensor or the last dimension of the input tensor. 111 | :param gamma: Add a scale parameter if it is True. 112 | :param beta: Add an offset parameter if it is True. 113 | :param epsilon: Epsilon for calculating variance. 114 | """ 115 | super().__init__() 116 | if isinstance(normal_shape, int): 117 | normal_shape = (normal_shape,) 118 | else: 119 | normal_shape = (normal_shape[-1],) 120 | self.normal_shape = torch.Size(normal_shape) 121 | self.epsilon = epsilon 122 | if gamma: 123 | self.gamma = nn.Parameter(torch.Tensor(*normal_shape)) 124 | else: 125 | self.register_parameter('gamma', None) 126 | if beta: 127 | self.beta = nn.Parameter(torch.Tensor(*normal_shape)) 128 | else: 129 | self.register_parameter('beta', None) 130 | self.reset_parameters() 131 | 132 | def reset_parameters(self): 133 | if self.gamma is not None: 134 | self.gamma.data.fill_(1) 135 | if self.beta is not None: 136 | self.beta.data.zero_() 137 | 138 | def forward(self, x): 139 | mean = x.mean(dim=-1, keepdim=True) 140 | var = ((x - mean) ** 2).mean(dim=-1, keepdim=True) 141 | std = (var + self.epsilon).sqrt() 142 | y = (x - mean) / std 143 | if self.gamma is not None: 144 | y *= self.gamma 145 | if self.beta is not None: 146 | y += self.beta 147 | return y 148 | 149 | def extra_repr(self): 150 | return 'normal_shape={}, gamma={}, beta={}, epsilon={}'.format( 151 | self.normal_shape, self.gamma is not None, self.beta is not None, self.epsilon, 152 | ) 153 | -------------------------------------------------------------------------------- /models/components/coformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from models.components.attention import MultiHeadSelfAttention, FlashMultiHeadSelfAttention 5 | 6 | 7 | from models.components.rope import RotaryPositionEmbedding 8 | 9 | import torch.utils.checkpoint as checkpoint 10 | 11 | class SwiGLU(nn.Module): 12 | """ 13 | Swish-Gated Linear Unit 14 | https://arxiv.org/pdf/2002.05202v1.pdf 15 | In the cited paper beta is set to 1 and is not learnable; 16 | but by the Swish definition it is learnable parameter otherwise 17 | it is SiLU activation function (https://paperswithcode.com/method/swish) 18 | """ 19 | def __init__(self, size_in, size_out, beta_is_learnable=True, bias=True): 20 | """ 21 | Args: 22 | size_in: input embedding dimension 23 | size_out: output embedding dimension 24 | beta_is_learnable: whether beta is learnable or set to 1, learnable by default 25 | bias: whether use bias term, enabled by default 26 | """ 27 | super().__init__() 28 | self.linear = nn.Linear(size_in, size_out, bias=bias) 29 | self.linear_gate = nn.Linear(size_in, size_out, bias=bias) 30 | self.beta = nn.Parameter(torch.ones(1), requires_grad=beta_is_learnable) 31 | 32 | def forward(self, x): 33 | linear_out = self.linear(x) 34 | swish_out = linear_out * torch.sigmoid(self.beta * linear_out) 35 | return swish_out * self.linear_gate(x) 36 | 37 | 38 | class CoFormer(nn.Module): 39 | def __init__(self, embed_dim, pair_dim, num_blocks, num_heads, use_rot_emb=True, attn_qkv_bias=False, transition_dropout=0.0, attention_dropout=0.0, residual_dropout=0.0, transition_factor=4, use_flash_attn=False): 40 | super().__init__() 41 | 42 | self.use_flash_attn = use_flash_attn 43 | 44 | self.blocks = nn.ModuleList( 45 | [ 46 | TransformerBlock(embed_dim, pair_dim, num_heads, use_rot_emb, attn_qkv_bias, transition_dropout, attention_dropout, residual_dropout, transition_factor, use_flash_attn) for _ in range(num_blocks) 47 | ] 48 | ) 49 | 50 | self.final_layer_norm = nn.LayerNorm(embed_dim) 51 | self.pair_final_layer_norm = nn.LayerNorm(pair_dim) 52 | 53 | def forward(self, x, struct_embed, key_padding_mask=None, need_attn_weights=False, attn_mask=None): 54 | attn_weights = None 55 | if need_attn_weights: 56 | attn_weights = [] 57 | 58 | for block in self.blocks: 59 | # x, struct_embed, attn = checkpoint.checkpoint( 60 | # block, 61 | # x, 62 | # struct_embed, 63 | # key_padding_mask, 64 | # need_attn_weights, 65 | # use_reentrant=False 66 | # ) 67 | x, struct_embed, attn = block(x, struct_embed, key_padding_mask, attn_mask) 68 | if need_attn_weights: 69 | attn_weights.append(attn) 70 | 71 | x = self.final_layer_norm(x) 72 | struct_embed = self.pair_final_layer_norm(struct_embed) 73 | return x, struct_embed, attn_weights 74 | 75 | class TransformerBlock(nn.Module): 76 | def __init__(self, embed_dim, pair_dim, num_heads, use_rot_emb=True, attn_qkv_bias=False, transition_dropout=0.0, attention_dropout=0.0, residual_dropout=0.0, transition_factor=4, use_flash_attn=False): 77 | super().__init__() 78 | 79 | self.use_flash_attn = use_flash_attn 80 | 81 | if use_flash_attn: 82 | self.mh_attn = FlashMultiHeadSelfAttention(embed_dim, num_heads, attention_dropout, causal=False, use_rot_emb=use_rot_emb, bias=attn_qkv_bias) 83 | else: 84 | self.mh_attn = MultiHeadSelfAttention(embed_dim, pair_dim, num_heads, attention_dropout, use_rot_emb, attn_qkv_bias) 85 | 86 | self.attn_layer_norm = nn.LayerNorm(embed_dim) 87 | 88 | self.transition = nn.Sequential( 89 | SwiGLU(embed_dim, int(2 / 3 * transition_factor * embed_dim), beta_is_learnable=True, bias=True), 90 | nn.Dropout(p=transition_dropout), 91 | nn.Linear(int(2 / 3 * transition_factor * embed_dim), embed_dim, bias=True), 92 | ) 93 | 94 | 95 | # self.transition_struct = nn.Sequential( 96 | # SwiGLU(embed_dim, int(2 / 3 * transition_factor * embed_dim), beta_is_learnable=True, bias=True), 97 | # nn.Dropout(p=transition_dropout), 98 | # nn.Linear(int(2 / 3 * transition_factor * embed_dim), embed_dim, bias=True), 99 | # ) 100 | 101 | self.out_layer_norm = nn.LayerNorm(embed_dim) 102 | self.pair_layer_norm = nn.LayerNorm(pair_dim) 103 | 104 | self.residual_dropout_1 = nn.Dropout(p=residual_dropout) 105 | self.residual_dropout_2 = nn.Dropout(p=residual_dropout) 106 | 107 | def forward(self, x, struct_embed, key_padding_mask=None, attn_mask=None): 108 | x = self.attn_layer_norm(x) 109 | # if self.use_flash_attn: 110 | # mh_out, attn = self.mh_attn(x, key_padding_mask=key_padding_mask, return_attn_probs=need_attn_weights) 111 | # else: 112 | # Temporarily unable flash_attn 113 | mh_out, struct_out, attn = self.mh_attn(x, struct_embed, attn_mask, key_pad_mask=key_padding_mask) 114 | x = x + self.residual_dropout_1(mh_out) 115 | struct_embed = struct_embed + self.residual_dropout_1(struct_out) 116 | residual = x 117 | struct_residual = struct_embed 118 | x = self.out_layer_norm(x) 119 | struct_embed = self.pair_layer_norm(struct_embed) 120 | x = residual + self.residual_dropout_2(self.transition(x)) 121 | struct_embed = struct_residual + self.residual_dropout_2(struct_embed) 122 | return x, struct_embed, attn -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🥥 CoPRA 2 | 3 |

4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |

20 | This is the official implementation of CoPRA: Bridging Cross-domain Pretrained Sequence Models with Complex Structures for Protein-RNA Binding Affinity Prediction (AAAI 2025) 21 | 22 | Overview of CoPRA 23 | 24 | 25 | 26 | CoPRA is a state-of-the-art predictor of protein-RNA binding affinity. The framework of CoPRA is based on a protein language model and an RNA-language model, with complex structure as input. The model was pre-trained on the PRI30k dataset via a bi-scope stratege and fine-tuned on PRA310. CoPRA can also be redirected to predict mutation effects, showing its strong per-structure prediction performance on mCSM_RNA dataset. Please see more details in [our paper](https://arxiv.org/abs/2409.03773). 27 | 28 | Please do not hesitate to contact us or create an issue/PR if you have any questions or suggestions! 29 | 30 | ## 🛠️ Installation 31 | 32 | **Step 1**. Clone this repository and setup the environment. We recommend you to install the dependencies via the fast package management tool [mamba](https://mamba.readthedocs.io/en/latest/mamba-installation.html) (you can also replace the command 'mamba' with 'conda' to install them). Generally, CoPRA works with Python 3.10.14 and PyTorch version 2.1.2. 33 | ``` 34 | git@github.com:hanrthu/CoPRA.git 35 | cd CoPRA 36 | mamba env create -f environment.yml 37 | ``` 38 | 39 | **Step 2**. Install flash-attn and rinalmo with the following command, you may also need to download Rinalmo-650M model and place it at `./weights` folder of this repo. 40 | ``` 41 | # Download flash-attn-2.6.3 wheel file at https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl 42 | pip install flash_attn-2.6.3+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl 43 | git clone git@github.com:lbcb-sci/RiNALMo.git 44 | cd RiNALMo 45 | pip install -e . 46 | ``` 47 | ## 📖 Datasets and model weights for Protein-RNA binding affinity prediction 48 | Here, we first provide our proposed datasets, including PRA310, PRA201 and PRI30k together with an mCSM_RNA dataset, you can easily access them through 🤗Huggingface: [/Jesse7/CoPRA_data](https://huggingface.co/datasets/Jesse7/CoPRA_data/tree/main). The only difference between PRA201 and PRA310 are the selected samples, thus the PRA201 labels and splits are in PRA310/splits/PRA201.csv. Download these datasets and place them at `./datasets` folder. 49 | 50 | The number of samples of the original dataset is shown below, we take PRA as the abbreviation of Protein-RNA binding affinity: 51 | 52 | | Dataset | Type | Size | 53 | | :---: | :---: | :---: | 54 | | PRA310 | PRA | 310 | 55 | | PRA201 | PRA (pair-only) | 201 | 56 | | PRI30k | Unsupervised complexes | 30006 | 57 | | mCSM-RNA | Mutation effect on PRA | 79 | 58 | 59 | 60 | We also provide a five-fold model checkpoints after pretraining Co-Former with PRI30k and finetune it with PRA310, and they can also be downloaded through 🤗Huggingface: [/Jesse7/CoPRA](https://huggingface.co/Jesse7/CoPRA). This repository also contains a pretrained RiNALMo-650M weights. Download these weights at place them at `./weights` folder. 61 | 62 | The performance of 5-fold cross validation on PRA310 reaches state-of-the-art, and here is the comparison: 63 | 64 | Results on PRA 65 | 66 | 67 | 68 | ## 🚀 Training on the protein-RNA datasets 69 | 70 | **Note1:** It is normal that the first epoch for training on a new dataset is relatively slow, because we need to conduct the caching procedure. 71 | 72 | **Note2:** We also support LoRA tuning and all-param tuning. For LoRA tuning, just specify `lora_tune: true` in `./config/models/copra.yml`. For all-param tuning, just specify `fix_lms: false` in `./config/models/copra.yml`. 73 | 74 | ### Run 5-fold inference on PRA310 75 | ``` 76 | python run.py test dG --model_config ./config/models/copra.yml --data_config ./config/datasets/PRA310.yml --run_config ./config/runs/test_basic.yml 77 | ``` 78 | 79 | ### Run finetune on PRA310 80 | ``` 81 | python run.py finetune dG --model_config ./config/models/copra.yml --data_config ./config/datasets/PRA310.yml --run_config ./config/runs/finetune_struct.yml 82 | ``` 83 | 84 | ### Run finetune on PRA201 85 | ``` 86 | python run.py finetune dG --model_config ./config/models/copra.yml --data_config ./config/datasets/PRA201.yml --run_config ./config/runs/finetune_struct.yml 87 | ``` 88 | 89 | ### Run Bi-scope Pre-training on PRI30k 90 | ``` 91 | python run.py finetune pretune --model_config ./config/models/copra.yml --data_config ./config/datasets/biolip.yml --run_config ./config/runs/pretune_struct.yml 92 | ``` 93 | After pretraining, you can continue to finetune on a new dataset with the finetuning scripts and the specification of ckpt for the pretrained model in config/runs/finetune_struct.yml 94 | 95 | ## 🚀 Zero-shot Blind-test on the protein-RNA mutation effect datasets 96 | 97 | ``` 98 | python run.py test ddG --model_config ./config/models/copra.yml --data_config ./config/datasets/blindtest.yml --run_config ./config/runs/zero_shot_blindtest.yml 99 | ``` 100 | 101 | ## 🖌️ Citation 102 | If you find our repo useful, please kindly consider citing: 103 | ``` 104 | @article{han2024copra, 105 | title={CoPRA: Bridging Cross-domain Pretrained Sequence Models with Complex Structures for Protein-RNA Binding Affinity Prediction}, 106 | author={Han, Rong and Liu, Xiaohong and Pan, Tong and Xu, Jing and Wang, Xiaoyu and Lan, Wuyang and Li, Zhenyu and Wang, Zixuan and Song, Jiangning and Wang, Guangyu and others}, 107 | journal={arXiv preprint arXiv:2409.03773}, 108 | year={2024} 109 | } 110 | ``` -------------------------------------------------------------------------------- /models/components/valina_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from models.components.valina_attn import MultiHeadSelfAttention, FlashMultiHeadSelfAttention 6 | 7 | import torch.utils.checkpoint as checkpoint 8 | 9 | class TokenDropout(nn.Module): 10 | def __init__( 11 | self, 12 | active: bool, 13 | mask_ratio: float, 14 | mask_tkn_prob: float, 15 | mask_tkn_idx: int, 16 | pad_tkn_idx: int, 17 | ): 18 | super().__init__() 19 | 20 | self.active = active 21 | 22 | self.mask_ratio_train = mask_ratio * mask_tkn_prob 23 | 24 | self.mask_tkn_idx = mask_tkn_idx 25 | self.pad_tkn_idx = pad_tkn_idx 26 | 27 | def forward(self, x, tokens): 28 | if self.active: 29 | pad_mask = tokens.eq(self.pad_tkn_idx) 30 | src_lens = (~pad_mask).sum(dim=-1) 31 | 32 | x = torch.where((tokens == self.mask_tkn_idx).unsqueeze(dim=-1), 0.0, x) 33 | mask_ratio_observed = (tokens == self.mask_tkn_idx).sum(dim=-1) / src_lens 34 | x = x * (1 - self.mask_ratio_train) / (1 - mask_ratio_observed[..., None, None]) 35 | 36 | return x 37 | 38 | class Transformer(nn.Module): 39 | def __init__(self, embed_dim, num_blocks, num_heads, use_rot_emb=True, attn_qkv_bias=False, transition_dropout=0.0, attention_dropout=0.0, residual_dropout=0.0, transition_factor=4, use_flash_attn=False): 40 | super().__init__() 41 | 42 | self.use_flash_attn = use_flash_attn 43 | 44 | self.blocks = nn.ModuleList( 45 | [ 46 | TransformerBlock(embed_dim, num_heads, use_rot_emb, attn_qkv_bias, transition_dropout, attention_dropout, residual_dropout, transition_factor, use_flash_attn) for _ in range(num_blocks) 47 | ] 48 | ) 49 | 50 | self.final_layer_norm = nn.LayerNorm(embed_dim) 51 | 52 | def forward(self, x, key_padding_mask=None, need_attn_weights=False): 53 | attn_weights = None 54 | if need_attn_weights: 55 | attn_weights = [] 56 | 57 | for i, block in enumerate(self.blocks): 58 | # x, attn = checkpoint.checkpoint( 59 | # block, 60 | # x, 61 | # key_padding_mask=key_padding_mask, 62 | # need_attn_weights=need_attn_weights, 63 | # use_reentrant=False 64 | # ) 65 | 66 | x, attn = block(x, key_padding_mask, need_attn_weights) 67 | if torch.isnan(x).any(): 68 | print("Find Nan in X! at layer {}".format(i), x.shape) 69 | if need_attn_weights: 70 | attn_weights.append(attn) 71 | 72 | x = self.final_layer_norm(x) 73 | 74 | return x, attn_weights 75 | 76 | class SwiGLU(nn.Module): 77 | """ 78 | Swish-Gated Linear Unit 79 | https://arxiv.org/pdf/2002.05202v1.pdf 80 | In the cited paper beta is set to 1 and is not learnable; 81 | but by the Swish definition it is learnable parameter otherwise 82 | it is SiLU activation function (https://paperswithcode.com/method/swish) 83 | """ 84 | def __init__(self, size_in, size_out, beta_is_learnable=True, bias=True): 85 | """ 86 | Args: 87 | size_in: input embedding dimension 88 | size_out: output embedding dimension 89 | beta_is_learnable: whether beta is learnable or set to 1, learnable by default 90 | bias: whether use bias term, enabled by default 91 | """ 92 | super().__init__() 93 | self.linear = nn.Linear(size_in, size_out, bias=bias) 94 | self.linear_gate = nn.Linear(size_in, size_out, bias=bias) 95 | self.beta = nn.Parameter(torch.ones(1), requires_grad=beta_is_learnable) 96 | 97 | def forward(self, x): 98 | linear_out = self.linear(x) 99 | swish_out = linear_out * torch.sigmoid(self.beta * linear_out) 100 | return swish_out * self.linear_gate(x) 101 | 102 | class TransformerBlock(nn.Module): 103 | def __init__(self, embed_dim, num_heads, use_rot_emb=True, attn_qkv_bias=False, transition_dropout=0.0, attention_dropout=0.0, residual_dropout=0.0, transition_factor=4, use_flash_attn=False): 104 | super().__init__() 105 | 106 | self.use_flash_attn = use_flash_attn 107 | 108 | if use_flash_attn: 109 | self.mh_attn = FlashMultiHeadSelfAttention(embed_dim, num_heads, attention_dropout, causal=False, use_rot_emb=use_rot_emb, bias=attn_qkv_bias) 110 | else: 111 | self.mh_attn = MultiHeadSelfAttention(embed_dim, num_heads, attention_dropout, use_rot_emb, attn_qkv_bias) 112 | 113 | self.attn_layer_norm = nn.LayerNorm(embed_dim) 114 | 115 | self.transition = nn.Sequential( 116 | SwiGLU(embed_dim, int(2 / 3 * transition_factor * embed_dim), beta_is_learnable=True, bias=True), 117 | nn.Dropout(p=transition_dropout), 118 | nn.Linear(int(2 / 3 * transition_factor * embed_dim), embed_dim, bias=True), 119 | ) 120 | self.out_layer_norm = nn.LayerNorm(embed_dim) 121 | 122 | self.residual_dropout_1 = nn.Dropout(p=residual_dropout) 123 | self.residual_dropout_2 = nn.Dropout(p=residual_dropout) 124 | 125 | def forward(self, input, key_padding_mask=None, need_attn_weights=None): 126 | x = self.attn_layer_norm(input) 127 | if self.use_flash_attn: 128 | mh_out, attn = self.mh_attn(x, key_padding_mask=key_padding_mask, return_attn_probs=need_attn_weights) 129 | else: 130 | mh_out, attn = self.mh_attn(x, attn_mask=None, key_pad_mask=key_padding_mask) 131 | x = x + self.residual_dropout_1(mh_out) 132 | 133 | residual = x 134 | x = self.out_layer_norm(x) 135 | x = residual + self.residual_dropout_2(self.transition(x)) 136 | if torch.isnan(x).any(): 137 | print("Find Nan in X!", x.shape) 138 | return x, attn 139 | 140 | class MaskedLanguageModelHead(nn.Module): 141 | def __init__(self, embed_dim, alphabet_size): 142 | super().__init__() 143 | 144 | self.linear1 = nn.Linear(embed_dim, embed_dim) 145 | self.layer_norm = nn.LayerNorm(embed_dim) 146 | self.linear2 = nn.Linear(embed_dim, alphabet_size) 147 | 148 | def forward(self, x): 149 | x = self.linear1(x) 150 | x = F.gelu(x) 151 | x = self.layer_norm(x) 152 | x = self.linear2(x) 153 | 154 | return x 155 | -------------------------------------------------------------------------------- /data/complex.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import dataclasses 3 | import sys 4 | from pathlib import Path 5 | import io 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | from dataclasses import dataclass 9 | sys.path.append('/home/CoPRA/') 10 | import data.protein.proteins as proteins 11 | from data.protein.atom_convert import atom37_to_atom14 12 | from data.protein.proteins import chains_from_cif_string, chains_from_pdb_string 13 | import data.rna.rnas as rnas 14 | 15 | rna_residues = ['A', 'G', 'C', 'U'] 16 | protein_residues = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL'] 17 | SUPER_PROT_IDX = 27 18 | SUPER_RNA_IDX = 28 19 | SUPER_CPLX_IDX = 29 20 | SUPER_CHAIN_IDX = 4 21 | PADDING_NODE_IDX = 26 22 | 23 | @dataclass 24 | class ComplexInput: 25 | seq: str # L 26 | mask: np.ndarray # (L, ) 27 | restype: np.ndarray # (L, ) # In total 21 + 5 = 26 types, including '_' and 'X' 28 | res_nb: np.ndarray # (L, ) 29 | prot_seqs: list 30 | na_seqs: list 31 | atom_mask: np.ndarray # (L, 37 + 27) 32 | atom_positions: np.ndarray #(L, 37 + 27, 3) 33 | 34 | atom41_mask: np.ndarray # (L, 14 + 27) 35 | atom41_positions: np.ndarray #(L, 14 + 27, 3) 36 | 37 | identifier: np.ndarray #(L, ), to identify rna or protein 38 | chainid: np.ndarray # (L, ), to identify chain of the complex 39 | 40 | @classmethod 41 | def from_path(self, path, valid_rna_chains=None, valid_prot_chains=None): 42 | if isinstance(path, io.IOBase): 43 | file_string = path.read() 44 | else: 45 | # print(path) 46 | path = Path(path) 47 | file_string = path.read_text() 48 | if valid_prot_chains is None or valid_rna_chains is None: 49 | valid_prot_chains = [] 50 | valid_rna_chains = [] 51 | if '.pdb' in str(path): 52 | chains = chains_from_pdb_string(file_string) 53 | elif '.cif' in str(path): 54 | chains = chains_from_cif_string(file_string) 55 | 56 | for chain in chains: 57 | for residue in chain: 58 | if residue.get_resname() in protein_residues: 59 | valid_prot_chains.append(chain.get_full_id()[2]) 60 | break 61 | if residue.get_resname() in rna_residues: 62 | valid_rna_chains.append(chain.get_full_id()[2]) 63 | break 64 | # valid_prot_chains = list(set(valid_prot_chains)) 65 | # print(valid_prot_chains, valid_rna_chains) 66 | protein = proteins.ProteinInput.from_path(path, with_angles=False, return_dict=True, valid_chains=valid_prot_chains) 67 | rna = rnas.RNAInput.from_path(path, valid_rna_chains) 68 | complex_dict = complex_merge([protein[chain] for chain in valid_prot_chains], [rna[chain] for chain in valid_rna_chains]) 69 | return self(**complex_dict) 70 | 71 | @property 72 | def length(self): 73 | return len(self.seq) 74 | 75 | def __repr__(self): 76 | return self.__str__() 77 | 78 | def __str__(self): 79 | texts = [] 80 | texts += [f'seq: {self.seq}'] 81 | texts += [f'length: {len(self.seq)}'] 82 | texts += [f"mask: {''.join(self.mask.astype('int').astype('str'))}"] 83 | if self.chainid is not None: 84 | texts += [f"chainid: {''.join(self.chainid.astype('int').astype('str'))}"] 85 | texts += [f"identifier: {''.join(self.identifier.astype('int').astype('str'))}"] 86 | names = [ 87 | 'restype', 88 | 'atom_mask', 89 | 'atom_positions', 90 | ] 91 | for name in names: 92 | value = getattr(self, name) 93 | if value is None: 94 | text = f'{name}: None' 95 | else: 96 | text = f'{name}: {value.shape}' 97 | texts += [text] 98 | text = ', \n '.join(texts) 99 | text = f'Protein-RNA Complex(\n {text}\n)' 100 | return text 101 | 102 | 103 | def complex_merge(protein, rna): 104 | assert len(protein) > 0 105 | assert len(rna) > 0 106 | 107 | p_lengths = [p.length for i, p in enumerate(protein)] 108 | r_lengths = [r.length for i, r in enumerate(rna)] 109 | lengths = p_lengths + r_lengths 110 | prot_list = [] 111 | na_list = [] 112 | seq = "".join([item.seq for item in protein + rna]) 113 | # identifier = np.array([0] * len(protein) + [1] * len(rna), dtype=np.int32) 114 | mask = np.concatenate([item.mask for item in protein + rna] ) 115 | chain_arr = np.concatenate([[i] * p for i, p in enumerate(lengths)]).astype('int') 116 | res_nb = np.zeros([len(seq)]) 117 | restype = np.zeros([len(seq)]) 118 | atom_positions = np.zeros([len(seq), 37 + 27, 3]) 119 | atom41_positions = np.zeros([len(seq), 14 + 27, 3]) 120 | atom41_masks = np.zeros([len(seq), 14 + 27]) 121 | atom_masks = np.zeros([len(seq), 37 + 27]) 122 | identifier = np.zeros([len(seq)]) 123 | curr_idx = 0 124 | for item in protein: 125 | prot_list.append(item.seq) 126 | res_nb[curr_idx: curr_idx+item.length] = item.res_nb 127 | restype[curr_idx: curr_idx+item.length] = item.aatype 128 | identifier[curr_idx: curr_idx+item.length] = 0 129 | atom_positions[curr_idx: curr_idx+item.length, :37, :] = item.atom_positions 130 | atom14, mask_14, arrs = atom37_to_atom14(item.aatype, item.atom_positions, [item.atom_mask]) 131 | mask_14 = arrs[0] * mask_14 132 | atom41_positions[curr_idx: curr_idx+item.length, :14, :] = atom14 133 | atom41_masks[curr_idx: curr_idx+item.length, :14] = mask_14 134 | atom_masks[curr_idx: curr_idx+item.length, :37] = item.atom_mask 135 | curr_idx += item.length 136 | for item in rna: 137 | na_list.append(item.seq) 138 | res_nb[curr_idx: curr_idx+item.length] = item.res_nb 139 | restype[curr_idx: curr_idx+item.length] = item.basetype + 21 140 | identifier[curr_idx: curr_idx+item.length] = 1 141 | atom_positions[curr_idx: curr_idx+item.length, 37:, :] = item.atom_positions 142 | atom41_positions[curr_idx: curr_idx+item.length, 14:, :] = item.atom_positions 143 | atom41_masks[curr_idx: curr_idx+item.length, 14:] = item.atom_mask 144 | atom_masks[curr_idx: curr_idx+item.length, 37:] = item.atom_mask 145 | curr_idx += item.length 146 | 147 | complex_dict = { 148 | 'seq': seq, 149 | 'mask': mask, 150 | 'restype': restype, 151 | 'res_nb': res_nb, 152 | 153 | 'prot_seqs': prot_list, 154 | 'na_seqs': na_list, 155 | 156 | 'atom_mask': atom_masks, 157 | 'atom_positions': atom_positions, 158 | 159 | 'atom41_mask': atom41_masks, 160 | 'atom41_positions': atom41_positions, 161 | 162 | 'identifier': identifier, 163 | 'chainid': chain_arr 164 | } 165 | 166 | return complex_dict 167 | 168 | if __name__ == '__main__': 169 | comp = ComplexInput.from_path('./datasets/PRA310/PDBs/1RPU.pdb') 170 | print("Complex:", comp) 171 | print(comp.atom_positions[1]) -------------------------------------------------------------------------------- /data/rna/sec_struct_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from datetime import datetime 4 | import numpy as np 5 | import wandb 6 | from typing import Any, List, Literal, Optional 7 | 8 | from Bio import SeqIO 9 | from Bio.Seq import Seq 10 | from Bio.SeqRecord import SeqRecord 11 | 12 | import biotite 13 | from biotite.structure.io import load_structure 14 | from biotite.structure import dot_bracket_from_structure 15 | 16 | from data.rna.base_constants import ( 17 | PROJECT_PATH, 18 | X3DNA_PATH, 19 | ETERNAFOLD_PATH, 20 | DOTBRACKET_TO_NUM 21 | ) 22 | 23 | 24 | def pdb_to_sec_struct( 25 | pdb_file_path: str, 26 | sequence: str, 27 | keep_pseudoknots: bool = False, 28 | x3dna_path: str = os.path.join(X3DNA_PATH, "bin/find_pair"), 29 | max_len_for_biotite: int = 1000, 30 | ) -> str: 31 | """ 32 | Get secondary structure in dot-bracket notation from a PDB file. 33 | 34 | Args: 35 | pdb_file_path (str): Path to PDB file. 36 | sequence (str): Sequence of RNA molecule. 37 | keep_pseudoknots (bool, optional): Whether to keep pseudoknots in 38 | secondary structure. Defaults to False. 39 | x3dna_path (str, optional): Path to x3dna find_pair tool. 40 | max_len_for_biotite (int, optional): Maximum length of sequence for 41 | which to use biotite. Otherwise use X3DNA Defaults to 1000. 42 | """ 43 | if len(sequence) < max_len_for_biotite: 44 | try: 45 | # get secondary structure using biotite 46 | atom_array = load_structure(pdb_file_path) 47 | sec_struct = dot_bracket_from_structure(atom_array)[0] 48 | if not keep_pseudoknots: 49 | # replace all characters that are not '.', '(', ')' with '.' 50 | sec_struct = "".join([dotbrac if dotbrac in ['.', '(', ')'] else '.' for dotbrac in sec_struct]) 51 | 52 | except Exception as e: 53 | # biotite fails for very short seqeunces 54 | if "out of bounds for array" not in str(e): raise e 55 | # get secondary structure using x3dna find_pair tool 56 | # does not support pseudoknots 57 | sec_struct = x3dna_to_sec_struct( 58 | pdb_to_x3dna(pdb_file_path, x3dna_path), 59 | sequence 60 | ) 61 | 62 | else: 63 | # get secondary structure using x3dna find_pair tool 64 | # does not support pseudoknots 65 | sec_struct = x3dna_to_sec_struct( 66 | pdb_to_x3dna(pdb_file_path, x3dna_path), 67 | sequence 68 | ) 69 | 70 | return sec_struct 71 | 72 | def pdb_to_x3dna( 73 | pdb_file_path: str, 74 | x3dna_path: str = os.path.join(X3DNA_PATH, "bin/find_pair") 75 | ) -> List[str]: 76 | # Run x3dna find_pair tool 77 | cmd = [ 78 | x3dna_path, 79 | pdb_file_path, 80 | ] 81 | output = subprocess.run(cmd, check=True, capture_output=True).stdout.decode("utf-8") 82 | output = output.split("\n") 83 | 84 | # Delete temporary files 85 | # os.remove("./bestpairs.pdb") 86 | # os.remove("./bp_order.dat") 87 | # os.remove("./col_chains.scr") 88 | # os.remove("./col_helices.scr") 89 | # os.remove("./hel_regions.pdb") 90 | # os.remove("./ref_frames.dat") 91 | 92 | return output 93 | 94 | 95 | def x3dna_to_sec_struct(output: List[str], sequence: str) -> str: 96 | # Secondary structure in dot-bracket notation 97 | num_base_pairs = int(output[3].split()[0]) 98 | sec_struct = ["."] * len(sequence) 99 | for i in range(1, num_base_pairs+1): 100 | line = output[4 + i].split() 101 | start, end = int(line[0]), int(line[1]) 102 | sec_struct[start-1] = "(" 103 | sec_struct[end-1] = ")" 104 | return "".join(sec_struct) 105 | 106 | 107 | def predict_sec_struct( 108 | sequence: Optional[str] = None, 109 | fasta_file_path: Optional[str] = None, 110 | eternafold_path: str = os.path.join(ETERNAFOLD_PATH, "src/contrafold"), 111 | n_samples: int = 1, 112 | ) -> str: 113 | """ 114 | Predict secondary structure using EternaFold. 115 | 116 | Notes: 117 | - EternaFold does not support pseudoknots. 118 | - EternaFold only supports single chains in a fasta file. 119 | - When sampling multiple structures, EternaFold only supports nsamples=100. 120 | 121 | Args: 122 | sequence (str, optional): Sequence of RNA molecule. Defaults to None. 123 | fasta_file_path (str, optional): Path to fasta file. Defaults to None. 124 | eternafold_path (str, optional): Path to EternaFold. Defaults to ETERNAFOLD_PATH env variable. 125 | n_samples (int, optional): Number of samples to take. Defaults to 1. 126 | """ 127 | if sequence is not None: 128 | assert fasta_file_path is None 129 | # Write sequence to temporary fasta file 130 | current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S") 131 | try: 132 | fasta_file_path = os.path.join(wandb.run.dir, f"temp_{current_datetime}.fasta") 133 | except AttributeError: 134 | fasta_file_path = os.path.join(PROJECT_PATH, f"temp_{current_datetime}.fasta") 135 | SeqIO.write( 136 | SeqRecord(Seq(sequence), id="temp"), 137 | fasta_file_path, "fasta" 138 | ) 139 | 140 | # Run EternaFold 141 | if n_samples > 1: 142 | assert n_samples == 100, "EternaFold using subprocess only supports nsamples=100" 143 | cmd = [ 144 | eternafold_path, 145 | "sample", 146 | fasta_file_path, 147 | # f" --nsamples {n_samples}", 148 | # It seems like EternaFold using subprocess can only sample the default nsamples=100... 149 | # Reason: unknown for now 150 | ] 151 | else: 152 | cmd = [ 153 | eternafold_path, 154 | "predict", 155 | fasta_file_path, 156 | ] 157 | 158 | output = subprocess.run(cmd, check=True, capture_output=True).stdout.decode("utf-8") 159 | 160 | # Delete temporary files 161 | if sequence is not None: 162 | os.remove(fasta_file_path) 163 | 164 | if n_samples > 1: 165 | return output.split("\n")[:-1] 166 | else: 167 | return [output.split("\n")[-2]] 168 | 169 | 170 | def dotbracket_to_paired(sec_struct: str) -> np.ndarray: 171 | """ 172 | Return whether each residue is paired (1) or unpaired (0) given 173 | secondary structure in dot-bracket notation. 174 | """ 175 | is_paired = np.zeros(len(sec_struct), dtype=np.int8) 176 | for i, c in enumerate(sec_struct): 177 | if c == '(' or c == ')': 178 | is_paired[i] = 1 179 | return is_paired 180 | 181 | 182 | def dotbracket_to_num(sec_struct: str) -> np.ndarray: 183 | """ 184 | Convert secondary structure in dot-bracket notation to 185 | numerical representation. 186 | """ 187 | return np.array([DOTBRACKET_TO_NUM[c] for c in sec_struct]) 188 | 189 | 190 | def dotbracket_to_adjacency(sec_struct: str) -> np.ndarray: 191 | """ 192 | Convert secondary structure in dot-bracket notation to 193 | adjacency matrix. 194 | """ 195 | n = len(sec_struct) 196 | adj = np.zeros((n, n), dtype=np.int8) 197 | stack = [] 198 | for i, db_char in enumerate(sec_struct): 199 | if db_char == '(': 200 | stack.append(i) 201 | elif db_char == ')': 202 | j = stack.pop() 203 | adj[i, j] = 1 204 | adj[j, i] = 1 205 | return adj -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | os.environ["NUMEXPR_MAX_THREADS"] = '56' 4 | os.environ["MKL_NUM_THREADS"] = '4' 5 | os.environ["OMP_NUM_THREADS"] = '4' 6 | import fire 7 | from pathlib import Path 8 | import pandas as pd 9 | 10 | import numpy as np 11 | import yaml 12 | import wandb 13 | import time 14 | from easydict import EasyDict 15 | import torch 16 | import pytorch_lightning as pl 17 | from pytorch_lightning.loggers import CSVLogger, WandbLogger 18 | from pytorch_lightning.callbacks import TQDMProgressBar, EarlyStopping, ModelCheckpoint, ModelSummary 19 | from pytorch_lightning.strategies.ddp import DDPStrategy 20 | from pl_modules import ModelModule, DataModule, PretuneModule, DDGModule 21 | from collections import defaultdict 22 | 23 | torch.set_num_threads(16) 24 | 25 | def parse_yaml(yaml_dir): 26 | with open(yaml_dir, 'r') as f: 27 | content = f.read() 28 | config_dict = EasyDict(yaml.load(content, Loader=yaml.FullLoader)) 29 | # args = Namespace(**config_dict) 30 | return config_dict 31 | def init_pytorch_settings(): 32 | # Multiprocess Setting to speedup dataloader 33 | torch.multiprocessing.set_start_method('forkserver') 34 | torch.multiprocessing.set_sharing_strategy('file_system') 35 | # torch.set_float32_matmul_precision('high') 36 | torch.set_num_threads(4) 37 | torch.backends.cuda.matmul.allow_tf32 = True 38 | torch.backends.cudnn.allow_tf32 = True 39 | 40 | class LightningRunner(object): 41 | def __init__(self, model_config='./config/models/esm2_rinalmo.yaml', data_config='./config/datasets/rpi.yaml', 42 | run_config='./config/runs/finetune_sequence.yaml'): 43 | super(LightningRunner, self).__init__() 44 | self.model_args = parse_yaml(model_config) 45 | self.dataset_args = parse_yaml(data_config) 46 | self.run_args = parse_yaml(run_config) 47 | init_pytorch_settings() 48 | 49 | def save_model(self, model, output_dir, trainer): 50 | print("Best Model Path:", trainer.checkpoint_callback.best_model_path) 51 | module = ModelModule.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) 52 | if trainer.global_rank == 0: 53 | best_model = module.model 54 | (output_dir / 'model_data.json').write_text(json.dumps(vars(self.dataset_args), indent=2)) 55 | torch.save(best_model, str(output_dir / 'model.pt')) 56 | 57 | def select_module(self, stage, log_dir): 58 | if stage=='pretune': 59 | model = PretuneModule(output_dir=log_dir, model_args=self.model_args, data_args=self.dataset_args, run_args=self.run_args) 60 | elif stage=='dG': 61 | model = ModelModule(output_dir=log_dir, model_args=self.model_args, data_args=self.dataset_args, run_args=self.run_args) 62 | elif stage=='ddG': 63 | model = DDGModule(output_dir=log_dir, model_args=self.model_args, data_args=self.dataset_args, run_args=self.run_args) 64 | else: 65 | raise NotImplementedError 66 | return model 67 | 68 | def finetune(self, stage='dG'): 69 | print("Run args:", self.run_args, "\n") 70 | print("Model args:", self.model_args, "\n") 71 | print("Dataset args:", self.dataset_args, "\n") 72 | output_dir, gpus = (self.run_args.output_dir, self.run_args.gpus) 73 | self.model_args.model.stage = stage 74 | # Setup datamodule 75 | run_results = [] 76 | for k in range(self.run_args.num_folds): 77 | # if k != 4: 78 | # continue 79 | print(f"Training fold {k} Started!") 80 | output_dir = Path(output_dir) 81 | log_dir = output_dir / f'log_fold_{k}' 82 | data_module = DataModule(dataset_args=self.dataset_args, **self.dataset_args, col_group=f'fold_{k}') 83 | 84 | # Setup model module 85 | model = self.select_module(stage, log_dir) 86 | # Trainer setting 87 | name = self.run_args.run_name + time.strftime("%Y-%m-%d-%H-%M-%S") 88 | if self.run_args.wandb: 89 | wandb.init(project='copra', name=name) 90 | logger = WandbLogger() 91 | else: 92 | logger = CSVLogger(str(log_dir)) 93 | # version_dir = Path(logger_csv.log_dir) 94 | pl.seed_everything(self.model_args.train.seed) 95 | print("Successfully initialized, start trainer...") 96 | strategy=DDPStrategy(find_unused_parameters=True) 97 | # strategy.lightning_restore_optimizer = False 98 | trainer = pl.Trainer( 99 | devices=gpus, 100 | # max_steps=self.run_args.iters, 101 | max_epochs=self.run_args.epochs, 102 | logger=logger, 103 | callbacks=[ 104 | EarlyStopping(monitor="val_loss", mode="min", patience=self.run_args.patience, strict=False), 105 | ModelCheckpoint(dirpath=(log_dir / 'checkpoint'), filename='{epoch}-{val_loss:.3f}', 106 | monitor="val_loss", mode="min", save_last=True, save_top_k=3), 107 | ], 108 | # gradient_clip_val=self.model_args.train.max_grad_norm if self.model_args.train.max_grad_norm is not None else None, 109 | # gradient_clip_algorithm='norm' if self.model_args.train.max_grad_norm is not None else None, 110 | strategy=strategy, 111 | log_every_n_steps=3, 112 | ) 113 | trainer.fit(model=model, datamodule=data_module, ckpt_path=self.run_args.ckpt) 114 | print(f"Training fold {k} Finished!") 115 | trainer.strategy.barrier() 116 | print("Best Validation Results:") 117 | _ = trainer.test(model=model, ckpt_path="best", datamodule=data_module) 118 | res = model.res 119 | run_results.append(res) 120 | if trainer.global_rank == 0: 121 | self.save_model(model, output_dir, trainer) 122 | result_dir = Path(output_dir) / name 123 | os.makedirs(result_dir, exist_ok=True) 124 | with open(result_dir / 'res.json', 'w') as f: 125 | json.dump(run_results, f) 126 | results_df = pd.DataFrame(run_results) 127 | print(results_df.describe()) 128 | 129 | def test(self, stage='dG'): 130 | print("Args:", self.run_args, self.dataset_args, self.model_args) 131 | output_dir, ckpts, gpus = (self.run_args.output_dir, self.run_args.ckpts, 132 | self.run_args.gpus) 133 | run_results = [] 134 | for k in range(self.run_args.num_folds): 135 | output_dir = Path(output_dir) 136 | log_dir = output_dir / f'log_fold_{k}' 137 | data_module = DataModule(dataset_args=self.dataset_args, **self.dataset_args, col_group=f'fold_{k}') 138 | # data_module.setup() 139 | model = self.select_module(stage, log_dir) 140 | logger = CSVLogger(str(log_dir)) 141 | strategy=DDPStrategy(find_unused_parameters=True) 142 | # strategy.lightning_restore_optimizer = False 143 | trainer = pl.Trainer( 144 | devices=gpus, 145 | max_epochs=0, 146 | logger=[ 147 | logger, 148 | ], 149 | callbacks=[ 150 | TQDMProgressBar(refresh_rate=1), 151 | ], 152 | strategy=strategy, 153 | ) 154 | 155 | _ = trainer.test(model=model, ckpt_path=ckpts[k], datamodule=data_module) 156 | res = model.res 157 | run_results.append(res) 158 | if trainer.global_rank == 0: 159 | results_df = pd.DataFrame(run_results) 160 | print(results_df.describe()) 161 | 162 | 163 | if __name__ == '__main__': 164 | fire.Fire(LightningRunner) -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import argparse 4 | import numpy as np 5 | import pandas as pd 6 | import scipy.stats as stats 7 | from sklearn.metrics import roc_auc_score 8 | import torch.nn.functional as F 9 | from sklearn.metrics import mean_squared_error 10 | from models.components.loss import PearsonCorrLoss 11 | 12 | 13 | class BlackHole(object): 14 | def __setattr__(self, name, value): 15 | pass 16 | 17 | def __call__(self, *args, **kwargs): 18 | return self 19 | 20 | def __getattr__(self, name): 21 | return self 22 | 23 | 24 | class ScalarMetricAccumulator(object): 25 | 26 | def __init__(self): 27 | super().__init__() 28 | self.accum_dict = {} 29 | self.count_dict = {} 30 | 31 | @torch.no_grad() 32 | def add(self, name, value, batchsize=None, mode=None): 33 | assert mode is None or mode in ('mean', 'sum') 34 | 35 | if mode is None: 36 | delta = value.sum() 37 | count = value.size(0) 38 | elif mode == 'mean': 39 | delta = value * batchsize 40 | count = batchsize 41 | elif mode == 'sum': 42 | delta = value 43 | count = batchsize 44 | delta = delta.item() if isinstance(delta, torch.Tensor) else delta 45 | 46 | if name not in self.accum_dict: 47 | self.accum_dict[name] = 0 48 | self.count_dict[name] = 0 49 | self.accum_dict[name] += delta 50 | self.count_dict[name] += count 51 | 52 | def log(self, it, tag, logger=BlackHole(), writer=BlackHole()): 53 | summary = {k: self.accum_dict[k] / self.count_dict[k] for k in self.accum_dict} 54 | logstr = '[%s] Iter %05d' % (tag, it) 55 | for k, v in summary.items(): 56 | logstr += ' | %s %.4f' % (k, v) 57 | writer.add_scalar('%s/%s' % (tag, k), v, it) 58 | logger.info(logstr) 59 | 60 | def get_average(self, name): 61 | return self.accum_dict[name] / self.count_dict[name] 62 | 63 | 64 | def per_complex_corr(df, pred_attr='y_pred', true_attr='y_true', limit=10): 65 | corr_table = [] 66 | for cplx in df['complex'].unique(): 67 | df_cplx = df.query(f'complex == "{cplx}"') 68 | if len(df_cplx) <= 2: 69 | continue 70 | # if len(df_cplx) < limit: 71 | # continue 72 | y_pred = np.array(df_cplx[pred_attr]) 73 | y_true = np.array(df_cplx[true_attr]) 74 | # print("HIIII!", cal_pearson(y_pred, y_true)) 75 | # print("Pred_and True:", y_pred, y_true) 76 | corr_table.append({ 77 | 'complex': cplx, 78 | 'pearson': abs(cal_pearson(y_pred, y_true)), 79 | 'spearman': abs(cal_spearman(y_pred, y_true)), 80 | 'rmse': cal_rmse(y_pred, y_true), 81 | 'mae': cal_mae(y_pred, y_true) 82 | }) 83 | # print("Corr_tabel:", corr_table) 84 | corr_table = pd.DataFrame(corr_table) 85 | corr_table.fillna(0) 86 | avg = corr_table[['pearson', 'spearman', 'rmse', 'mae']].mean() 87 | return avg['pearson'], avg['spearman'], avg['rmse'], avg['mae'] 88 | 89 | 90 | def per_complex_acc(df, pred_attr='y_pred', true_attr='y_true', limit=10): 91 | acc_table = [] 92 | for cplx in df['complex'].unique(): 93 | df_cplx = df.query(f'complex == "{cplx}"') 94 | if len(df_cplx) <= 2: 95 | continue 96 | y_pred = np.array(df_cplx[pred_attr]) 97 | y_true = np.array(df_cplx[true_attr]) 98 | acc_table.append({ 99 | 'complex': cplx, 100 | 'accuracy': cal_accuracy(y_pred, y_true), 101 | # 'auc': cal_auc(y_pred, y_true), 102 | 'precision': cal_precision(y_pred, y_true), 103 | 'recall': cal_recall(y_pred, y_true) 104 | }) 105 | acc_table = pd.DataFrame(acc_table) 106 | # avg = acc_table[['accuracy', 'auc', 'precision', 'recall']].mean() 107 | # return avg['accuracy'], avg['auc'], avg['precision'], avg['recall'] 108 | avg = acc_table[['accuracy', 'precision', 'recall']].mean() 109 | return avg['accuracy'], avg['precision'], avg['recall'] 110 | 111 | 112 | def sum_weighted_losses(losses, weights): 113 | """ 114 | Args: 115 | losses: Dict of scalar tensors. 116 | weights: Dict of weights. 117 | """ 118 | loss = 0 119 | for k in losses.keys(): 120 | if weights is None: 121 | loss = loss + losses[k] 122 | else: 123 | loss = loss + weights[k] * losses[k] 124 | return loss 125 | 126 | 127 | def get_loss(loss_type, pred, y, reduction='none'): 128 | if loss_type == 'regression': 129 | # criterion = PearsonCorrLoss() 130 | # losses = F.huber_loss(pred, y, delta=2, reduction=reduction) 131 | losses = F.mse_loss(pred, y, reduction=reduction) 132 | # print("MSE Loss:", F.mse_loss(pred, y, reduction=reduction)) 133 | # print("Pearson Loss:", criterion(pred, y)) 134 | elif loss_type == 'binary': 135 | losses = F.binary_cross_entropy_with_logits(pred, y, reduction=reduction) 136 | else: 137 | raise NotImplementedError("Loss Not Implemented!") 138 | return losses 139 | 140 | 141 | def cal_weighted_loss(pred_dict, y, mask, loss_types, loss_weights): 142 | loss_list = [] 143 | y_pred = pred_dict['y_pred'] 144 | # print("Y_pred:", y_pred.shape, y.shape, mask.shape) 145 | if len(y_pred.shape) > 1: 146 | assert y_pred.shape[1] == len(loss_types) 147 | else: 148 | y_pred = y_pred.unsqueeze(1) 149 | for i, l_type in enumerate(loss_types): 150 | y_pred_i = y_pred[:, i] 151 | y_i = y[:, i] 152 | mask_i = mask[:, i] 153 | l_i = get_loss(l_type, y_pred_i, y_i) * float(loss_weights[i]) 154 | if 'y_pred_inv' in pred_dict and l_type == 'regression' and i == 0: 155 | # Only ddG task has the inversion property 156 | y_pred_inv = pred_dict['y_pred_inv'] 157 | if len(y_pred_inv.shape) == 1: 158 | y_pred_inv = y_pred_inv.unsqueeze(1) 159 | y_pred_inv_i = y_pred_inv[:, i] 160 | l_i_inv = get_loss(l_type, y_pred_inv_i, -y_i) * float(loss_weights[i]) 161 | l_i = (l_i + l_i_inv) / 2 162 | loss_list.append(l_i) 163 | losses = torch.stack(loss_list, dim=-1) 164 | loss = (losses * mask).sum() / (mask.sum().clip(min=1)) 165 | return loss 166 | 167 | 168 | def cal_pearson(pred, gt): 169 | # print("Pearson Cal:", np.unique(pred.shape), np.unique(gt.shape)) 170 | if np.isnan(stats.pearsonr(pred, gt).statistic): 171 | print("Pearson Cal:", pred, gt) 172 | return stats.pearsonr(pred, gt).statistic 173 | 174 | def cal_spearman(pred, gt): 175 | if np.isnan(stats.spearmanr(pred, gt).statistic): 176 | print("SPearman Cal:", pred, gt) 177 | return stats.spearmanr(pred, gt).statistic 178 | 179 | def cal_rmse(pred, gt): 180 | return math.sqrt(mean_squared_error(pred, gt)) 181 | 182 | def cal_mae(pred, gt): 183 | return np.abs(pred-gt).sum() / len(pred) 184 | 185 | def cal_accuracy(pred, gt, thres=0.5): 186 | logits = 1 / (1+np.exp(-pred)) 187 | binary = np.where(logits>=thres, 1, 0) 188 | acc = (binary == gt).sum() / np.size(binary) 189 | return acc 190 | 191 | def cal_auc(pred, gt): 192 | logits = 1 / (1+np.exp(-pred)) 193 | if len(np.unique(gt)) == 1: 194 | return 0 195 | return roc_auc_score(gt.astype(np.int32), logits) 196 | 197 | def cal_precision(pred, gt, thres=0.5): 198 | logits = 1 / (1+np.exp(-pred)) 199 | binary = np.where(logits>=thres, 1, 0) 200 | true_positives = (binary == 1) & (gt == 1) 201 | positive_preds = binary == 1 202 | return np.sum(true_positives) / np.sum(positive_preds) if np.sum(positive_preds) > 0 else 0 203 | 204 | def cal_recall(pred, gt, thres=0.5): 205 | logits = 1 / (1+np.exp(-pred)) 206 | binary = np.where(logits>=thres, 1, 0) 207 | true_positives = (binary == 1) & (gt == 1) 208 | actual_positives = (gt == 1) 209 | return np.sum(true_positives) / np.sum(actual_positives) if np.sum(actual_positives) > 0 else 0 210 | 211 | 212 | 213 | 214 | 215 | if __name__ == '__main__': 216 | parser = argparse.ArgumentParser() 217 | parser.add_argument('--pred_dir', default=None, type=str) 218 | args = parser.parse_args() 219 | df = pd.read_csv(args.pred_dir) 220 | pred = df['pred_0'] 221 | gt = df['gt_0'] 222 | pearson = cal_pearson(pred, gt)[0] 223 | spearman = cal_spearman(pred, gt)[0] 224 | rmse = cal_rmse(pred, gt) 225 | mae = cal_mae(pred, gt) 226 | print("Pearson Score is {}".format(pearson)) 227 | print("Spearman Score is {}".format(spearman)) 228 | print("RMSE Score is {}".format(rmse)) 229 | print("MAE Score is {}".format(mae)) 230 | -------------------------------------------------------------------------------- /models/encoders/attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from utils.geometry import global_to_local, local_to_global, normalize_vector, construct_3d_basis, angstrom_to_nm 7 | from models.encoders.layers import mask_zero, LayerNorm 8 | 9 | 10 | def _alpha_from_logits(logits, mask, inf=1e5): 11 | """ 12 | Args: 13 | logits: Logit matrices, (N, L_i, L_j, num_heads). 14 | mask: Masks, (N, L). 15 | Returns: 16 | alpha: Attention weights. 17 | """ 18 | N, L, _, _ = logits.size() 19 | mask_row = mask.view(N, L, 1, 1).expand_as(logits) # (N, L, *, *) 20 | mask_pair = mask_row * mask_row.permute(0, 2, 1, 3) # (N, L, L, *) 21 | 22 | logits = torch.where(mask_pair, logits, logits - inf) 23 | alpha = torch.softmax(logits, dim=2) # (N, L, L, num_heads) 24 | alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha)) 25 | return alpha 26 | 27 | 28 | def _heads(x, n_heads, n_ch): 29 | """ 30 | Args: 31 | x: (..., num_heads * num_channels) 32 | Returns: 33 | (..., num_heads, num_channels) 34 | """ 35 | s = list(x.size())[:-1] + [n_heads, n_ch] 36 | return x.view(*s) 37 | 38 | 39 | class GABlock(nn.Module): 40 | 41 | def __init__(self, node_feat_dim, pair_feat_dim, value_dim=32, query_key_dim=32, num_query_points=8, 42 | num_value_points=8, num_heads=12, bias=False): 43 | super().__init__() 44 | self.node_feat_dim = node_feat_dim 45 | self.pair_feat_dim = pair_feat_dim 46 | self.value_dim = value_dim 47 | self.query_key_dim = query_key_dim 48 | self.num_query_points = num_query_points 49 | self.num_value_points = num_value_points 50 | self.num_heads = num_heads 51 | 52 | # Node 53 | self.proj_query = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias) 54 | self.proj_key = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias) 55 | self.proj_value = nn.Linear(node_feat_dim, value_dim * num_heads, bias=bias) 56 | 57 | # Pair 58 | self.proj_pair_bias = nn.Linear(pair_feat_dim, num_heads, bias=bias) 59 | 60 | # Spatial 61 | self.spatial_coef = nn.Parameter(torch.full([1, 1, 1, self.num_heads], fill_value=np.log(np.exp(1.) - 1.)), 62 | requires_grad=True) 63 | self.proj_query_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias) 64 | self.proj_key_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias) 65 | self.proj_value_point = nn.Linear(node_feat_dim, num_value_points * num_heads * 3, bias=bias) 66 | 67 | # Output 68 | self.out_transform = nn.Linear( 69 | in_features=(num_heads * pair_feat_dim) + (num_heads * value_dim) + ( 70 | num_heads * num_value_points * (3 + 3 + 1)), 71 | out_features=node_feat_dim, 72 | ) 73 | 74 | self.layer_norm_1 = LayerNorm(node_feat_dim) 75 | self.mlp_transition = nn.Sequential(nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(), 76 | nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(), 77 | nn.Linear(node_feat_dim, node_feat_dim)) 78 | self.layer_norm_2 = LayerNorm(node_feat_dim) 79 | 80 | def _node_logits(self, x): 81 | query_l = _heads(self.proj_query(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch) 82 | key_l = _heads(self.proj_key(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch) 83 | logits_node = (query_l.unsqueeze(2) * key_l.unsqueeze(1) * 84 | (1 / np.sqrt(self.query_key_dim))).sum(-1) # (N, L, L, num_heads) 85 | return logits_node 86 | 87 | def _pair_logits(self, z): 88 | logits_pair = self.proj_pair_bias(z) 89 | return logits_pair 90 | 91 | def _spatial_logits(self, R, t, x): 92 | N, L, _ = t.size() 93 | # Query 94 | query_points = _heads(self.proj_query_point(x), self.num_heads * self.num_query_points, 95 | 3) # (N, L, n_heads * n_pnts, 3) 96 | query_points = local_to_global(R, t, query_points) # Global query coordinates, (N, L, n_heads * n_pnts, 3) 97 | query_s = query_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3) 98 | # Key 99 | key_points = _heads(self.proj_key_point(x), self.num_heads * self.num_query_points, 100 | 3) # (N, L, 3, n_heads * n_pnts) 101 | key_points = local_to_global(R, t, key_points) # Global key coordinates, (N, L, n_heads * n_pnts, 3) 102 | key_s = key_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3) 103 | # Q-K Product 104 | sum_sq_dist = ((query_s.unsqueeze(2) - key_s.unsqueeze(1)) ** 2).sum(-1) # (N, L, L, n_heads) 105 | gamma = F.softplus(self.spatial_coef) 106 | logits_spatial = sum_sq_dist * ((-1 * gamma * np.sqrt(2 / (9 * self.num_query_points))) 107 | / 2) # (N, L, L, n_heads) 108 | return logits_spatial 109 | 110 | def _pair_aggregation(self, alpha, z): 111 | N, L = z.shape[:2] 112 | feat_p2n = alpha.unsqueeze(-1) * z.unsqueeze(-2) # (N, L, L, n_heads, C) 113 | feat_p2n = feat_p2n.sum(dim=2) # (N, L, n_heads, C) 114 | return feat_p2n.reshape(N, L, -1) 115 | 116 | def _node_aggregation(self, alpha, x): 117 | N, L = x.shape[:2] 118 | 119 | value_l = _heads(self.proj_value(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, v_ch) 120 | feat_node = alpha.unsqueeze(-1) * value_l.unsqueeze(1) # (N, L, L, n_heads, *) @ (N, *, L, n_heads, v_ch) 121 | feat_node = feat_node.sum(dim=2) # (N, L, n_heads, v_ch) 122 | return feat_node.reshape(N, L, -1) 123 | 124 | def _spatial_aggregation(self, alpha, R, t, x): 125 | N, L, _ = t.size() 126 | value_points = _heads(self.proj_value_point(x), self.num_heads * self.num_value_points, 127 | 3) # (N, L, n_heads * n_v_pnts, 3) 128 | value_points = local_to_global(R, t, value_points.reshape(N, L, self.num_heads, self.num_value_points, 129 | 3)) # (N, L, n_heads, n_v_pnts, 3) 130 | aggr_points = alpha.reshape(N, L, L, self.num_heads, 1, 1) * \ 131 | value_points.unsqueeze(1) # (N, *, L, n_heads, n_pnts, 3) 132 | aggr_points = aggr_points.sum(dim=2) # (N, L, n_heads, n_pnts, 3) 133 | 134 | feat_points = global_to_local(R, t, aggr_points) # (N, L, n_heads, n_pnts, 3) 135 | feat_distance = feat_points.norm(dim=-1) # (N, L, n_heads, n_pnts) 136 | feat_direction = normalize_vector(feat_points, dim=-1, eps=1e-4) # (N, L, n_heads, n_pnts, 3) 137 | 138 | feat_spatial = torch.cat([ 139 | feat_points.reshape(N, L, -1), 140 | feat_distance.reshape(N, L, -1), 141 | feat_direction.reshape(N, L, -1), 142 | ], dim=-1) 143 | 144 | return feat_spatial 145 | 146 | def forward(self, R, t, x, z, mask): 147 | """ 148 | Args: 149 | R: Frame basis matrices, (N, L, 3, 3_index). 150 | t: Frame external (absolute) coordinates, (N, L, 3). 151 | x: Node-wise features, (N, L, F). 152 | z: Pair-wise features, (N, L, L, C). 153 | mask: Masks, (N, L). 154 | Returns: 155 | x': Updated node-wise features, (N, L, F). 156 | """ 157 | # Attention logits 158 | logits_node = self._node_logits(x) 159 | logits_pair = self._pair_logits(z) 160 | logits_spatial = self._spatial_logits(R, t, x) 161 | # Summing logits up and apply `softmax`. 162 | logits_sum = logits_node + logits_pair + logits_spatial 163 | alpha = _alpha_from_logits(logits_sum * np.sqrt(1 / 3), mask) # (N, L, L, n_heads) 164 | 165 | # Aggregate features 166 | feat_p2n = self._pair_aggregation(alpha, z) 167 | feat_node = self._node_aggregation(alpha, x) 168 | feat_spatial = self._spatial_aggregation(alpha, R, t, x) 169 | 170 | # Finally 171 | feat_all = self.out_transform(torch.cat([feat_p2n, feat_node, feat_spatial], dim=-1)) # (N, L, F) 172 | feat_all = mask_zero(mask.unsqueeze(-1), feat_all) 173 | x_updated = self.layer_norm_1(x + feat_all) 174 | x_updated = self.layer_norm_2(x_updated + self.mlp_transition(x_updated)) 175 | return x_updated 176 | 177 | 178 | # Graph Attention Encoder 179 | class GAEncoder(nn.Module): 180 | 181 | def __init__(self, node_feat_dim, pair_feat_dim, num_layers, ga_block_opt={}): 182 | super(GAEncoder, self).__init__() 183 | self.blocks = nn.ModuleList([ 184 | GABlock(node_feat_dim, pair_feat_dim, **ga_block_opt) 185 | for _ in range(num_layers) 186 | ]) 187 | 188 | def forward(self, pos_atoms, res_feat, pair_feat, mask): 189 | R = construct_3d_basis( 190 | pos_atoms[:, :, 1], 191 | pos_atoms[:, :, 2], 192 | pos_atoms[:, :, 0] 193 | ) 194 | t = pos_atoms[:, :, 1] 195 | t = angstrom_to_nm(t) 196 | for block in self.blocks: 197 | res_feat = block(R, t, res_feat, pair_feat, mask) 198 | return res_feat 199 | -------------------------------------------------------------------------------- /data/sequence_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from data.register import DataRegister 3 | import esm 4 | from rinalmo.data.constants import * 5 | from rinalmo.data.alphabet import Alphabet 6 | import torch 7 | from tqdm import tqdm 8 | R = DataRegister() 9 | 10 | na_alphabet_config = { 11 | "standard_tkns": RNA_TOKENS, 12 | "special_tkns": [CLS_TKN, PAD_TKN, EOS_TKN, UNK_TKN, MASK_TKN], 13 | } 14 | 15 | 16 | CLS_TOKEN_IDX = 0 17 | PAD_TOKEN_IDX = 1 18 | EOS_TOKEN_IDX = 2 19 | 20 | @R.register('sequence_dataset') 21 | class SequenceDataset(Dataset): 22 | def __init__(self, dataframe, 23 | col_prot='protein', col_na='na',col_label='dG', col_prot_name = 'PDB', 24 | diskcache=None, 25 | **kwargs): 26 | super(SequenceDataset, self).__init__() 27 | self.df = dataframe 28 | self.col_protein = col_prot 29 | self.col_prot_name = col_prot_name 30 | self.col_na = col_na 31 | self.col_label = col_label 32 | self.type = 'reg' 33 | self.diskcache = diskcache 34 | self.prot_alphabet = esm.data.Alphabet.from_architecture("ESM-1b") 35 | self.na_alphabet = Alphabet(**na_alphabet_config) 36 | self.load_data() 37 | 38 | def load_data(self): 39 | self.data = [] 40 | 41 | for i, row in tqdm(self.df.iterrows(), total=self.df.shape[0]): 42 | structure_id = row[self.col_prot_name] 43 | if self.diskcache is None or structure_id not in self.diskcache: 44 | max_prot_length = 0 45 | max_na_length = 0 46 | prot_seqs_info = row[self.col_protein].split(',') 47 | na_seqs_info = row[self.col_na].split(',') 48 | prot_seqs = [] 49 | na_seqs = [] 50 | for prot_seq in prot_seqs_info: 51 | if ':' in prot_seq: 52 | prot_seq = prot_seq.split(':')[1] 53 | if len(prot_seq) > max_prot_length: 54 | max_prot_length = len(prot_seq) 55 | prot_seqs.append(prot_seq) 56 | for na_seq in na_seqs_info: 57 | if ':' in na_seq: 58 | na_seq = na_seq.split(':')[1] 59 | if len(na_seq) > max_na_length: 60 | max_na_length = len(na_seq) 61 | na_seqs.append(na_seq) 62 | 63 | label = row[self.col_label] 64 | item = { 65 | 'id': structure_id, 66 | 'prot_seqs': prot_seqs, 67 | 'na_seqs': na_seqs, 68 | 'label': label, 69 | 'max_prot_length': max_prot_length, 70 | 'max_na_length': max_na_length 71 | } 72 | 73 | self.data.append(item) 74 | if self.diskcache is not None: 75 | self.diskcache[structure_id] = item 76 | else: 77 | self.data.append(self.diskcache[structure_id]) 78 | 79 | def __getitem__(self, idx): 80 | return self.data[idx] 81 | 82 | def __len__(self): 83 | return len(self.df) 84 | 85 | class CustomSeqCollate(object): 86 | def __init__(self, strategy='separate'): 87 | super(CustomSeqCollate, self).__init__() 88 | self.strategy=strategy 89 | def __call__(self, batch): 90 | size = len(batch) 91 | labels = torch.tensor([item['labels'] for item in batch], dtype=torch.float32) 92 | prot_alphabet = esm.data.Alphabet.from_architecture("ESM-1b") 93 | na_alphabet = Alphabet(**na_alphabet_config) 94 | # print("Batch:", batch) 95 | # for item in batch: 96 | # for seq in item['na_seqs']: 97 | # print("RNA seqs:", seq, len(seq)) 98 | prot_chains = [len(item['prot_seqs']) for item in batch] 99 | na_chains = [len(item['na_seqs']) for item in batch] 100 | if self.strategy == 'separate': 101 | max_item_prot_length = [item['max_prot_length'] for item in batch] 102 | max_item_na_length = [item['max_na_length'] for item in batch] 103 | max_prot_length = max(max_item_prot_length) 104 | max_na_length = max(max_item_na_length) 105 | total_prot_chains = sum(prot_chains) 106 | total_na_chains = sum(na_chains) 107 | prot_batch = torch.empty([total_prot_chains, max_prot_length+2]) 108 | prot_batch.fill_(prot_alphabet.padding_idx) 109 | na_batch = torch.empty([total_na_chains, max_na_length+2]) 110 | na_batch.fill_(na_alphabet.pad_idx) 111 | curr_prot_idx = 0 112 | curr_na_idx = 0 113 | for item in batch: 114 | prot_seqs = item['prot_seqs'] 115 | na_seqs = item['na_seqs'] 116 | # print(item['id']) 117 | for prot_seq in prot_seqs: 118 | prot_batch[curr_prot_idx, 0] = prot_alphabet.cls_idx 119 | prot_seq_encode = prot_alphabet.encode(prot_seq) 120 | seq = torch.tensor(prot_seq_encode, dtype=torch.int64) 121 | prot_batch[curr_prot_idx, 1: len(prot_seq_encode)+1] = seq 122 | prot_batch[curr_prot_idx, len(prot_seq_encode)+1] = prot_alphabet.eos_idx 123 | curr_prot_idx += 1 124 | for na_seq in na_seqs: 125 | # na_batch[curr_na_idx, 0] = na_alphabet.cls_idx 126 | # NA encoder adds CLS and EOS by default 127 | na_seq_encode = na_alphabet.encode(na_seq) 128 | seq = torch.tensor(na_seq_encode, dtype=torch.int64) 129 | na_batch[curr_na_idx, :len(seq)] = seq 130 | # na_batch[curr_na_idx, len(na_seq_encode)+1] = na_alphabet.eos_idx 131 | curr_na_idx += 1 132 | 133 | elif self.strategy == 'combine': 134 | # prot_linker = 'G' * 25 135 | # na_linker = 'T' * 15 136 | prot_linker = '' 137 | na_linker = '' 138 | complex_prot_max_length = 0 139 | complex_na_max_length = 0 140 | for i, item in enumerate(batch): 141 | prot_seqs = item['prot_seqs'] 142 | na_seqs = item['na_seqs'] 143 | prot_complex_seq = prot_linker.join(prot_seqs) 144 | na_complex_seq = na_linker.join(na_seqs) 145 | if len(prot_complex_seq) > complex_prot_max_length: 146 | complex_prot_max_length = len(prot_complex_seq) 147 | if len(na_complex_seq) > complex_na_max_length: 148 | complex_na_max_length = len(na_complex_seq) 149 | 150 | prot_batch = torch.empty([len(batch), complex_prot_max_length+2]) 151 | prot_batch.fill_(prot_alphabet.padding_idx) 152 | na_batch = torch.empty([len(batch), complex_na_max_length+5]) 153 | na_batch.fill_(na_alphabet.pad_idx) 154 | 155 | for i, item in enumerate(batch): 156 | prot_batch[i, 0] = prot_alphabet.cls_idx 157 | prot_complex_encode = prot_alphabet.encode(prot_complex_seq) 158 | seq = torch.tensor(prot_complex_encode, dtype=torch.int64) 159 | prot_batch[i, 1: len(prot_complex_encode)+1] = seq 160 | prot_batch[i, len(prot_complex_encode)+1] = prot_alphabet.eos_idx 161 | prot_linker_start = 1 162 | if len(item['prot_seqs']) > 1 and len(prot_linker) > 0: 163 | # print("Combining Protein...", len(item['prot_seqs'])) 164 | for j, p_seq in enumerate(item['prot_seqs']): 165 | if j == len(item['prot_seqs']) - 1: 166 | break 167 | seq_len = len(p_seq) 168 | prot_linker_start += seq_len 169 | linker_len = len(prot_linker) 170 | prot_batch[i, prot_linker_start: prot_linker_start + linker_len] = prot_alphabet.padding_idx 171 | # print("Done!", i) 172 | prot_linker_start += linker_len 173 | # na_batch[i, 0] = na_alphabet.cls_idx 174 | na_complex_encode = na_alphabet.encode(na_complex_seq) 175 | seq = torch.tensor(na_complex_encode, dtype=torch.int64) 176 | na_batch[i, :len(seq)] = seq 177 | na_linker_start = 1 178 | if len(item['na_seqs']) > 1 and len(na_linker) > 0: 179 | # print("Combining NA...", len(item['na_seqs'])) 180 | for j, n_seq in enumerate(item['na_seqs']): 181 | if j == len(item['na_seqs']) - 1: 182 | break 183 | seq_len = len(n_seq) 184 | na_linker_start += seq_len 185 | linker_len = len(na_linker) 186 | na_batch[i, na_linker_start: na_linker_start + linker_len] = na_alphabet.pad_idx 187 | na_linker_start += linker_len 188 | # na_batch[i, len(na_complex_encode)+1] = na_alphabet.eos_idx 189 | else: 190 | raise ValueError 191 | prot_mask = torch.zeros_like(prot_batch) 192 | na_mask = torch.zeros_like(na_batch) 193 | prot_mask[(prot_batch!=prot_alphabet.padding_idx) & (prot_batch!=prot_alphabet.eos_idx) & (prot_batch!=prot_alphabet.cls_idx)] = 1 194 | na_mask[(na_batch!=na_alphabet.pad_idx) & (na_batch!=na_alphabet.eos_idx) & (na_batch!=na_alphabet.cls_idx)] = 1 195 | 196 | data = { 197 | 'size': size, 198 | 'labels': labels, 199 | 'prot': prot_batch.long(), 200 | 'prot_chains': prot_chains, 201 | 'protein_mask': prot_mask, 202 | 'na': na_batch.long(), 203 | 'na_chains': na_chains, 204 | 'na_mask': na_mask, 205 | 'strategy': self.strategy 206 | } 207 | # print(data) 208 | return data -------------------------------------------------------------------------------- /models/components/valina_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | import math 5 | 6 | from rinalmo.model.rope import RotaryPositionEmbedding 7 | 8 | from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func 9 | from flash_attn.layers.rotary import RotaryEmbedding 10 | 11 | from flash_attn.bert_padding import unpad_input, pad_input 12 | 13 | from einops import rearrange 14 | 15 | def dot_product_attention(q, k, v, attn_mask=None, key_pad_mask=None, dropout=None): 16 | c = q.shape[-1] 17 | attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(c) 18 | 19 | if attn_mask is not None: 20 | attn = attn.masked_fill(attn_mask, float("-inf")) 21 | 22 | if key_pad_mask is not None: 23 | attn = attn.masked_fill(key_pad_mask.unsqueeze(1).unsqueeze(2), float("-inf")) 24 | 25 | attn = attn.softmax(dim=-1) 26 | if dropout is not None: 27 | attn = dropout(attn) 28 | 29 | output = torch.matmul(attn, v) 30 | if torch.isnan(output).any(): 31 | print("Found nan in attn!") 32 | return output, attn 33 | 34 | class MultiHeadAttention(nn.Module): 35 | def __init__(self, c_in, num_heads, attention_dropout=0.0, use_rot_emb=True, bias=False): 36 | super().__init__() 37 | assert c_in % num_heads == 0, "Embedding dimensionality must be divisible with number of attention heads!" 38 | 39 | self.c_in = c_in 40 | self.num_heads = num_heads 41 | 42 | self.c_head = c_in // self.num_heads 43 | self.c_qkv = self.c_head * num_heads 44 | 45 | self.use_rot_emb = use_rot_emb 46 | if self.use_rot_emb: 47 | self.rotary_emb = RotaryPositionEmbedding(self.c_head) 48 | 49 | self.to_q = nn.Linear(self.c_in, self.c_qkv, bias=bias) 50 | self.to_k = nn.Linear(self.c_in, self.c_qkv, bias=bias) 51 | self.to_v = nn.Linear(self.c_in, self.c_qkv, bias=bias) 52 | 53 | self.attention_dropout = nn.Dropout(p=attention_dropout) 54 | 55 | self.out_proj = nn.Linear(c_in, c_in, bias=bias) 56 | 57 | def forward(self, q, k, v, attn_mask=None, key_pad_mask=None): 58 | bs = q.shape[0] 59 | 60 | q = self.to_q(q).view(bs, -1, self.num_heads, self.c_head).transpose(-2, -3) 61 | k = self.to_k(k).view(bs, -1, self.num_heads, self.c_head).transpose(-2, -3) 62 | v = self.to_v(v).view(bs, -1, self.num_heads, self.c_head).transpose(-2, -3) 63 | 64 | if self.use_rot_emb: 65 | q, k = self.rotary_emb(q, k) 66 | 67 | output, attn = dot_product_attention(q, k, v, attn_mask, key_pad_mask, self.attention_dropout) 68 | 69 | output = output.transpose(-2, -3).contiguous().view(bs, -1, self.num_heads * self.c_head) 70 | output = self.out_proj(output) 71 | if torch.isnan(output).any(): 72 | print("Found nan in attn!") 73 | return output, attn 74 | 75 | class MultiHeadSelfAttention(nn.Module): 76 | def __init__(self, c_in, num_heads, attention_dropout=0.0, use_rot_emb=True, bias=False): 77 | super().__init__() 78 | 79 | self.mh_attn = MultiHeadAttention(c_in, num_heads, attention_dropout, use_rot_emb, bias) 80 | 81 | def forward(self, x, attn_mask=None, key_pad_mask=None): 82 | return self.mh_attn(x, x, x, attn_mask, key_pad_mask) 83 | 84 | class FlashAttention(nn.Module): 85 | """ 86 | Implement the scaled dot product attention with softmax. 87 | """ 88 | def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): 89 | """ 90 | Args: 91 | causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). 92 | softmax_scale: float. The scaling of QK^T before applying softmax. 93 | Default to 1 / sqrt(headdim). 94 | attention_dropout: float. The dropout rate to apply to the attention 95 | (default: 0.0) 96 | """ 97 | super().__init__() 98 | self.softmax_scale = softmax_scale 99 | self.attention_dropout = attention_dropout 100 | self.causal = causal 101 | 102 | def forward(self, qkv, cu_seqlens=None, max_seqlen=None, return_attn_probs=False): 103 | """ 104 | Arguments 105 | --------- 106 | qkv: The tensor containing the query, key, and value. 107 | If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). 108 | If cu_seqlens is not None and max_seqlen is not None, then qkv has shape 109 | (total, 3, H, D), where total is the sum of the sequence lengths in the batch. 110 | cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 111 | of the sequences in the batch, used to index into qkv. 112 | max_seqlen: int. Maximum sequence length in the batch. 113 | return_attn_probs: bool. Whether to return the attention probabilities. This option is for 114 | testing only. The returned probabilities are not guaranteed to be correct 115 | (they might not have the right scaling). 116 | Returns: 117 | -------- 118 | out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, 119 | else (B, S, H, D). 120 | """ 121 | assert qkv.dtype in [torch.float16, torch.bfloat16] 122 | assert qkv.is_cuda 123 | 124 | unpadded = cu_seqlens is not None 125 | 126 | if unpadded: 127 | assert cu_seqlens.dtype == torch.int32 128 | assert max_seqlen is not None 129 | assert isinstance(max_seqlen, int) 130 | return flash_attn_varlen_qkvpacked_func( 131 | qkv, 132 | cu_seqlens, 133 | max_seqlen, 134 | self.attention_dropout if self.training else 0.0, 135 | softmax_scale=self.softmax_scale, 136 | causal=self.causal, 137 | return_attn_probs=return_attn_probs 138 | ) 139 | else: 140 | return flash_attn_qkvpacked_func( 141 | qkv, 142 | self.attention_dropout if self.training else 0.0, 143 | softmax_scale=self.softmax_scale, 144 | causal=self.causal, 145 | return_attn_probs=return_attn_probs 146 | ) 147 | 148 | class FlashMultiHeadSelfAttention(nn.Module): 149 | """ 150 | Multi-head self-attention implemented using FlashAttention. 151 | """ 152 | def __init__(self, embed_dim, num_heads, attention_dropout=0.0, causal=False, use_rot_emb=True, bias=False): 153 | super().__init__() 154 | assert embed_dim % num_heads == 0, "Embedding dimensionality must be divisible with number of attention heads!" 155 | 156 | self.causal = causal 157 | 158 | self.embed_dim = embed_dim 159 | self.num_heads = num_heads 160 | 161 | self.head_dim = self.embed_dim // self.num_heads 162 | self.qkv_dim = self.head_dim * num_heads * 3 163 | 164 | self.rotary_emb_dim = self.head_dim 165 | self.use_rot_emb = use_rot_emb 166 | if self.use_rot_emb: 167 | self.rotary_emb = RotaryEmbedding( 168 | dim=self.rotary_emb_dim, 169 | base=10000.0, 170 | interleaved=False, 171 | scale_base=None, 172 | pos_idx_in_fp32=True, # fp32 RoPE precision 173 | device=None 174 | ) 175 | self.flash_self_attn = FlashAttention(causal=self.causal, softmax_scale=None, attention_dropout=attention_dropout) 176 | 177 | self.Wqkv = nn.Linear(self.embed_dim, self.qkv_dim, bias=bias) 178 | 179 | self.attention_dropout = nn.Dropout(p=attention_dropout) 180 | 181 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias) 182 | 183 | def forward(self, x, key_padding_mask=None, return_attn_probs=False): 184 | """ 185 | Arguments: 186 | x: (batch, seqlen, hidden_dim) (where hidden_dim = num_heads * head_dim) 187 | key_pad_mask: boolean mask, True means to keep, False means to mask out. 188 | (batch, seqlen) 189 | return_attn_probs: whether to return attention masks (False by default) 190 | """ 191 | 192 | qkv = self.Wqkv(x) 193 | qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) 194 | 195 | if self.use_rot_emb: 196 | qkv = self.rotary_emb(qkv, seqlen_offset=0) 197 | 198 | if return_attn_probs: 199 | bs = qkv.shape[0] 200 | qkv = torch.permute(qkv, (0, 3, 2, 1, 4)) 201 | q = qkv[:, :, 0, :, :] 202 | k = qkv[:, :, 1, :, :] 203 | v = qkv[:, :, 2, :, :] 204 | out, attn = dot_product_attention(q, k, v, key_pad_mask=torch.logical_not(key_padding_mask) if key_padding_mask is not None else None, dropout=self.attention_dropout) 205 | output = out.transpose(-2, -3).contiguous().view(bs, -1, self.num_heads * self.head_dim) 206 | output = self.out_proj(output) 207 | return output, attn 208 | 209 | if key_padding_mask is not None: 210 | batch_size = qkv.shape[0] 211 | seqlen = qkv.shape[1] 212 | x_unpad, indices, cu_seqlens, max_s = unpad_input(qkv, key_padding_mask) 213 | output_unpad = self.flash_self_attn( 214 | x_unpad, 215 | cu_seqlens=cu_seqlens, 216 | max_seqlen=max_s, 217 | return_attn_probs=return_attn_probs 218 | ) 219 | out = pad_input(rearrange(output_unpad, '... h d -> ... (h d)'), indices, batch_size, seqlen) 220 | else: 221 | output = self.flash_self_attn( 222 | qkv, 223 | cu_seqlens=None, 224 | max_seqlen=None, 225 | return_attn_probs=return_attn_probs 226 | ) 227 | out = rearrange(output, '... h d -> ... (h d)') 228 | 229 | out = self.out_proj(out) 230 | return out, None 231 | -------------------------------------------------------------------------------- /models/ipa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.encoders.single import PerResidueEncoder 5 | from models.encoders.pair import ResiduePairEncoder 6 | from models.encoders.attn import GAEncoder 7 | from models.register import ModelRegister 8 | from models.model import load_esm, load_rinalmo, segment_cat_pad, cat_pad 9 | from models.lora_tune import LoRAESM, LoRARiNALMo, ESMConfig, RiNALMoConfig 10 | from peft import ( 11 | LoraConfig, 12 | get_peft_model, 13 | ) 14 | R = ModelRegister() 15 | 16 | @R.register('ipa') 17 | class InvariantPointAttention(nn.Module): 18 | def __init__(self, 19 | node_feat_dim=640, 20 | rinalmo_weights='./weights/rinalmo_giga_pretrained.pt', 21 | esm_type='650M', 22 | use_lm=False, 23 | fix_lms=True, 24 | pair_feat_dim=64, 25 | num_layers=3, 26 | pooling='mean', 27 | output_dim=1, 28 | representation_layer=33, 29 | lora_tune=False, 30 | lora_rank=16, 31 | lora_alpha=32, 32 | **kwargs): 33 | 34 | super().__init__() 35 | self.use_lm = use_lm 36 | self.fix_lms = fix_lms 37 | self.proj = 0 38 | self.representation_layer = representation_layer 39 | if self.use_lm: 40 | self.esm, esm_feat_size = load_esm(esm_type) 41 | self.rinalmo, rinalmo_feat_size = load_rinalmo(rinalmo_weights) 42 | if esm_feat_size != rinalmo_feat_size: 43 | self.proj = 1 44 | self.project_feat= nn.Linear(esm_feat_size, rinalmo_feat_size) 45 | self.feat_size = rinalmo_feat_size 46 | self.proj_cplx= nn.Linear(self.feat_size, node_feat_dim) 47 | if lora_tune: 48 | print("Getting Lora!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") 49 | # copied from LongLoRA 50 | rinalmo_lora_config = LoraConfig( 51 | r=lora_rank, 52 | bias="none", 53 | lora_alpha=lora_alpha 54 | ) 55 | esm_lora_config = LoraConfig( 56 | r=lora_rank, 57 | bias="none", 58 | lora_alpha=lora_alpha 59 | ) 60 | rinalmo_config = RiNALMoConfig() 61 | esm_config = ESMConfig() 62 | self.rinalmo = LoRARiNALMo(self.rinalmo, rinalmo_config) 63 | self.esm = LoRAESM(self.esm, esm_config) 64 | self.rinalmo = get_peft_model(self.rinalmo, rinalmo_lora_config) 65 | print("Get RINALMO DONE!!!!!") 66 | self.esm = get_peft_model(self.esm, esm_lora_config) 67 | print("Get ESM DONE!!!!!") 68 | elif fix_lms: 69 | for p in self.rinalmo.parameters(): 70 | p.requires_grad_(False) 71 | for p in self.esm.parameters(): 72 | p.requires_grad_(False) 73 | # Encoding 74 | else: 75 | self.single_encoder = PerResidueEncoder( 76 | feat_dim=node_feat_dim, 77 | max_num_atoms=4, # N, CA, C, O, CB, 78 | ) 79 | self.masked_bias = nn.Embedding( 80 | num_embeddings=2, 81 | embedding_dim=node_feat_dim, 82 | padding_idx=0, 83 | ) 84 | self.pair_encoder = ResiduePairEncoder( 85 | feat_dim=pair_feat_dim, 86 | max_num_atoms=4, # N, CA, C, O, CB, 87 | ) 88 | self.pooling = pooling 89 | self.attn_encoder = GAEncoder(node_feat_dim=node_feat_dim, pair_feat_dim=pair_feat_dim, num_layers=num_layers) 90 | self.pred_head = nn.Sequential( 91 | nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(), 92 | nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(), 93 | nn.Linear(node_feat_dim, output_dim) 94 | ) 95 | 96 | def forward(self, input, strategy='separate'): 97 | mask_residue = input['mask_atoms'][:, :, 1] #CA 98 | if not self.use_lm: 99 | x = self.single_encoder( 100 | aa=input['restype'], 101 | # phi=batch['phi'], phi_mask=batch['phi_mask'], 102 | # psi=batch['psi'], psi_mask=batch['psi_mask'], 103 | # chi=batch['chi'], chi_mask=batch['chi_mask'], 104 | mask_residue=mask_residue, 105 | ) 106 | else: 107 | prot_input = input['prot'] 108 | prot_chains = input['prot_chains'] 109 | prot_mask = input['protein_mask'] 110 | na_input = input['na'] 111 | na_chains = input['na_chains'] 112 | na_mask = input['na_mask'] 113 | with torch.cuda.amp.autocast(): 114 | prot_embedding = self.esm(prot_input, repr_layers=[self.representation_layer], return_contacts=False)['representations'][self.representation_layer] 115 | na_embedding = self.rinalmo(na_input)['representation'] 116 | if self.proj: 117 | prot_embedding = self.proj_cplx(prot_embedding) 118 | prot_embedding = prot_embedding.float() 119 | na_embedding = na_embedding.float() 120 | # print("Original Embedding:", prot_embedding, na_embedding) 121 | max_len = input['pos_atoms'].shape[1] 122 | if 'patch_idx' in input: 123 | patch_idx = input['patch_idx'] 124 | else: 125 | patch_idx = None 126 | # Adjust the embeddings from LMs for CFormer 127 | if strategy == 'separate': 128 | # input shape [N', L], where N' is flexible in every batch 129 | out_embedding, masks = segment_cat_pad(prot_embedding, prot_chains, prot_mask, na_embedding, na_chains, na_mask, max_len, patch_idx) 130 | assert out_embedding.shape[0] == input['size'] 131 | else: 132 | out_embedding, masks = cat_pad(prot_embedding, prot_mask, na_embedding, na_mask, max_len, patch_idx) 133 | assert out_embedding.shape[0] == input['size'] 134 | x = self.proj_cplx(out_embedding) 135 | 136 | aa=input['restype'] 137 | res_nb=input['res_nb'] 138 | chain_nb=input['chain_nb'] 139 | pos_atoms=input['pos_atoms'] 140 | mask_atoms=input['mask_atoms'] 141 | 142 | if self.pooling == 'token': 143 | mask_special = torch.zeros((len(out_embedding), 1), device=out_embedding.device, dtype=key_padding_mask.dtype) 144 | cplx_embed = self.complex_embedding.repeat(len(out_embedding), 1, 1) 145 | prot_embed = self.prot_embedding.repeat(len(out_embedding), 1, 1) 146 | rna_embed = self.rna_embedding.repeat(len(out_embedding), 1, 1) 147 | 148 | out_embedding = torch.cat([cplx_embed, prot_embed, rna_embed, out_embedding], dim=1) 149 | key_padding_mask = torch.cat([mask_special, mask_special, mask_special, key_padding_mask], dim=1) 150 | 151 | cplx_type = torch.ones_like(mask_special, device=out_embedding.device, dtype=aa.dtype) * SUPER_CPLX_IDX 152 | prot_type = torch.ones_like(mask_special, device=out_embedding.device, dtype=aa.dtype) * SUPER_PROT_IDX 153 | rna_type = torch.ones_like(mask_special, device=out_embedding.device, dtype=aa.dtype) * SUPER_RNA_IDX 154 | aa = torch.cat([cplx_type, prot_type, rna_type, aa], dim=1) 155 | 156 | res_nb_cplx = torch.ones_like(mask_special, device=out_embedding.device, dtype=res_nb.dtype) * 0 157 | res_nb_prot = torch.ones_like(mask_special, device=out_embedding.device, dtype=res_nb.dtype) * 1 158 | res_nb_rna = torch.ones_like(mask_special, device=out_embedding.device, dtype=res_nb.dtype) * 2 159 | 160 | res_nb = torch.cat([res_nb_cplx, res_nb_prot, res_nb_rna, res_nb], dim=1) 161 | super_chain_id = torch.ones_like(mask_special, device=out_embedding.device, dtype=chain_nb.dtype) * SUPER_CHAIN_IDX 162 | chain_nb = torch.cat([super_chain_id, super_chain_id, super_chain_id, chain_nb], dim=1) 163 | 164 | center_cplx = torch.zeros((len(out_embedding), 1, pos_atoms.shape[2], 3), device=out_embedding.device, dtype=pos_atoms.dtype) 165 | center_prot = ((pos_atoms * input['identifier'][:, :, None, None] * mask_atoms.unsqueeze(-1)).reshape([len(out_embedding), -1, 3]).sum(dim=1) / (input['identifier'][:, :, None] * mask_atoms).reshape([len(out_embedding), -1]).sum(dim=-1).unsqueeze(-1))[:, None, None, :].repeat(1, 1, 4, 1) 166 | center_rna = ((pos_atoms * (1-input['identifier'][:, :, None, None]) * mask_atoms.unsqueeze(-1)).reshape([len(out_embedding), -1, 3]).sum(dim=1) / ((1-input['identifier'][:, :, None]) * mask_atoms).reshape([len(out_embedding), -1]).sum(dim=-1).unsqueeze(-1))[:, None, None, :].repeat(1, 1, 4, 1) 167 | pos_atoms = torch.cat([center_cplx, center_prot, center_rna, pos_atoms], dim=1) 168 | mask_atom = torch.zeros((len(out_embedding), 1, pos_atoms.shape[2]), device=out_embedding.device, dtype=mask_atoms.dtype) 169 | mask_atom[:,:,0] = 1 170 | mask_atoms = torch.cat([mask_atom, mask_atom, mask_atom, mask_atoms], dim=1) 171 | 172 | 173 | z = self.pair_encoder( 174 | aa=aa, 175 | res_nb=res_nb, 176 | chain_nb=chain_nb, 177 | pos_atoms=pos_atoms, 178 | mask_atoms=mask_atoms, 179 | ) 180 | 181 | x = self.attn_encoder( 182 | pos_atoms=input['pos_atoms'], 183 | res_feat=x, pair_feat=z, 184 | mask=mask_residue 185 | ) 186 | if self.pooling == 'token': 187 | complex_embedding = x[:, 0, :] 188 | else: 189 | complex_embedding = (x * input['seq_mask'].unsqueeze(-1)).sum(dim=1) 190 | if self.pooling == 'mean': 191 | # Prot_mask: [N, L] 192 | complex_mask_sum = input['seq_mask'].sum(dim=1, keepdim=True) 193 | complex_embedding = complex_embedding / (complex_mask_sum + 1e-10) 194 | 195 | output = self.pred_head(complex_embedding) 196 | output = output.squeeze(1) 197 | 198 | return output -------------------------------------------------------------------------------- /models/esm_rinalmo_seq.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import esm 4 | from rinalmo.config import model_config 5 | from rinalmo.model.model import RiNALMo 6 | from models.register import ModelRegister 7 | from peft import ( 8 | LoraConfig, 9 | get_peft_model, 10 | ) 11 | from models.lora_tune import LoRAESM, LoRARiNALMo, ESMConfig, RiNALMoConfig 12 | from models.components.valina_transformer import Transformer 13 | from models.model import cat_pad, segment_cat_pad 14 | R = ModelRegister() 15 | 16 | def load_esm(esm_type): 17 | if esm_type == '650M': 18 | model, _ = esm.pretrained.esm2_t33_650M_UR50D() 19 | elif esm_type == '3B': 20 | model, _ = esm.pretrained.esm2_t36_3B_UR50D() 21 | elif esm_type == '15B': 22 | model, _ = esm.pretrained.esm2_t48_15B_UR50D() 23 | else: 24 | raise NotImplementedError 25 | feat_size = model.embed_dim 26 | return model, feat_size 27 | 28 | def load_rinalmo(rinalmo_weights): 29 | config = model_config('giga') 30 | model = RiNALMo(config) 31 | # alphabet = Alphabet(**config['alphabet']) 32 | model.load_state_dict(torch.load(rinalmo_weights)) 33 | feat_size = config.globals.embed_dim 34 | return model, feat_size 35 | 36 | def segment_pool(input, chains, mask, pooling): 37 | # input shape [N', L, E], mask_shape [N', L] 38 | result = input.new_full([len(chains), input.shape[-1]], 0) # (N, E) 39 | mask_result = mask.new_full([len(chains), 1], 0) #(N, 1) 40 | input_flattened = input.reshape((-1, input.shape[-1])) #(N'*L, E) 41 | mask_flattened = mask.reshape((-1, 1)) #(N'*L, 1) 42 | # print("Shapes:", result.shape, mask_result.shape, input_flattened.shape, mask_flattened.shape) 43 | # segment_id shape (N', ) 44 | segment_id = torch.tensor(sum([[i] * chain for i, chain in enumerate(chains)], start=[]), device=result.device, dtype=torch.int64) 45 | segment_id = segment_id.repeat_interleave(input.shape[1]) #(N'*L) 46 | result.scatter_add_(0, segment_id.unsqueeze(1).expand_as(input_flattened), input_flattened*mask_flattened) 47 | mask_result.scatter_add_(0, segment_id.unsqueeze(1), mask_flattened) 48 | mask_result.reshape((-1, )) 49 | 50 | if pooling == 'mean': 51 | result = result / (mask_result + 1e-10) 52 | 53 | return result 54 | 55 | @R.register('esm2_rinalmo_seq') 56 | class ESM2RiNALMo(nn.Module): 57 | def __init__(self, 58 | rinalmo_weights='./weights/rinalmo_giga_pretrained.pt', 59 | esm_type='650M', 60 | pooling='token', 61 | output_dim=1, 62 | fix_lms=True, 63 | lora_tune=False, 64 | lora_rank=16, 65 | lora_alpha=32, 66 | representation_layer=33, 67 | vallina=True, 68 | **kwargs 69 | ): 70 | super(ESM2RiNALMo, self).__init__() 71 | self.esm, esm_feat_size = load_esm(esm_type) 72 | self.rinalmo, rinalmo_feat_size = load_rinalmo(rinalmo_weights) 73 | self.vallina=vallina 74 | # if esm_feat_size != rinalmo_feat_size: 75 | # self.project_layer = nn.Linear(esm_feat_size, rinalmo_feat_size) 76 | self.cat_size = esm_feat_size + rinalmo_feat_size 77 | self.feat_size = rinalmo_feat_size 78 | self.representation_layer = representation_layer 79 | self.transformer = Transformer(**kwargs['transformer']) 80 | self.complex_dim = kwargs['transformer']['embed_dim'] 81 | self.proj_cplx= nn.Linear(self.feat_size, self.complex_dim) 82 | self.pooling = pooling 83 | if self.pooling == 'token': 84 | self.prot_embedding = nn.Parameter(torch.zeros((1, self.complex_dim), dtype=torch.float32)) 85 | self.rna_embedding = nn.Parameter(torch.zeros((1, self.complex_dim), dtype=torch.float32)) 86 | self.complex_embedding = nn.Parameter(torch.zeros((1, self.complex_dim), dtype=torch.float32)) 87 | nn.init.normal_(self.prot_embedding) 88 | nn.init.normal_(self.rna_embedding) 89 | nn.init.normal_(self.complex_embedding) 90 | if lora_tune: 91 | # copied from LongLoRA 92 | rinalmo_lora_config = LoraConfig( 93 | r=lora_rank, 94 | bias="none", 95 | lora_alpha=lora_alpha 96 | ) 97 | esm_lora_config = LoraConfig( 98 | r=lora_rank, 99 | bias="none", 100 | lora_alpha=lora_alpha 101 | ) 102 | rinalmo_config = RiNALMoConfig() 103 | esm_config = ESMConfig() 104 | self.rinalmo = LoRARiNALMo(self.rinalmo, rinalmo_config) 105 | # print(esm_config) 106 | self.esm = LoRAESM(self.esm, esm_config) 107 | # print("ESM:", self.esm) 108 | self.rinalmo = get_peft_model(self.rinalmo, rinalmo_lora_config) 109 | # print("Get RINALMO DONE!!!!!") 110 | self.esm = get_peft_model(self.esm, esm_lora_config) 111 | # print("Get ESM DONE!!!!!") 112 | 113 | elif fix_lms: 114 | for p in self.rinalmo.parameters(): 115 | p.requires_grad_(False) 116 | for p in self.esm.parameters(): 117 | p.requires_grad_(False) 118 | self.pooling = pooling 119 | self.pred_head = nn.Sequential( 120 | nn.Linear(self.complex_dim, self.feat_size), nn.ReLU(), 121 | nn.Linear(self.feat_size, output_dim) 122 | ) 123 | if self.vallina: 124 | print("Using vallina version!") 125 | self.cat_pred_head = nn.Sequential( 126 | nn.Linear(self.cat_size, self.feat_size), nn.ReLU(), 127 | nn.Linear(self.feat_size, output_dim) 128 | ) 129 | 130 | 131 | def forward(self, input, strategy='separate'): 132 | prot_input = input['prot'] 133 | prot_chains = input['prot_chains'] 134 | prot_mask = input['protein_mask'] 135 | na_input = input['na'] 136 | na_chains = input['na_chains'] 137 | na_mask = input['na_mask'] 138 | # print("Input Shape:", prot_input.shape, na_input.shape, prot_chains, na_chains) 139 | with torch.cuda.amp.autocast(): 140 | prot_embedding = self.esm(prot_input, repr_layers=[self.representation_layer], return_contacts=False)['representations'][self.representation_layer] 141 | na_embedding = self.rinalmo(na_input)['representation'] 142 | 143 | # Vallina implementation with mean pooling 144 | # print("Original Embedding:", prot_embedding.shape, na_embedding.shape) 145 | if self.vallina: 146 | if strategy == 'separate': 147 | # input shape [N', L], where N' is flexible in every batch 148 | prot_embedding = segment_pool(prot_embedding, prot_chains, prot_mask, pooling=self.pooling) 149 | na_embedding = segment_pool(na_embedding, na_chains, na_mask, pooling=self.pooling) 150 | else: 151 | if self.pooling == 'max': 152 | prot_embedding = (prot_embedding * prot_mask.unsqueeze(-1)).max(dim=1)[0] 153 | na_embedding = (na_embedding * na_mask.unsqueeze(-1)).max(dim=1)[0] 154 | else: 155 | prot_embedding = (prot_embedding * prot_mask.unsqueeze(-1)).sum(dim=1) 156 | na_embedding = (na_embedding * na_mask.unsqueeze(-1)).sum(dim=1) 157 | if self.pooling == 'mean': 158 | # Prot_mask: [N, L] 159 | prot_mask_sum = prot_mask.sum(dim=1, keepdim=True) 160 | na_mask_sum = na_mask.sum(dim=1, keepdim=True) 161 | prot_embedding = prot_embedding / (prot_mask_sum + 1e-10) 162 | na_embedding = na_embedding / (na_mask_sum + 1e-10) 163 | complex_embedding = torch.cat([prot_embedding, na_embedding], dim=1) 164 | output = self.cat_pred_head(complex_embedding) 165 | output = output.squeeze(1) 166 | else: 167 | prot_embedding = prot_embedding.float() 168 | na_embedding = na_embedding.float() 169 | # print("Original Embedding:", prot_embedding, na_embedding) 170 | max_len = input['pos_atoms'].shape[1] 171 | # Adjust the embeddings from LMs for CFormer 172 | if 'patch_idx' in input: 173 | patch_idx = input['patch_idx'] 174 | else: 175 | patch_idx = None 176 | 177 | if strategy == 'separate': 178 | # input shape [N', L], where N' is flexible in every batch 179 | out_embedding, masks = segment_cat_pad(prot_embedding, prot_chains, prot_mask, na_embedding, na_chains, na_mask, max_len, patch_idx) 180 | assert out_embedding.shape[0] == input['size'] 181 | else: 182 | out_embedding, masks = cat_pad(prot_embedding, prot_mask, na_embedding, na_mask, max_len, patch_idx) 183 | assert out_embedding.shape[0] == input['size'] 184 | 185 | out_embedding = self.proj_cplx(out_embedding) 186 | key_padding_mask = ~masks 187 | 188 | if self.pooling == 'token': 189 | mask_special = torch.zeros((len(out_embedding), 1), device=out_embedding.device, dtype=key_padding_mask.dtype) 190 | cplx_embed = self.complex_embedding.repeat(len(out_embedding), 1, 1) 191 | prot_embed = self.prot_embedding.repeat(len(out_embedding), 1, 1) 192 | rna_embed = self.rna_embedding.repeat(len(out_embedding), 1, 1) 193 | out_embedding = torch.cat([cplx_embed, prot_embed, rna_embed, out_embedding], dim=1) 194 | key_padding_mask = torch.cat([mask_special, mask_special, mask_special, key_padding_mask], dim=1) 195 | 196 | output, _ = self.transformer(out_embedding, key_padding_mask=key_padding_mask, need_attn_weights=False) 197 | 198 | if self.pooling == 'token': 199 | complex_embedding = output[:, 0, :].squeeze(1) 200 | else: 201 | complex_embedding = (output * (~key_padding_mask).unsqueeze(-1)).sum(dim=1) 202 | if self.pooling == 'mean': 203 | # Prot_mask: [N, L] 204 | seq_mask_sum = (~key_padding_mask).sum(dim=1, keepdim=True) 205 | complex_embedding = complex_embedding / (seq_mask_sum + 1e-10) 206 | 207 | output = self.pred_head(complex_embedding) 208 | output = output.squeeze(1) 209 | 210 | return output 211 | 212 | 213 | 214 | 215 | 216 | -------------------------------------------------------------------------------- /pl_modules/model_module.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import pandas as pd 8 | import pytorch_lightning as pl 9 | from models import ModelRegister 10 | from utils.metrics import ScalarMetricAccumulator, cal_pearson, cal_spearman, cal_rmse, cal_mae, get_loss 11 | def get_model(model_args:dict=None): 12 | register = ModelRegister() 13 | model_args_ori = {} 14 | model_args_ori.update(model_args) 15 | model_cls = register[model_args['model_type']] 16 | model = model_cls(**model_args_ori) 17 | return model 18 | 19 | class ModelModule(pl.LightningModule): 20 | def __init__(self, output_dir=None, model_args=None, data_args=None, run_args=None): 21 | super().__init__() 22 | self.save_hyperparameters() 23 | if model_args is None: 24 | model_args = {} 25 | if data_args is None: 26 | data_args = {} 27 | self.output_dir = output_dir 28 | if self.output_dir is not None: 29 | self.output_dir = Path(self.output_dir) / 'pred' 30 | self.output_dir.mkdir(parents=True, exist_ok=True) 31 | self.l_type = data_args.loss_type 32 | self.model = get_model(model_args=model_args.model) 33 | self.model_args = model_args 34 | self.data_args = data_args 35 | self.run_args = run_args 36 | self.optimizers_cfg = self.model_args.train.optimizer 37 | self.scheduler_cfg = self.model_args.train.scheduler 38 | self.valid_it = 0 39 | self.batch_size = data_args.batch_size 40 | 41 | self.train_loss = None 42 | 43 | def get_progress_bar_dict(self): 44 | tqdm_dict = super().get_progress_bar_dict() 45 | tqdm_dict.pop('v_num', None) 46 | return tqdm_dict 47 | 48 | def configure_optimizers(self): 49 | if self.optimizers_cfg.type == 'adam': 50 | optimizer = torch.optim.Adam(self.parameters(), 51 | lr=self.optimizers_cfg.lr, 52 | betas=(self.optimizers_cfg.beta1, self.optimizers_cfg.beta2, )) 53 | elif self.optimizers_cfg.type == 'sgd': 54 | optimizer = torch.optim.SGD(self.parameters(), lr=self.optimizers_cfg.lr) 55 | elif self.optimizers_cfg.type == 'rmsprop': 56 | optimizer = torch.optim.RMSprop(self.parameters(), lr=self.optimizers_cfg.lr) 57 | else: 58 | raise NotImplementedError('Optimizer not supported: %s' % self.optimizers_cfg.type) 59 | 60 | if self.scheduler_cfg.type == 'plateau': 61 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 62 | factor=self.scheduler_cfg.factor, 63 | patience=self.scheduler_cfg.patience, 64 | min_lr=self.scheduler_cfg.min_lr) 65 | elif self.scheduler_cfg.type == 'multistep': 66 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 67 | milestones=self.scheduler_cfg.milestones, 68 | gamma=self.scheduler_cfg.gamma) 69 | elif self.scheduler_cfg.type == 'exp': 70 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 71 | gamma=self.scheduler_cfg.gamma) 72 | else: 73 | raise NotImplementedError('Scheduler not supported: %s' % self.scheduler_cfg.type) 74 | 75 | if self.model_args.resume is not None: 76 | print("Resuming from checkloint: %s" % self.model_args.resume) 77 | ckpt = torch.load(self.model_args.resume, map_location=self.model_args.device) 78 | it_first = ckpt['iteration'] 79 | lsd_result = self.model.load_state_dict(ckpt['state_dict'], strict=False) 80 | print('Missing keys (%d): %s' % (len(lsd_result.missing_keys), ', '.join(lsd_result.missing_keys))) 81 | print( 82 | 'Unexpected keys (%d): %s' % (len(lsd_result.unexpected_keys), ', '.join(lsd_result.unexpected_keys))) 83 | 84 | print('Resuming optimizer states...') 85 | optimizer.load_state_dict(ckpt['optimizer']) 86 | print('Resuming scheduler states...') 87 | scheduler.load_state_dict(ckpt['scheduler']) 88 | 89 | if self.scheduler_cfg.type == 'plateau': 90 | optim_dict = { 91 | "optimizer": optimizer, 92 | "lr_scheduler": { 93 | "scheduler": scheduler, 94 | "monitor": 'val_loss' 95 | } 96 | } 97 | else: 98 | optim_dict = { 99 | "optimizer": optimizer, 100 | "lr_scheduler": { 101 | "scheduler": scheduler, 102 | } 103 | } 104 | return optim_dict 105 | 106 | def on_train_start(self): 107 | log_hyperparams = {'model_args':self.model_args, 'data_args': self.data_args, 'run_args': self.run_args} 108 | self.logger.log_hyperparams(log_hyperparams) 109 | 110 | def on_before_optimizer_step(self, optimizer) -> None: 111 | pass 112 | # for name, param in self.named_parameters(): 113 | # if param.grad is None: 114 | # print(name) 115 | # print("Found Unused Parameters") 116 | 117 | def training_step(self, batch, batch_idx): 118 | y = batch['labels'] 119 | pred = self.model(batch, self.data_args.strategy) 120 | # print(y.shape, pred.shape) 121 | loss = get_loss(self.l_type, pred, y, reduction='mean') 122 | # if torch.isnan(loss).any(): 123 | # print("Found nan in loss!", input) 124 | # exit() 125 | self.train_loss = loss.detach() 126 | self.log("train_loss", float(self.train_loss), batch_size=self.batch_size, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True) 127 | return loss 128 | 129 | def on_validation_epoch_start(self): 130 | self.scalar_accum = ScalarMetricAccumulator() 131 | self.results = [] 132 | 133 | def validation_step(self, batch, batch_idx): 134 | y = batch['labels'] 135 | pred = self.model(batch, self.data_args.strategy) 136 | val_loss = get_loss(self.l_type, pred, y, reduction='mean') 137 | self.scalar_accum.add(name='val_loss', value=val_loss, batchsize=self.batch_size, mode='mean') 138 | self.log("val_loss_step", val_loss, batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) 139 | 140 | for y_true, y_pred in zip(batch['labels'], pred): 141 | result = {} 142 | result['y_true'] = y_true.item() 143 | result['y_pred'] = y_pred.item() 144 | self.results.append(result) 145 | return val_loss 146 | 147 | def on_validation_epoch_end(self): 148 | results = pd.DataFrame(self.results) 149 | # print("Validation:", results) 150 | if self.output_dir is not None: 151 | results.to_csv(os.path.join(self.output_dir, f'results_{self.valid_it}.csv'), index=False) 152 | y_pred = np.array(results[f'y_pred']) 153 | y_true = np.array(results[f'y_true']) 154 | pearson_all = np.abs(cal_pearson(y_pred, y_true)) 155 | spearman_all = np.abs(cal_spearman(y_pred, y_true)) 156 | rmse_all = cal_rmse(y_pred, y_true) 157 | mae_all = cal_mae(y_pred, y_true) 158 | print(f'[All_Task] Pearson {pearson_all:.6f} Spearman {spearman_all:.6f} RMSE {rmse_all:.6f} MAE {mae_all:.6f}') 159 | 160 | self.log(f'val/all_pearson', pearson_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True) 161 | self.log(f'val/all_spearman', spearman_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True) 162 | self.log(f'val/all_rmse', rmse_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True) 163 | self.log(f'val/all_mae', mae_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True) 164 | 165 | val_loss = rmse_all * rmse_all 166 | self.log('val_loss', val_loss, batch_size=self.batch_size, on_epoch=True, sync_dist=True) 167 | # Trigger scheduler 168 | self.valid_it += 1 169 | return val_loss 170 | 171 | def on_test_epoch_start(self) -> None: 172 | self.results = [] 173 | self.scalar_accum = ScalarMetricAccumulator() 174 | 175 | def test_step(self, batch, batch_idx): 176 | y = batch['labels'] 177 | pred = self.model(batch, self.data_args.strategy) 178 | test_loss = get_loss(self.l_type, pred, y, reduction='mean') 179 | self.scalar_accum.add(name='loss', value = test_loss, batchsize=self.batch_size, mode='mean') 180 | self.log("test_loss_step", test_loss, batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) 181 | 182 | for y_true, y_pred in zip(batch['labels'], pred): 183 | result = {} 184 | result['y_true'] = y_true.item() 185 | result['y_pred'] = y_pred.item() 186 | self.results.append(result) 187 | return test_loss 188 | 189 | def on_test_epoch_end(self): 190 | results = pd.DataFrame(self.results) 191 | if self.output_dir is not None: 192 | results.to_csv(os.path.join(self.output_dir, f'results_test.csv'), index=False) 193 | y_pred = np.array(results[f'y_pred']) 194 | y_true = np.array(results[f'y_true']) 195 | pearson_all = np.abs(cal_pearson(y_pred, y_true)) 196 | spearman_all = np.abs(cal_spearman(y_pred, y_true)) 197 | rmse_all = cal_rmse(y_pred, y_true) 198 | mae_all = cal_mae(y_pred, y_true) 199 | print(f'[All_Task] Pearson {pearson_all:.6f} Spearman {spearman_all:.6f} RMSE {rmse_all:.6f} MAE {mae_all:.6f}') 200 | 201 | self.log(f'test/all_pearson', pearson_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True) 202 | self.log(f'test/all_spearman', spearman_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True) 203 | self.log(f'test/all_rmse', rmse_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True) 204 | self.log(f'test/all_mae', mae_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True) 205 | self.res = {"pearson": pearson_all,"spearman": spearman_all, "rmse": rmse_all, "mae": mae_all} 206 | print("Self.Res:", self.res) 207 | # test_loss = self.scalar_accum.get_average('loss') 208 | test_loss = rmse_all * rmse_all 209 | self.log('test_loss', test_loss, batch_size=self.batch_size, on_epoch=True, sync_dist=True) 210 | 211 | return test_loss -------------------------------------------------------------------------------- /models/components/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | import math 5 | 6 | from models.components.rope import RotaryPositionEmbedding 7 | 8 | from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func 9 | from flash_attn.layers.rotary import RotaryEmbedding 10 | 11 | from flash_attn.bert_padding import unpad_input, pad_input 12 | import torch 13 | from einops import rearrange 14 | 15 | def dot_product_attention(q, k, v, struct_embed, attn_mask=None, key_pad_mask=None, dropout=None): 16 | c = q.shape[-1] 17 | attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(c) 18 | # Add Structure into attn 19 | attn += struct_embed 20 | if torch.isnan(attn).any(): 21 | print("Found nan 0!") 22 | if attn_mask is not None: 23 | if len(attn_mask.shape) < len(attn.shape): 24 | attn_mask = attn_mask[:, None, ...] 25 | attn = attn.masked_fill(attn_mask, float("-inf")) 26 | if torch.isnan(attn).any(): 27 | print("Found nan 1!") 28 | # attn shape is [B, H, L, L] 29 | if key_pad_mask is not None: 30 | attn = attn.masked_fill(key_pad_mask.unsqueeze(1).unsqueeze(2), float("-inf")) 31 | if torch.isnan(attn).any(): 32 | print("Found nan 2!") 33 | attn = attn.softmax(dim=-1) 34 | if dropout is not None: 35 | attn = dropout(attn) 36 | 37 | output = torch.matmul(attn, v) 38 | return output, attn 39 | 40 | class MultiHeadAttention(nn.Module): 41 | def __init__(self, c_in, pair_dim, num_heads, attention_dropout=0.0, use_rot_emb=True, bias=False, c_squeeze=8): 42 | super().__init__() 43 | assert c_in % num_heads == 0, "Embedding dimensionality must be divisible with number of attention heads!" 44 | 45 | self.c_in = c_in 46 | self.num_heads = num_heads 47 | 48 | self.c_head = c_in // self.num_heads 49 | self.c_qkv = self.c_head * num_heads 50 | 51 | self.use_rot_emb = use_rot_emb 52 | if self.use_rot_emb: 53 | self.rotary_emb = RotaryPositionEmbedding(self.c_head) 54 | 55 | self.struct_bias = nn.Linear(pair_dim, self.num_heads, bias=bias) 56 | self.outer_linear = nn.Linear(c_squeeze * c_squeeze, pair_dim) 57 | self.outer_squeeze = nn.Linear(self.c_in, c_squeeze) 58 | self.to_q = nn.Linear(self.c_in, self.c_qkv, bias=bias) 59 | self.to_k = nn.Linear(self.c_in, self.c_qkv, bias=bias) 60 | self.to_v = nn.Linear(self.c_in, self.c_qkv, bias=bias) 61 | 62 | self.attention_dropout = nn.Dropout(p=attention_dropout) 63 | 64 | self.out_proj = nn.Linear(c_in, c_in, bias=bias) 65 | self.struct_out_proj = nn.Linear(pair_dim, pair_dim, bias=bias) 66 | 67 | def forward(self, q, k, v, struct_embed, attn_mask=None, key_pad_mask=None): 68 | bs = q.shape[0] 69 | 70 | q = self.to_q(q).view(bs, -1, self.num_heads, self.c_head).transpose(-2, -3) 71 | k = self.to_k(k).view(bs, -1, self.num_heads, self.c_head).transpose(-2, -3) 72 | v = self.to_v(v).view(bs, -1, self.num_heads, self.c_head).transpose(-2, -3) 73 | 74 | struct_attr = self.struct_bias(struct_embed).permute(0, 3, 1, 2) # (N, H, L, L) 75 | # print("Shape of Structure attribute:", struct_attr.shape) 76 | if self.use_rot_emb: 77 | q, k = self.rotary_emb(q, k) 78 | 79 | output, attn = dot_product_attention(q, k, v, struct_attr, attn_mask, key_pad_mask, self.attention_dropout) 80 | 81 | output = output.transpose(-2, -3).contiguous().view(bs, -1, self.num_heads * self.c_head) 82 | 83 | a = self.outer_squeeze(output) 84 | outer_product = torch.einsum('...bc, ...de -> ...bdce', a, a) # N, L, L, C, C 85 | struct_output = struct_embed + self.outer_linear(outer_product.reshape(list(outer_product.shape[:-2]) + [-1])) 86 | 87 | output = self.out_proj(output) 88 | struct_output = self.struct_out_proj(struct_output) 89 | 90 | return output, struct_output, attn 91 | 92 | class MultiHeadSelfAttention(nn.Module): 93 | def __init__(self, c_in, pair_dim, num_heads, attention_dropout=0.0, use_rot_emb=True, bias=False): 94 | super().__init__() 95 | 96 | self.mh_attn = MultiHeadAttention(c_in, pair_dim, num_heads, attention_dropout, use_rot_emb, bias) 97 | 98 | def forward(self, x, struct_embed, attn_mask=None, key_pad_mask=None): 99 | return self.mh_attn(x, x, x, struct_embed, attn_mask, key_pad_mask) 100 | 101 | class FlashAttention(nn.Module): 102 | """ 103 | Implement the scaled dot product attention with softmax. 104 | """ 105 | def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): 106 | """ 107 | Args: 108 | causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). 109 | softmax_scale: float. The scaling of QK^T before applying softmax. 110 | Default to 1 / sqrt(headdim). 111 | attention_dropout: float. The dropout rate to apply to the attention 112 | (default: 0.0) 113 | """ 114 | super().__init__() 115 | self.softmax_scale = softmax_scale 116 | self.attention_dropout = attention_dropout 117 | self.causal = causal 118 | 119 | def forward(self, qkv, cu_seqlens=None, max_seqlen=None, return_attn_probs=False): 120 | """ 121 | Arguments 122 | --------- 123 | qkv: The tensor containing the query, key, and value. 124 | If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). 125 | If cu_seqlens is not None and max_seqlen is not None, then qkv has shape 126 | (total, 3, H, D), where total is the sum of the sequence lengths in the batch. 127 | cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 128 | of the sequences in the batch, used to index into qkv. 129 | max_seqlen: int. Maximum sequence length in the batch. 130 | return_attn_probs: bool. Whether to return the attention probabilities. This option is for 131 | testing only. The returned probabilities are not guaranteed to be correct 132 | (they might not have the right scaling). 133 | Returns: 134 | -------- 135 | out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, 136 | else (B, S, H, D). 137 | """ 138 | assert qkv.dtype in [torch.float16, torch.bfloat16] 139 | assert qkv.is_cuda 140 | 141 | unpadded = cu_seqlens is not None 142 | 143 | if unpadded: 144 | assert cu_seqlens.dtype == torch.int32 145 | assert max_seqlen is not None 146 | assert isinstance(max_seqlen, int) 147 | return flash_attn_varlen_qkvpacked_func( 148 | qkv, 149 | cu_seqlens, 150 | max_seqlen, 151 | self.attention_dropout if self.training else 0.0, 152 | softmax_scale=self.softmax_scale, 153 | causal=self.causal, 154 | return_attn_probs=return_attn_probs 155 | ) 156 | else: 157 | return flash_attn_qkvpacked_func( 158 | qkv, 159 | self.attention_dropout if self.training else 0.0, 160 | softmax_scale=self.softmax_scale, 161 | causal=self.causal, 162 | return_attn_probs=return_attn_probs 163 | ) 164 | 165 | class FlashMultiHeadSelfAttention(nn.Module): 166 | """ 167 | Multi-head self-attention implemented using FlashAttention. 168 | """ 169 | def __init__(self, embed_dim, num_heads, attention_dropout=0.0, causal=False, use_rot_emb=True, bias=False): 170 | super().__init__() 171 | assert embed_dim % num_heads == 0, "Embedding dimensionality must be divisible with number of attention heads!" 172 | 173 | self.causal = causal 174 | 175 | self.embed_dim = embed_dim 176 | self.num_heads = num_heads 177 | 178 | self.head_dim = self.embed_dim // self.num_heads 179 | self.qkv_dim = self.head_dim * num_heads * 3 180 | 181 | self.rotary_emb_dim = self.head_dim 182 | self.use_rot_emb = use_rot_emb 183 | if self.use_rot_emb: 184 | self.rotary_emb = RotaryEmbedding( 185 | dim=self.rotary_emb_dim, 186 | base=10000.0, 187 | interleaved=False, 188 | scale_base=None, 189 | pos_idx_in_fp32=True, # fp32 RoPE precision 190 | device=None 191 | ) 192 | self.flash_self_attn = FlashAttention(causal=self.causal, softmax_scale=None, attention_dropout=attention_dropout) 193 | 194 | self.Wqkv = nn.Linear(self.embed_dim, self.qkv_dim, bias=bias) 195 | 196 | self.attention_dropout = nn.Dropout(p=attention_dropout) 197 | 198 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias) 199 | 200 | def forward(self, x, key_padding_mask=None, return_attn_probs=False): 201 | """ 202 | Arguments: 203 | x: (batch, seqlen, hidden_dim) (where hidden_dim = num_heads * head_dim) 204 | key_pad_mask: boolean mask, True means to keep, False means to mask out. 205 | (batch, seqlen) 206 | return_attn_probs: whether to return attention masks (False by default) 207 | """ 208 | 209 | qkv = self.Wqkv(x) 210 | qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) 211 | 212 | if self.use_rot_emb: 213 | qkv = self.rotary_emb(qkv, seqlen_offset=0) 214 | 215 | if return_attn_probs: 216 | bs = qkv.shape[0] 217 | qkv = torch.permute(qkv, (0, 3, 2, 1, 4)) 218 | q = qkv[:, :, 0, :, :] 219 | k = qkv[:, :, 1, :, :] 220 | v = qkv[:, :, 2, :, :] 221 | out, attn = dot_product_attention(q, k, v, key_pad_mask=torch.logical_not(key_padding_mask) if key_padding_mask is not None else None, dropout=self.attention_dropout) 222 | output = out.transpose(-2, -3).contiguous().view(bs, -1, self.num_heads * self.head_dim) 223 | output = self.out_proj(output) 224 | return output, attn 225 | 226 | if key_padding_mask is not None: 227 | batch_size = qkv.shape[0] 228 | seqlen = qkv.shape[1] 229 | x_unpad, indices, cu_seqlens, max_s = unpad_input(qkv, key_padding_mask) 230 | output_unpad = self.flash_self_attn( 231 | x_unpad, 232 | cu_seqlens=cu_seqlens, 233 | max_seqlen=max_s, 234 | return_attn_probs=return_attn_probs 235 | ) 236 | out = pad_input(rearrange(output_unpad, '... h d -> ... (h d)'), indices, batch_size, seqlen) 237 | else: 238 | output = self.flash_self_attn( 239 | qkv, 240 | cu_seqlens=None, 241 | max_seqlen=None, 242 | return_attn_probs=return_attn_probs 243 | ) 244 | out = rearrange(output, '... h d -> ... (h d)') 245 | 246 | out = self.out_proj(out) 247 | return out, None 248 | -------------------------------------------------------------------------------- /data/transforms/patch.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | from ._base import _index_select_data, register_transform, _get_CB_positions, _index_select_complex 5 | import math 6 | 7 | @register_transform('focused_random_patch') 8 | class FocusedRandomPatch(object): 9 | 10 | def __init__(self, focus_attr, seed_nbh_size=32, patch_size=128): 11 | super().__init__() 12 | self.focus_attr = focus_attr 13 | self.seed_nbh_size = seed_nbh_size 14 | self.patch_size = patch_size 15 | 16 | def __call__(self, data): 17 | focus_flag = (data[self.focus_attr] > 0) # (L, ) 18 | if focus_flag.sum() == 0: 19 | # If there is no active residues, randomly pick one. 20 | focus_flag[random.randint(0, focus_flag.size(0) - 1)] = True 21 | seed_idx = torch.multinomial(focus_flag.float(), num_samples=1).item() 22 | 23 | pos_CB = _get_CB_positions(data['pos_atoms'], data['mask_atoms']) # (L, ) 24 | pos_seed = pos_CB[seed_idx:seed_idx + 1] # (1, ) 25 | dist_from_seed = torch.cdist(pos_CB, pos_seed)[:, 0] # (L, 1) -> (L, ) 26 | nbh_seed_idx = dist_from_seed.argsort()[:self.seed_nbh_size] # (Nb, ) 27 | 28 | core_idx = nbh_seed_idx[focus_flag[nbh_seed_idx]] # (Ac, ), the core-set must be a subset of the focus-set 29 | dist_from_core = torch.cdist(pos_CB, pos_CB[core_idx]).min(dim=1)[0] # (L, ) 30 | patch_idx = dist_from_core.argsort()[:self.patch_size] # (P, ) # The distance to the itself is zero, thus the item must be chose. 31 | patch_idx = patch_idx.sort()[0] 32 | 33 | core_flag = torch.zeros([data['aa'].size(0), ], dtype=torch.bool) 34 | core_flag[core_idx] = True 35 | data['core_flag'] = core_flag 36 | 37 | data_patch = _index_select_data(data, patch_idx) 38 | return data_patch 39 | 40 | 41 | @register_transform('random_patch') 42 | class RandomPatch(object): 43 | 44 | def __init__(self, seed_nbh_size=32, patch_size=128): 45 | super().__init__() 46 | self.seed_nbh_size = seed_nbh_size 47 | self.patch_size = patch_size 48 | 49 | def __call__(self, data): 50 | seed_idx = random.randint(0, data['aa'].size(0) - 1) 51 | 52 | pos_CB = _get_CB_positions(data['pos_atoms'], data['mask_atoms']) # (L, ) 53 | pos_seed = pos_CB[seed_idx:seed_idx + 1] # (1, ) 54 | dist_from_seed = torch.cdist(pos_CB, pos_seed)[:, 0] # (L, 1) -> (L, ) 55 | core_idx = dist_from_seed.argsort()[:self.seed_nbh_size] # (Nb, ) 56 | 57 | dist_from_core = torch.cdist(pos_CB, pos_CB[core_idx]).min(dim=1)[0] # (L, ) 58 | patch_idx = dist_from_core.argsort()[:self.patch_size] # (P, ) 59 | patch_idx = patch_idx.sort()[0] 60 | 61 | core_flag = torch.zeros([data['aa'].size(0), ], dtype=torch.bool) 62 | core_flag[core_idx] = True 63 | data['core_flag'] = core_flag 64 | 65 | data_patch = _index_select_data(data, patch_idx) 66 | return data_patch 67 | 68 | 69 | @register_transform('selected_region_with_padding_patch') 70 | class SelectedRegionWithPaddingPatch(object): 71 | 72 | def __init__(self, select_attr, each_residue_nbh_size, patch_size_limit): 73 | super().__init__() 74 | self.select_attr = select_attr 75 | self.each_residue_nbh_size = each_residue_nbh_size 76 | self.patch_size_limit = patch_size_limit 77 | 78 | def __call__(self, data): 79 | select_flag = (data[self.select_attr] > 0) 80 | 81 | pos_CB = _get_CB_positions(data['pos_atoms'], data['mask_atoms']) # (L, 3) 82 | pos_sel = pos_CB[select_flag] # (S, 3) 83 | dist_from_sel = torch.cdist(pos_CB, pos_sel) # (L, S) 84 | nbh_sel_idx = torch.argsort(dist_from_sel, dim=0)[:self.each_residue_nbh_size, :] # (nbh, S) 85 | patch_idx = nbh_sel_idx.view(-1).unique() # (patchsize,) 86 | 87 | data_patch = _index_select_data(data, patch_idx) 88 | return data_patch 89 | 90 | 91 | @register_transform('selected_region_fixed_size_patch') 92 | class SelectedRegionFixedSizePatch(object): 93 | 94 | def __init__(self, select_attr, patch_size): 95 | super().__init__() 96 | self.select_attr = select_attr 97 | self.patch_size = patch_size 98 | 99 | def __call__(self, data): 100 | select_flag = (data[self.select_attr] > 0) 101 | 102 | pos_CB = _get_CB_positions(data['pos_atoms'], data['mask_atoms']) # (L, 3) 103 | pos_sel = pos_CB[select_flag] # (S, 3) 104 | # print("Pos CB and sel:", pos_sel.shape, pos_CB.shape, select_flag.shape) 105 | dist_from_sel = torch.cdist(pos_CB, pos_sel).min(dim=1)[0] # (L, ) 106 | # print(self.patch_size) 107 | patch_idx = torch.argsort(dist_from_sel)[:self.patch_size] 108 | 109 | data_patch = _index_select_data(data, patch_idx) 110 | return data_patch 111 | 112 | @register_transform('selected_region_with_distmap') 113 | class SelectedRegionWithDistmap(object): 114 | 115 | def __init__(self, patch_size): 116 | super().__init__() 117 | self.patch_size = patch_size 118 | 119 | def __call__(self, data): 120 | atoms_dist_min = data['atom_min_dist'] 121 | 122 | identifier = data['identifier'] 123 | tmp = atoms_dist_min[identifier==0] 124 | interface_distance = tmp[:, identifier==1] 125 | prot_min_dist = interface_distance.min(dim=1)[0] 126 | rna_min_dist = interface_distance.transpose(0, 1).min(dim=1)[0] 127 | total_min = torch.cat([prot_min_dist, rna_min_dist], dim=0) 128 | patch_idx = torch.argsort(total_min)[:self.patch_size] 129 | patch_idx, _ = torch.sort(patch_idx) 130 | # print(self.patch_size) 131 | data_patch = _index_select_complex(data, patch_idx) 132 | data_patch['patch_idx'] = patch_idx 133 | return data_patch 134 | 135 | @register_transform('interface_max_size_patch') 136 | class InterfaceFixedMaxSizePatch(object): 137 | def __init__(self, max_size, interface_dist=7.5, kernel_size=100): 138 | super().__init__() 139 | self.max_size = max_size 140 | self.interface_dist = interface_dist 141 | self.kernel_size = kernel_size 142 | 143 | def segment_fill_zero(self, mask, chain_nb): 144 | unique_segs = torch.unique(chain_nb) 145 | result = mask.clone() 146 | can_fills = [] 147 | for seg in unique_segs: 148 | seg_mask = chain_nb == seg 149 | seg_x = mask[seg_mask] 150 | 151 | ones_mask = seg_x == 1 152 | seg_result = ones_mask.clone() 153 | # cumsum = torch.cumsum(ones_mask, dim=0) 154 | indices = torch.nonzero(ones_mask).flatten() 155 | start = indices.min() 156 | end = indices.max() + 1 157 | seg_result[start: end] = True 158 | if len(seg_result) <= 30: 159 | seg_result[:] = True 160 | # seg_result = (cumsum > 0).bool() 161 | can_fill = (~seg_result).sum() 162 | can_fills.append(can_fill.item()) 163 | result[seg_mask] = seg_result 164 | return result, can_fills 165 | 166 | def segment_fill_to_max(self, mask, chain_nb, length_to_fill, can_fill, prot_seqs, na_seqs): 167 | unique_segs = torch.unique(chain_nb) 168 | result = mask.clone() 169 | # assert sum(can_fill) >= length_to_fill 170 | remainders = length_to_fill 171 | prot_seqs_new = [] 172 | na_seqs_new = [] 173 | for i, seg in enumerate(unique_segs): 174 | to_fill = math.ceil(int((can_fill[i] / (sum(can_fill) + 1e-5)) * length_to_fill)) 175 | if remainders < to_fill: 176 | to_fill = remainders 177 | remainders -= to_fill 178 | seg_mask = chain_nb == seg 179 | seg_x = mask[seg_mask] 180 | 181 | indices = torch.nonzero(seg_x).flatten() 182 | try: 183 | left_len = indices.min() 184 | right_len = len(seg_x) - indices.max() - 1 185 | except: 186 | print("???") 187 | 188 | if to_fill >= left_len + right_len: 189 | seg_x[:] = 1 190 | start = 0 191 | end = indices.max() + 1 192 | else: 193 | left_fill = random.randint(max(to_fill-right_len, 0), min(left_len,to_fill)) 194 | right_fill = to_fill - left_fill 195 | assert left_fill + right_fill == to_fill 196 | start = indices.min() - left_fill 197 | end = indices.max() + right_fill + 1 198 | seg_x[start: end] = 1 199 | 200 | result[seg_mask] = seg_x 201 | 202 | if i < len(prot_seqs): 203 | prot_seq_new = prot_seqs[i][start: end] 204 | prot_seqs_new.append(prot_seq_new) 205 | else: 206 | na_seq_new = na_seqs[i-len(prot_seqs)][start: end] 207 | na_seqs_new.append(na_seq_new) 208 | return result, prot_seqs_new, na_seqs_new 209 | 210 | def __call__(self, data): 211 | if len(data['pos_atoms']) <= self.max_size: 212 | # print("Short data, no need to process!") 213 | return data 214 | pos_CB = _get_CB_positions(data['pos_atoms'], data['mask_atoms']) 215 | identifier = data['identifier'] 216 | chain_nb = data['chain_nb'] 217 | dist_map = torch.cdist(pos_CB, pos_CB) 218 | dist_map[identifier[:, None] == identifier[None, :]] = 10000 219 | # print("Original data:", data['id'], len(''.join(data['prot_seqs'])), len(''.join(data['rna_seqs']))) 220 | contact_dis_min = torch.min(dist_map, dim=-1)[0] 221 | kernel_area = torch.zeros_like(contact_dis_min).bool() 222 | # kernel_area = contact_dis_min <= self.interface_dist 223 | contact_indices = torch.argsort(contact_dis_min) 224 | kernel_indices = contact_indices[: self.kernel_size] 225 | kernel_area.index_fill_(dim=0, index=kernel_indices, value=True) 226 | continuous_kernel_area, can_fill = self.segment_fill_zero(kernel_area, chain_nb) 227 | # if torch.sum(continuous_kernel_area) > self.max_size: 228 | # print("Max size:", torch.sum(continuous_kernel_area)) 229 | to_fill = max((self.max_size - torch.sum(continuous_kernel_area)).item(), 0) 230 | continuous_kernel_area, prot_seq_new, na_seq_new = self.segment_fill_to_max(continuous_kernel_area, chain_nb, 231 | to_fill, can_fill, 232 | data['prot_seqs'], data['rna_seqs']) 233 | select_idx = torch.nonzero(continuous_kernel_area).flatten() 234 | # print("New selected data:", len(select_idx)) 235 | data = _index_select_complex(data, select_idx) 236 | data['prot_seqs'] = prot_seq_new 237 | data['rna_seqs'] = na_seq_new 238 | prot_lengths = [len(item) for item in prot_seq_new] 239 | na_lengths = [len(item) for item in na_seq_new] 240 | data['max_prot_length'] = max(prot_lengths) 241 | data['max_na_length'] = max(na_lengths) 242 | return data -------------------------------------------------------------------------------- /data/pri30k_dataset.py: -------------------------------------------------------------------------------- 1 | from data.register import DataRegister 2 | from torch.utils.data import Dataset 3 | import pandas as pd 4 | import esm 5 | import torch 6 | from rinalmo.data.constants import * 7 | from rinalmo.data.alphabet import Alphabet 8 | from tqdm import tqdm 9 | import diskcache 10 | import os 11 | import math 12 | import time 13 | from data.transforms import get_transform 14 | from torch.utils.data._utils.collate import default_collate 15 | from typing import Optional, Dict 16 | from easydict import EasyDict 17 | from data.structure_dataset import _process_structure 18 | 19 | na_alphabet_config = { 20 | "standard_tkns": RNA_TOKENS, 21 | "special_tkns": [CLS_TKN, PAD_TKN, EOS_TKN, UNK_TKN, MASK_TKN], 22 | } 23 | 24 | R = DataRegister() 25 | # ATOM_N, ATOM_CA, ATOM_C, ATOM_O, ATOM_CB = 0, 1, 2, 3, 4 26 | # ATOM_P, ATOM_C4, ATOM_NB = 37, 38, 27 | @R.register('pri30k_dataset') 28 | class PRI30kDataset(Dataset): 29 | ''' 30 | The implementation of Protein-RNA structure Dataset 31 | ''' 32 | def __init__(self, 33 | dataframe, 34 | data_root, 35 | col_prot_name='PDB', 36 | col_prot_chain='Protein chains', 37 | col_na_chain='RNA chains', 38 | col_binding_site='Binding site renumbered merged', 39 | col_ligand='Binding ligands', 40 | diskcache=None, 41 | transform=None, 42 | **kwargs 43 | ): 44 | self.data_root = data_root 45 | self.df: pd.DataFrame = dataframe.copy() 46 | self.df.reset_index(drop=True, inplace=True) 47 | self.col_prot_name = col_prot_name 48 | self.col_prot_chain = col_prot_chain 49 | self.col_na_chain = col_na_chain 50 | self.col_binding_site = col_binding_site 51 | self.col_ligand = col_ligand 52 | self.diskcache = diskcache 53 | self.prot_alphabet = esm.data.Alphabet.from_architecture("ESM-1b") 54 | self.na_alphabet = Alphabet(**na_alphabet_config) 55 | 56 | self.transform = get_transform(transform) 57 | 58 | # self.load_data() 59 | 60 | def load_data(self, idx): 61 | row = self.df.loc[idx] 62 | structure_id = row[self.col_prot_name] 63 | prot_chains = [row[self.col_prot_chain]] 64 | na_chains = [row[self.col_na_chain]] 65 | structure_id = structure_id + '_' + prot_chains[0] + '_' + na_chains[0] 66 | if self.diskcache is None or structure_id not in self.diskcache: 67 | structure_name = structure_id + '.cif' 68 | pdb_path = os.path.join(self.data_root, structure_name) 69 | 70 | cplx = _process_structure(pdb_path, structure_id, prot_chains, na_chains, gpu='cuda:0') 71 | 72 | ligand_id = row[self.col_prot_name] + '_' + row[self.col_na_chain] 73 | L = len(cplx['seq']) 74 | gpu_atoms = cplx['pos_heavyatom'] 75 | gpu_masks = cplx['mask_heavyatom'] 76 | distance_map = torch.linalg.norm(gpu_atoms[:, None, :, None, :]- gpu_atoms[None, :, None, :, :], dim=-1, ord=2).reshape(L, L, -1) 77 | mask = (gpu_masks[:, None, :, None] * gpu_masks[None, :, None, :]).reshape(L, L, -1) 78 | distance_map[~mask] = torch.inf 79 | atom_min_dist = torch.min(distance_map, dim=-1)[0] 80 | 81 | max_prot_length = 0 82 | max_na_length = 0 83 | for prot_seq in cplx.prot_seqs: 84 | if len(prot_seq) > max_prot_length: 85 | max_prot_length = len(prot_seq) 86 | for na_seq in cplx.rna_seqs: 87 | if len(na_seq) > max_na_length: 88 | max_na_length = len(na_seq) 89 | 90 | item = { 91 | 'ligand_id': ligand_id, # no need to pad 92 | 'atom_min_dist': atom_min_dist, # needs 2D padding 93 | 'max_prot_length': max_prot_length, # will be ignored in batching 94 | 'max_na_length': max_na_length, # will be ignored in batching 95 | 'can_bind': row[self.col_ligand] # no need to pad 96 | } 97 | 98 | cplx.update(item) 99 | if self.diskcache is not None: 100 | for key in cplx: 101 | if isinstance(cplx[key], torch.Tensor): 102 | cplx[key] = cplx[key].detach().cpu() 103 | self.diskcache[structure_id] = cplx 104 | return cplx 105 | else: 106 | return self.diskcache[structure_id] 107 | 108 | def __len__(self): 109 | return len(self.df) 110 | 111 | def __getitem__(self, idx): 112 | data = self.load_data(idx) 113 | if self.transform is not None: 114 | data = self.transform(data) 115 | return data 116 | 117 | EXCLUDE_KEYS = [] 118 | DEFAULT_PAD_VALUES = { 119 | 'restype': 26, 120 | 'mask_atoms': 0, 121 | 'chain_nb': -1, 122 | 'identifier': -1 123 | } 124 | 125 | class PRI30kStructCollate(object): 126 | def __init__(self, strategy='separate', length_ref_key='restype', pad_values=DEFAULT_PAD_VALUES, exclude_keys=EXCLUDE_KEYS, eight=True): 127 | super().__init__() 128 | self.strategy = strategy 129 | self.length_ref_key = length_ref_key 130 | self.pad_values = pad_values 131 | self.exclude_keys = exclude_keys 132 | self.eight = eight 133 | 134 | @staticmethod 135 | def _pad_last(x, n, value=0): 136 | if isinstance(x, torch.Tensor): 137 | assert x.size(0) <= n 138 | if x.size(0) == n: 139 | return x 140 | pad_size = [n - x.size(0)] + list(x.shape[1:]) 141 | pad = torch.full(pad_size, fill_value=value).to(x) 142 | return torch.cat([x, pad], dim=0) 143 | elif isinstance(x, list): 144 | pad = [value] * (n - len(x)) 145 | return x + pad 146 | else: 147 | return x 148 | 149 | @staticmethod 150 | def _get_pad_mask(l, n): 151 | return torch.cat([ 152 | torch.ones([l], dtype=torch.bool), 153 | torch.zeros([n - l], dtype=torch.bool) 154 | ], dim=0) 155 | 156 | @staticmethod 157 | def _get_common_keys(list_of_dict): 158 | keys = set(list_of_dict[0].keys()) 159 | for d in list_of_dict[1:]: 160 | keys = keys.intersection(d.keys()) 161 | return keys 162 | 163 | def _get_pad_value(self, key): 164 | if key not in self.pad_values: 165 | return 0 166 | return self.pad_values[key] 167 | 168 | def pad_2d(self, x, n, value=10000): 169 | assert isinstance(x, torch.Tensor) 170 | assert x.shape[0] == x.shape[1] 171 | if x.size(0) == n and x.size(1) == n: 172 | return x 173 | pad_size_1 = [n - x.size(0)] + list(x.shape[1:]) 174 | pad = torch.full(pad_size_1, fill_value=value).to(x) 175 | x_padded_1 = torch.cat([x, pad], dim=0) 176 | pad_size_2 = [x_padded_1.shape[0]] + [n - x_padded_1.size(1)] + list(x_padded_1.shape[2:]) 177 | pad = torch.full(pad_size_2, fill_value=value).to(x) 178 | x_padded = torch.cat([x_padded_1, pad], dim=1) 179 | return x_padded 180 | 181 | 182 | 183 | def collate_complex(self, data_list): 184 | max_length = max([data[self.length_ref_key].size(0) for data in data_list]) 185 | keys_inter = self._get_common_keys(data_list) 186 | keys = [] 187 | keys_not_pad = [] 188 | keys_ignore = ['prot_seqs', 'rna_seqs', 'max_prot_length', 'max_na_length', 'can_bind', 'ligand_id', 'atom_min_dist'] 189 | pad_2d = ['atom_min_dist'] 190 | for key in keys_inter: 191 | if key in keys_ignore: 192 | continue 193 | elif key not in self.exclude_keys: 194 | keys.append(key) 195 | else: 196 | keys_not_pad.append(key) 197 | 198 | if self.eight: 199 | max_length = math.ceil(max_length / 8) * 8 200 | data_list_padded = [] 201 | 202 | for data in data_list: 203 | data_padded = { 204 | k: self._pad_last(v, max_length, value=self._get_pad_value(k)) 205 | for k, v in data.items() 206 | if k in keys 207 | } 208 | 209 | for k in pad_2d: 210 | data_padded[k] = self.pad_2d(data[k], max_length) 211 | 212 | for k in keys_not_pad: 213 | data_padded[k] = data[k] 214 | 215 | data_padded['mask'] = self._get_pad_mask(data[self.length_ref_key].size(0), max_length) 216 | data_list_padded.append(data_padded) 217 | return data_list_padded 218 | 219 | def pad_for_berts(self, batch): 220 | prot_alphabet = esm.data.Alphabet.from_architecture("ESM-1b") 221 | na_alphabet = Alphabet(**na_alphabet_config) 222 | prot_chains = [len(item['prot_seqs']) for item in batch] 223 | na_chains = [len(item['rna_seqs']) for item in batch] 224 | max_item_prot_length = [item['max_prot_length'] for item in batch] 225 | max_item_na_length = [item['max_na_length'] for item in batch] 226 | max_prot_length = max(max_item_prot_length) 227 | max_na_length = max(max_item_na_length) 228 | total_prot_chains = sum(prot_chains) 229 | total_na_chains = sum(na_chains) 230 | if self.eight: 231 | max_prot_length = math.ceil((max_prot_length + 2) / 8) * 8 232 | max_na_length = math.ceil((max_na_length + 2) / 8) * 8 233 | else: 234 | max_prot_length = max_prot_length + 2 235 | max_na_length = max_na_length + 2 236 | prot_batch = torch.empty([total_prot_chains, max_prot_length]) 237 | prot_batch.fill_(prot_alphabet.padding_idx) 238 | na_batch = torch.empty([total_na_chains, max_na_length]) 239 | na_batch.fill_(na_alphabet.pad_idx) 240 | curr_prot_idx = 0 241 | curr_na_idx = 0 242 | for item in batch: 243 | prot_seqs = item['prot_seqs'] 244 | na_seqs = item['rna_seqs'] 245 | for prot_seq in prot_seqs: 246 | prot_batch[curr_prot_idx, 0] = prot_alphabet.cls_idx 247 | prot_seq_encode = prot_alphabet.encode(prot_seq) 248 | seq = torch.tensor(prot_seq_encode, dtype=torch.int64) 249 | prot_batch[curr_prot_idx, 1: len(prot_seq_encode)+1] = seq 250 | prot_batch[curr_prot_idx, len(prot_seq_encode)+1] = prot_alphabet.eos_idx 251 | curr_prot_idx += 1 252 | for na_seq in na_seqs: 253 | na_seq_encode = na_alphabet.encode(na_seq) 254 | seq = torch.tensor(na_seq_encode, dtype=torch.int64) 255 | na_batch[curr_na_idx, :len(seq)] = seq 256 | curr_na_idx += 1 257 | 258 | prot_mask = torch.zeros_like(prot_batch) 259 | na_mask = torch.zeros_like(na_batch) 260 | prot_mask[(prot_batch!=prot_alphabet.padding_idx) & (prot_batch!=prot_alphabet.eos_idx) & (prot_batch!=prot_alphabet.cls_idx)] = 1 261 | na_mask[(na_batch!=na_alphabet.pad_idx) & (na_batch!=na_alphabet.eos_idx) & (na_batch!=na_alphabet.cls_idx)] = 1 262 | return prot_batch.long(), prot_chains, prot_mask, na_batch.long(), na_chains, na_mask 263 | 264 | def gen_clip_label(self, ligand_ids, can_bind_info): 265 | clip_label = torch.eye(len(ligand_ids)) 266 | for i, can_bind in enumerate(can_bind_info): 267 | candidates = can_bind.split(',') 268 | for j, ligand_id in enumerate(ligand_ids): 269 | if ligand_id in candidates: 270 | clip_label[i, j] = 1 271 | return clip_label.long() 272 | 273 | 274 | def __call__(self, data_list): 275 | data_list_padded = self.collate_complex(data_list) 276 | batch = default_collate(data_list_padded) 277 | batch['size'] = len(data_list_padded) 278 | prot_batch, prot_chains, prot_mask, na_batch, na_chains, na_mask = self.pad_for_berts(data_list) 279 | batch['prot'] = prot_batch 280 | batch['prot_chains'] = prot_chains 281 | batch['protein_mask'] = prot_mask 282 | batch['na'] = na_batch 283 | batch['na_chains'] = na_chains 284 | batch['na_mask'] = na_mask 285 | batch['strategy'] = self.strategy 286 | ligand_ids = [item['ligand_id'] for item in data_list] 287 | can_bind_info = [item['can_bind'] for item in data_list] 288 | batch['clip_label'] = self.gen_clip_label(ligand_ids, can_bind_info) 289 | return batch -------------------------------------------------------------------------------- /pl_modules/pretune_module.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import pandas as pd 8 | import pytorch_lightning as pl 9 | from models import ModelRegister 10 | from utils.metrics import ScalarMetricAccumulator 11 | 12 | def get_model(model_args:dict=None): 13 | register = ModelRegister() 14 | model_args_ori = {} 15 | model_args_ori.update(model_args) 16 | model_cls = register[model_args['model_type']] 17 | model = model_cls(**model_args_ori) 18 | return model 19 | 20 | class PretuneModule(pl.LightningModule): 21 | def __init__(self, output_dir=None, model_args=None, data_args=None, run_args=None): 22 | super().__init__() 23 | self.save_hyperparameters() 24 | if model_args is None: 25 | model_args = {} 26 | if data_args is None: 27 | data_args = {} 28 | self.output_dir = output_dir 29 | if self.output_dir is not None: 30 | self.output_dir = Path(self.output_dir) / 'pred' 31 | self.output_dir.mkdir(parents=True, exist_ok=True) 32 | self.l_type = data_args.loss_type 33 | self.model = get_model(model_args=model_args.model) 34 | self.model_args = model_args 35 | self.data_args = data_args 36 | self.run_args = run_args 37 | self.optimizers_cfg = self.model_args.train.optimizer 38 | self.scheduler_cfg = self.model_args.train.scheduler 39 | self.valid_it = 0 40 | self.temperature = model_args.train.temperature 41 | self.batch_size = data_args.batch_size 42 | 43 | self.train_loss = None 44 | 45 | def get_progress_bar_dict(self): 46 | tqdm_dict = super().get_progress_bar_dict() 47 | tqdm_dict.pop('v_num', None) 48 | return tqdm_dict 49 | 50 | def configure_optimizers(self): 51 | if self.optimizers_cfg.type == 'adam': 52 | optimizer = torch.optim.Adam(self.parameters(), 53 | lr=self.optimizers_cfg.lr, 54 | betas=(self.optimizers_cfg.beta1, self.optimizers_cfg.beta2, )) 55 | elif self.optimizers_cfg.type == 'sgd': 56 | optimizer = torch.optim.SGD(self.parameters(), lr=self.optimizers_cfg.lr) 57 | elif self.optimizers_cfg.type == 'rmsprop': 58 | optimizer = torch.optim.RMSprop(self.parameters(), lr=self.optimizers_cfg.lr) 59 | else: 60 | raise NotImplementedError('Optimizer not supported: %s' % self.optimizers_cfg.type) 61 | 62 | if self.scheduler_cfg.type == 'plateau': 63 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 64 | factor=self.scheduler_cfg.factor, 65 | patience=self.scheduler_cfg.patience, 66 | min_lr=self.scheduler_cfg.min_lr) 67 | elif self.scheduler_cfg.type == 'multistep': 68 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 69 | milestones=self.scheduler_cfg.milestones, 70 | gamma=self.scheduler_cfg.gamma) 71 | elif self.scheduler_cfg.type == 'exp': 72 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 73 | gamma=self.scheduler_cfg.gamma) 74 | else: 75 | raise NotImplementedError('Scheduler not supported: %s' % self.scheduler_cfg.type) 76 | 77 | if self.model_args.resume is not None: 78 | print("Resuming from checkloint: %s" % self.model_args.resume) 79 | ckpt = torch.load(self.model_args.resume, map_location=self.model_args.device) 80 | it_first = ckpt['iteration'] 81 | lsd_result = self.model.load_state_dict(ckpt['state_dict'], strict=False) 82 | print('Missing keys (%d): %s' % (len(lsd_result.missing_keys), ', '.join(lsd_result.missing_keys))) 83 | print( 84 | 'Unexpected keys (%d): %s' % (len(lsd_result.unexpected_keys), ', '.join(lsd_result.unexpected_keys))) 85 | 86 | print('Resuming optimizer states...') 87 | optimizer.load_state_dict(ckpt['optimizer']) 88 | print('Resuming scheduler states...') 89 | scheduler.load_state_dict(ckpt['scheduler']) 90 | 91 | if self.scheduler_cfg.type == 'plateau': 92 | optim_dict = { 93 | "optimizer": optimizer, 94 | "lr_scheduler": { 95 | "scheduler": scheduler, 96 | "monitor": 'val_loss' 97 | } 98 | } 99 | else: 100 | optim_dict = { 101 | "optimizer": optimizer, 102 | "lr_scheduler": { 103 | "scheduler": scheduler, 104 | } 105 | } 106 | return optim_dict 107 | 108 | def on_train_start(self): 109 | log_hyperparams = {'model_args':self.model_args, 'data_args': self.data_args, 'run_args': self.run_args} 110 | self.logger.log_hyperparams(log_hyperparams) 111 | 112 | def on_before_optimizer_step(self, optimizer) -> None: 113 | pass 114 | # for name, param in self.named_parameters(): 115 | # if param.grad is None: 116 | # print(name) 117 | # print("Found Unused Parameters") 118 | 119 | def continuous_to_discrete_tensor(self, values): 120 | discrete_values = torch.zeros_like(values, dtype=torch.long, device=values.device) 121 | 122 | mask_0_8 = values < 8 123 | discrete_values[mask_0_8] = (values[mask_0_8] * 2 + 0.5).long() 124 | 125 | mask_8_32 = (values >= 8) & (values < 32) 126 | discrete_values[mask_8_32] = (16 + (values[mask_8_32] - 8)).long() 127 | 128 | mask_ge_32 = (values >= 32) 129 | discrete_values[mask_ge_32] = 39 130 | discrete_values = torch.clamp(discrete_values, 0, 39).long() 131 | return discrete_values 132 | 133 | def cal_loss(self, pred_clip, y_clip, pred_dist, y_dist, identifier): 134 | # y_dist = torch.clamp(y_dist.long(), 0, 31).long() 135 | y_dist = self.continuous_to_discrete_tensor(y_dist) 136 | total_y_dists = [] 137 | total_pred_dists = [] 138 | total_y_dists_inv = [] 139 | total_pred_dists_inv = [] 140 | for i in range(identifier.shape[0]): 141 | item_identifier = identifier[i].squeeze() 142 | y_tmp = y_dist[i, item_identifier==0] 143 | y_interface_dist = y_tmp[:, item_identifier==1].flatten() 144 | total_y_dists.append(y_interface_dist) 145 | 146 | y_pred_tmp = pred_dist[i, item_identifier==0] 147 | y_interface_pred = y_pred_tmp[:, item_identifier==1].reshape([-1, pred_dist.shape[-1]]) 148 | total_pred_dists.append(y_interface_pred) 149 | 150 | y_tmp_inv = y_dist[i, item_identifier==1] 151 | y_interface_dist_inv = y_tmp_inv[:, item_identifier==0].flatten() 152 | total_y_dists_inv.append(y_interface_dist_inv) 153 | 154 | y_pred_tmp_inv = pred_dist[i, item_identifier==1] 155 | y_interface_pred_inv = y_pred_tmp_inv[:, item_identifier==0].reshape([-1, pred_dist.shape[-1]]) 156 | total_pred_dists_inv.append(y_interface_pred_inv) 157 | 158 | total_pred_dists = torch.cat(total_pred_dists, dim=0) 159 | total_pred_dists_inv = torch.cat(total_pred_dists_inv, dim=0) 160 | total_y_dists = torch.cat(total_y_dists) 161 | total_y_dists_inv = torch.cat(total_y_dists_inv) 162 | # y_dist_one_hot = F.one_hot(indices, num_classes=32).float() 163 | loss_clip_forward = F.cross_entropy(pred_clip / self.temperature, y_clip.long()) 164 | loss_clip_inverse = F.cross_entropy(pred_clip.transpose(0, 1) / self.temperature, y_clip.long()) 165 | loss_clip = 0.5 * (loss_clip_forward + loss_clip_inverse) 166 | # print("Loss CLIP:", loss_clip) 167 | # loss_dist_forward = F.cross_entropy(y_interface_pred.permute(0, 3, 1, 2) / self.temperature , y_interface_dist.long()) 168 | # loss_dist_inv = F.cross_entropy(y_interface_pred_inv.permute(0, 3, 1, 2) / self.temperature , y_interface_dist_inv.long()) 169 | loss_dist_forward = F.cross_entropy(total_pred_dists / self.temperature , total_y_dists.long()) 170 | loss_dist_inverse = F.cross_entropy(total_pred_dists_inv / self.temperature, total_y_dists_inv.long()) 171 | 172 | loss_dist = 0.5 * (loss_dist_forward + loss_dist_inverse) 173 | # print("Loss Dist:", loss_dist) 174 | return loss_clip, loss_dist 175 | 176 | def training_step(self, batch, batch_idx): 177 | y_dist = batch['atom_min_dist'] 178 | res_identifier = batch['identifier'] 179 | # y_clip = batch['clip_label'] 180 | y_clip = torch.arange(batch['size'], dtype=torch.long, device=y_dist.device) 181 | pred_dist, pred_clip = self.model(batch, self.data_args.strategy, stage='pretune', need_mask=True) 182 | loss_clip, loss_dist = self.cal_loss(pred_clip, y_clip, pred_dist, y_dist, res_identifier) 183 | loss = loss_clip + loss_dist 184 | if torch.isnan(loss).any(): 185 | print("Found nan in loss!", input) 186 | exit() 187 | self.train_loss = loss.detach() 188 | self.log("train_loss", float(self.train_loss), batch_size=self.batch_size, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True) 189 | self.log("train_clip_loss", float(loss_clip.detach()), batch_size=self.batch_size, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True) 190 | self.log("train_dist_loss", float(loss_dist.detach()), batch_size=self.batch_size, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True) 191 | return loss 192 | 193 | def on_validation_epoch_start(self): 194 | self.scalar_accum = ScalarMetricAccumulator() 195 | self.results = [] 196 | 197 | def validation_step(self, batch, batch_idx): 198 | y_dist = batch['atom_min_dist'] 199 | res_identifier = batch['identifier'] 200 | y_clip = torch.arange(batch['size'], dtype=torch.long, device=y_dist.device) 201 | pred_dist, pred_clip = self.model(batch, self.data_args.strategy, stage='pretune', need_mask=True) 202 | 203 | loss_clip, loss_dist = self.cal_loss(pred_clip, y_clip, pred_dist, y_dist, res_identifier) 204 | val_loss = loss_clip + loss_dist 205 | self.scalar_accum.add(name='val_loss', value=val_loss, batchsize=batch['size'], mode='mean') 206 | self.log("val_loss_step", val_loss, batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) 207 | self.log("val_clip_loss", float(loss_clip.detach()), batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) 208 | self.log("val_dist_loss", float(loss_dist.detach()), batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) 209 | return val_loss 210 | 211 | def on_validation_epoch_end(self): 212 | val_loss = self.scalar_accum.get_average('val_loss') 213 | self.log('val_loss', val_loss, batch_size=self.batch_size, on_epoch=True, sync_dist=True) 214 | # Trigger scheduler 215 | self.valid_it += 1 216 | return val_loss 217 | 218 | def on_test_epoch_start(self) -> None: 219 | self.results = [] 220 | self.scalar_accum = ScalarMetricAccumulator() 221 | 222 | def test_step(self, batch, batch_idx): 223 | y_dist = batch['atom_min_dist'] 224 | res_identifier = batch['identifier'] 225 | y_clip = torch.arange(batch['size'], dtype=torch.long, device=y_dist.device) 226 | pred_dist, pred_clip = self.model(batch, self.data_args.strategy, stage='pretune', need_mask=True) 227 | loss_clip, loss_dist = self.cal_loss(pred_clip, y_clip, pred_dist, y_dist, res_identifier) 228 | test_loss = loss_clip + loss_dist 229 | self.scalar_accum.add(name='loss', value = test_loss, batchsize=batch['size'], mode='mean') 230 | self.log("test_loss_step", test_loss, batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) 231 | self.log("test_clip_loss", float(loss_clip.detach()), batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) 232 | self.log("test_dist_loss", float(loss_dist.detach()), batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True) 233 | self.res = {"test_loss_step": float(test_loss.detach()),"test_clip_loss": float(loss_clip.detach()), "test_dist_loss": float(loss_dist.detach())} 234 | return test_loss 235 | 236 | def on_test_epoch_end(self): 237 | test_loss = self.scalar_accum.get_average('loss') 238 | self.log('test_loss', test_loss, batch_size=self.batch_size, on_epoch=True, sync_dist=True) 239 | 240 | return test_loss --------------------------------------------------------------------------------