├── utils ├── __init__.py ├── common.py └── metrics.py ├── assets ├── model_overview.jpg └── results_on_PRA.png ├── data ├── protein │ └── __init__.py ├── rna │ ├── __init__.py │ ├── base_constants.py │ ├── rnas.py │ └── sec_struct_utils.py ├── __init__.py ├── transforms │ ├── __init__.py │ ├── geometric.py │ ├── select_chain.py │ ├── noise.py │ ├── _base.py │ ├── corrupt_chi.py │ ├── select_atom.py │ ├── mask.py │ └── patch.py ├── register.py ├── complex.py ├── sequence_dataset.py └── pri30k_dataset.py ├── pl_modules ├── __init__.py ├── data_module.py ├── model_module.py └── pretune_module.py ├── models ├── __init__.py ├── components │ ├── loss.py │ ├── rope.py │ ├── coformer.py │ ├── valina_transformer.py │ ├── valina_attn.py │ └── attention.py ├── register.py ├── lora_tune.py ├── encoders │ ├── single.py │ ├── pair.py │ ├── layers.py │ └── attn.py ├── ipa.py └── esm_rinalmo_seq.py ├── config ├── runs │ ├── finetune_struct.yml │ ├── finetune_sequence.yml │ ├── zero_shot_blindtest.yml │ ├── pretune_struct.yml │ └── test_basic.yml ├── datasets │ ├── PRA201.yml │ ├── PRA310.yml │ ├── PRI30k.yml │ ├── mCSM.yml │ └── blindtest.yml └── models │ ├── esm2_rinalmo_seq.yml │ └── copra.yml ├── environment.yml ├── LICENSE ├── .gitignore ├── README.md └── run.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | from .metrics import * 3 | from .geometry import * -------------------------------------------------------------------------------- /assets/model_overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hanrthu/CoPRA/HEAD/assets/model_overview.jpg -------------------------------------------------------------------------------- /assets/results_on_PRA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hanrthu/CoPRA/HEAD/assets/results_on_PRA.png -------------------------------------------------------------------------------- /data/protein/__init__.py: -------------------------------------------------------------------------------- 1 | from .atom_convert import * 2 | from .proteins import * 3 | from .residue_constants import * -------------------------------------------------------------------------------- /pl_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_module import * 2 | from .model_module import * 3 | from .ddg_module import * 4 | from .pretune_module import * -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .esm_rinalmo_seq import * 2 | from .model import * 3 | from .components.coformer import * 4 | from .register import * 5 | from .ipa import * -------------------------------------------------------------------------------- /data/rna/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_constants import * 2 | from .data_utils import * 3 | from .featurizer import * 4 | from .rnas import * 5 | # from .sec_struct_utils import * -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .sequence_dataset import * 2 | from .structure_dataset import * 3 | from .register import * 4 | from data.protein.proteins import * 5 | from data.protein.residue_constants import * 6 | from .pri30k_dataset import * -------------------------------------------------------------------------------- /config/runs/finetune_struct.yml: -------------------------------------------------------------------------------- 1 | epochs: 140 2 | patience: 150 3 | output_dir: './outputs/esm2rinalmo_650M/scratch/PRA310' 4 | gpus: 5 | - 7 6 | 7 | ckpt: null 8 | 9 | run_name: 'esm2rinalmo_scratch_patch256' 10 | wandb: False 11 | num_folds: 5 -------------------------------------------------------------------------------- /config/runs/finetune_sequence.yml: -------------------------------------------------------------------------------- 1 | epochs: 30 2 | patience: 50 3 | output_dir: './outputs/esm2rinalmo_650M/finetune_sequence/with_transformer_pdbbind_dG_restrict_remove_cls_token' 4 | gpus: 5 | - 0 6 | ckpt: null 7 | run_name: 'esm2_rinalmo_650M_clstoken' 8 | wandb: true 9 | num_folds: 5 -------------------------------------------------------------------------------- /config/runs/zero_shot_blindtest.yml: -------------------------------------------------------------------------------- 1 | iters: 10500 2 | epochs: 110 3 | patience: 50 4 | output_dir: './outputs/esm2rinalmo_650M/mCSM/blind_test' 5 | gpus: 6 | - 0 7 | ckpts: 8 | - './weights/CoPRA_fold3.ckpt' 9 | 10 | run_name: 'esm2rinalmo_mmCSM_patch256' 11 | wandb: False 12 | num_folds: 1 -------------------------------------------------------------------------------- /config/runs/pretune_struct.yml: -------------------------------------------------------------------------------- 1 | # iters: 30000 2 | epochs: 100 3 | patience: 50 4 | output_dir: './outputs/esm2rinalmo_650M/pretune_struct/esm2rinalmo_650M_cformer_PRI30k_fixed_interface_dist_random_mask0.15_tokenpooling' 5 | gpus: 6 | - 0 7 | - 1 8 | - 2 9 | - 3 10 | ckpt: null 11 | 12 | run_name: 'esm2rinalmo_650M_cformer_PRI30k_fixed_interface_dist_random_mask0.15_tokenpooling' 13 | wandb: false 14 | num_folds: 1 -------------------------------------------------------------------------------- /config/runs/test_basic.yml: -------------------------------------------------------------------------------- 1 | output_dir: './outputs/esm2rinalmo_650M/test_struct/best_model' 2 | gpus: 3 | - 0 4 | 5 | ckpts: 6 | - ./weights/CoPRA_5fold/CoPRA_fold0.ckpt 7 | - ./weights/CoPRA_5fold/CoPRA_fold1.ckpt 8 | - ./weights/CoPRA_5fold/CoPRA_fold2.ckpt 9 | - ./weights/CoPRA_5fold/CoPRA_fold3.ckpt 10 | - ./weights/CoPRA_5fold/CoPRA_fold4.ckpt 11 | 12 | run_name: 'esm2rinalmo_test_best' 13 | wandb: False 14 | num_folds: 5 -------------------------------------------------------------------------------- /data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Transforms 2 | from .patch import FocusedRandomPatch, RandomPatch, SelectedRegionWithPaddingPatch, SelectedRegionFixedSizePatch 3 | from .select_chain import SelectFocused 4 | from .select_atom import SelectAtom 5 | from .mask import RandomMaskAminoAcids, MaskSelectedAminoAcids 6 | from .noise import AddAtomNoise, AddChiAngleNoise 7 | from .corrupt_chi import CorruptChiAngle 8 | from .geometric import SubtractCOM 9 | # Factory 10 | from ._base import get_transform, Compose, _get_CB_positions 11 | -------------------------------------------------------------------------------- /config/datasets/PRA201.yml: -------------------------------------------------------------------------------- 1 | dataset_type: 'structure_dataset' 2 | df_path: './datasets/PRA310/splits/PRA201.csv' 3 | batch_size: 16 4 | data_root: './datasets/PRA310/PDBs' 5 | num_workers: 2 6 | col_prot_chain: 'Protein chains' 7 | col_na_chain: 'RNA chains' 8 | col_prot: 'Protein sequences' 9 | col_na: 'RNA sequences' 10 | col_label: '△G(kcal/mol)' 11 | pin_memory: True 12 | cache_dir: './cache/pra201' 13 | loss_type: regression 14 | strategy: separate 15 | 16 | transform: 17 | - type: select_atom 18 | resolution: backbone 19 | - type: selected_region_with_distmap 20 | patch_size: 256 21 | - type: subtract_center_of_mass -------------------------------------------------------------------------------- /config/datasets/PRA310.yml: -------------------------------------------------------------------------------- 1 | dataset_type: 'structure_dataset' 2 | df_path: './datasets/PRA310/splits/PRA310.csv' 3 | batch_size: 16 4 | data_root: './datasets/PRA310/PDBs' 5 | num_workers: 2 6 | col_prot_chain: 'Protein chains' 7 | col_na_chain: 'RNA chains' 8 | col_prot: 'Protein sequences' 9 | col_na: 'RNA sequences' 10 | col_label: '△G(kcal/mol)' 11 | pin_memory: True 12 | cache_dir: './cache/pra310' 13 | loss_type: regression 14 | strategy: separate 15 | 16 | transform: 17 | - type: select_atom 18 | resolution: backbone 19 | - type: selected_region_with_distmap 20 | patch_size: 256 21 | - type: subtract_center_of_mass -------------------------------------------------------------------------------- /config/datasets/PRI30k.yml: -------------------------------------------------------------------------------- 1 | dataset_type: 'pri30k_dataset' 2 | df_path: './datasets/PRI30k/splits/pretrain_length_750_clean.csv' 3 | batch_size: 20 4 | data_root: './datasets/PRI30k/PDBs' 5 | num_workers: 0 6 | col_prot_name: PDB 7 | col_prot_chain: Protein chains 8 | col_na_chain: RNA chains 9 | col_binding_site: Binding site renumbered merged 10 | col_ligand: Binding ligands 11 | pin_memory: True 12 | cache_dir: './cache/pri30k' 13 | loss_type: regression 14 | strategy: separate 15 | 16 | transform: 17 | - type: select_atom 18 | resolution: backbone 19 | - type: selected_region_with_distmap 20 | patch_size: 256 21 | - type: subtract_center_of_mass -------------------------------------------------------------------------------- /data/transforms/geometric.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | from ._base import register_transform 5 | 6 | @register_transform('subtract_center_of_mass') 7 | class SubtractCOM(object): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def __call__(self, data): 12 | pos = data['pos_atoms'] 13 | mask = data['mask_atoms'] 14 | if mask is None: 15 | center = np.zeros(3) 16 | elif mask.sum() == 0: 17 | center = np.zeros(3) 18 | else: 19 | center = pos[mask].mean(axis=0) 20 | data['pos_atoms'] = pos - center[None, None, :] 21 | return data 22 | 23 | -------------------------------------------------------------------------------- /config/datasets/mCSM.yml: -------------------------------------------------------------------------------- 1 | dataset_type: 'structure_dataset' 2 | df_path: './datasets/mCSM_RNA/splits/crossvalidation.csv' 3 | batch_size: 8 4 | data_root: './datasets/mCSM_RNA/PDBs' 5 | num_workers: 2 6 | col_prot_chain: 'Protein chains' 7 | col_na_chain: 'RNA chains' 8 | col_prot: 'Protein sequences' 9 | col_mut: 'Mutation sequences' 10 | mut: True 11 | col_na: 'RNA sequences' 12 | col_label: 'DDG' 13 | pin_memory: True 14 | cache_dir: './cache/mmCSM' 15 | loss_type: regression 16 | strategy: separate 17 | 18 | transform: 19 | - type: select_atom 20 | resolution: backbone 21 | - type: selected_region_with_distmap 22 | patch_size: 256 23 | - type: subtract_center_of_mass -------------------------------------------------------------------------------- /config/datasets/blindtest.yml: -------------------------------------------------------------------------------- 1 | dataset_type: 'structure_dataset' 2 | df_path: './datasets/mCSM_RNA/splits/blindtest.csv' 3 | batch_size: 8 4 | data_root: './datasets/mCSM_RNA/PDBs' 5 | num_workers: 2 6 | col_prot_chain: 'Protein chains' 7 | col_na_chain: 'RNA chains' 8 | col_prot: 'Protein sequences' 9 | col_mut: 'Mutation sequences' 10 | mut: True 11 | col_na: 'RNA sequences' 12 | col_label: 'DDG' 13 | pin_memory: True 14 | cache_dir: './cache/blindtest' 15 | loss_type: regression 16 | strategy: separate 17 | 18 | transform: 19 | - type: select_atom 20 | resolution: backbone 21 | - type: selected_region_with_distmap 22 | patch_size: 256 23 | - type: subtract_center_of_mass -------------------------------------------------------------------------------- /models/components/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class PearsonCorrLoss(nn.Module): 6 | def __init__(self): 7 | super(PearsonCorrLoss, self).__init__() 8 | 9 | def forward(self, pred, target): 10 | pred_mean = torch.mean(pred) 11 | target_mean = torch.mean(target) 12 | 13 | pred_centered = pred - pred_mean 14 | target_centered = target - target_mean 15 | 16 | numerator = torch.sum(pred_centered * target_centered) 17 | denominator = torch.sqrt(torch.sum(pred_centered ** 2) * torch.sum(target_centered ** 2)) 18 | 19 | pearson_corr = numerator / (denominator + 1e-5) 20 | loss = 1 - pearson_corr 21 | 22 | return loss -------------------------------------------------------------------------------- /config/models/esm2_rinalmo_seq.yml: -------------------------------------------------------------------------------- 1 | resume: null 2 | model: 3 | model_type: esm2_rinalmo_seq 4 | esm_type: 650M 5 | pooling: mean 6 | output_dim: 1 7 | fix_lms: True 8 | lora_tune: false 9 | lora_rank: 16 10 | lora_alpha: 32 11 | vallina: True 12 | representation_layer: 33 13 | transformer: 14 | embed_dim: 320 15 | num_blocks: 6 16 | num_heads: 20 17 | use_rot_emb: true 18 | attn_qkv_bias: false 19 | attention_dropout: 0.1 20 | transition_dropout: 0.0 21 | residual_dropout: 0.1 22 | transition_factor: 4 23 | use_flash_attn: false 24 | train: 25 | max_iters: 30_000 26 | val_freq: 1000 27 | seed: 2024 28 | # max_grad_norm: 100.0 29 | optimizer: 30 | type: adam 31 | lr: 3.e-5 32 | weight_decay: 0.0 33 | beta1: 0.9 34 | beta2: 0.999 35 | scheduler: 36 | type: plateau 37 | factor: 0.8 38 | patience: 5 39 | min_lr: 1.e-6 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: copra_h 2 | channels: 3 | - pytorch3d 4 | - pytorch 5 | - conda-forge 6 | - nvidia 7 | - pyg 8 | dependencies: 9 | - python=3.10.14 10 | - pytorch:pytorch=2.1.2 11 | - pytorch:pytorch-cuda=11.8 12 | # - llvm-openmp<16 13 | - einops=0.6.1 14 | - pytorch-lightning 15 | - pandas 16 | - dm-tree=0.1.7 17 | - diskcache=5.4.0 18 | - fire=0.5.0 19 | - numpy=1.26.4 20 | - scikit-learn=1.3.0 21 | - pip 22 | - gpustat 23 | - wandb 24 | - easydict=1.9 25 | - matplotlib=3.8.0 26 | - pytorch_geometric 27 | - pytorch-scatter 28 | - pytorch-sparse 29 | - pytorch-cluster 30 | - pip: 31 | - biopython==1.84 32 | - fair-esm==2.0.0 33 | - tree==0.2.4 34 | - seaborn==0.13.2 35 | - triton==2.1.0 36 | - torchsummary==1.5.1 37 | - peft==0.14.0 38 | - biotite==0.41.2 39 | - cpdb-protein==0.2.0 40 | - -i https://pypi.tuna.tsinghua.edu.cn/simple -------------------------------------------------------------------------------- /config/models/copra.yml: -------------------------------------------------------------------------------- 1 | resume: null 2 | model: 3 | model_type: copra 4 | esm_type: 650M 5 | pooling: token 6 | output_dim: 1 7 | fix_lms: true 8 | lora_tune: false 9 | lora_rank: 16 10 | lora_alpha: 32 11 | pair_dim: 40 12 | dist_dim: 40 13 | representation_layer: 33 14 | coformer: 15 | embed_dim: 320 16 | pair_dim: 40 17 | num_blocks: 6 18 | num_heads: 20 19 | use_rot_emb: true 20 | attn_qkv_bias: false 21 | attention_dropout: 0.1 22 | transition_dropout: 0.0 23 | residual_dropout: 0.1 24 | transition_factor: 4 25 | use_flash_attn: false 26 | 27 | train: 28 | temperature: 0.2 29 | max_iters: 30_000 30 | val_freq: 1000 31 | seed: 2024 32 | max_grad_norm: 100.0 33 | optimizer: 34 | type: adam 35 | lr: 3.e-5 36 | weight_decay: 0.0 37 | beta1: 0.9 38 | beta2: 0.999 39 | scheduler: 40 | type: plateau 41 | factor: 0.8 42 | patience: 5 43 | min_lr: 1.e-6 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Rong Han 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/transforms/select_chain.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | from ._base import _mask_select_data, register_transform 5 | 6 | 7 | @register_transform('random_interacting_chain') 8 | class RandomInteractingChain(object): 9 | 10 | def __init__(self, interaction_attr): 11 | super().__init__() 12 | self.interaction_attr = interaction_attr 13 | 14 | def __call__(self, data): 15 | # Randomly choose a chain to mask 16 | interact_flag = (data[self.interaction_attr] > 0) # (L, ) 17 | if interact_flag.sum() == 0: 18 | # If there is no active residues, randomly pick one. 19 | interact_flag[random.randint(0, interact_flag.size(0)-1)] = True 20 | seed_idx = torch.multinomial(interact_flag.float(), num_samples=1).item() 21 | 22 | chain_nb_selected = data['chain_nb'][seed_idx].item() 23 | mask_chain = (data['chain_nb'] == chain_nb_selected) 24 | return _mask_select_data(data, mask_chain) 25 | 26 | 27 | @register_transform('select_focused') 28 | class SelectFocused(object): 29 | 30 | def __init__(self, focus_attr): 31 | super().__init__() 32 | self.focus_attr = focus_attr 33 | 34 | def __call__(self, data): 35 | mask_focus = (data[self.focus_attr] > 0) 36 | return _mask_select_data(data, mask_focus) 37 | 38 | -------------------------------------------------------------------------------- /data/transforms/noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from ._base import register_transform 5 | 6 | 7 | @register_transform('add_atom_noise') 8 | class AddAtomNoise(object): 9 | 10 | def __init__(self, noise_std=0.02): 11 | super().__init__() 12 | self.noise_std = noise_std 13 | 14 | def __call__(self, data): 15 | pos_atoms = data['pos_atoms'] # (L, A, 3) 16 | mask_atoms = data['mask_atoms'] # (L, A) 17 | noise = (torch.randn_like(pos_atoms) * self.noise_std) * mask_atoms[:, :, None] 18 | pos_noisy = pos_atoms + noise 19 | data['pos_atoms'] = pos_noisy 20 | return data 21 | 22 | 23 | @register_transform('add_chi_angle_noise') 24 | class AddChiAngleNoise(object): 25 | 26 | def __init__(self, noise_std=0.02): 27 | super().__init__() 28 | self.noise_std = noise_std 29 | 30 | def _normalize_angles(self, angles): 31 | angles = angles % (2*np.pi) 32 | return torch.where(angles > np.pi, angles - 2*np.pi, angles) 33 | 34 | def __call__(self, data): 35 | chi, chi_alt = data['chi'], data['chi_alt'] # (L, 4) 36 | chi_mask = data['chi_mask'] # (L, 4) 37 | 38 | _get_noise = lambda: ((torch.randn_like(chi) * self.noise_std) * chi_mask) 39 | data['chi'] = self._normalize_angles( chi + _get_noise() ) 40 | data['chi_alt'] = self._normalize_angles( chi_alt + _get_noise() ) 41 | return data 42 | -------------------------------------------------------------------------------- /data/rna/base_constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | 4 | 5 | PROJECT_PATH = os.environ.get("PROJECT_PATH") 6 | 7 | DATA_PATH = os.environ.get("DATA_PATH") 8 | 9 | X3DNA_PATH = os.environ.get("X3DNA") 10 | 11 | ETERNAFOLD_PATH = os.environ.get("ETERNAFOLD") 12 | 13 | 14 | # Value to fill missing coordinate entries when reading PDB files 15 | FILL_VALUE = 1e-5 16 | 17 | 18 | # Small epsilon value added to distances to avoid division by zero 19 | DISTANCE_EPS = 0.001 20 | 21 | 22 | # List of possible atoms in RNA nucleotides 23 | RNA_ATOMS = [ 24 | 'P', "C5'", "O5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", 25 | 'N1', 26 | 'C2', 27 | 'O2', 'N2', 28 | 'N3', 29 | 'C4', 'O4', 'N4', 30 | 'C5', 31 | 'C6', 32 | 'O6', 'N6', 33 | 'N7', 34 | 'C8', 35 | 'N9', 36 | 'OP1', 'OP2', 37 | ] 38 | 39 | 40 | # List of possible RNA nucleotides 41 | RNA_NUCLEOTIDES = [ 42 | 'A', 43 | 'G', 44 | 'C', 45 | 'U', 46 | # '_' # placeholder for missing/unknown nucleotides 47 | ] 48 | 49 | 50 | # List of purine nucleotides 51 | PURINES = ["A", "G"] 52 | 53 | 54 | # List of pyrimidine nucleotides 55 | PYRIMIDINES = ["C", "U"] 56 | 57 | 58 | # 59 | LETTER_TO_NUM = dict(zip( 60 | RNA_NUCLEOTIDES, 61 | list(range(len(RNA_NUCLEOTIDES))) 62 | )) 63 | 64 | 65 | # 66 | NUM_TO_LETTER = {v:k for k, v in LETTER_TO_NUM.items()} 67 | 68 | 69 | # 70 | DOTBRACKET_TO_NUM = { 71 | '.': 0, 72 | '(': 1, 73 | ')': 2 74 | } 75 | -------------------------------------------------------------------------------- /models/register.py: -------------------------------------------------------------------------------- 1 | from utils import singleton 2 | 3 | REGISTERED_MODELS = [] 4 | 5 | @singleton 6 | class ModelRegister(dict): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self._dict = {} 10 | def register(self, target): 11 | def add_register_item(key, value): 12 | if not callable(value): 13 | raise Exception(f"register object must be callable! But receice:{value} is not callable!") 14 | if key in self._dict: 15 | print(f"warning: \033[33m{value.__name__} has been registered before, so we will overriden it\033[0m") 16 | self[key] = value 17 | REGISTERED_MODELS.append(key) 18 | return value 19 | 20 | if callable(target): 21 | return add_register_item(target.__name__, target) 22 | else: 23 | return lambda x: add_register_item(target, x) 24 | 25 | def __call__(self, target): 26 | return self.register(target) 27 | 28 | def __setitem__(self, key, value): 29 | self._dict[key] = value 30 | 31 | def __getitem__(self, key): 32 | return self._dict[key] 33 | 34 | def __contains__(self, key): 35 | return key in self._dict 36 | 37 | def __str__(self): 38 | return str(self._dict) 39 | 40 | def keys(self): 41 | return self._dict.keys() 42 | 43 | def values(self): 44 | return self._dict.values() 45 | 46 | def items(self): 47 | return self._dict.items() 48 | 49 | 50 | -------------------------------------------------------------------------------- /data/register.py: -------------------------------------------------------------------------------- 1 | from utils import singleton 2 | 3 | REGISTERED_DATASETS = [] 4 | 5 | @singleton 6 | class DataRegister(dict): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self._dict = {} 10 | def register(self, target): 11 | def add_register_item(key, value): 12 | if not callable(value): 13 | raise Exception(f"register object must be callable! But receice:{value} is not callable!") 14 | if key in self._dict: 15 | print(f"warning: \033[33m{value.__name__} has been registered before, so we will overriden it\033[0m") 16 | self[key] = value 17 | REGISTERED_DATASETS.append(key) 18 | return value 19 | 20 | if callable(target): 21 | return add_register_item(target.__name__, target) 22 | else: 23 | return lambda x: add_register_item(target, x) 24 | 25 | def __call__(self, target): 26 | return self.register(target) 27 | 28 | def __setitem__(self, key, value): 29 | self._dict[key] = value 30 | 31 | def __getitem__(self, key): 32 | # print(self._dict) 33 | return self._dict[key] 34 | 35 | def __contains__(self, key): 36 | return key in self._dict 37 | 38 | def __str__(self): 39 | return str(self._dict) 40 | 41 | def keys(self): 42 | return self._dict.keys() 43 | 44 | def values(self): 45 | return self._dict.values() 46 | 47 | def items(self): 48 | return self._dict.items() 49 | 50 | 51 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pandas as pd 4 | import torch 5 | from torch_geometric.data import Data 6 | from torch_geometric.typing import OptTensor, SparseTensor 7 | 8 | def singleton(cls): 9 | _instance = {} 10 | 11 | def inner(*args, **kwargs): 12 | if cls not in _instance: 13 | _instance[cls] = cls(*args, **kwargs) 14 | return _instance[cls] 15 | 16 | return inner 17 | 18 | 19 | class MyData(Data): 20 | chain: pd.Series 21 | channel_weights: torch.Tensor 22 | 23 | def __init__( 24 | self, x: OptTensor = None, edge_index: OptTensor = None, 25 | edge_attr: OptTensor = None, y: OptTensor = None, 26 | pos: OptTensor = None, **kwargs 27 | ): 28 | super(MyData, self).__init__(x, edge_index, edge_attr, y, pos, **kwargs) 29 | 30 | def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any: 31 | if 'batch' in key: 32 | return int(value.max()) + 1 33 | elif 'index' in key or 'face' in key: 34 | return self.num_nodes 35 | elif 'interaction' in key: 36 | return self.num_edges 37 | elif 'chains' in key: 38 | return self.num_chains 39 | else: 40 | return 0 41 | def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: 42 | if isinstance(value, SparseTensor) and 'adj' in key: 43 | return (0, 1) 44 | elif 'index' in key or 'face' in key or 'interaction' in key: 45 | return -1 46 | else: 47 | return 0 48 | @property 49 | def num_chains(self): 50 | return len(torch.unique(self.chains)) -------------------------------------------------------------------------------- /models/components/rope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Code heavily inspired by https://blog.eleuther.ai/rotary-embeddings/, GPT-NeoX (Pytorch) implementation 5 | # and ESM2 implementation https://github.com/facebookresearch/esm/blob/main/esm/rotary_embedding.py 6 | 7 | def rotate_half(x): 8 | x1, x2 = x.chunk(2, dim=-1) 9 | return torch.cat((-x2, x1), dim=-1) 10 | 11 | 12 | @torch.jit.script 13 | def apply_rotary_pos_emb(q, k, cos, sin): 14 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 15 | 16 | class RotaryPositionEmbedding(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | base: int = 10000 21 | ): 22 | super().__init__() 23 | 24 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) 25 | self.register_buffer("inv_freq", inv_freq) 26 | 27 | self.seq_len_cached = None 28 | self.cos_cached = None 29 | self.sin_cached = None 30 | 31 | def _update_cached(self, x, seq_dim): 32 | seq_len = x.shape[seq_dim] 33 | 34 | if seq_len != self.seq_len_cached: 35 | self.seq_len_cached = seq_len 36 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) 37 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 38 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 39 | 40 | self.cos_cached = emb.cos()[None, None, :, :] 41 | self.sin_cached = emb.sin()[None, None, :, :] 42 | 43 | def forward(self, q, k): 44 | self._update_cached(k, seq_dim=-2) 45 | return apply_rotary_pos_emb(q, k, self.cos_cached, self.sin_cached) 46 | -------------------------------------------------------------------------------- /models/lora_tune.py: -------------------------------------------------------------------------------- 1 | from transformers import PreTrainedModel, BertConfig, PretrainedConfig, BertForMaskedLM 2 | from typing import List 3 | class ESMConfig(PretrainedConfig): 4 | model_type = "esm2" 5 | def __init__( 6 | self, 7 | pooling='mean', 8 | output_dim=1, 9 | fix_lms=True, 10 | lora_tune=True, 11 | **kwargs, 12 | ): 13 | self.pooling = pooling 14 | self.output_dim = output_dim 15 | self.fix_lms = fix_lms 16 | self.lora_tune = lora_tune 17 | super().__init__(**kwargs) 18 | 19 | class RiNALMoConfig(PretrainedConfig): 20 | model_type = "rinalmo_flash" 21 | def __init__( 22 | self, 23 | pooling='mean', 24 | output_dim=1, 25 | fix_lms=True, 26 | lora_tune=True, 27 | **kwargs, 28 | ): 29 | self.pooling = pooling 30 | self.output_dim = output_dim 31 | self.fix_lms = fix_lms 32 | self.lora_tune = lora_tune 33 | super().__init__(**kwargs) 34 | 35 | 36 | class LoRAESM(PreTrainedModel): 37 | config_class = ESMConfig 38 | def __init__(self, model, config): 39 | super().__init__(config) 40 | self.model = model 41 | def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False): 42 | return self.model(tokens, repr_layers, need_head_weights, return_contacts) 43 | 44 | class LoRARiNALMo(PreTrainedModel): 45 | config_class = RiNALMoConfig 46 | def __init__(self, model, config): 47 | super().__init__(config) 48 | self.model = model 49 | def forward(self, tokens, need_attn_weights=False): 50 | return self.model(tokens, need_attn_weights) -------------------------------------------------------------------------------- /models/encoders/single.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.encoders.layers import AngularEncoding 5 | 6 | 7 | class PerResidueEncoder(nn.Module): 8 | 9 | def __init__(self, feat_dim, max_num_atoms, max_aa_types=27): 10 | super().__init__() 11 | self.max_num_atoms = max_num_atoms 12 | self.max_aa_types = max_aa_types 13 | self.aatype_embed = nn.Embedding(self.max_aa_types, feat_dim) 14 | # self.dihed_embed = AngularEncoding() 15 | # infeat_dim = feat_dim + self.dihed_embed.get_out_dim(6) # Phi, Psi, Chi1-4 16 | infeat_dim = feat_dim 17 | self.mlp = nn.Sequential( 18 | nn.Linear(infeat_dim, feat_dim * 2), nn.ReLU(), 19 | nn.Linear(feat_dim * 2, feat_dim), nn.ReLU(), 20 | nn.Linear(feat_dim, feat_dim), nn.ReLU(), 21 | nn.Linear(feat_dim, feat_dim) 22 | ) 23 | 24 | # def forward(self, aa, phi, phi_mask, psi, psi_mask, chi, chi_mask, mask_residue): 25 | def forward(self, aa, mask_residue): 26 | """ 27 | Args: 28 | aa: (N, L) 29 | # phi, phi_mask: (N, L) 30 | # psi, psi_mask: (N, L) 31 | # chi, chi_mask: (N, L, 4) 32 | mask_residue: (N, L) 33 | """ 34 | N, L = aa.size() 35 | 36 | # Amino acid identity features 37 | aa_feat = self.aatype_embed(aa) # (N, L, feat) 38 | 39 | # Dihedral features 40 | # dihedral = torch.cat( 41 | # [phi[..., None], psi[..., None], chi], 42 | # dim=-1 43 | # ) # (N, L, 6) 44 | # dihedral_mask = torch.cat([ 45 | # phi_mask[..., None], psi_mask[..., None], chi_mask], 46 | # dim=-1 47 | # ) # (N, L, 6) 48 | # dihedral_feat = self.dihed_embed(dihedral[..., None]) * dihedral_mask[..., None] # (N, L, 6, feat) 49 | # dihedral_feat = dihedral_feat.reshape(N, L, -1) 50 | 51 | # Mix 52 | # out_feat = self.mlp(torch.cat([aa_feat, dihedral_feat], dim=-1)) # (N, L, F) 53 | out_feat = self.mlp(aa_feat) 54 | out_feat = out_feat * mask_residue[:, :, None] 55 | return out_feat 56 | -------------------------------------------------------------------------------- /data/rna/rnas.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | from typing import Dict, Optional, List 4 | import numpy as np 5 | from data.rna.data_utils import pdb_to_array, cif_to_array 6 | 7 | 8 | @dataclass 9 | class RNAInput: 10 | seq: str # L 11 | mask: np.ndarray # (L, ) 12 | basetype: np.ndarray # (L, ) 13 | atom_mask: np.ndarray # (L, 27) 14 | 15 | atom_positions: Optional[np.ndarray] = field(default=None) # (L, 27, 3) 16 | residue_index: Optional[np.ndarray] = field(default=None) # (L, ) 17 | 18 | res_nb: Optional[np.ndarray] = field(default=None) # (L, ) 19 | resseq: Optional[np.ndarray] = field(default=None) # (L, ) 20 | chain_nb: Optional[np.ndarray] = field(default=None) # (L, ) 21 | chain_id: Optional[List] = field(default=None) # len(chain_id) = L 22 | 23 | chainid: Optional[np.ndarray] = field(default=None) # (L) 24 | 25 | 26 | @classmethod 27 | def from_path(self, pdb_filepath, valid_chains): 28 | """ 29 | Load RNA backbone from PDB file. 30 | 31 | Args: 32 | pdb_filepath (str): Path to PDB file. 33 | """ 34 | pdb_filepath = str(pdb_filepath) 35 | if '.pdb' in pdb_filepath: 36 | rnas = pdb_to_array( 37 | pdb_filepath, valid_chains, return_sec_struct=True, return_sasa=False) 38 | else: 39 | rnas = cif_to_array(pdb_filepath, valid_chains) 40 | # print(rnas) 41 | rna_dict = {} 42 | for key in rnas: 43 | rna = self(**rnas[key]) 44 | rna_dict[key] = rna 45 | return rna_dict 46 | # coords = get_backbone_coords(coords, sequence) 47 | # rna = { 48 | # 'sequence': sequence, 49 | # 'coords_list': coords, 50 | # } 51 | # return rna 52 | 53 | @property 54 | def length(self): 55 | return len(self.seq) 56 | 57 | def __repr__(self): 58 | return self.__str__() 59 | 60 | def __str__(self): 61 | texts = [] 62 | texts += [f'seq: {self.seq}'] 63 | texts += [f'length: {len(self.seq)}'] 64 | texts += [f"mask: {''.join(self.mask.astype('int').astype('str'))}"] 65 | if self.chainid is not None: 66 | texts += [f"chainid: {''.join(self.chainid.astype('int').astype('str'))}"] 67 | 68 | names = [ 69 | 'basetype', 70 | 'atom_mask', 71 | 'atom_positions', 72 | ] 73 | for name in names: 74 | value = getattr(self, name) 75 | if value is None: 76 | text = f'{name}: None' 77 | else: 78 | text = f'{name}: {value.shape}' 79 | texts += [text] 80 | text = ', \n '.join(texts) 81 | text = f'RNA(\n {text}\n)' 82 | return text -------------------------------------------------------------------------------- /data/transforms/_base.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | 4 | 5 | class Compose: 6 | 7 | def __init__(self, transforms): 8 | self.transforms = transforms 9 | 10 | def __call__(self, data): 11 | for t in self.transforms: 12 | data = t(data) 13 | return data 14 | 15 | 16 | _TRANSFORM_DICT = {} 17 | 18 | 19 | def register_transform(name): 20 | def decorator(cls): 21 | _TRANSFORM_DICT[name] = cls 22 | return cls 23 | return decorator 24 | 25 | 26 | def get_transform(cfg): 27 | if cfg is None or len(cfg) == 0: 28 | return None 29 | tfms = [] 30 | for t_dict in cfg: 31 | t_dict = copy.deepcopy(t_dict) 32 | cls = _TRANSFORM_DICT[t_dict.pop('type')] 33 | tfms.append(cls(**t_dict)) 34 | return Compose(tfms) 35 | 36 | 37 | def _index_select(v, index, n): 38 | if isinstance(v, torch.Tensor) and v.size(0) == n: 39 | return v[index] 40 | elif isinstance(v, list) and len(v) == n: 41 | return [v[i] for i in index] 42 | else: 43 | return v 44 | 45 | 46 | def _index_select_with_dist(v, index, n): 47 | if isinstance(v, torch.Tensor) and v.size(0) == n and (len(v.shape) == 1 or v.size(0) != v.size(1)): 48 | return v[index] 49 | elif isinstance(v, torch.Tensor) and v.size(0) == v.size(1) and v.size(0) == n: 50 | tmp = v[index] 51 | new_v = tmp[:, index] 52 | return new_v 53 | elif isinstance(v, list) and len(v) == n: 54 | return [v[i] for i in index] 55 | else: 56 | return v 57 | 58 | def _index_select_data(data, index): 59 | return { 60 | k: _index_select(v, index, data['aa'].size(0)) 61 | for k, v in data.items() 62 | } 63 | 64 | def _index_select_complex(data, index): 65 | return { 66 | k: _index_select_with_dist(v, index, data['restype'].size(0)) 67 | for k, v in data.items() 68 | } 69 | 70 | def _mask_select(v, mask): 71 | if isinstance(v, torch.Tensor) and v.size(0) == mask.size(0): 72 | return v[mask] 73 | elif isinstance(v, list) and len(v) == mask.size(0): 74 | return [v[i] for i, b in enumerate(mask) if b] 75 | else: 76 | return v 77 | 78 | 79 | def _mask_select_data(data, mask): 80 | return { 81 | k: _mask_select(v, mask) 82 | for k, v in data.items() 83 | } 84 | 85 | 86 | def _get_CB_positions(pos_atoms, mask_atoms): 87 | """ 88 | Args: 89 | pos_atoms: (L, A, 3) 90 | mask_atoms: (L, A) 91 | """ 92 | L = pos_atoms.size(0) 93 | pos_CA = pos_atoms[:, 1] # (L, 3) # CA 94 | if pos_atoms.size(1) < 5: 95 | return pos_CA 96 | pos_CB = pos_atoms[:, 4] #CB 97 | mask_CB = mask_atoms[:, 4, None].expand(L, 3) #CB 98 | return torch.where(mask_CB, pos_CB, pos_CA) 99 | 100 | -------------------------------------------------------------------------------- /data/transforms/corrupt_chi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from ._base import register_transform, _get_CB_positions 5 | 6 | 7 | @register_transform('corrupt_chi_angle') 8 | class CorruptChiAngle(object): 9 | 10 | def __init__(self, ratio_mask=0.1, add_noise=True, maskable_flag_attr=None): 11 | super().__init__() 12 | self.ratio_mask = ratio_mask 13 | self.add_noise = add_noise 14 | self.maskable_flag_attr = maskable_flag_attr 15 | 16 | def _normalize_angles(self, angles): 17 | angles = angles % (2*np.pi) 18 | return torch.where(angles > np.pi, angles - 2*np.pi, angles) 19 | 20 | def _get_min_dist(self, data, center_idx): 21 | pos_beta_all = _get_CB_positions(data['pos_atoms'], data['mask_atoms']) 22 | pos_beta_center = pos_beta_all[center_idx] 23 | cdist = torch.cdist(pos_beta_all, pos_beta_center) # (L, K) 24 | min_dist = cdist.min(dim=1)[0] # (L, ) 25 | return min_dist 26 | 27 | def _get_noise_std(self, min_dist): 28 | return torch.clamp_min((-1/16) * min_dist + 1, 0) 29 | 30 | def _get_flip_prob(self, min_dist): 31 | return torch.where( 32 | min_dist <= 8.0, 33 | torch.full_like(min_dist, 0.25), 34 | torch.zeros_like(min_dist,), 35 | ) 36 | 37 | def _add_chi_gaussian_noise(self, chi, noise_std, chi_mask): 38 | """ 39 | Args: 40 | chi: (L, 4) 41 | noise_std: (L, ) 42 | chi_mask: (L, 4) 43 | """ 44 | noise = torch.randn_like(chi) * noise_std[:, None] * chi_mask 45 | return self._normalize_angles(chi + noise) 46 | 47 | def _random_flip_chi(self, chi, flip_prob, chi_mask): 48 | """ 49 | Args: 50 | chi: (L, 4) 51 | flip_prob: (L, ) 52 | chi_mask: (L, 4) 53 | """ 54 | delta = torch.where( 55 | torch.rand_like(chi) <= flip_prob[:, None], 56 | torch.full_like(chi, np.pi), 57 | torch.zeros_like(chi), 58 | ) * chi_mask 59 | return self._normalize_angles(chi + delta) 60 | 61 | def __call__(self, data): 62 | L = data['aa'].size(0) # Calculate aa num 63 | idx = torch.arange(0, L) 64 | num_mask = max(int(self.ratio_mask * L), 1) # Calculate mask ratio 65 | if self.maskable_flag_attr is not None: 66 | flag = data[self.maskable_flag_attr] 67 | idx = idx[flag] 68 | 69 | idx = idx.tolist() 70 | np.random.shuffle(idx) 71 | idx_mask = idx[:num_mask] 72 | min_dist = self._get_min_dist(data, idx_mask) 73 | noise_std = self._get_noise_std(min_dist) 74 | flip_prob = self._get_flip_prob(min_dist) 75 | 76 | chi_native = torch.where( 77 | torch.randn_like(data['chi']) > 0, 78 | data['chi'], 79 | data['chi_alt'], 80 | ) # (L, 4), randomly pick from chi and chi_alt 81 | chi = chi_native.clone() 82 | chi_mask = data['chi_mask'] 83 | 84 | if self.add_noise: 85 | chi = self._add_chi_gaussian_noise(chi, noise_std, chi_mask) 86 | chi = self._random_flip_chi(chi, flip_prob, chi_mask) 87 | chi[idx_mask] = 0.0 # Mask chi angles 88 | 89 | corrupt_flag = torch.zeros(L, dtype=torch.bool) 90 | corrupt_flag[idx_mask] = True 91 | corrupt_flag[min_dist <= 8] = True 92 | 93 | masked_flag = torch.zeros(L, dtype=torch.bool) 94 | masked_flag[idx_mask] = True 95 | 96 | data['chi_native'] = chi_native 97 | data['chi_corrupt'] = chi 98 | data['chi_corrupt_flag'] = corrupt_flag 99 | data['chi_masked_flag'] = masked_flag 100 | return data 101 | -------------------------------------------------------------------------------- /models/encoders/pair.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from utils.geometry import angstrom_to_nm, pairwise_dihedrals 6 | from models.encoders.layers import AngularEncoding 7 | 8 | 9 | class ResiduePairEncoder(nn.Module): 10 | 11 | def __init__(self, feat_dim, max_num_atoms, max_aa_types=30, max_relpos=32): 12 | super().__init__() 13 | self.max_num_atoms = max_num_atoms 14 | self.max_aa_types = max_aa_types 15 | self.max_relpos = max_relpos 16 | self.aa_pair_embed = nn.Embedding(self.max_aa_types * self.max_aa_types, feat_dim) 17 | self.relpos_embed = nn.Embedding(2 * max_relpos + 1, feat_dim) 18 | 19 | self.aapair_to_distcoef = nn.Embedding(self.max_aa_types * self.max_aa_types, max_num_atoms * max_num_atoms) 20 | # nn.init.zeros_(self.aapair_to_distcoef.weight) 21 | self.distance_embed = nn.Sequential( 22 | nn.Linear(max_num_atoms * max_num_atoms, feat_dim), nn.ReLU(), 23 | nn.Linear(feat_dim, feat_dim), nn.ReLU(), 24 | ) 25 | 26 | self.dihedral_embed = AngularEncoding() 27 | feat_dihed_dim = self.dihedral_embed.get_out_dim(2) # Phi and Psi 28 | 29 | infeat_dim = feat_dim + feat_dim + feat_dim + feat_dihed_dim 30 | self.out_mlp = nn.Sequential( 31 | nn.Linear(infeat_dim, feat_dim), nn.ReLU(), 32 | nn.Linear(feat_dim, feat_dim), nn.ReLU(), 33 | nn.Linear(feat_dim, feat_dim), 34 | ) 35 | 36 | def forward(self, aa, res_nb, chain_nb, pos_atoms, mask_atoms): 37 | """ 38 | Args: 39 | aa: (N, L). 40 | res_nb: (N, L). 41 | chain_nb: (N, L). 42 | pos_atoms: (N, L, A, 3) 43 | mask_atoms: (N, L, A) 44 | Returns: 45 | (N, L, L, feat_dim) 46 | """ 47 | N, L = aa.size() 48 | mask_residue = mask_atoms[:, :, 1] # (N, L) 49 | mask_pair = mask_residue[:, :, None] * mask_residue[:, None, :] 50 | 51 | # Pair identities 52 | aa_pair = aa[:, :, None] * self.max_aa_types + aa[:, None, :] # (N, L, L) 53 | feat_aapair = self.aa_pair_embed(aa_pair) 54 | 55 | # Relative positions 56 | same_chain = (chain_nb[:, :, None] == chain_nb[:, None, :]) 57 | relpos = torch.clamp( 58 | res_nb[:, :, None] - res_nb[:, None, :], 59 | min=-self.max_relpos, max=self.max_relpos, 60 | ) # (N, L, L) 61 | 62 | feat_relpos = self.relpos_embed(relpos + self.max_relpos) * same_chain[:, :, :, None] 63 | 64 | # Distances 65 | d = angstrom_to_nm(torch.linalg.norm( 66 | pos_atoms[:, :, None, :, None] - pos_atoms[:, None, :, None, :], 67 | dim=-1, ord=2, 68 | )).reshape(N, L, L, -1) # (N, L, L, A*A) 69 | c = F.softplus(self.aapair_to_distcoef(aa_pair)) # (N, L, L, A*A) 70 | d_gauss = torch.exp(-1 * c * d ** 2) 71 | mask_atom_pair = (mask_atoms[:, :, None, :, None] * mask_atoms[:, None, :, None, :]).reshape(N, L, L, -1) 72 | feat_dist = self.distance_embed(d_gauss * mask_atom_pair) 73 | 74 | # Orientations 75 | dihed = pairwise_dihedrals(pos_atoms) # (N, L, L, 2) 76 | feat_dihed = self.dihedral_embed(dihed) 77 | # if torch.isnan(feat_dihed).any(): 78 | # print("Found nan in dihed!") 79 | # All 80 | feat_all = torch.cat([feat_aapair, feat_relpos, feat_dist, feat_dihed], dim=-1) 81 | feat_all = self.out_mlp(feat_all) # (N, L, L, F) 82 | feat_all = feat_all * mask_pair[:, :, :, None] 83 | if torch.isnan(feat_all).any(): 84 | # print("??:", torch.isnan(aa_pair).any()) 85 | print("Let's check:", torch.isnan(feat_aapair).any(), torch.isnan(feat_relpos).any(), torch.isnan(feat_dist).any(), torch.isnan(feat_dihed).any()) 86 | print("Let's check 2:", torch.isnan(aa).any(), torch.isnan(res_nb).any(), torch.isnan(chain_nb).any(), torch.isnan(pos_atoms).any(), torch.isnan(mask_atoms).any()) 87 | return feat_all 88 | 89 | -------------------------------------------------------------------------------- /data/transforms/select_atom.py: -------------------------------------------------------------------------------- 1 | 2 | from ._base import register_transform 3 | import torch 4 | from data.rna import get_backbone_coords, RNA_ATOMS 5 | 6 | 7 | @register_transform('select_atom') 8 | class SelectAtom(object): 9 | 10 | def __init__(self, resolution): 11 | super().__init__() 12 | assert resolution in ('full', 'backbone', 'backbone+R1', 'C_Only') 13 | self.resolution = resolution 14 | 15 | def __call__(self, data): 16 | if self.resolution == 'full': 17 | data['pos_atoms'] = data['pos_heavyatom'][:, :] 18 | data['mask_atoms'] = data['mask_heavyatom'][:, :] 19 | 20 | elif self.resolution == 'backbone': 21 | pos_atoms = torch.zeros([data['pos_heavyatom'].shape[0]] + [4, 3], device=data['pos_heavyatom'].device) 22 | mask_atoms = torch.zeros([data['mask_heavyatom'].shape[0]] + [4], device=data['mask_heavyatom'].device).bool() 23 | # print(pos_atoms.shape, mask_atoms.shape) 24 | pos_atoms[data['identifier']==0] = data['pos_heavyatom'][data['identifier']==0, :4] 25 | mask_atoms[data['identifier']==0] = data['mask_heavyatom'][data['identifier']==0, :4] 26 | na_atoms = data['pos_heavyatom'][data['identifier']==1, 14:] 27 | na_seqs = data['seq'] 28 | indices = torch.nonzero(data['identifier']) 29 | first_index = indices[0].item() if len(indices) > 0 else -1 30 | na_seqs = na_seqs[first_index:] 31 | fill_value = 1e-5 32 | pyrimidine_bb_indices = [RNA_ATOMS.index("P"), RNA_ATOMS.index("C4'"), RNA_ATOMS.index("C1'"), RNA_ATOMS.index("N1")] 33 | purine_bb_indices = [RNA_ATOMS.index("P"), RNA_ATOMS.index("C4'"), RNA_ATOMS.index("C1'"), RNA_ATOMS.index("N9")] 34 | backbone_coords = get_backbone_coords(na_atoms, na_seqs, pyrimidine_bb_indices, purine_bb_indices, fill_value) 35 | pos_atoms[data['identifier']==1, :4] = backbone_coords 36 | mask_atoms[data['identifier']==1, :4] = (backbone_coords != fill_value).sum(-1).bool() 37 | data['pos_atoms'] = pos_atoms 38 | data['mask_atoms'] = mask_atoms 39 | 40 | elif self.resolution == 'backbone+R1': 41 | # For Protein, it's N,CA,C,O,CB; For RNA, it's P, C4', N1/N9, C1' 42 | pos_atoms = torch.zeros([data['pos_heavyatom'].shape[0]] + [5, 3], device=data['pos_heavyatom'].device) 43 | mask_atoms = torch.zeros(data['mask_heavyatom'].shape[0] + [5], device=data['mask_heavyatom'].device).bool() 44 | pos_atoms[data['identifier']==0] = data['pos_heavyatom'][data['identifier']==0, :5] 45 | mask_atoms[data['identifier']==0] = data['mask_heavyatom'][data['identifier']==0, :5] 46 | 47 | na_atoms = data['pos_heavyatom'][data['identifier']==1, 14:] 48 | na_seqs = data['seq'] 49 | indices = torch.nonzero(data['identifier']) 50 | first_index = indices[0].item() if len(indices) > 0 else -1 51 | na_seqs = na_seqs[first_index:] 52 | fill_value = 1e-5 53 | pyrimidine_bb_indices = [RNA_ATOMS.index("P"), RNA_ATOMS.index("C4'"), RNA_ATOMS.index("C1'"), RNA_ATOMS.index("C5'"), RNA_ATOMS.index("N1")] 54 | purine_bb_indices = [RNA_ATOMS.index("P"), RNA_ATOMS.index("C4'"), RNA_ATOMS.index("C1'"), RNA_ATOMS.index("C5'"), RNA_ATOMS.index("N9")] 55 | backbone_coords = get_backbone_coords(na_atoms, na_seqs, pyrimidine_bb_indices, purine_bb_indices, fill_value) 56 | pos_atoms[data['identifier']==1, :5] = backbone_coords 57 | mask_atoms[data['identifier']==1, :5] = backbone_coords != fill_value 58 | 59 | data['pos_atoms'] = data['pos_heavyatom'][:, :5] 60 | data['mask_atoms'] = data['mask_heavyatom'][:, :5] 61 | 62 | elif self.resolution == 'C_Only': 63 | data['pos_atoms'] = data['pos_heavyatom'][:, :1] 64 | data['mask_atoms'] = data['mask_heavyatom'][:, :1] 65 | 66 | return data 67 | -------------------------------------------------------------------------------- /data/transforms/mask.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from ._base import register_transform 7 | 8 | 9 | def _extend_mask(mask, chain_nb): 10 | """ 11 | Args: 12 | mask, chain_nb: (L, ). 13 | """ 14 | # Shift right 15 | mask_sr = torch.logical_and( 16 | F.pad(mask[:-1], pad=(1, 0), value=0), 17 | (F.pad(chain_nb[:-1], pad=(1, 0), value=-1) == chain_nb) 18 | ) 19 | # Shift left 20 | mask_sl = torch.logical_and( 21 | F.pad(mask[1:], pad=(0, 1), value=0), 22 | (F.pad(chain_nb[1:], pad=(0, 1), value=-1) == chain_nb) 23 | ) 24 | return torch.logical_or(mask, torch.logical_or(mask_sr, mask_sl)) 25 | 26 | 27 | def _mask_sidechains(pos_atoms, mask_atoms, mask_idx): 28 | """ 29 | Args: 30 | pos_atoms: (L, A, 3) 31 | mask_atoms: (L, A) 32 | """ 33 | pos_atoms = pos_atoms.clone() 34 | pos_atoms[mask_idx, 4:] = 0.0 35 | 36 | mask_atoms = mask_atoms.clone() 37 | mask_atoms[mask_idx, 4:] = False 38 | return pos_atoms, mask_atoms 39 | 40 | 41 | @register_transform('random_mask_amino_acids') 42 | class RandomMaskAminoAcids(object): 43 | 44 | def __init__( 45 | self, 46 | mask_ratio_in_all=0.05, 47 | ratio_in_maskable_limit=0.5, 48 | mask_token=20, 49 | maskable_flag_attr='core_flag', 50 | extend_maskable_flag=False, 51 | mask_ratio_mode='constant', 52 | ): 53 | super().__init__() 54 | self.mask_ratio_in_all = mask_ratio_in_all 55 | self.ratio_in_maskable_limit = ratio_in_maskable_limit 56 | self.mask_token = mask_token 57 | self.maskable_flag_attr = maskable_flag_attr 58 | self.extend_maskable_flag = extend_maskable_flag 59 | assert mask_ratio_mode in ('constant', 'random') 60 | self.mask_ratio_mode = mask_ratio_mode 61 | 62 | def __call__(self, data): 63 | if self.maskable_flag_attr is None: 64 | maskable_flag = torch.ones([data['aa'].size(0), ], dtype=torch.bool) 65 | else: 66 | maskable_flag = data[self.maskable_flag_attr] 67 | if self.extend_maskable_flag: 68 | maskable_flag = _extend_mask(maskable_flag, data['chain_nb']) 69 | 70 | num_masked_max = math.ceil(self.mask_ratio_in_all * data['aa'].size(0)) 71 | if self.mask_ratio_mode == 'random': 72 | num_masked = random.randint(1, num_masked_max) 73 | else: 74 | num_masked = num_masked_max 75 | mask_idx = torch.multinomial( 76 | maskable_flag.float() / maskable_flag.sum(), 77 | num_samples=num_masked, 78 | ) 79 | mask_idx = mask_idx[:math.ceil(self.ratio_in_maskable_limit * maskable_flag.sum().item())] 80 | 81 | aa_masked = data['aa'].clone() 82 | aa_masked[mask_idx] = self.mask_token 83 | data['aa_true'] = data['aa'] 84 | data['aa_masked'] = aa_masked 85 | 86 | data['pos_atoms'], data['mask_atoms'] = _mask_sidechains( 87 | data['pos_atoms'], data['mask_atoms'], mask_idx 88 | ) 89 | 90 | return data 91 | 92 | 93 | @register_transform('mask_selected_amino_acids') 94 | class MaskSelectedAminoAcids(object): 95 | 96 | def __init__(self, select_attr, mask_token=20): 97 | super().__init__() 98 | self.select_attr = select_attr 99 | self.mask_token = mask_token 100 | 101 | def __call__(self, data): 102 | mask_flag = (data[self.select_attr] > 0) 103 | 104 | aa_masked = data['aa'].clone() 105 | aa_masked[mask_flag] = self.mask_token 106 | data['aa_true'] = data['aa'] 107 | data['aa_masked'] = aa_masked 108 | 109 | data['pos_atoms'], data['mask_atoms'] = _mask_sidechains( 110 | data['pos_atoms'], data['mask_atoms'], mask_flag 111 | ) 112 | 113 | return data 114 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | datasets/* 165 | .vscode 166 | *.whl 167 | weights 168 | RiNALMo 169 | outputs 170 | cache 171 | wandb 172 | *.pdb 173 | # code_play.ipynb 174 | *.ipynb 175 | hugging_face -------------------------------------------------------------------------------- /pl_modules/data_module.py: -------------------------------------------------------------------------------- 1 | from data.sequence_dataset import CustomSeqCollate 2 | from data.structure_dataset import CustomStructCollate 3 | from data.pri30k_dataset import PRI30kStructCollate 4 | from data import DataRegister 5 | import pytorch_lightning as pl 6 | import diskcache 7 | import pandas as pd 8 | from torch.utils.data import DataLoader 9 | from torch_geometric.loader import DataLoader as GraphLoader 10 | 11 | def get_dataset(data_args:dict=None): 12 | register = DataRegister() 13 | dataset_cls = register[data_args.dataset_type] 14 | return dataset_cls 15 | 16 | def get_collate(dataset_type): 17 | collate_dict = {'sequence_dataset': CustomSeqCollate, 18 | 'structure_dataset': CustomStructCollate, 19 | 'pri30k_dataset': PRI30kStructCollate, 20 | } 21 | return collate_dict[dataset_type] 22 | 23 | class DataModule(pl.LightningDataModule): 24 | def __init__(self, 25 | df_path='', 26 | col_group='fold_0', 27 | batch_size=32, 28 | num_workers=0, 29 | pin_memory=True, 30 | cache_dir=None, 31 | strategy='separate', 32 | dataset_args=None, 33 | **kwargs): 34 | super().__init__() 35 | self.df_path = df_path 36 | self.col_group=col_group 37 | self.batch_size=batch_size 38 | self.num_workers=num_workers 39 | self.pin_memory=pin_memory 40 | self.cache_dir=cache_dir 41 | self.strategy=strategy 42 | self.dataset_args=dataset_args 43 | # print("Dataset Args:", dataset_args) 44 | 45 | def setup(self, stage=None): 46 | if self.cache_dir is None: 47 | cache = None 48 | else: 49 | print("Using diskcache at {}.".format(self.cache_dir)) 50 | cache = diskcache.Cache(directory=self.cache_dir, eviction_policy='none') 51 | 52 | df = pd.read_csv(self.df_path) 53 | df_train = df[df[self.col_group].isin(['train'])] 54 | df_val = df[df[self.col_group].isin(['val'])] 55 | df_test = df[df[self.col_group].isin(['test'])] 56 | dataset_cls = get_dataset(self.dataset_args) 57 | self.train_dataset = dataset_cls(df_train, **self.dataset_args, diskcache=cache) 58 | self.val_dataset = dataset_cls(df_val, **self.dataset_args, diskcache=cache) 59 | 60 | 61 | if len(df_test) > 0 : 62 | print(f"Using Test Fold to test the model!") 63 | self.test_dataset = dataset_cls(df_test, **self.dataset_args, diskcache=cache) 64 | else: 65 | print(f"Using Validation Fold {self.col_group} to test the model!") 66 | self.test_dataset = dataset_cls(df_val, **self.dataset_args, diskcache=cache) 67 | 68 | def train_dataloader(self): 69 | if self.dataset_args.dataset_type != 'graph_dataset': 70 | collate = get_collate(self.dataset_args.dataset_type) 71 | return DataLoader( 72 | self.train_dataset, 73 | batch_size=self.batch_size, 74 | shuffle=True, 75 | num_workers=self.num_workers, 76 | pin_memory=self.pin_memory, 77 | persistent_workers=self.num_workers > 0, 78 | collate_fn=collate(strategy=self.strategy), 79 | ) 80 | else: 81 | return GraphLoader( 82 | self.train_dataset, 83 | batch_size=self.batch_size, 84 | shuffle=True, 85 | num_workers=self.num_workers, 86 | pin_memory=self.pin_memory, 87 | persistent_workers=self.num_workers > 0, 88 | ) 89 | 90 | def val_dataloader(self): 91 | if self.dataset_args.dataset_type != 'graph_dataset': 92 | collate = get_collate(self.dataset_args.dataset_type) 93 | return DataLoader( 94 | self.val_dataset, 95 | batch_size=self.batch_size, 96 | shuffle=False, 97 | num_workers=self.num_workers, 98 | pin_memory=self.pin_memory, 99 | persistent_workers=self.num_workers > 0, 100 | collate_fn=collate(strategy=self.strategy), 101 | ) 102 | else: 103 | return GraphLoader( 104 | self.val_dataset, 105 | batch_size=self.batch_size, 106 | shuffle=True, 107 | num_workers=self.num_workers, 108 | pin_memory=self.pin_memory, 109 | persistent_workers=self.num_workers > 0, 110 | ) 111 | 112 | 113 | def test_dataloader(self): 114 | if self.dataset_args.dataset_type != 'graph_dataset': 115 | collate = get_collate(self.dataset_args.dataset_type) 116 | return DataLoader( 117 | self.test_dataset, 118 | batch_size=self.batch_size, 119 | shuffle=False, 120 | num_workers=self.num_workers, 121 | pin_memory=self.pin_memory, 122 | persistent_workers=self.num_workers > 0, 123 | collate_fn=collate(strategy=self.strategy), 124 | ) 125 | else: 126 | return GraphLoader( 127 | self.test_dataset, 128 | batch_size=self.batch_size, 129 | shuffle=True, 130 | num_workers=self.num_workers, 131 | pin_memory=self.pin_memory, 132 | persistent_workers=self.num_workers > 0, 133 | ) 134 | 135 | -------------------------------------------------------------------------------- /models/encoders/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def mask_zero(mask, value): 6 | return torch.where(mask, value, torch.zeros_like(value)) 7 | 8 | 9 | class DistanceToBins(nn.Module): 10 | 11 | def __init__(self, dist_min=0.0, dist_max=20.0, num_bins=64, use_onehot=False): 12 | super().__init__() 13 | self.dist_min = dist_min 14 | self.dist_max = dist_max 15 | self.num_bins = num_bins 16 | self.use_onehot = use_onehot 17 | 18 | if use_onehot: 19 | offset = torch.linspace(dist_min, dist_max, self.num_bins) 20 | else: 21 | offset = torch.linspace(dist_min, dist_max, self.num_bins - 1) # 1 overflow flag 22 | self.coeff = -0.5 / ((offset[1] - offset[0]) * 0.2).item() ** 2 # `*0.2`: makes it not too blurred 23 | self.register_buffer('offset', offset) 24 | 25 | @property 26 | def out_channels(self): 27 | return self.num_bins 28 | 29 | def forward(self, dist, dim, normalize=True): 30 | """ 31 | Args: 32 | dist: (N, *, 1, *) 33 | Returns: 34 | (N, *, num_bins, *) 35 | """ 36 | assert dist.size()[dim] == 1 37 | offset_shape = [1] * len(dist.size()) 38 | offset_shape[dim] = -1 39 | 40 | if self.use_onehot: 41 | diff = torch.abs(dist - self.offset.view(*offset_shape)) # (N, *, num_bins, *) 42 | bin_idx = torch.argmin(diff, dim=dim, keepdim=True) # (N, *, 1, *) 43 | y = torch.zeros_like(diff).scatter_(dim=dim, index=bin_idx, value=1.0) 44 | else: 45 | overflow_symb = (dist >= self.dist_max).float() # (N, *, 1, *) 46 | y = dist - self.offset.view(*offset_shape) # (N, *, num_bins-1, *) 47 | y = torch.exp(self.coeff * torch.pow(y, 2)) # (N, *, num_bins-1, *) 48 | y = torch.cat([y, overflow_symb], dim=dim) # (N, *, num_bins, *) 49 | if normalize: 50 | y = y / y.sum(dim=dim, keepdim=True) 51 | 52 | return y 53 | 54 | 55 | class PositionalEncoding(nn.Module): 56 | 57 | def __init__(self, num_funcs=6): 58 | super().__init__() 59 | self.num_funcs = num_funcs 60 | self.register_buffer('freq_bands', 2.0 ** torch.linspace(0.0, num_funcs - 1, num_funcs)) 61 | 62 | def get_out_dim(self, in_dim): 63 | return in_dim * (2 * self.num_funcs + 1) 64 | 65 | def forward(self, x): 66 | """ 67 | Args: 68 | x: (..., d). 69 | """ 70 | shape = list(x.shape[:-1]) + [-1] 71 | x = x.unsqueeze(-1) # (..., d, 1) 72 | code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1) 73 | code = code.reshape(shape) 74 | return code 75 | 76 | 77 | class AngularEncoding(nn.Module): 78 | 79 | def __init__(self, num_funcs=3): 80 | super().__init__() 81 | self.num_funcs = num_funcs 82 | self.register_buffer('freq_bands', torch.FloatTensor( 83 | [i + 1 for i in range(num_funcs)] + [1. / (i + 1) for i in range(num_funcs)] 84 | )) 85 | 86 | def get_out_dim(self, in_dim): 87 | return in_dim * (1 + 2 * 2 * self.num_funcs) 88 | 89 | def forward(self, x): 90 | """ 91 | Args: 92 | x: (..., d). 93 | """ 94 | shape = list(x.shape[:-1]) + [-1] 95 | x = x.unsqueeze(-1) # (..., d, 1) 96 | code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1) 97 | code = code.reshape(shape) 98 | return code 99 | 100 | 101 | class LayerNorm(nn.Module): 102 | 103 | def __init__(self, 104 | normal_shape, 105 | gamma=True, 106 | beta=True, 107 | epsilon=1e-10): 108 | """Layer normalization layer 109 | See: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf) 110 | :param normal_shape: The shape of the input tensor or the last dimension of the input tensor. 111 | :param gamma: Add a scale parameter if it is True. 112 | :param beta: Add an offset parameter if it is True. 113 | :param epsilon: Epsilon for calculating variance. 114 | """ 115 | super().__init__() 116 | if isinstance(normal_shape, int): 117 | normal_shape = (normal_shape,) 118 | else: 119 | normal_shape = (normal_shape[-1],) 120 | self.normal_shape = torch.Size(normal_shape) 121 | self.epsilon = epsilon 122 | if gamma: 123 | self.gamma = nn.Parameter(torch.Tensor(*normal_shape)) 124 | else: 125 | self.register_parameter('gamma', None) 126 | if beta: 127 | self.beta = nn.Parameter(torch.Tensor(*normal_shape)) 128 | else: 129 | self.register_parameter('beta', None) 130 | self.reset_parameters() 131 | 132 | def reset_parameters(self): 133 | if self.gamma is not None: 134 | self.gamma.data.fill_(1) 135 | if self.beta is not None: 136 | self.beta.data.zero_() 137 | 138 | def forward(self, x): 139 | mean = x.mean(dim=-1, keepdim=True) 140 | var = ((x - mean) ** 2).mean(dim=-1, keepdim=True) 141 | std = (var + self.epsilon).sqrt() 142 | y = (x - mean) / std 143 | if self.gamma is not None: 144 | y *= self.gamma 145 | if self.beta is not None: 146 | y += self.beta 147 | return y 148 | 149 | def extra_repr(self): 150 | return 'normal_shape={}, gamma={}, beta={}, epsilon={}'.format( 151 | self.normal_shape, self.gamma is not None, self.beta is not None, self.epsilon, 152 | ) 153 | -------------------------------------------------------------------------------- /models/components/coformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from models.components.attention import MultiHeadSelfAttention, FlashMultiHeadSelfAttention 5 | 6 | 7 | from models.components.rope import RotaryPositionEmbedding 8 | 9 | import torch.utils.checkpoint as checkpoint 10 | 11 | class SwiGLU(nn.Module): 12 | """ 13 | Swish-Gated Linear Unit 14 | https://arxiv.org/pdf/2002.05202v1.pdf 15 | In the cited paper beta is set to 1 and is not learnable; 16 | but by the Swish definition it is learnable parameter otherwise 17 | it is SiLU activation function (https://paperswithcode.com/method/swish) 18 | """ 19 | def __init__(self, size_in, size_out, beta_is_learnable=True, bias=True): 20 | """ 21 | Args: 22 | size_in: input embedding dimension 23 | size_out: output embedding dimension 24 | beta_is_learnable: whether beta is learnable or set to 1, learnable by default 25 | bias: whether use bias term, enabled by default 26 | """ 27 | super().__init__() 28 | self.linear = nn.Linear(size_in, size_out, bias=bias) 29 | self.linear_gate = nn.Linear(size_in, size_out, bias=bias) 30 | self.beta = nn.Parameter(torch.ones(1), requires_grad=beta_is_learnable) 31 | 32 | def forward(self, x): 33 | linear_out = self.linear(x) 34 | swish_out = linear_out * torch.sigmoid(self.beta * linear_out) 35 | return swish_out * self.linear_gate(x) 36 | 37 | 38 | class CoFormer(nn.Module): 39 | def __init__(self, embed_dim, pair_dim, num_blocks, num_heads, use_rot_emb=True, attn_qkv_bias=False, transition_dropout=0.0, attention_dropout=0.0, residual_dropout=0.0, transition_factor=4, use_flash_attn=False): 40 | super().__init__() 41 | 42 | self.use_flash_attn = use_flash_attn 43 | 44 | self.blocks = nn.ModuleList( 45 | [ 46 | TransformerBlock(embed_dim, pair_dim, num_heads, use_rot_emb, attn_qkv_bias, transition_dropout, attention_dropout, residual_dropout, transition_factor, use_flash_attn) for _ in range(num_blocks) 47 | ] 48 | ) 49 | 50 | self.final_layer_norm = nn.LayerNorm(embed_dim) 51 | self.pair_final_layer_norm = nn.LayerNorm(pair_dim) 52 | 53 | def forward(self, x, struct_embed, key_padding_mask=None, need_attn_weights=False, attn_mask=None): 54 | attn_weights = None 55 | if need_attn_weights: 56 | attn_weights = [] 57 | 58 | for block in self.blocks: 59 | # x, struct_embed, attn = checkpoint.checkpoint( 60 | # block, 61 | # x, 62 | # struct_embed, 63 | # key_padding_mask, 64 | # need_attn_weights, 65 | # use_reentrant=False 66 | # ) 67 | x, struct_embed, attn = block(x, struct_embed, key_padding_mask, attn_mask) 68 | if need_attn_weights: 69 | attn_weights.append(attn) 70 | 71 | x = self.final_layer_norm(x) 72 | struct_embed = self.pair_final_layer_norm(struct_embed) 73 | return x, struct_embed, attn_weights 74 | 75 | class TransformerBlock(nn.Module): 76 | def __init__(self, embed_dim, pair_dim, num_heads, use_rot_emb=True, attn_qkv_bias=False, transition_dropout=0.0, attention_dropout=0.0, residual_dropout=0.0, transition_factor=4, use_flash_attn=False): 77 | super().__init__() 78 | 79 | self.use_flash_attn = use_flash_attn 80 | 81 | if use_flash_attn: 82 | self.mh_attn = FlashMultiHeadSelfAttention(embed_dim, num_heads, attention_dropout, causal=False, use_rot_emb=use_rot_emb, bias=attn_qkv_bias) 83 | else: 84 | self.mh_attn = MultiHeadSelfAttention(embed_dim, pair_dim, num_heads, attention_dropout, use_rot_emb, attn_qkv_bias) 85 | 86 | self.attn_layer_norm = nn.LayerNorm(embed_dim) 87 | 88 | self.transition = nn.Sequential( 89 | SwiGLU(embed_dim, int(2 / 3 * transition_factor * embed_dim), beta_is_learnable=True, bias=True), 90 | nn.Dropout(p=transition_dropout), 91 | nn.Linear(int(2 / 3 * transition_factor * embed_dim), embed_dim, bias=True), 92 | ) 93 | 94 | 95 | # self.transition_struct = nn.Sequential( 96 | # SwiGLU(embed_dim, int(2 / 3 * transition_factor * embed_dim), beta_is_learnable=True, bias=True), 97 | # nn.Dropout(p=transition_dropout), 98 | # nn.Linear(int(2 / 3 * transition_factor * embed_dim), embed_dim, bias=True), 99 | # ) 100 | 101 | self.out_layer_norm = nn.LayerNorm(embed_dim) 102 | self.pair_layer_norm = nn.LayerNorm(pair_dim) 103 | 104 | self.residual_dropout_1 = nn.Dropout(p=residual_dropout) 105 | self.residual_dropout_2 = nn.Dropout(p=residual_dropout) 106 | 107 | def forward(self, x, struct_embed, key_padding_mask=None, attn_mask=None): 108 | x = self.attn_layer_norm(x) 109 | # if self.use_flash_attn: 110 | # mh_out, attn = self.mh_attn(x, key_padding_mask=key_padding_mask, return_attn_probs=need_attn_weights) 111 | # else: 112 | # Temporarily unable flash_attn 113 | mh_out, struct_out, attn = self.mh_attn(x, struct_embed, attn_mask, key_pad_mask=key_padding_mask) 114 | x = x + self.residual_dropout_1(mh_out) 115 | struct_embed = struct_embed + self.residual_dropout_1(struct_out) 116 | residual = x 117 | struct_residual = struct_embed 118 | x = self.out_layer_norm(x) 119 | struct_embed = self.pair_layer_norm(struct_embed) 120 | x = residual + self.residual_dropout_2(self.transition(x)) 121 | struct_embed = struct_residual + self.residual_dropout_2(struct_embed) 122 | return x, struct_embed, attn -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🥥 CoPRA 2 | 3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
23 |
24 |
25 |
26 | CoPRA is a state-of-the-art predictor of protein-RNA binding affinity. The framework of CoPRA is based on a protein language model and an RNA-language model, with complex structure as input. The model was pre-trained on the PRI30k dataset via a bi-scope stratege and fine-tuned on PRA310. CoPRA can also be redirected to predict mutation effects, showing its strong per-structure prediction performance on mCSM_RNA dataset. Please see more details in [our paper](https://arxiv.org/abs/2409.03773).
27 |
28 | Please do not hesitate to contact us or create an issue/PR if you have any questions or suggestions!
29 |
30 | ## 🛠️ Installation
31 |
32 | **Step 1**. Clone this repository and setup the environment. We recommend you to install the dependencies via the fast package management tool [mamba](https://mamba.readthedocs.io/en/latest/mamba-installation.html) (you can also replace the command 'mamba' with 'conda' to install them). Generally, CoPRA works with Python 3.10.14 and PyTorch version 2.1.2.
33 | ```
34 | git@github.com:hanrthu/CoPRA.git
35 | cd CoPRA
36 | mamba env create -f environment.yml
37 | ```
38 |
39 | **Step 2**. Install flash-attn and rinalmo with the following command, you may also need to download Rinalmo-650M model and place it at `./weights` folder of this repo.
40 | ```
41 | # Download flash-attn-2.6.3 wheel file at https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
42 | pip install flash_attn-2.6.3+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
43 | git clone git@github.com:lbcb-sci/RiNALMo.git
44 | cd RiNALMo
45 | pip install -e .
46 | ```
47 | ## 📖 Datasets and model weights for Protein-RNA binding affinity prediction
48 | Here, we first provide our proposed datasets, including PRA310, PRA201 and PRI30k together with an mCSM_RNA dataset, you can easily access them through 🤗Huggingface: [/Jesse7/CoPRA_data](https://huggingface.co/datasets/Jesse7/CoPRA_data/tree/main). The only difference between PRA201 and PRA310 are the selected samples, thus the PRA201 labels and splits are in PRA310/splits/PRA201.csv. Download these datasets and place them at `./datasets` folder.
49 |
50 | The number of samples of the original dataset is shown below, we take PRA as the abbreviation of Protein-RNA binding affinity:
51 |
52 | | Dataset | Type | Size |
53 | | :---: | :---: | :---: |
54 | | PRA310 | PRA | 310 |
55 | | PRA201 | PRA (pair-only) | 201 |
56 | | PRI30k | Unsupervised complexes | 30006 |
57 | | mCSM-RNA | Mutation effect on PRA | 79 |
58 |
59 |
60 | We also provide a five-fold model checkpoints after pretraining Co-Former with PRI30k and finetune it with PRA310, and they can also be downloaded through 🤗Huggingface: [/Jesse7/CoPRA](https://huggingface.co/Jesse7/CoPRA). This repository also contains a pretrained RiNALMo-650M weights. Download these weights at place them at `./weights` folder.
61 |
62 | The performance of 5-fold cross validation on PRA310 reaches state-of-the-art, and here is the comparison:
63 |
64 |
65 |
66 |
67 |
68 | ## 🚀 Training on the protein-RNA datasets
69 |
70 | **Note1:** It is normal that the first epoch for training on a new dataset is relatively slow, because we need to conduct the caching procedure.
71 |
72 | **Note2:** We also support LoRA tuning and all-param tuning. For LoRA tuning, just specify `lora_tune: true` in `./config/models/copra.yml`. For all-param tuning, just specify `fix_lms: false` in `./config/models/copra.yml`.
73 |
74 | ### Run 5-fold inference on PRA310
75 | ```
76 | python run.py test dG --model_config ./config/models/copra.yml --data_config ./config/datasets/PRA310.yml --run_config ./config/runs/test_basic.yml
77 | ```
78 |
79 | ### Run finetune on PRA310
80 | ```
81 | python run.py finetune dG --model_config ./config/models/copra.yml --data_config ./config/datasets/PRA310.yml --run_config ./config/runs/finetune_struct.yml
82 | ```
83 |
84 | ### Run finetune on PRA201
85 | ```
86 | python run.py finetune dG --model_config ./config/models/copra.yml --data_config ./config/datasets/PRA201.yml --run_config ./config/runs/finetune_struct.yml
87 | ```
88 |
89 | ### Run Bi-scope Pre-training on PRI30k
90 | ```
91 | python run.py finetune pretune --model_config ./config/models/copra.yml --data_config ./config/datasets/biolip.yml --run_config ./config/runs/pretune_struct.yml
92 | ```
93 | After pretraining, you can continue to finetune on a new dataset with the finetuning scripts and the specification of ckpt for the pretrained model in config/runs/finetune_struct.yml
94 |
95 | ## 🚀 Zero-shot Blind-test on the protein-RNA mutation effect datasets
96 |
97 | ```
98 | python run.py test ddG --model_config ./config/models/copra.yml --data_config ./config/datasets/blindtest.yml --run_config ./config/runs/zero_shot_blindtest.yml
99 | ```
100 |
101 | ## 🖌️ Citation
102 | If you find our repo useful, please kindly consider citing:
103 | ```
104 | @article{han2024copra,
105 | title={CoPRA: Bridging Cross-domain Pretrained Sequence Models with Complex Structures for Protein-RNA Binding Affinity Prediction},
106 | author={Han, Rong and Liu, Xiaohong and Pan, Tong and Xu, Jing and Wang, Xiaoyu and Lan, Wuyang and Li, Zhenyu and Wang, Zixuan and Song, Jiangning and Wang, Guangyu and others},
107 | journal={arXiv preprint arXiv:2409.03773},
108 | year={2024}
109 | }
110 | ```
--------------------------------------------------------------------------------
/models/components/valina_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from models.components.valina_attn import MultiHeadSelfAttention, FlashMultiHeadSelfAttention
6 |
7 | import torch.utils.checkpoint as checkpoint
8 |
9 | class TokenDropout(nn.Module):
10 | def __init__(
11 | self,
12 | active: bool,
13 | mask_ratio: float,
14 | mask_tkn_prob: float,
15 | mask_tkn_idx: int,
16 | pad_tkn_idx: int,
17 | ):
18 | super().__init__()
19 |
20 | self.active = active
21 |
22 | self.mask_ratio_train = mask_ratio * mask_tkn_prob
23 |
24 | self.mask_tkn_idx = mask_tkn_idx
25 | self.pad_tkn_idx = pad_tkn_idx
26 |
27 | def forward(self, x, tokens):
28 | if self.active:
29 | pad_mask = tokens.eq(self.pad_tkn_idx)
30 | src_lens = (~pad_mask).sum(dim=-1)
31 |
32 | x = torch.where((tokens == self.mask_tkn_idx).unsqueeze(dim=-1), 0.0, x)
33 | mask_ratio_observed = (tokens == self.mask_tkn_idx).sum(dim=-1) / src_lens
34 | x = x * (1 - self.mask_ratio_train) / (1 - mask_ratio_observed[..., None, None])
35 |
36 | return x
37 |
38 | class Transformer(nn.Module):
39 | def __init__(self, embed_dim, num_blocks, num_heads, use_rot_emb=True, attn_qkv_bias=False, transition_dropout=0.0, attention_dropout=0.0, residual_dropout=0.0, transition_factor=4, use_flash_attn=False):
40 | super().__init__()
41 |
42 | self.use_flash_attn = use_flash_attn
43 |
44 | self.blocks = nn.ModuleList(
45 | [
46 | TransformerBlock(embed_dim, num_heads, use_rot_emb, attn_qkv_bias, transition_dropout, attention_dropout, residual_dropout, transition_factor, use_flash_attn) for _ in range(num_blocks)
47 | ]
48 | )
49 |
50 | self.final_layer_norm = nn.LayerNorm(embed_dim)
51 |
52 | def forward(self, x, key_padding_mask=None, need_attn_weights=False):
53 | attn_weights = None
54 | if need_attn_weights:
55 | attn_weights = []
56 |
57 | for i, block in enumerate(self.blocks):
58 | # x, attn = checkpoint.checkpoint(
59 | # block,
60 | # x,
61 | # key_padding_mask=key_padding_mask,
62 | # need_attn_weights=need_attn_weights,
63 | # use_reentrant=False
64 | # )
65 |
66 | x, attn = block(x, key_padding_mask, need_attn_weights)
67 | if torch.isnan(x).any():
68 | print("Find Nan in X! at layer {}".format(i), x.shape)
69 | if need_attn_weights:
70 | attn_weights.append(attn)
71 |
72 | x = self.final_layer_norm(x)
73 |
74 | return x, attn_weights
75 |
76 | class SwiGLU(nn.Module):
77 | """
78 | Swish-Gated Linear Unit
79 | https://arxiv.org/pdf/2002.05202v1.pdf
80 | In the cited paper beta is set to 1 and is not learnable;
81 | but by the Swish definition it is learnable parameter otherwise
82 | it is SiLU activation function (https://paperswithcode.com/method/swish)
83 | """
84 | def __init__(self, size_in, size_out, beta_is_learnable=True, bias=True):
85 | """
86 | Args:
87 | size_in: input embedding dimension
88 | size_out: output embedding dimension
89 | beta_is_learnable: whether beta is learnable or set to 1, learnable by default
90 | bias: whether use bias term, enabled by default
91 | """
92 | super().__init__()
93 | self.linear = nn.Linear(size_in, size_out, bias=bias)
94 | self.linear_gate = nn.Linear(size_in, size_out, bias=bias)
95 | self.beta = nn.Parameter(torch.ones(1), requires_grad=beta_is_learnable)
96 |
97 | def forward(self, x):
98 | linear_out = self.linear(x)
99 | swish_out = linear_out * torch.sigmoid(self.beta * linear_out)
100 | return swish_out * self.linear_gate(x)
101 |
102 | class TransformerBlock(nn.Module):
103 | def __init__(self, embed_dim, num_heads, use_rot_emb=True, attn_qkv_bias=False, transition_dropout=0.0, attention_dropout=0.0, residual_dropout=0.0, transition_factor=4, use_flash_attn=False):
104 | super().__init__()
105 |
106 | self.use_flash_attn = use_flash_attn
107 |
108 | if use_flash_attn:
109 | self.mh_attn = FlashMultiHeadSelfAttention(embed_dim, num_heads, attention_dropout, causal=False, use_rot_emb=use_rot_emb, bias=attn_qkv_bias)
110 | else:
111 | self.mh_attn = MultiHeadSelfAttention(embed_dim, num_heads, attention_dropout, use_rot_emb, attn_qkv_bias)
112 |
113 | self.attn_layer_norm = nn.LayerNorm(embed_dim)
114 |
115 | self.transition = nn.Sequential(
116 | SwiGLU(embed_dim, int(2 / 3 * transition_factor * embed_dim), beta_is_learnable=True, bias=True),
117 | nn.Dropout(p=transition_dropout),
118 | nn.Linear(int(2 / 3 * transition_factor * embed_dim), embed_dim, bias=True),
119 | )
120 | self.out_layer_norm = nn.LayerNorm(embed_dim)
121 |
122 | self.residual_dropout_1 = nn.Dropout(p=residual_dropout)
123 | self.residual_dropout_2 = nn.Dropout(p=residual_dropout)
124 |
125 | def forward(self, input, key_padding_mask=None, need_attn_weights=None):
126 | x = self.attn_layer_norm(input)
127 | if self.use_flash_attn:
128 | mh_out, attn = self.mh_attn(x, key_padding_mask=key_padding_mask, return_attn_probs=need_attn_weights)
129 | else:
130 | mh_out, attn = self.mh_attn(x, attn_mask=None, key_pad_mask=key_padding_mask)
131 | x = x + self.residual_dropout_1(mh_out)
132 |
133 | residual = x
134 | x = self.out_layer_norm(x)
135 | x = residual + self.residual_dropout_2(self.transition(x))
136 | if torch.isnan(x).any():
137 | print("Find Nan in X!", x.shape)
138 | return x, attn
139 |
140 | class MaskedLanguageModelHead(nn.Module):
141 | def __init__(self, embed_dim, alphabet_size):
142 | super().__init__()
143 |
144 | self.linear1 = nn.Linear(embed_dim, embed_dim)
145 | self.layer_norm = nn.LayerNorm(embed_dim)
146 | self.linear2 = nn.Linear(embed_dim, alphabet_size)
147 |
148 | def forward(self, x):
149 | x = self.linear1(x)
150 | x = F.gelu(x)
151 | x = self.layer_norm(x)
152 | x = self.linear2(x)
153 |
154 | return x
155 |
--------------------------------------------------------------------------------
/data/complex.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import dataclasses
3 | import sys
4 | from pathlib import Path
5 | import io
6 | import warnings
7 | warnings.filterwarnings("ignore")
8 | from dataclasses import dataclass
9 | sys.path.append('/home/CoPRA/')
10 | import data.protein.proteins as proteins
11 | from data.protein.atom_convert import atom37_to_atom14
12 | from data.protein.proteins import chains_from_cif_string, chains_from_pdb_string
13 | import data.rna.rnas as rnas
14 |
15 | rna_residues = ['A', 'G', 'C', 'U']
16 | protein_residues = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL']
17 | SUPER_PROT_IDX = 27
18 | SUPER_RNA_IDX = 28
19 | SUPER_CPLX_IDX = 29
20 | SUPER_CHAIN_IDX = 4
21 | PADDING_NODE_IDX = 26
22 |
23 | @dataclass
24 | class ComplexInput:
25 | seq: str # L
26 | mask: np.ndarray # (L, )
27 | restype: np.ndarray # (L, ) # In total 21 + 5 = 26 types, including '_' and 'X'
28 | res_nb: np.ndarray # (L, )
29 | prot_seqs: list
30 | na_seqs: list
31 | atom_mask: np.ndarray # (L, 37 + 27)
32 | atom_positions: np.ndarray #(L, 37 + 27, 3)
33 |
34 | atom41_mask: np.ndarray # (L, 14 + 27)
35 | atom41_positions: np.ndarray #(L, 14 + 27, 3)
36 |
37 | identifier: np.ndarray #(L, ), to identify rna or protein
38 | chainid: np.ndarray # (L, ), to identify chain of the complex
39 |
40 | @classmethod
41 | def from_path(self, path, valid_rna_chains=None, valid_prot_chains=None):
42 | if isinstance(path, io.IOBase):
43 | file_string = path.read()
44 | else:
45 | # print(path)
46 | path = Path(path)
47 | file_string = path.read_text()
48 | if valid_prot_chains is None or valid_rna_chains is None:
49 | valid_prot_chains = []
50 | valid_rna_chains = []
51 | if '.pdb' in str(path):
52 | chains = chains_from_pdb_string(file_string)
53 | elif '.cif' in str(path):
54 | chains = chains_from_cif_string(file_string)
55 |
56 | for chain in chains:
57 | for residue in chain:
58 | if residue.get_resname() in protein_residues:
59 | valid_prot_chains.append(chain.get_full_id()[2])
60 | break
61 | if residue.get_resname() in rna_residues:
62 | valid_rna_chains.append(chain.get_full_id()[2])
63 | break
64 | # valid_prot_chains = list(set(valid_prot_chains))
65 | # print(valid_prot_chains, valid_rna_chains)
66 | protein = proteins.ProteinInput.from_path(path, with_angles=False, return_dict=True, valid_chains=valid_prot_chains)
67 | rna = rnas.RNAInput.from_path(path, valid_rna_chains)
68 | complex_dict = complex_merge([protein[chain] for chain in valid_prot_chains], [rna[chain] for chain in valid_rna_chains])
69 | return self(**complex_dict)
70 |
71 | @property
72 | def length(self):
73 | return len(self.seq)
74 |
75 | def __repr__(self):
76 | return self.__str__()
77 |
78 | def __str__(self):
79 | texts = []
80 | texts += [f'seq: {self.seq}']
81 | texts += [f'length: {len(self.seq)}']
82 | texts += [f"mask: {''.join(self.mask.astype('int').astype('str'))}"]
83 | if self.chainid is not None:
84 | texts += [f"chainid: {''.join(self.chainid.astype('int').astype('str'))}"]
85 | texts += [f"identifier: {''.join(self.identifier.astype('int').astype('str'))}"]
86 | names = [
87 | 'restype',
88 | 'atom_mask',
89 | 'atom_positions',
90 | ]
91 | for name in names:
92 | value = getattr(self, name)
93 | if value is None:
94 | text = f'{name}: None'
95 | else:
96 | text = f'{name}: {value.shape}'
97 | texts += [text]
98 | text = ', \n '.join(texts)
99 | text = f'Protein-RNA Complex(\n {text}\n)'
100 | return text
101 |
102 |
103 | def complex_merge(protein, rna):
104 | assert len(protein) > 0
105 | assert len(rna) > 0
106 |
107 | p_lengths = [p.length for i, p in enumerate(protein)]
108 | r_lengths = [r.length for i, r in enumerate(rna)]
109 | lengths = p_lengths + r_lengths
110 | prot_list = []
111 | na_list = []
112 | seq = "".join([item.seq for item in protein + rna])
113 | # identifier = np.array([0] * len(protein) + [1] * len(rna), dtype=np.int32)
114 | mask = np.concatenate([item.mask for item in protein + rna] )
115 | chain_arr = np.concatenate([[i] * p for i, p in enumerate(lengths)]).astype('int')
116 | res_nb = np.zeros([len(seq)])
117 | restype = np.zeros([len(seq)])
118 | atom_positions = np.zeros([len(seq), 37 + 27, 3])
119 | atom41_positions = np.zeros([len(seq), 14 + 27, 3])
120 | atom41_masks = np.zeros([len(seq), 14 + 27])
121 | atom_masks = np.zeros([len(seq), 37 + 27])
122 | identifier = np.zeros([len(seq)])
123 | curr_idx = 0
124 | for item in protein:
125 | prot_list.append(item.seq)
126 | res_nb[curr_idx: curr_idx+item.length] = item.res_nb
127 | restype[curr_idx: curr_idx+item.length] = item.aatype
128 | identifier[curr_idx: curr_idx+item.length] = 0
129 | atom_positions[curr_idx: curr_idx+item.length, :37, :] = item.atom_positions
130 | atom14, mask_14, arrs = atom37_to_atom14(item.aatype, item.atom_positions, [item.atom_mask])
131 | mask_14 = arrs[0] * mask_14
132 | atom41_positions[curr_idx: curr_idx+item.length, :14, :] = atom14
133 | atom41_masks[curr_idx: curr_idx+item.length, :14] = mask_14
134 | atom_masks[curr_idx: curr_idx+item.length, :37] = item.atom_mask
135 | curr_idx += item.length
136 | for item in rna:
137 | na_list.append(item.seq)
138 | res_nb[curr_idx: curr_idx+item.length] = item.res_nb
139 | restype[curr_idx: curr_idx+item.length] = item.basetype + 21
140 | identifier[curr_idx: curr_idx+item.length] = 1
141 | atom_positions[curr_idx: curr_idx+item.length, 37:, :] = item.atom_positions
142 | atom41_positions[curr_idx: curr_idx+item.length, 14:, :] = item.atom_positions
143 | atom41_masks[curr_idx: curr_idx+item.length, 14:] = item.atom_mask
144 | atom_masks[curr_idx: curr_idx+item.length, 37:] = item.atom_mask
145 | curr_idx += item.length
146 |
147 | complex_dict = {
148 | 'seq': seq,
149 | 'mask': mask,
150 | 'restype': restype,
151 | 'res_nb': res_nb,
152 |
153 | 'prot_seqs': prot_list,
154 | 'na_seqs': na_list,
155 |
156 | 'atom_mask': atom_masks,
157 | 'atom_positions': atom_positions,
158 |
159 | 'atom41_mask': atom41_masks,
160 | 'atom41_positions': atom41_positions,
161 |
162 | 'identifier': identifier,
163 | 'chainid': chain_arr
164 | }
165 |
166 | return complex_dict
167 |
168 | if __name__ == '__main__':
169 | comp = ComplexInput.from_path('./datasets/PRA310/PDBs/1RPU.pdb')
170 | print("Complex:", comp)
171 | print(comp.atom_positions[1])
--------------------------------------------------------------------------------
/data/rna/sec_struct_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | from datetime import datetime
4 | import numpy as np
5 | import wandb
6 | from typing import Any, List, Literal, Optional
7 |
8 | from Bio import SeqIO
9 | from Bio.Seq import Seq
10 | from Bio.SeqRecord import SeqRecord
11 |
12 | import biotite
13 | from biotite.structure.io import load_structure
14 | from biotite.structure import dot_bracket_from_structure
15 |
16 | from data.rna.base_constants import (
17 | PROJECT_PATH,
18 | X3DNA_PATH,
19 | ETERNAFOLD_PATH,
20 | DOTBRACKET_TO_NUM
21 | )
22 |
23 |
24 | def pdb_to_sec_struct(
25 | pdb_file_path: str,
26 | sequence: str,
27 | keep_pseudoknots: bool = False,
28 | x3dna_path: str = os.path.join(X3DNA_PATH, "bin/find_pair"),
29 | max_len_for_biotite: int = 1000,
30 | ) -> str:
31 | """
32 | Get secondary structure in dot-bracket notation from a PDB file.
33 |
34 | Args:
35 | pdb_file_path (str): Path to PDB file.
36 | sequence (str): Sequence of RNA molecule.
37 | keep_pseudoknots (bool, optional): Whether to keep pseudoknots in
38 | secondary structure. Defaults to False.
39 | x3dna_path (str, optional): Path to x3dna find_pair tool.
40 | max_len_for_biotite (int, optional): Maximum length of sequence for
41 | which to use biotite. Otherwise use X3DNA Defaults to 1000.
42 | """
43 | if len(sequence) < max_len_for_biotite:
44 | try:
45 | # get secondary structure using biotite
46 | atom_array = load_structure(pdb_file_path)
47 | sec_struct = dot_bracket_from_structure(atom_array)[0]
48 | if not keep_pseudoknots:
49 | # replace all characters that are not '.', '(', ')' with '.'
50 | sec_struct = "".join([dotbrac if dotbrac in ['.', '(', ')'] else '.' for dotbrac in sec_struct])
51 |
52 | except Exception as e:
53 | # biotite fails for very short seqeunces
54 | if "out of bounds for array" not in str(e): raise e
55 | # get secondary structure using x3dna find_pair tool
56 | # does not support pseudoknots
57 | sec_struct = x3dna_to_sec_struct(
58 | pdb_to_x3dna(pdb_file_path, x3dna_path),
59 | sequence
60 | )
61 |
62 | else:
63 | # get secondary structure using x3dna find_pair tool
64 | # does not support pseudoknots
65 | sec_struct = x3dna_to_sec_struct(
66 | pdb_to_x3dna(pdb_file_path, x3dna_path),
67 | sequence
68 | )
69 |
70 | return sec_struct
71 |
72 | def pdb_to_x3dna(
73 | pdb_file_path: str,
74 | x3dna_path: str = os.path.join(X3DNA_PATH, "bin/find_pair")
75 | ) -> List[str]:
76 | # Run x3dna find_pair tool
77 | cmd = [
78 | x3dna_path,
79 | pdb_file_path,
80 | ]
81 | output = subprocess.run(cmd, check=True, capture_output=True).stdout.decode("utf-8")
82 | output = output.split("\n")
83 |
84 | # Delete temporary files
85 | # os.remove("./bestpairs.pdb")
86 | # os.remove("./bp_order.dat")
87 | # os.remove("./col_chains.scr")
88 | # os.remove("./col_helices.scr")
89 | # os.remove("./hel_regions.pdb")
90 | # os.remove("./ref_frames.dat")
91 |
92 | return output
93 |
94 |
95 | def x3dna_to_sec_struct(output: List[str], sequence: str) -> str:
96 | # Secondary structure in dot-bracket notation
97 | num_base_pairs = int(output[3].split()[0])
98 | sec_struct = ["."] * len(sequence)
99 | for i in range(1, num_base_pairs+1):
100 | line = output[4 + i].split()
101 | start, end = int(line[0]), int(line[1])
102 | sec_struct[start-1] = "("
103 | sec_struct[end-1] = ")"
104 | return "".join(sec_struct)
105 |
106 |
107 | def predict_sec_struct(
108 | sequence: Optional[str] = None,
109 | fasta_file_path: Optional[str] = None,
110 | eternafold_path: str = os.path.join(ETERNAFOLD_PATH, "src/contrafold"),
111 | n_samples: int = 1,
112 | ) -> str:
113 | """
114 | Predict secondary structure using EternaFold.
115 |
116 | Notes:
117 | - EternaFold does not support pseudoknots.
118 | - EternaFold only supports single chains in a fasta file.
119 | - When sampling multiple structures, EternaFold only supports nsamples=100.
120 |
121 | Args:
122 | sequence (str, optional): Sequence of RNA molecule. Defaults to None.
123 | fasta_file_path (str, optional): Path to fasta file. Defaults to None.
124 | eternafold_path (str, optional): Path to EternaFold. Defaults to ETERNAFOLD_PATH env variable.
125 | n_samples (int, optional): Number of samples to take. Defaults to 1.
126 | """
127 | if sequence is not None:
128 | assert fasta_file_path is None
129 | # Write sequence to temporary fasta file
130 | current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
131 | try:
132 | fasta_file_path = os.path.join(wandb.run.dir, f"temp_{current_datetime}.fasta")
133 | except AttributeError:
134 | fasta_file_path = os.path.join(PROJECT_PATH, f"temp_{current_datetime}.fasta")
135 | SeqIO.write(
136 | SeqRecord(Seq(sequence), id="temp"),
137 | fasta_file_path, "fasta"
138 | )
139 |
140 | # Run EternaFold
141 | if n_samples > 1:
142 | assert n_samples == 100, "EternaFold using subprocess only supports nsamples=100"
143 | cmd = [
144 | eternafold_path,
145 | "sample",
146 | fasta_file_path,
147 | # f" --nsamples {n_samples}",
148 | # It seems like EternaFold using subprocess can only sample the default nsamples=100...
149 | # Reason: unknown for now
150 | ]
151 | else:
152 | cmd = [
153 | eternafold_path,
154 | "predict",
155 | fasta_file_path,
156 | ]
157 |
158 | output = subprocess.run(cmd, check=True, capture_output=True).stdout.decode("utf-8")
159 |
160 | # Delete temporary files
161 | if sequence is not None:
162 | os.remove(fasta_file_path)
163 |
164 | if n_samples > 1:
165 | return output.split("\n")[:-1]
166 | else:
167 | return [output.split("\n")[-2]]
168 |
169 |
170 | def dotbracket_to_paired(sec_struct: str) -> np.ndarray:
171 | """
172 | Return whether each residue is paired (1) or unpaired (0) given
173 | secondary structure in dot-bracket notation.
174 | """
175 | is_paired = np.zeros(len(sec_struct), dtype=np.int8)
176 | for i, c in enumerate(sec_struct):
177 | if c == '(' or c == ')':
178 | is_paired[i] = 1
179 | return is_paired
180 |
181 |
182 | def dotbracket_to_num(sec_struct: str) -> np.ndarray:
183 | """
184 | Convert secondary structure in dot-bracket notation to
185 | numerical representation.
186 | """
187 | return np.array([DOTBRACKET_TO_NUM[c] for c in sec_struct])
188 |
189 |
190 | def dotbracket_to_adjacency(sec_struct: str) -> np.ndarray:
191 | """
192 | Convert secondary structure in dot-bracket notation to
193 | adjacency matrix.
194 | """
195 | n = len(sec_struct)
196 | adj = np.zeros((n, n), dtype=np.int8)
197 | stack = []
198 | for i, db_char in enumerate(sec_struct):
199 | if db_char == '(':
200 | stack.append(i)
201 | elif db_char == ')':
202 | j = stack.pop()
203 | adj[i, j] = 1
204 | adj[j, i] = 1
205 | return adj
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | os.environ["NUMEXPR_MAX_THREADS"] = '56'
4 | os.environ["MKL_NUM_THREADS"] = '4'
5 | os.environ["OMP_NUM_THREADS"] = '4'
6 | import fire
7 | from pathlib import Path
8 | import pandas as pd
9 |
10 | import numpy as np
11 | import yaml
12 | import wandb
13 | import time
14 | from easydict import EasyDict
15 | import torch
16 | import pytorch_lightning as pl
17 | from pytorch_lightning.loggers import CSVLogger, WandbLogger
18 | from pytorch_lightning.callbacks import TQDMProgressBar, EarlyStopping, ModelCheckpoint, ModelSummary
19 | from pytorch_lightning.strategies.ddp import DDPStrategy
20 | from pl_modules import ModelModule, DataModule, PretuneModule, DDGModule
21 | from collections import defaultdict
22 |
23 | torch.set_num_threads(16)
24 |
25 | def parse_yaml(yaml_dir):
26 | with open(yaml_dir, 'r') as f:
27 | content = f.read()
28 | config_dict = EasyDict(yaml.load(content, Loader=yaml.FullLoader))
29 | # args = Namespace(**config_dict)
30 | return config_dict
31 | def init_pytorch_settings():
32 | # Multiprocess Setting to speedup dataloader
33 | torch.multiprocessing.set_start_method('forkserver')
34 | torch.multiprocessing.set_sharing_strategy('file_system')
35 | # torch.set_float32_matmul_precision('high')
36 | torch.set_num_threads(4)
37 | torch.backends.cuda.matmul.allow_tf32 = True
38 | torch.backends.cudnn.allow_tf32 = True
39 |
40 | class LightningRunner(object):
41 | def __init__(self, model_config='./config/models/esm2_rinalmo.yaml', data_config='./config/datasets/rpi.yaml',
42 | run_config='./config/runs/finetune_sequence.yaml'):
43 | super(LightningRunner, self).__init__()
44 | self.model_args = parse_yaml(model_config)
45 | self.dataset_args = parse_yaml(data_config)
46 | self.run_args = parse_yaml(run_config)
47 | init_pytorch_settings()
48 |
49 | def save_model(self, model, output_dir, trainer):
50 | print("Best Model Path:", trainer.checkpoint_callback.best_model_path)
51 | module = ModelModule.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
52 | if trainer.global_rank == 0:
53 | best_model = module.model
54 | (output_dir / 'model_data.json').write_text(json.dumps(vars(self.dataset_args), indent=2))
55 | torch.save(best_model, str(output_dir / 'model.pt'))
56 |
57 | def select_module(self, stage, log_dir):
58 | if stage=='pretune':
59 | model = PretuneModule(output_dir=log_dir, model_args=self.model_args, data_args=self.dataset_args, run_args=self.run_args)
60 | elif stage=='dG':
61 | model = ModelModule(output_dir=log_dir, model_args=self.model_args, data_args=self.dataset_args, run_args=self.run_args)
62 | elif stage=='ddG':
63 | model = DDGModule(output_dir=log_dir, model_args=self.model_args, data_args=self.dataset_args, run_args=self.run_args)
64 | else:
65 | raise NotImplementedError
66 | return model
67 |
68 | def finetune(self, stage='dG'):
69 | print("Run args:", self.run_args, "\n")
70 | print("Model args:", self.model_args, "\n")
71 | print("Dataset args:", self.dataset_args, "\n")
72 | output_dir, gpus = (self.run_args.output_dir, self.run_args.gpus)
73 | self.model_args.model.stage = stage
74 | # Setup datamodule
75 | run_results = []
76 | for k in range(self.run_args.num_folds):
77 | # if k != 4:
78 | # continue
79 | print(f"Training fold {k} Started!")
80 | output_dir = Path(output_dir)
81 | log_dir = output_dir / f'log_fold_{k}'
82 | data_module = DataModule(dataset_args=self.dataset_args, **self.dataset_args, col_group=f'fold_{k}')
83 |
84 | # Setup model module
85 | model = self.select_module(stage, log_dir)
86 | # Trainer setting
87 | name = self.run_args.run_name + time.strftime("%Y-%m-%d-%H-%M-%S")
88 | if self.run_args.wandb:
89 | wandb.init(project='copra', name=name)
90 | logger = WandbLogger()
91 | else:
92 | logger = CSVLogger(str(log_dir))
93 | # version_dir = Path(logger_csv.log_dir)
94 | pl.seed_everything(self.model_args.train.seed)
95 | print("Successfully initialized, start trainer...")
96 | strategy=DDPStrategy(find_unused_parameters=True)
97 | # strategy.lightning_restore_optimizer = False
98 | trainer = pl.Trainer(
99 | devices=gpus,
100 | # max_steps=self.run_args.iters,
101 | max_epochs=self.run_args.epochs,
102 | logger=logger,
103 | callbacks=[
104 | EarlyStopping(monitor="val_loss", mode="min", patience=self.run_args.patience, strict=False),
105 | ModelCheckpoint(dirpath=(log_dir / 'checkpoint'), filename='{epoch}-{val_loss:.3f}',
106 | monitor="val_loss", mode="min", save_last=True, save_top_k=3),
107 | ],
108 | # gradient_clip_val=self.model_args.train.max_grad_norm if self.model_args.train.max_grad_norm is not None else None,
109 | # gradient_clip_algorithm='norm' if self.model_args.train.max_grad_norm is not None else None,
110 | strategy=strategy,
111 | log_every_n_steps=3,
112 | )
113 | trainer.fit(model=model, datamodule=data_module, ckpt_path=self.run_args.ckpt)
114 | print(f"Training fold {k} Finished!")
115 | trainer.strategy.barrier()
116 | print("Best Validation Results:")
117 | _ = trainer.test(model=model, ckpt_path="best", datamodule=data_module)
118 | res = model.res
119 | run_results.append(res)
120 | if trainer.global_rank == 0:
121 | self.save_model(model, output_dir, trainer)
122 | result_dir = Path(output_dir) / name
123 | os.makedirs(result_dir, exist_ok=True)
124 | with open(result_dir / 'res.json', 'w') as f:
125 | json.dump(run_results, f)
126 | results_df = pd.DataFrame(run_results)
127 | print(results_df.describe())
128 |
129 | def test(self, stage='dG'):
130 | print("Args:", self.run_args, self.dataset_args, self.model_args)
131 | output_dir, ckpts, gpus = (self.run_args.output_dir, self.run_args.ckpts,
132 | self.run_args.gpus)
133 | run_results = []
134 | for k in range(self.run_args.num_folds):
135 | output_dir = Path(output_dir)
136 | log_dir = output_dir / f'log_fold_{k}'
137 | data_module = DataModule(dataset_args=self.dataset_args, **self.dataset_args, col_group=f'fold_{k}')
138 | # data_module.setup()
139 | model = self.select_module(stage, log_dir)
140 | logger = CSVLogger(str(log_dir))
141 | strategy=DDPStrategy(find_unused_parameters=True)
142 | # strategy.lightning_restore_optimizer = False
143 | trainer = pl.Trainer(
144 | devices=gpus,
145 | max_epochs=0,
146 | logger=[
147 | logger,
148 | ],
149 | callbacks=[
150 | TQDMProgressBar(refresh_rate=1),
151 | ],
152 | strategy=strategy,
153 | )
154 |
155 | _ = trainer.test(model=model, ckpt_path=ckpts[k], datamodule=data_module)
156 | res = model.res
157 | run_results.append(res)
158 | if trainer.global_rank == 0:
159 | results_df = pd.DataFrame(run_results)
160 | print(results_df.describe())
161 |
162 |
163 | if __name__ == '__main__':
164 | fire.Fire(LightningRunner)
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import argparse
4 | import numpy as np
5 | import pandas as pd
6 | import scipy.stats as stats
7 | from sklearn.metrics import roc_auc_score
8 | import torch.nn.functional as F
9 | from sklearn.metrics import mean_squared_error
10 | from models.components.loss import PearsonCorrLoss
11 |
12 |
13 | class BlackHole(object):
14 | def __setattr__(self, name, value):
15 | pass
16 |
17 | def __call__(self, *args, **kwargs):
18 | return self
19 |
20 | def __getattr__(self, name):
21 | return self
22 |
23 |
24 | class ScalarMetricAccumulator(object):
25 |
26 | def __init__(self):
27 | super().__init__()
28 | self.accum_dict = {}
29 | self.count_dict = {}
30 |
31 | @torch.no_grad()
32 | def add(self, name, value, batchsize=None, mode=None):
33 | assert mode is None or mode in ('mean', 'sum')
34 |
35 | if mode is None:
36 | delta = value.sum()
37 | count = value.size(0)
38 | elif mode == 'mean':
39 | delta = value * batchsize
40 | count = batchsize
41 | elif mode == 'sum':
42 | delta = value
43 | count = batchsize
44 | delta = delta.item() if isinstance(delta, torch.Tensor) else delta
45 |
46 | if name not in self.accum_dict:
47 | self.accum_dict[name] = 0
48 | self.count_dict[name] = 0
49 | self.accum_dict[name] += delta
50 | self.count_dict[name] += count
51 |
52 | def log(self, it, tag, logger=BlackHole(), writer=BlackHole()):
53 | summary = {k: self.accum_dict[k] / self.count_dict[k] for k in self.accum_dict}
54 | logstr = '[%s] Iter %05d' % (tag, it)
55 | for k, v in summary.items():
56 | logstr += ' | %s %.4f' % (k, v)
57 | writer.add_scalar('%s/%s' % (tag, k), v, it)
58 | logger.info(logstr)
59 |
60 | def get_average(self, name):
61 | return self.accum_dict[name] / self.count_dict[name]
62 |
63 |
64 | def per_complex_corr(df, pred_attr='y_pred', true_attr='y_true', limit=10):
65 | corr_table = []
66 | for cplx in df['complex'].unique():
67 | df_cplx = df.query(f'complex == "{cplx}"')
68 | if len(df_cplx) <= 2:
69 | continue
70 | # if len(df_cplx) < limit:
71 | # continue
72 | y_pred = np.array(df_cplx[pred_attr])
73 | y_true = np.array(df_cplx[true_attr])
74 | # print("HIIII!", cal_pearson(y_pred, y_true))
75 | # print("Pred_and True:", y_pred, y_true)
76 | corr_table.append({
77 | 'complex': cplx,
78 | 'pearson': abs(cal_pearson(y_pred, y_true)),
79 | 'spearman': abs(cal_spearman(y_pred, y_true)),
80 | 'rmse': cal_rmse(y_pred, y_true),
81 | 'mae': cal_mae(y_pred, y_true)
82 | })
83 | # print("Corr_tabel:", corr_table)
84 | corr_table = pd.DataFrame(corr_table)
85 | corr_table.fillna(0)
86 | avg = corr_table[['pearson', 'spearman', 'rmse', 'mae']].mean()
87 | return avg['pearson'], avg['spearman'], avg['rmse'], avg['mae']
88 |
89 |
90 | def per_complex_acc(df, pred_attr='y_pred', true_attr='y_true', limit=10):
91 | acc_table = []
92 | for cplx in df['complex'].unique():
93 | df_cplx = df.query(f'complex == "{cplx}"')
94 | if len(df_cplx) <= 2:
95 | continue
96 | y_pred = np.array(df_cplx[pred_attr])
97 | y_true = np.array(df_cplx[true_attr])
98 | acc_table.append({
99 | 'complex': cplx,
100 | 'accuracy': cal_accuracy(y_pred, y_true),
101 | # 'auc': cal_auc(y_pred, y_true),
102 | 'precision': cal_precision(y_pred, y_true),
103 | 'recall': cal_recall(y_pred, y_true)
104 | })
105 | acc_table = pd.DataFrame(acc_table)
106 | # avg = acc_table[['accuracy', 'auc', 'precision', 'recall']].mean()
107 | # return avg['accuracy'], avg['auc'], avg['precision'], avg['recall']
108 | avg = acc_table[['accuracy', 'precision', 'recall']].mean()
109 | return avg['accuracy'], avg['precision'], avg['recall']
110 |
111 |
112 | def sum_weighted_losses(losses, weights):
113 | """
114 | Args:
115 | losses: Dict of scalar tensors.
116 | weights: Dict of weights.
117 | """
118 | loss = 0
119 | for k in losses.keys():
120 | if weights is None:
121 | loss = loss + losses[k]
122 | else:
123 | loss = loss + weights[k] * losses[k]
124 | return loss
125 |
126 |
127 | def get_loss(loss_type, pred, y, reduction='none'):
128 | if loss_type == 'regression':
129 | # criterion = PearsonCorrLoss()
130 | # losses = F.huber_loss(pred, y, delta=2, reduction=reduction)
131 | losses = F.mse_loss(pred, y, reduction=reduction)
132 | # print("MSE Loss:", F.mse_loss(pred, y, reduction=reduction))
133 | # print("Pearson Loss:", criterion(pred, y))
134 | elif loss_type == 'binary':
135 | losses = F.binary_cross_entropy_with_logits(pred, y, reduction=reduction)
136 | else:
137 | raise NotImplementedError("Loss Not Implemented!")
138 | return losses
139 |
140 |
141 | def cal_weighted_loss(pred_dict, y, mask, loss_types, loss_weights):
142 | loss_list = []
143 | y_pred = pred_dict['y_pred']
144 | # print("Y_pred:", y_pred.shape, y.shape, mask.shape)
145 | if len(y_pred.shape) > 1:
146 | assert y_pred.shape[1] == len(loss_types)
147 | else:
148 | y_pred = y_pred.unsqueeze(1)
149 | for i, l_type in enumerate(loss_types):
150 | y_pred_i = y_pred[:, i]
151 | y_i = y[:, i]
152 | mask_i = mask[:, i]
153 | l_i = get_loss(l_type, y_pred_i, y_i) * float(loss_weights[i])
154 | if 'y_pred_inv' in pred_dict and l_type == 'regression' and i == 0:
155 | # Only ddG task has the inversion property
156 | y_pred_inv = pred_dict['y_pred_inv']
157 | if len(y_pred_inv.shape) == 1:
158 | y_pred_inv = y_pred_inv.unsqueeze(1)
159 | y_pred_inv_i = y_pred_inv[:, i]
160 | l_i_inv = get_loss(l_type, y_pred_inv_i, -y_i) * float(loss_weights[i])
161 | l_i = (l_i + l_i_inv) / 2
162 | loss_list.append(l_i)
163 | losses = torch.stack(loss_list, dim=-1)
164 | loss = (losses * mask).sum() / (mask.sum().clip(min=1))
165 | return loss
166 |
167 |
168 | def cal_pearson(pred, gt):
169 | # print("Pearson Cal:", np.unique(pred.shape), np.unique(gt.shape))
170 | if np.isnan(stats.pearsonr(pred, gt).statistic):
171 | print("Pearson Cal:", pred, gt)
172 | return stats.pearsonr(pred, gt).statistic
173 |
174 | def cal_spearman(pred, gt):
175 | if np.isnan(stats.spearmanr(pred, gt).statistic):
176 | print("SPearman Cal:", pred, gt)
177 | return stats.spearmanr(pred, gt).statistic
178 |
179 | def cal_rmse(pred, gt):
180 | return math.sqrt(mean_squared_error(pred, gt))
181 |
182 | def cal_mae(pred, gt):
183 | return np.abs(pred-gt).sum() / len(pred)
184 |
185 | def cal_accuracy(pred, gt, thres=0.5):
186 | logits = 1 / (1+np.exp(-pred))
187 | binary = np.where(logits>=thres, 1, 0)
188 | acc = (binary == gt).sum() / np.size(binary)
189 | return acc
190 |
191 | def cal_auc(pred, gt):
192 | logits = 1 / (1+np.exp(-pred))
193 | if len(np.unique(gt)) == 1:
194 | return 0
195 | return roc_auc_score(gt.astype(np.int32), logits)
196 |
197 | def cal_precision(pred, gt, thres=0.5):
198 | logits = 1 / (1+np.exp(-pred))
199 | binary = np.where(logits>=thres, 1, 0)
200 | true_positives = (binary == 1) & (gt == 1)
201 | positive_preds = binary == 1
202 | return np.sum(true_positives) / np.sum(positive_preds) if np.sum(positive_preds) > 0 else 0
203 |
204 | def cal_recall(pred, gt, thres=0.5):
205 | logits = 1 / (1+np.exp(-pred))
206 | binary = np.where(logits>=thres, 1, 0)
207 | true_positives = (binary == 1) & (gt == 1)
208 | actual_positives = (gt == 1)
209 | return np.sum(true_positives) / np.sum(actual_positives) if np.sum(actual_positives) > 0 else 0
210 |
211 |
212 |
213 |
214 |
215 | if __name__ == '__main__':
216 | parser = argparse.ArgumentParser()
217 | parser.add_argument('--pred_dir', default=None, type=str)
218 | args = parser.parse_args()
219 | df = pd.read_csv(args.pred_dir)
220 | pred = df['pred_0']
221 | gt = df['gt_0']
222 | pearson = cal_pearson(pred, gt)[0]
223 | spearman = cal_spearman(pred, gt)[0]
224 | rmse = cal_rmse(pred, gt)
225 | mae = cal_mae(pred, gt)
226 | print("Pearson Score is {}".format(pearson))
227 | print("Spearman Score is {}".format(spearman))
228 | print("RMSE Score is {}".format(rmse))
229 | print("MAE Score is {}".format(mae))
230 |
--------------------------------------------------------------------------------
/models/encoders/attn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | from utils.geometry import global_to_local, local_to_global, normalize_vector, construct_3d_basis, angstrom_to_nm
7 | from models.encoders.layers import mask_zero, LayerNorm
8 |
9 |
10 | def _alpha_from_logits(logits, mask, inf=1e5):
11 | """
12 | Args:
13 | logits: Logit matrices, (N, L_i, L_j, num_heads).
14 | mask: Masks, (N, L).
15 | Returns:
16 | alpha: Attention weights.
17 | """
18 | N, L, _, _ = logits.size()
19 | mask_row = mask.view(N, L, 1, 1).expand_as(logits) # (N, L, *, *)
20 | mask_pair = mask_row * mask_row.permute(0, 2, 1, 3) # (N, L, L, *)
21 |
22 | logits = torch.where(mask_pair, logits, logits - inf)
23 | alpha = torch.softmax(logits, dim=2) # (N, L, L, num_heads)
24 | alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
25 | return alpha
26 |
27 |
28 | def _heads(x, n_heads, n_ch):
29 | """
30 | Args:
31 | x: (..., num_heads * num_channels)
32 | Returns:
33 | (..., num_heads, num_channels)
34 | """
35 | s = list(x.size())[:-1] + [n_heads, n_ch]
36 | return x.view(*s)
37 |
38 |
39 | class GABlock(nn.Module):
40 |
41 | def __init__(self, node_feat_dim, pair_feat_dim, value_dim=32, query_key_dim=32, num_query_points=8,
42 | num_value_points=8, num_heads=12, bias=False):
43 | super().__init__()
44 | self.node_feat_dim = node_feat_dim
45 | self.pair_feat_dim = pair_feat_dim
46 | self.value_dim = value_dim
47 | self.query_key_dim = query_key_dim
48 | self.num_query_points = num_query_points
49 | self.num_value_points = num_value_points
50 | self.num_heads = num_heads
51 |
52 | # Node
53 | self.proj_query = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias)
54 | self.proj_key = nn.Linear(node_feat_dim, query_key_dim * num_heads, bias=bias)
55 | self.proj_value = nn.Linear(node_feat_dim, value_dim * num_heads, bias=bias)
56 |
57 | # Pair
58 | self.proj_pair_bias = nn.Linear(pair_feat_dim, num_heads, bias=bias)
59 |
60 | # Spatial
61 | self.spatial_coef = nn.Parameter(torch.full([1, 1, 1, self.num_heads], fill_value=np.log(np.exp(1.) - 1.)),
62 | requires_grad=True)
63 | self.proj_query_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias)
64 | self.proj_key_point = nn.Linear(node_feat_dim, num_query_points * num_heads * 3, bias=bias)
65 | self.proj_value_point = nn.Linear(node_feat_dim, num_value_points * num_heads * 3, bias=bias)
66 |
67 | # Output
68 | self.out_transform = nn.Linear(
69 | in_features=(num_heads * pair_feat_dim) + (num_heads * value_dim) + (
70 | num_heads * num_value_points * (3 + 3 + 1)),
71 | out_features=node_feat_dim,
72 | )
73 |
74 | self.layer_norm_1 = LayerNorm(node_feat_dim)
75 | self.mlp_transition = nn.Sequential(nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(),
76 | nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(),
77 | nn.Linear(node_feat_dim, node_feat_dim))
78 | self.layer_norm_2 = LayerNorm(node_feat_dim)
79 |
80 | def _node_logits(self, x):
81 | query_l = _heads(self.proj_query(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch)
82 | key_l = _heads(self.proj_key(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, qk_ch)
83 | logits_node = (query_l.unsqueeze(2) * key_l.unsqueeze(1) *
84 | (1 / np.sqrt(self.query_key_dim))).sum(-1) # (N, L, L, num_heads)
85 | return logits_node
86 |
87 | def _pair_logits(self, z):
88 | logits_pair = self.proj_pair_bias(z)
89 | return logits_pair
90 |
91 | def _spatial_logits(self, R, t, x):
92 | N, L, _ = t.size()
93 | # Query
94 | query_points = _heads(self.proj_query_point(x), self.num_heads * self.num_query_points,
95 | 3) # (N, L, n_heads * n_pnts, 3)
96 | query_points = local_to_global(R, t, query_points) # Global query coordinates, (N, L, n_heads * n_pnts, 3)
97 | query_s = query_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3)
98 | # Key
99 | key_points = _heads(self.proj_key_point(x), self.num_heads * self.num_query_points,
100 | 3) # (N, L, 3, n_heads * n_pnts)
101 | key_points = local_to_global(R, t, key_points) # Global key coordinates, (N, L, n_heads * n_pnts, 3)
102 | key_s = key_points.reshape(N, L, self.num_heads, -1) # (N, L, n_heads, n_pnts*3)
103 | # Q-K Product
104 | sum_sq_dist = ((query_s.unsqueeze(2) - key_s.unsqueeze(1)) ** 2).sum(-1) # (N, L, L, n_heads)
105 | gamma = F.softplus(self.spatial_coef)
106 | logits_spatial = sum_sq_dist * ((-1 * gamma * np.sqrt(2 / (9 * self.num_query_points)))
107 | / 2) # (N, L, L, n_heads)
108 | return logits_spatial
109 |
110 | def _pair_aggregation(self, alpha, z):
111 | N, L = z.shape[:2]
112 | feat_p2n = alpha.unsqueeze(-1) * z.unsqueeze(-2) # (N, L, L, n_heads, C)
113 | feat_p2n = feat_p2n.sum(dim=2) # (N, L, n_heads, C)
114 | return feat_p2n.reshape(N, L, -1)
115 |
116 | def _node_aggregation(self, alpha, x):
117 | N, L = x.shape[:2]
118 |
119 | value_l = _heads(self.proj_value(x), self.num_heads, self.query_key_dim) # (N, L, n_heads, v_ch)
120 | feat_node = alpha.unsqueeze(-1) * value_l.unsqueeze(1) # (N, L, L, n_heads, *) @ (N, *, L, n_heads, v_ch)
121 | feat_node = feat_node.sum(dim=2) # (N, L, n_heads, v_ch)
122 | return feat_node.reshape(N, L, -1)
123 |
124 | def _spatial_aggregation(self, alpha, R, t, x):
125 | N, L, _ = t.size()
126 | value_points = _heads(self.proj_value_point(x), self.num_heads * self.num_value_points,
127 | 3) # (N, L, n_heads * n_v_pnts, 3)
128 | value_points = local_to_global(R, t, value_points.reshape(N, L, self.num_heads, self.num_value_points,
129 | 3)) # (N, L, n_heads, n_v_pnts, 3)
130 | aggr_points = alpha.reshape(N, L, L, self.num_heads, 1, 1) * \
131 | value_points.unsqueeze(1) # (N, *, L, n_heads, n_pnts, 3)
132 | aggr_points = aggr_points.sum(dim=2) # (N, L, n_heads, n_pnts, 3)
133 |
134 | feat_points = global_to_local(R, t, aggr_points) # (N, L, n_heads, n_pnts, 3)
135 | feat_distance = feat_points.norm(dim=-1) # (N, L, n_heads, n_pnts)
136 | feat_direction = normalize_vector(feat_points, dim=-1, eps=1e-4) # (N, L, n_heads, n_pnts, 3)
137 |
138 | feat_spatial = torch.cat([
139 | feat_points.reshape(N, L, -1),
140 | feat_distance.reshape(N, L, -1),
141 | feat_direction.reshape(N, L, -1),
142 | ], dim=-1)
143 |
144 | return feat_spatial
145 |
146 | def forward(self, R, t, x, z, mask):
147 | """
148 | Args:
149 | R: Frame basis matrices, (N, L, 3, 3_index).
150 | t: Frame external (absolute) coordinates, (N, L, 3).
151 | x: Node-wise features, (N, L, F).
152 | z: Pair-wise features, (N, L, L, C).
153 | mask: Masks, (N, L).
154 | Returns:
155 | x': Updated node-wise features, (N, L, F).
156 | """
157 | # Attention logits
158 | logits_node = self._node_logits(x)
159 | logits_pair = self._pair_logits(z)
160 | logits_spatial = self._spatial_logits(R, t, x)
161 | # Summing logits up and apply `softmax`.
162 | logits_sum = logits_node + logits_pair + logits_spatial
163 | alpha = _alpha_from_logits(logits_sum * np.sqrt(1 / 3), mask) # (N, L, L, n_heads)
164 |
165 | # Aggregate features
166 | feat_p2n = self._pair_aggregation(alpha, z)
167 | feat_node = self._node_aggregation(alpha, x)
168 | feat_spatial = self._spatial_aggregation(alpha, R, t, x)
169 |
170 | # Finally
171 | feat_all = self.out_transform(torch.cat([feat_p2n, feat_node, feat_spatial], dim=-1)) # (N, L, F)
172 | feat_all = mask_zero(mask.unsqueeze(-1), feat_all)
173 | x_updated = self.layer_norm_1(x + feat_all)
174 | x_updated = self.layer_norm_2(x_updated + self.mlp_transition(x_updated))
175 | return x_updated
176 |
177 |
178 | # Graph Attention Encoder
179 | class GAEncoder(nn.Module):
180 |
181 | def __init__(self, node_feat_dim, pair_feat_dim, num_layers, ga_block_opt={}):
182 | super(GAEncoder, self).__init__()
183 | self.blocks = nn.ModuleList([
184 | GABlock(node_feat_dim, pair_feat_dim, **ga_block_opt)
185 | for _ in range(num_layers)
186 | ])
187 |
188 | def forward(self, pos_atoms, res_feat, pair_feat, mask):
189 | R = construct_3d_basis(
190 | pos_atoms[:, :, 1],
191 | pos_atoms[:, :, 2],
192 | pos_atoms[:, :, 0]
193 | )
194 | t = pos_atoms[:, :, 1]
195 | t = angstrom_to_nm(t)
196 | for block in self.blocks:
197 | res_feat = block(R, t, res_feat, pair_feat, mask)
198 | return res_feat
199 |
--------------------------------------------------------------------------------
/data/sequence_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from data.register import DataRegister
3 | import esm
4 | from rinalmo.data.constants import *
5 | from rinalmo.data.alphabet import Alphabet
6 | import torch
7 | from tqdm import tqdm
8 | R = DataRegister()
9 |
10 | na_alphabet_config = {
11 | "standard_tkns": RNA_TOKENS,
12 | "special_tkns": [CLS_TKN, PAD_TKN, EOS_TKN, UNK_TKN, MASK_TKN],
13 | }
14 |
15 |
16 | CLS_TOKEN_IDX = 0
17 | PAD_TOKEN_IDX = 1
18 | EOS_TOKEN_IDX = 2
19 |
20 | @R.register('sequence_dataset')
21 | class SequenceDataset(Dataset):
22 | def __init__(self, dataframe,
23 | col_prot='protein', col_na='na',col_label='dG', col_prot_name = 'PDB',
24 | diskcache=None,
25 | **kwargs):
26 | super(SequenceDataset, self).__init__()
27 | self.df = dataframe
28 | self.col_protein = col_prot
29 | self.col_prot_name = col_prot_name
30 | self.col_na = col_na
31 | self.col_label = col_label
32 | self.type = 'reg'
33 | self.diskcache = diskcache
34 | self.prot_alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
35 | self.na_alphabet = Alphabet(**na_alphabet_config)
36 | self.load_data()
37 |
38 | def load_data(self):
39 | self.data = []
40 |
41 | for i, row in tqdm(self.df.iterrows(), total=self.df.shape[0]):
42 | structure_id = row[self.col_prot_name]
43 | if self.diskcache is None or structure_id not in self.diskcache:
44 | max_prot_length = 0
45 | max_na_length = 0
46 | prot_seqs_info = row[self.col_protein].split(',')
47 | na_seqs_info = row[self.col_na].split(',')
48 | prot_seqs = []
49 | na_seqs = []
50 | for prot_seq in prot_seqs_info:
51 | if ':' in prot_seq:
52 | prot_seq = prot_seq.split(':')[1]
53 | if len(prot_seq) > max_prot_length:
54 | max_prot_length = len(prot_seq)
55 | prot_seqs.append(prot_seq)
56 | for na_seq in na_seqs_info:
57 | if ':' in na_seq:
58 | na_seq = na_seq.split(':')[1]
59 | if len(na_seq) > max_na_length:
60 | max_na_length = len(na_seq)
61 | na_seqs.append(na_seq)
62 |
63 | label = row[self.col_label]
64 | item = {
65 | 'id': structure_id,
66 | 'prot_seqs': prot_seqs,
67 | 'na_seqs': na_seqs,
68 | 'label': label,
69 | 'max_prot_length': max_prot_length,
70 | 'max_na_length': max_na_length
71 | }
72 |
73 | self.data.append(item)
74 | if self.diskcache is not None:
75 | self.diskcache[structure_id] = item
76 | else:
77 | self.data.append(self.diskcache[structure_id])
78 |
79 | def __getitem__(self, idx):
80 | return self.data[idx]
81 |
82 | def __len__(self):
83 | return len(self.df)
84 |
85 | class CustomSeqCollate(object):
86 | def __init__(self, strategy='separate'):
87 | super(CustomSeqCollate, self).__init__()
88 | self.strategy=strategy
89 | def __call__(self, batch):
90 | size = len(batch)
91 | labels = torch.tensor([item['labels'] for item in batch], dtype=torch.float32)
92 | prot_alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
93 | na_alphabet = Alphabet(**na_alphabet_config)
94 | # print("Batch:", batch)
95 | # for item in batch:
96 | # for seq in item['na_seqs']:
97 | # print("RNA seqs:", seq, len(seq))
98 | prot_chains = [len(item['prot_seqs']) for item in batch]
99 | na_chains = [len(item['na_seqs']) for item in batch]
100 | if self.strategy == 'separate':
101 | max_item_prot_length = [item['max_prot_length'] for item in batch]
102 | max_item_na_length = [item['max_na_length'] for item in batch]
103 | max_prot_length = max(max_item_prot_length)
104 | max_na_length = max(max_item_na_length)
105 | total_prot_chains = sum(prot_chains)
106 | total_na_chains = sum(na_chains)
107 | prot_batch = torch.empty([total_prot_chains, max_prot_length+2])
108 | prot_batch.fill_(prot_alphabet.padding_idx)
109 | na_batch = torch.empty([total_na_chains, max_na_length+2])
110 | na_batch.fill_(na_alphabet.pad_idx)
111 | curr_prot_idx = 0
112 | curr_na_idx = 0
113 | for item in batch:
114 | prot_seqs = item['prot_seqs']
115 | na_seqs = item['na_seqs']
116 | # print(item['id'])
117 | for prot_seq in prot_seqs:
118 | prot_batch[curr_prot_idx, 0] = prot_alphabet.cls_idx
119 | prot_seq_encode = prot_alphabet.encode(prot_seq)
120 | seq = torch.tensor(prot_seq_encode, dtype=torch.int64)
121 | prot_batch[curr_prot_idx, 1: len(prot_seq_encode)+1] = seq
122 | prot_batch[curr_prot_idx, len(prot_seq_encode)+1] = prot_alphabet.eos_idx
123 | curr_prot_idx += 1
124 | for na_seq in na_seqs:
125 | # na_batch[curr_na_idx, 0] = na_alphabet.cls_idx
126 | # NA encoder adds CLS and EOS by default
127 | na_seq_encode = na_alphabet.encode(na_seq)
128 | seq = torch.tensor(na_seq_encode, dtype=torch.int64)
129 | na_batch[curr_na_idx, :len(seq)] = seq
130 | # na_batch[curr_na_idx, len(na_seq_encode)+1] = na_alphabet.eos_idx
131 | curr_na_idx += 1
132 |
133 | elif self.strategy == 'combine':
134 | # prot_linker = 'G' * 25
135 | # na_linker = 'T' * 15
136 | prot_linker = ''
137 | na_linker = ''
138 | complex_prot_max_length = 0
139 | complex_na_max_length = 0
140 | for i, item in enumerate(batch):
141 | prot_seqs = item['prot_seqs']
142 | na_seqs = item['na_seqs']
143 | prot_complex_seq = prot_linker.join(prot_seqs)
144 | na_complex_seq = na_linker.join(na_seqs)
145 | if len(prot_complex_seq) > complex_prot_max_length:
146 | complex_prot_max_length = len(prot_complex_seq)
147 | if len(na_complex_seq) > complex_na_max_length:
148 | complex_na_max_length = len(na_complex_seq)
149 |
150 | prot_batch = torch.empty([len(batch), complex_prot_max_length+2])
151 | prot_batch.fill_(prot_alphabet.padding_idx)
152 | na_batch = torch.empty([len(batch), complex_na_max_length+5])
153 | na_batch.fill_(na_alphabet.pad_idx)
154 |
155 | for i, item in enumerate(batch):
156 | prot_batch[i, 0] = prot_alphabet.cls_idx
157 | prot_complex_encode = prot_alphabet.encode(prot_complex_seq)
158 | seq = torch.tensor(prot_complex_encode, dtype=torch.int64)
159 | prot_batch[i, 1: len(prot_complex_encode)+1] = seq
160 | prot_batch[i, len(prot_complex_encode)+1] = prot_alphabet.eos_idx
161 | prot_linker_start = 1
162 | if len(item['prot_seqs']) > 1 and len(prot_linker) > 0:
163 | # print("Combining Protein...", len(item['prot_seqs']))
164 | for j, p_seq in enumerate(item['prot_seqs']):
165 | if j == len(item['prot_seqs']) - 1:
166 | break
167 | seq_len = len(p_seq)
168 | prot_linker_start += seq_len
169 | linker_len = len(prot_linker)
170 | prot_batch[i, prot_linker_start: prot_linker_start + linker_len] = prot_alphabet.padding_idx
171 | # print("Done!", i)
172 | prot_linker_start += linker_len
173 | # na_batch[i, 0] = na_alphabet.cls_idx
174 | na_complex_encode = na_alphabet.encode(na_complex_seq)
175 | seq = torch.tensor(na_complex_encode, dtype=torch.int64)
176 | na_batch[i, :len(seq)] = seq
177 | na_linker_start = 1
178 | if len(item['na_seqs']) > 1 and len(na_linker) > 0:
179 | # print("Combining NA...", len(item['na_seqs']))
180 | for j, n_seq in enumerate(item['na_seqs']):
181 | if j == len(item['na_seqs']) - 1:
182 | break
183 | seq_len = len(n_seq)
184 | na_linker_start += seq_len
185 | linker_len = len(na_linker)
186 | na_batch[i, na_linker_start: na_linker_start + linker_len] = na_alphabet.pad_idx
187 | na_linker_start += linker_len
188 | # na_batch[i, len(na_complex_encode)+1] = na_alphabet.eos_idx
189 | else:
190 | raise ValueError
191 | prot_mask = torch.zeros_like(prot_batch)
192 | na_mask = torch.zeros_like(na_batch)
193 | prot_mask[(prot_batch!=prot_alphabet.padding_idx) & (prot_batch!=prot_alphabet.eos_idx) & (prot_batch!=prot_alphabet.cls_idx)] = 1
194 | na_mask[(na_batch!=na_alphabet.pad_idx) & (na_batch!=na_alphabet.eos_idx) & (na_batch!=na_alphabet.cls_idx)] = 1
195 |
196 | data = {
197 | 'size': size,
198 | 'labels': labels,
199 | 'prot': prot_batch.long(),
200 | 'prot_chains': prot_chains,
201 | 'protein_mask': prot_mask,
202 | 'na': na_batch.long(),
203 | 'na_chains': na_chains,
204 | 'na_mask': na_mask,
205 | 'strategy': self.strategy
206 | }
207 | # print(data)
208 | return data
--------------------------------------------------------------------------------
/models/components/valina_attn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | import math
5 |
6 | from rinalmo.model.rope import RotaryPositionEmbedding
7 |
8 | from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func
9 | from flash_attn.layers.rotary import RotaryEmbedding
10 |
11 | from flash_attn.bert_padding import unpad_input, pad_input
12 |
13 | from einops import rearrange
14 |
15 | def dot_product_attention(q, k, v, attn_mask=None, key_pad_mask=None, dropout=None):
16 | c = q.shape[-1]
17 | attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(c)
18 |
19 | if attn_mask is not None:
20 | attn = attn.masked_fill(attn_mask, float("-inf"))
21 |
22 | if key_pad_mask is not None:
23 | attn = attn.masked_fill(key_pad_mask.unsqueeze(1).unsqueeze(2), float("-inf"))
24 |
25 | attn = attn.softmax(dim=-1)
26 | if dropout is not None:
27 | attn = dropout(attn)
28 |
29 | output = torch.matmul(attn, v)
30 | if torch.isnan(output).any():
31 | print("Found nan in attn!")
32 | return output, attn
33 |
34 | class MultiHeadAttention(nn.Module):
35 | def __init__(self, c_in, num_heads, attention_dropout=0.0, use_rot_emb=True, bias=False):
36 | super().__init__()
37 | assert c_in % num_heads == 0, "Embedding dimensionality must be divisible with number of attention heads!"
38 |
39 | self.c_in = c_in
40 | self.num_heads = num_heads
41 |
42 | self.c_head = c_in // self.num_heads
43 | self.c_qkv = self.c_head * num_heads
44 |
45 | self.use_rot_emb = use_rot_emb
46 | if self.use_rot_emb:
47 | self.rotary_emb = RotaryPositionEmbedding(self.c_head)
48 |
49 | self.to_q = nn.Linear(self.c_in, self.c_qkv, bias=bias)
50 | self.to_k = nn.Linear(self.c_in, self.c_qkv, bias=bias)
51 | self.to_v = nn.Linear(self.c_in, self.c_qkv, bias=bias)
52 |
53 | self.attention_dropout = nn.Dropout(p=attention_dropout)
54 |
55 | self.out_proj = nn.Linear(c_in, c_in, bias=bias)
56 |
57 | def forward(self, q, k, v, attn_mask=None, key_pad_mask=None):
58 | bs = q.shape[0]
59 |
60 | q = self.to_q(q).view(bs, -1, self.num_heads, self.c_head).transpose(-2, -3)
61 | k = self.to_k(k).view(bs, -1, self.num_heads, self.c_head).transpose(-2, -3)
62 | v = self.to_v(v).view(bs, -1, self.num_heads, self.c_head).transpose(-2, -3)
63 |
64 | if self.use_rot_emb:
65 | q, k = self.rotary_emb(q, k)
66 |
67 | output, attn = dot_product_attention(q, k, v, attn_mask, key_pad_mask, self.attention_dropout)
68 |
69 | output = output.transpose(-2, -3).contiguous().view(bs, -1, self.num_heads * self.c_head)
70 | output = self.out_proj(output)
71 | if torch.isnan(output).any():
72 | print("Found nan in attn!")
73 | return output, attn
74 |
75 | class MultiHeadSelfAttention(nn.Module):
76 | def __init__(self, c_in, num_heads, attention_dropout=0.0, use_rot_emb=True, bias=False):
77 | super().__init__()
78 |
79 | self.mh_attn = MultiHeadAttention(c_in, num_heads, attention_dropout, use_rot_emb, bias)
80 |
81 | def forward(self, x, attn_mask=None, key_pad_mask=None):
82 | return self.mh_attn(x, x, x, attn_mask, key_pad_mask)
83 |
84 | class FlashAttention(nn.Module):
85 | """
86 | Implement the scaled dot product attention with softmax.
87 | """
88 | def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
89 | """
90 | Args:
91 | causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
92 | softmax_scale: float. The scaling of QK^T before applying softmax.
93 | Default to 1 / sqrt(headdim).
94 | attention_dropout: float. The dropout rate to apply to the attention
95 | (default: 0.0)
96 | """
97 | super().__init__()
98 | self.softmax_scale = softmax_scale
99 | self.attention_dropout = attention_dropout
100 | self.causal = causal
101 |
102 | def forward(self, qkv, cu_seqlens=None, max_seqlen=None, return_attn_probs=False):
103 | """
104 | Arguments
105 | ---------
106 | qkv: The tensor containing the query, key, and value.
107 | If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
108 | If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
109 | (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
110 | cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
111 | of the sequences in the batch, used to index into qkv.
112 | max_seqlen: int. Maximum sequence length in the batch.
113 | return_attn_probs: bool. Whether to return the attention probabilities. This option is for
114 | testing only. The returned probabilities are not guaranteed to be correct
115 | (they might not have the right scaling).
116 | Returns:
117 | --------
118 | out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
119 | else (B, S, H, D).
120 | """
121 | assert qkv.dtype in [torch.float16, torch.bfloat16]
122 | assert qkv.is_cuda
123 |
124 | unpadded = cu_seqlens is not None
125 |
126 | if unpadded:
127 | assert cu_seqlens.dtype == torch.int32
128 | assert max_seqlen is not None
129 | assert isinstance(max_seqlen, int)
130 | return flash_attn_varlen_qkvpacked_func(
131 | qkv,
132 | cu_seqlens,
133 | max_seqlen,
134 | self.attention_dropout if self.training else 0.0,
135 | softmax_scale=self.softmax_scale,
136 | causal=self.causal,
137 | return_attn_probs=return_attn_probs
138 | )
139 | else:
140 | return flash_attn_qkvpacked_func(
141 | qkv,
142 | self.attention_dropout if self.training else 0.0,
143 | softmax_scale=self.softmax_scale,
144 | causal=self.causal,
145 | return_attn_probs=return_attn_probs
146 | )
147 |
148 | class FlashMultiHeadSelfAttention(nn.Module):
149 | """
150 | Multi-head self-attention implemented using FlashAttention.
151 | """
152 | def __init__(self, embed_dim, num_heads, attention_dropout=0.0, causal=False, use_rot_emb=True, bias=False):
153 | super().__init__()
154 | assert embed_dim % num_heads == 0, "Embedding dimensionality must be divisible with number of attention heads!"
155 |
156 | self.causal = causal
157 |
158 | self.embed_dim = embed_dim
159 | self.num_heads = num_heads
160 |
161 | self.head_dim = self.embed_dim // self.num_heads
162 | self.qkv_dim = self.head_dim * num_heads * 3
163 |
164 | self.rotary_emb_dim = self.head_dim
165 | self.use_rot_emb = use_rot_emb
166 | if self.use_rot_emb:
167 | self.rotary_emb = RotaryEmbedding(
168 | dim=self.rotary_emb_dim,
169 | base=10000.0,
170 | interleaved=False,
171 | scale_base=None,
172 | pos_idx_in_fp32=True, # fp32 RoPE precision
173 | device=None
174 | )
175 | self.flash_self_attn = FlashAttention(causal=self.causal, softmax_scale=None, attention_dropout=attention_dropout)
176 |
177 | self.Wqkv = nn.Linear(self.embed_dim, self.qkv_dim, bias=bias)
178 |
179 | self.attention_dropout = nn.Dropout(p=attention_dropout)
180 |
181 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
182 |
183 | def forward(self, x, key_padding_mask=None, return_attn_probs=False):
184 | """
185 | Arguments:
186 | x: (batch, seqlen, hidden_dim) (where hidden_dim = num_heads * head_dim)
187 | key_pad_mask: boolean mask, True means to keep, False means to mask out.
188 | (batch, seqlen)
189 | return_attn_probs: whether to return attention masks (False by default)
190 | """
191 |
192 | qkv = self.Wqkv(x)
193 | qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
194 |
195 | if self.use_rot_emb:
196 | qkv = self.rotary_emb(qkv, seqlen_offset=0)
197 |
198 | if return_attn_probs:
199 | bs = qkv.shape[0]
200 | qkv = torch.permute(qkv, (0, 3, 2, 1, 4))
201 | q = qkv[:, :, 0, :, :]
202 | k = qkv[:, :, 1, :, :]
203 | v = qkv[:, :, 2, :, :]
204 | out, attn = dot_product_attention(q, k, v, key_pad_mask=torch.logical_not(key_padding_mask) if key_padding_mask is not None else None, dropout=self.attention_dropout)
205 | output = out.transpose(-2, -3).contiguous().view(bs, -1, self.num_heads * self.head_dim)
206 | output = self.out_proj(output)
207 | return output, attn
208 |
209 | if key_padding_mask is not None:
210 | batch_size = qkv.shape[0]
211 | seqlen = qkv.shape[1]
212 | x_unpad, indices, cu_seqlens, max_s = unpad_input(qkv, key_padding_mask)
213 | output_unpad = self.flash_self_attn(
214 | x_unpad,
215 | cu_seqlens=cu_seqlens,
216 | max_seqlen=max_s,
217 | return_attn_probs=return_attn_probs
218 | )
219 | out = pad_input(rearrange(output_unpad, '... h d -> ... (h d)'), indices, batch_size, seqlen)
220 | else:
221 | output = self.flash_self_attn(
222 | qkv,
223 | cu_seqlens=None,
224 | max_seqlen=None,
225 | return_attn_probs=return_attn_probs
226 | )
227 | out = rearrange(output, '... h d -> ... (h d)')
228 |
229 | out = self.out_proj(out)
230 | return out, None
231 |
--------------------------------------------------------------------------------
/models/ipa.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from models.encoders.single import PerResidueEncoder
5 | from models.encoders.pair import ResiduePairEncoder
6 | from models.encoders.attn import GAEncoder
7 | from models.register import ModelRegister
8 | from models.model import load_esm, load_rinalmo, segment_cat_pad, cat_pad
9 | from models.lora_tune import LoRAESM, LoRARiNALMo, ESMConfig, RiNALMoConfig
10 | from peft import (
11 | LoraConfig,
12 | get_peft_model,
13 | )
14 | R = ModelRegister()
15 |
16 | @R.register('ipa')
17 | class InvariantPointAttention(nn.Module):
18 | def __init__(self,
19 | node_feat_dim=640,
20 | rinalmo_weights='./weights/rinalmo_giga_pretrained.pt',
21 | esm_type='650M',
22 | use_lm=False,
23 | fix_lms=True,
24 | pair_feat_dim=64,
25 | num_layers=3,
26 | pooling='mean',
27 | output_dim=1,
28 | representation_layer=33,
29 | lora_tune=False,
30 | lora_rank=16,
31 | lora_alpha=32,
32 | **kwargs):
33 |
34 | super().__init__()
35 | self.use_lm = use_lm
36 | self.fix_lms = fix_lms
37 | self.proj = 0
38 | self.representation_layer = representation_layer
39 | if self.use_lm:
40 | self.esm, esm_feat_size = load_esm(esm_type)
41 | self.rinalmo, rinalmo_feat_size = load_rinalmo(rinalmo_weights)
42 | if esm_feat_size != rinalmo_feat_size:
43 | self.proj = 1
44 | self.project_feat= nn.Linear(esm_feat_size, rinalmo_feat_size)
45 | self.feat_size = rinalmo_feat_size
46 | self.proj_cplx= nn.Linear(self.feat_size, node_feat_dim)
47 | if lora_tune:
48 | print("Getting Lora!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
49 | # copied from LongLoRA
50 | rinalmo_lora_config = LoraConfig(
51 | r=lora_rank,
52 | bias="none",
53 | lora_alpha=lora_alpha
54 | )
55 | esm_lora_config = LoraConfig(
56 | r=lora_rank,
57 | bias="none",
58 | lora_alpha=lora_alpha
59 | )
60 | rinalmo_config = RiNALMoConfig()
61 | esm_config = ESMConfig()
62 | self.rinalmo = LoRARiNALMo(self.rinalmo, rinalmo_config)
63 | self.esm = LoRAESM(self.esm, esm_config)
64 | self.rinalmo = get_peft_model(self.rinalmo, rinalmo_lora_config)
65 | print("Get RINALMO DONE!!!!!")
66 | self.esm = get_peft_model(self.esm, esm_lora_config)
67 | print("Get ESM DONE!!!!!")
68 | elif fix_lms:
69 | for p in self.rinalmo.parameters():
70 | p.requires_grad_(False)
71 | for p in self.esm.parameters():
72 | p.requires_grad_(False)
73 | # Encoding
74 | else:
75 | self.single_encoder = PerResidueEncoder(
76 | feat_dim=node_feat_dim,
77 | max_num_atoms=4, # N, CA, C, O, CB,
78 | )
79 | self.masked_bias = nn.Embedding(
80 | num_embeddings=2,
81 | embedding_dim=node_feat_dim,
82 | padding_idx=0,
83 | )
84 | self.pair_encoder = ResiduePairEncoder(
85 | feat_dim=pair_feat_dim,
86 | max_num_atoms=4, # N, CA, C, O, CB,
87 | )
88 | self.pooling = pooling
89 | self.attn_encoder = GAEncoder(node_feat_dim=node_feat_dim, pair_feat_dim=pair_feat_dim, num_layers=num_layers)
90 | self.pred_head = nn.Sequential(
91 | nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(),
92 | nn.Linear(node_feat_dim, node_feat_dim), nn.ReLU(),
93 | nn.Linear(node_feat_dim, output_dim)
94 | )
95 |
96 | def forward(self, input, strategy='separate'):
97 | mask_residue = input['mask_atoms'][:, :, 1] #CA
98 | if not self.use_lm:
99 | x = self.single_encoder(
100 | aa=input['restype'],
101 | # phi=batch['phi'], phi_mask=batch['phi_mask'],
102 | # psi=batch['psi'], psi_mask=batch['psi_mask'],
103 | # chi=batch['chi'], chi_mask=batch['chi_mask'],
104 | mask_residue=mask_residue,
105 | )
106 | else:
107 | prot_input = input['prot']
108 | prot_chains = input['prot_chains']
109 | prot_mask = input['protein_mask']
110 | na_input = input['na']
111 | na_chains = input['na_chains']
112 | na_mask = input['na_mask']
113 | with torch.cuda.amp.autocast():
114 | prot_embedding = self.esm(prot_input, repr_layers=[self.representation_layer], return_contacts=False)['representations'][self.representation_layer]
115 | na_embedding = self.rinalmo(na_input)['representation']
116 | if self.proj:
117 | prot_embedding = self.proj_cplx(prot_embedding)
118 | prot_embedding = prot_embedding.float()
119 | na_embedding = na_embedding.float()
120 | # print("Original Embedding:", prot_embedding, na_embedding)
121 | max_len = input['pos_atoms'].shape[1]
122 | if 'patch_idx' in input:
123 | patch_idx = input['patch_idx']
124 | else:
125 | patch_idx = None
126 | # Adjust the embeddings from LMs for CFormer
127 | if strategy == 'separate':
128 | # input shape [N', L], where N' is flexible in every batch
129 | out_embedding, masks = segment_cat_pad(prot_embedding, prot_chains, prot_mask, na_embedding, na_chains, na_mask, max_len, patch_idx)
130 | assert out_embedding.shape[0] == input['size']
131 | else:
132 | out_embedding, masks = cat_pad(prot_embedding, prot_mask, na_embedding, na_mask, max_len, patch_idx)
133 | assert out_embedding.shape[0] == input['size']
134 | x = self.proj_cplx(out_embedding)
135 |
136 | aa=input['restype']
137 | res_nb=input['res_nb']
138 | chain_nb=input['chain_nb']
139 | pos_atoms=input['pos_atoms']
140 | mask_atoms=input['mask_atoms']
141 |
142 | if self.pooling == 'token':
143 | mask_special = torch.zeros((len(out_embedding), 1), device=out_embedding.device, dtype=key_padding_mask.dtype)
144 | cplx_embed = self.complex_embedding.repeat(len(out_embedding), 1, 1)
145 | prot_embed = self.prot_embedding.repeat(len(out_embedding), 1, 1)
146 | rna_embed = self.rna_embedding.repeat(len(out_embedding), 1, 1)
147 |
148 | out_embedding = torch.cat([cplx_embed, prot_embed, rna_embed, out_embedding], dim=1)
149 | key_padding_mask = torch.cat([mask_special, mask_special, mask_special, key_padding_mask], dim=1)
150 |
151 | cplx_type = torch.ones_like(mask_special, device=out_embedding.device, dtype=aa.dtype) * SUPER_CPLX_IDX
152 | prot_type = torch.ones_like(mask_special, device=out_embedding.device, dtype=aa.dtype) * SUPER_PROT_IDX
153 | rna_type = torch.ones_like(mask_special, device=out_embedding.device, dtype=aa.dtype) * SUPER_RNA_IDX
154 | aa = torch.cat([cplx_type, prot_type, rna_type, aa], dim=1)
155 |
156 | res_nb_cplx = torch.ones_like(mask_special, device=out_embedding.device, dtype=res_nb.dtype) * 0
157 | res_nb_prot = torch.ones_like(mask_special, device=out_embedding.device, dtype=res_nb.dtype) * 1
158 | res_nb_rna = torch.ones_like(mask_special, device=out_embedding.device, dtype=res_nb.dtype) * 2
159 |
160 | res_nb = torch.cat([res_nb_cplx, res_nb_prot, res_nb_rna, res_nb], dim=1)
161 | super_chain_id = torch.ones_like(mask_special, device=out_embedding.device, dtype=chain_nb.dtype) * SUPER_CHAIN_IDX
162 | chain_nb = torch.cat([super_chain_id, super_chain_id, super_chain_id, chain_nb], dim=1)
163 |
164 | center_cplx = torch.zeros((len(out_embedding), 1, pos_atoms.shape[2], 3), device=out_embedding.device, dtype=pos_atoms.dtype)
165 | center_prot = ((pos_atoms * input['identifier'][:, :, None, None] * mask_atoms.unsqueeze(-1)).reshape([len(out_embedding), -1, 3]).sum(dim=1) / (input['identifier'][:, :, None] * mask_atoms).reshape([len(out_embedding), -1]).sum(dim=-1).unsqueeze(-1))[:, None, None, :].repeat(1, 1, 4, 1)
166 | center_rna = ((pos_atoms * (1-input['identifier'][:, :, None, None]) * mask_atoms.unsqueeze(-1)).reshape([len(out_embedding), -1, 3]).sum(dim=1) / ((1-input['identifier'][:, :, None]) * mask_atoms).reshape([len(out_embedding), -1]).sum(dim=-1).unsqueeze(-1))[:, None, None, :].repeat(1, 1, 4, 1)
167 | pos_atoms = torch.cat([center_cplx, center_prot, center_rna, pos_atoms], dim=1)
168 | mask_atom = torch.zeros((len(out_embedding), 1, pos_atoms.shape[2]), device=out_embedding.device, dtype=mask_atoms.dtype)
169 | mask_atom[:,:,0] = 1
170 | mask_atoms = torch.cat([mask_atom, mask_atom, mask_atom, mask_atoms], dim=1)
171 |
172 |
173 | z = self.pair_encoder(
174 | aa=aa,
175 | res_nb=res_nb,
176 | chain_nb=chain_nb,
177 | pos_atoms=pos_atoms,
178 | mask_atoms=mask_atoms,
179 | )
180 |
181 | x = self.attn_encoder(
182 | pos_atoms=input['pos_atoms'],
183 | res_feat=x, pair_feat=z,
184 | mask=mask_residue
185 | )
186 | if self.pooling == 'token':
187 | complex_embedding = x[:, 0, :]
188 | else:
189 | complex_embedding = (x * input['seq_mask'].unsqueeze(-1)).sum(dim=1)
190 | if self.pooling == 'mean':
191 | # Prot_mask: [N, L]
192 | complex_mask_sum = input['seq_mask'].sum(dim=1, keepdim=True)
193 | complex_embedding = complex_embedding / (complex_mask_sum + 1e-10)
194 |
195 | output = self.pred_head(complex_embedding)
196 | output = output.squeeze(1)
197 |
198 | return output
--------------------------------------------------------------------------------
/models/esm_rinalmo_seq.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import esm
4 | from rinalmo.config import model_config
5 | from rinalmo.model.model import RiNALMo
6 | from models.register import ModelRegister
7 | from peft import (
8 | LoraConfig,
9 | get_peft_model,
10 | )
11 | from models.lora_tune import LoRAESM, LoRARiNALMo, ESMConfig, RiNALMoConfig
12 | from models.components.valina_transformer import Transformer
13 | from models.model import cat_pad, segment_cat_pad
14 | R = ModelRegister()
15 |
16 | def load_esm(esm_type):
17 | if esm_type == '650M':
18 | model, _ = esm.pretrained.esm2_t33_650M_UR50D()
19 | elif esm_type == '3B':
20 | model, _ = esm.pretrained.esm2_t36_3B_UR50D()
21 | elif esm_type == '15B':
22 | model, _ = esm.pretrained.esm2_t48_15B_UR50D()
23 | else:
24 | raise NotImplementedError
25 | feat_size = model.embed_dim
26 | return model, feat_size
27 |
28 | def load_rinalmo(rinalmo_weights):
29 | config = model_config('giga')
30 | model = RiNALMo(config)
31 | # alphabet = Alphabet(**config['alphabet'])
32 | model.load_state_dict(torch.load(rinalmo_weights))
33 | feat_size = config.globals.embed_dim
34 | return model, feat_size
35 |
36 | def segment_pool(input, chains, mask, pooling):
37 | # input shape [N', L, E], mask_shape [N', L]
38 | result = input.new_full([len(chains), input.shape[-1]], 0) # (N, E)
39 | mask_result = mask.new_full([len(chains), 1], 0) #(N, 1)
40 | input_flattened = input.reshape((-1, input.shape[-1])) #(N'*L, E)
41 | mask_flattened = mask.reshape((-1, 1)) #(N'*L, 1)
42 | # print("Shapes:", result.shape, mask_result.shape, input_flattened.shape, mask_flattened.shape)
43 | # segment_id shape (N', )
44 | segment_id = torch.tensor(sum([[i] * chain for i, chain in enumerate(chains)], start=[]), device=result.device, dtype=torch.int64)
45 | segment_id = segment_id.repeat_interleave(input.shape[1]) #(N'*L)
46 | result.scatter_add_(0, segment_id.unsqueeze(1).expand_as(input_flattened), input_flattened*mask_flattened)
47 | mask_result.scatter_add_(0, segment_id.unsqueeze(1), mask_flattened)
48 | mask_result.reshape((-1, ))
49 |
50 | if pooling == 'mean':
51 | result = result / (mask_result + 1e-10)
52 |
53 | return result
54 |
55 | @R.register('esm2_rinalmo_seq')
56 | class ESM2RiNALMo(nn.Module):
57 | def __init__(self,
58 | rinalmo_weights='./weights/rinalmo_giga_pretrained.pt',
59 | esm_type='650M',
60 | pooling='token',
61 | output_dim=1,
62 | fix_lms=True,
63 | lora_tune=False,
64 | lora_rank=16,
65 | lora_alpha=32,
66 | representation_layer=33,
67 | vallina=True,
68 | **kwargs
69 | ):
70 | super(ESM2RiNALMo, self).__init__()
71 | self.esm, esm_feat_size = load_esm(esm_type)
72 | self.rinalmo, rinalmo_feat_size = load_rinalmo(rinalmo_weights)
73 | self.vallina=vallina
74 | # if esm_feat_size != rinalmo_feat_size:
75 | # self.project_layer = nn.Linear(esm_feat_size, rinalmo_feat_size)
76 | self.cat_size = esm_feat_size + rinalmo_feat_size
77 | self.feat_size = rinalmo_feat_size
78 | self.representation_layer = representation_layer
79 | self.transformer = Transformer(**kwargs['transformer'])
80 | self.complex_dim = kwargs['transformer']['embed_dim']
81 | self.proj_cplx= nn.Linear(self.feat_size, self.complex_dim)
82 | self.pooling = pooling
83 | if self.pooling == 'token':
84 | self.prot_embedding = nn.Parameter(torch.zeros((1, self.complex_dim), dtype=torch.float32))
85 | self.rna_embedding = nn.Parameter(torch.zeros((1, self.complex_dim), dtype=torch.float32))
86 | self.complex_embedding = nn.Parameter(torch.zeros((1, self.complex_dim), dtype=torch.float32))
87 | nn.init.normal_(self.prot_embedding)
88 | nn.init.normal_(self.rna_embedding)
89 | nn.init.normal_(self.complex_embedding)
90 | if lora_tune:
91 | # copied from LongLoRA
92 | rinalmo_lora_config = LoraConfig(
93 | r=lora_rank,
94 | bias="none",
95 | lora_alpha=lora_alpha
96 | )
97 | esm_lora_config = LoraConfig(
98 | r=lora_rank,
99 | bias="none",
100 | lora_alpha=lora_alpha
101 | )
102 | rinalmo_config = RiNALMoConfig()
103 | esm_config = ESMConfig()
104 | self.rinalmo = LoRARiNALMo(self.rinalmo, rinalmo_config)
105 | # print(esm_config)
106 | self.esm = LoRAESM(self.esm, esm_config)
107 | # print("ESM:", self.esm)
108 | self.rinalmo = get_peft_model(self.rinalmo, rinalmo_lora_config)
109 | # print("Get RINALMO DONE!!!!!")
110 | self.esm = get_peft_model(self.esm, esm_lora_config)
111 | # print("Get ESM DONE!!!!!")
112 |
113 | elif fix_lms:
114 | for p in self.rinalmo.parameters():
115 | p.requires_grad_(False)
116 | for p in self.esm.parameters():
117 | p.requires_grad_(False)
118 | self.pooling = pooling
119 | self.pred_head = nn.Sequential(
120 | nn.Linear(self.complex_dim, self.feat_size), nn.ReLU(),
121 | nn.Linear(self.feat_size, output_dim)
122 | )
123 | if self.vallina:
124 | print("Using vallina version!")
125 | self.cat_pred_head = nn.Sequential(
126 | nn.Linear(self.cat_size, self.feat_size), nn.ReLU(),
127 | nn.Linear(self.feat_size, output_dim)
128 | )
129 |
130 |
131 | def forward(self, input, strategy='separate'):
132 | prot_input = input['prot']
133 | prot_chains = input['prot_chains']
134 | prot_mask = input['protein_mask']
135 | na_input = input['na']
136 | na_chains = input['na_chains']
137 | na_mask = input['na_mask']
138 | # print("Input Shape:", prot_input.shape, na_input.shape, prot_chains, na_chains)
139 | with torch.cuda.amp.autocast():
140 | prot_embedding = self.esm(prot_input, repr_layers=[self.representation_layer], return_contacts=False)['representations'][self.representation_layer]
141 | na_embedding = self.rinalmo(na_input)['representation']
142 |
143 | # Vallina implementation with mean pooling
144 | # print("Original Embedding:", prot_embedding.shape, na_embedding.shape)
145 | if self.vallina:
146 | if strategy == 'separate':
147 | # input shape [N', L], where N' is flexible in every batch
148 | prot_embedding = segment_pool(prot_embedding, prot_chains, prot_mask, pooling=self.pooling)
149 | na_embedding = segment_pool(na_embedding, na_chains, na_mask, pooling=self.pooling)
150 | else:
151 | if self.pooling == 'max':
152 | prot_embedding = (prot_embedding * prot_mask.unsqueeze(-1)).max(dim=1)[0]
153 | na_embedding = (na_embedding * na_mask.unsqueeze(-1)).max(dim=1)[0]
154 | else:
155 | prot_embedding = (prot_embedding * prot_mask.unsqueeze(-1)).sum(dim=1)
156 | na_embedding = (na_embedding * na_mask.unsqueeze(-1)).sum(dim=1)
157 | if self.pooling == 'mean':
158 | # Prot_mask: [N, L]
159 | prot_mask_sum = prot_mask.sum(dim=1, keepdim=True)
160 | na_mask_sum = na_mask.sum(dim=1, keepdim=True)
161 | prot_embedding = prot_embedding / (prot_mask_sum + 1e-10)
162 | na_embedding = na_embedding / (na_mask_sum + 1e-10)
163 | complex_embedding = torch.cat([prot_embedding, na_embedding], dim=1)
164 | output = self.cat_pred_head(complex_embedding)
165 | output = output.squeeze(1)
166 | else:
167 | prot_embedding = prot_embedding.float()
168 | na_embedding = na_embedding.float()
169 | # print("Original Embedding:", prot_embedding, na_embedding)
170 | max_len = input['pos_atoms'].shape[1]
171 | # Adjust the embeddings from LMs for CFormer
172 | if 'patch_idx' in input:
173 | patch_idx = input['patch_idx']
174 | else:
175 | patch_idx = None
176 |
177 | if strategy == 'separate':
178 | # input shape [N', L], where N' is flexible in every batch
179 | out_embedding, masks = segment_cat_pad(prot_embedding, prot_chains, prot_mask, na_embedding, na_chains, na_mask, max_len, patch_idx)
180 | assert out_embedding.shape[0] == input['size']
181 | else:
182 | out_embedding, masks = cat_pad(prot_embedding, prot_mask, na_embedding, na_mask, max_len, patch_idx)
183 | assert out_embedding.shape[0] == input['size']
184 |
185 | out_embedding = self.proj_cplx(out_embedding)
186 | key_padding_mask = ~masks
187 |
188 | if self.pooling == 'token':
189 | mask_special = torch.zeros((len(out_embedding), 1), device=out_embedding.device, dtype=key_padding_mask.dtype)
190 | cplx_embed = self.complex_embedding.repeat(len(out_embedding), 1, 1)
191 | prot_embed = self.prot_embedding.repeat(len(out_embedding), 1, 1)
192 | rna_embed = self.rna_embedding.repeat(len(out_embedding), 1, 1)
193 | out_embedding = torch.cat([cplx_embed, prot_embed, rna_embed, out_embedding], dim=1)
194 | key_padding_mask = torch.cat([mask_special, mask_special, mask_special, key_padding_mask], dim=1)
195 |
196 | output, _ = self.transformer(out_embedding, key_padding_mask=key_padding_mask, need_attn_weights=False)
197 |
198 | if self.pooling == 'token':
199 | complex_embedding = output[:, 0, :].squeeze(1)
200 | else:
201 | complex_embedding = (output * (~key_padding_mask).unsqueeze(-1)).sum(dim=1)
202 | if self.pooling == 'mean':
203 | # Prot_mask: [N, L]
204 | seq_mask_sum = (~key_padding_mask).sum(dim=1, keepdim=True)
205 | complex_embedding = complex_embedding / (seq_mask_sum + 1e-10)
206 |
207 | output = self.pred_head(complex_embedding)
208 | output = output.squeeze(1)
209 |
210 | return output
211 |
212 |
213 |
214 |
215 |
216 |
--------------------------------------------------------------------------------
/pl_modules/model_module.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 | import pandas as pd
8 | import pytorch_lightning as pl
9 | from models import ModelRegister
10 | from utils.metrics import ScalarMetricAccumulator, cal_pearson, cal_spearman, cal_rmse, cal_mae, get_loss
11 | def get_model(model_args:dict=None):
12 | register = ModelRegister()
13 | model_args_ori = {}
14 | model_args_ori.update(model_args)
15 | model_cls = register[model_args['model_type']]
16 | model = model_cls(**model_args_ori)
17 | return model
18 |
19 | class ModelModule(pl.LightningModule):
20 | def __init__(self, output_dir=None, model_args=None, data_args=None, run_args=None):
21 | super().__init__()
22 | self.save_hyperparameters()
23 | if model_args is None:
24 | model_args = {}
25 | if data_args is None:
26 | data_args = {}
27 | self.output_dir = output_dir
28 | if self.output_dir is not None:
29 | self.output_dir = Path(self.output_dir) / 'pred'
30 | self.output_dir.mkdir(parents=True, exist_ok=True)
31 | self.l_type = data_args.loss_type
32 | self.model = get_model(model_args=model_args.model)
33 | self.model_args = model_args
34 | self.data_args = data_args
35 | self.run_args = run_args
36 | self.optimizers_cfg = self.model_args.train.optimizer
37 | self.scheduler_cfg = self.model_args.train.scheduler
38 | self.valid_it = 0
39 | self.batch_size = data_args.batch_size
40 |
41 | self.train_loss = None
42 |
43 | def get_progress_bar_dict(self):
44 | tqdm_dict = super().get_progress_bar_dict()
45 | tqdm_dict.pop('v_num', None)
46 | return tqdm_dict
47 |
48 | def configure_optimizers(self):
49 | if self.optimizers_cfg.type == 'adam':
50 | optimizer = torch.optim.Adam(self.parameters(),
51 | lr=self.optimizers_cfg.lr,
52 | betas=(self.optimizers_cfg.beta1, self.optimizers_cfg.beta2, ))
53 | elif self.optimizers_cfg.type == 'sgd':
54 | optimizer = torch.optim.SGD(self.parameters(), lr=self.optimizers_cfg.lr)
55 | elif self.optimizers_cfg.type == 'rmsprop':
56 | optimizer = torch.optim.RMSprop(self.parameters(), lr=self.optimizers_cfg.lr)
57 | else:
58 | raise NotImplementedError('Optimizer not supported: %s' % self.optimizers_cfg.type)
59 |
60 | if self.scheduler_cfg.type == 'plateau':
61 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
62 | factor=self.scheduler_cfg.factor,
63 | patience=self.scheduler_cfg.patience,
64 | min_lr=self.scheduler_cfg.min_lr)
65 | elif self.scheduler_cfg.type == 'multistep':
66 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
67 | milestones=self.scheduler_cfg.milestones,
68 | gamma=self.scheduler_cfg.gamma)
69 | elif self.scheduler_cfg.type == 'exp':
70 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
71 | gamma=self.scheduler_cfg.gamma)
72 | else:
73 | raise NotImplementedError('Scheduler not supported: %s' % self.scheduler_cfg.type)
74 |
75 | if self.model_args.resume is not None:
76 | print("Resuming from checkloint: %s" % self.model_args.resume)
77 | ckpt = torch.load(self.model_args.resume, map_location=self.model_args.device)
78 | it_first = ckpt['iteration']
79 | lsd_result = self.model.load_state_dict(ckpt['state_dict'], strict=False)
80 | print('Missing keys (%d): %s' % (len(lsd_result.missing_keys), ', '.join(lsd_result.missing_keys)))
81 | print(
82 | 'Unexpected keys (%d): %s' % (len(lsd_result.unexpected_keys), ', '.join(lsd_result.unexpected_keys)))
83 |
84 | print('Resuming optimizer states...')
85 | optimizer.load_state_dict(ckpt['optimizer'])
86 | print('Resuming scheduler states...')
87 | scheduler.load_state_dict(ckpt['scheduler'])
88 |
89 | if self.scheduler_cfg.type == 'plateau':
90 | optim_dict = {
91 | "optimizer": optimizer,
92 | "lr_scheduler": {
93 | "scheduler": scheduler,
94 | "monitor": 'val_loss'
95 | }
96 | }
97 | else:
98 | optim_dict = {
99 | "optimizer": optimizer,
100 | "lr_scheduler": {
101 | "scheduler": scheduler,
102 | }
103 | }
104 | return optim_dict
105 |
106 | def on_train_start(self):
107 | log_hyperparams = {'model_args':self.model_args, 'data_args': self.data_args, 'run_args': self.run_args}
108 | self.logger.log_hyperparams(log_hyperparams)
109 |
110 | def on_before_optimizer_step(self, optimizer) -> None:
111 | pass
112 | # for name, param in self.named_parameters():
113 | # if param.grad is None:
114 | # print(name)
115 | # print("Found Unused Parameters")
116 |
117 | def training_step(self, batch, batch_idx):
118 | y = batch['labels']
119 | pred = self.model(batch, self.data_args.strategy)
120 | # print(y.shape, pred.shape)
121 | loss = get_loss(self.l_type, pred, y, reduction='mean')
122 | # if torch.isnan(loss).any():
123 | # print("Found nan in loss!", input)
124 | # exit()
125 | self.train_loss = loss.detach()
126 | self.log("train_loss", float(self.train_loss), batch_size=self.batch_size, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True)
127 | return loss
128 |
129 | def on_validation_epoch_start(self):
130 | self.scalar_accum = ScalarMetricAccumulator()
131 | self.results = []
132 |
133 | def validation_step(self, batch, batch_idx):
134 | y = batch['labels']
135 | pred = self.model(batch, self.data_args.strategy)
136 | val_loss = get_loss(self.l_type, pred, y, reduction='mean')
137 | self.scalar_accum.add(name='val_loss', value=val_loss, batchsize=self.batch_size, mode='mean')
138 | self.log("val_loss_step", val_loss, batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
139 |
140 | for y_true, y_pred in zip(batch['labels'], pred):
141 | result = {}
142 | result['y_true'] = y_true.item()
143 | result['y_pred'] = y_pred.item()
144 | self.results.append(result)
145 | return val_loss
146 |
147 | def on_validation_epoch_end(self):
148 | results = pd.DataFrame(self.results)
149 | # print("Validation:", results)
150 | if self.output_dir is not None:
151 | results.to_csv(os.path.join(self.output_dir, f'results_{self.valid_it}.csv'), index=False)
152 | y_pred = np.array(results[f'y_pred'])
153 | y_true = np.array(results[f'y_true'])
154 | pearson_all = np.abs(cal_pearson(y_pred, y_true))
155 | spearman_all = np.abs(cal_spearman(y_pred, y_true))
156 | rmse_all = cal_rmse(y_pred, y_true)
157 | mae_all = cal_mae(y_pred, y_true)
158 | print(f'[All_Task] Pearson {pearson_all:.6f} Spearman {spearman_all:.6f} RMSE {rmse_all:.6f} MAE {mae_all:.6f}')
159 |
160 | self.log(f'val/all_pearson', pearson_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True)
161 | self.log(f'val/all_spearman', spearman_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True)
162 | self.log(f'val/all_rmse', rmse_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True)
163 | self.log(f'val/all_mae', mae_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True)
164 |
165 | val_loss = rmse_all * rmse_all
166 | self.log('val_loss', val_loss, batch_size=self.batch_size, on_epoch=True, sync_dist=True)
167 | # Trigger scheduler
168 | self.valid_it += 1
169 | return val_loss
170 |
171 | def on_test_epoch_start(self) -> None:
172 | self.results = []
173 | self.scalar_accum = ScalarMetricAccumulator()
174 |
175 | def test_step(self, batch, batch_idx):
176 | y = batch['labels']
177 | pred = self.model(batch, self.data_args.strategy)
178 | test_loss = get_loss(self.l_type, pred, y, reduction='mean')
179 | self.scalar_accum.add(name='loss', value = test_loss, batchsize=self.batch_size, mode='mean')
180 | self.log("test_loss_step", test_loss, batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
181 |
182 | for y_true, y_pred in zip(batch['labels'], pred):
183 | result = {}
184 | result['y_true'] = y_true.item()
185 | result['y_pred'] = y_pred.item()
186 | self.results.append(result)
187 | return test_loss
188 |
189 | def on_test_epoch_end(self):
190 | results = pd.DataFrame(self.results)
191 | if self.output_dir is not None:
192 | results.to_csv(os.path.join(self.output_dir, f'results_test.csv'), index=False)
193 | y_pred = np.array(results[f'y_pred'])
194 | y_true = np.array(results[f'y_true'])
195 | pearson_all = np.abs(cal_pearson(y_pred, y_true))
196 | spearman_all = np.abs(cal_spearman(y_pred, y_true))
197 | rmse_all = cal_rmse(y_pred, y_true)
198 | mae_all = cal_mae(y_pred, y_true)
199 | print(f'[All_Task] Pearson {pearson_all:.6f} Spearman {spearman_all:.6f} RMSE {rmse_all:.6f} MAE {mae_all:.6f}')
200 |
201 | self.log(f'test/all_pearson', pearson_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True)
202 | self.log(f'test/all_spearman', spearman_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True)
203 | self.log(f'test/all_rmse', rmse_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True)
204 | self.log(f'test/all_mae', mae_all, batch_size=self.batch_size, on_epoch=True, sync_dist=True)
205 | self.res = {"pearson": pearson_all,"spearman": spearman_all, "rmse": rmse_all, "mae": mae_all}
206 | print("Self.Res:", self.res)
207 | # test_loss = self.scalar_accum.get_average('loss')
208 | test_loss = rmse_all * rmse_all
209 | self.log('test_loss', test_loss, batch_size=self.batch_size, on_epoch=True, sync_dist=True)
210 |
211 | return test_loss
--------------------------------------------------------------------------------
/models/components/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | import math
5 |
6 | from models.components.rope import RotaryPositionEmbedding
7 |
8 | from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func
9 | from flash_attn.layers.rotary import RotaryEmbedding
10 |
11 | from flash_attn.bert_padding import unpad_input, pad_input
12 | import torch
13 | from einops import rearrange
14 |
15 | def dot_product_attention(q, k, v, struct_embed, attn_mask=None, key_pad_mask=None, dropout=None):
16 | c = q.shape[-1]
17 | attn = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(c)
18 | # Add Structure into attn
19 | attn += struct_embed
20 | if torch.isnan(attn).any():
21 | print("Found nan 0!")
22 | if attn_mask is not None:
23 | if len(attn_mask.shape) < len(attn.shape):
24 | attn_mask = attn_mask[:, None, ...]
25 | attn = attn.masked_fill(attn_mask, float("-inf"))
26 | if torch.isnan(attn).any():
27 | print("Found nan 1!")
28 | # attn shape is [B, H, L, L]
29 | if key_pad_mask is not None:
30 | attn = attn.masked_fill(key_pad_mask.unsqueeze(1).unsqueeze(2), float("-inf"))
31 | if torch.isnan(attn).any():
32 | print("Found nan 2!")
33 | attn = attn.softmax(dim=-1)
34 | if dropout is not None:
35 | attn = dropout(attn)
36 |
37 | output = torch.matmul(attn, v)
38 | return output, attn
39 |
40 | class MultiHeadAttention(nn.Module):
41 | def __init__(self, c_in, pair_dim, num_heads, attention_dropout=0.0, use_rot_emb=True, bias=False, c_squeeze=8):
42 | super().__init__()
43 | assert c_in % num_heads == 0, "Embedding dimensionality must be divisible with number of attention heads!"
44 |
45 | self.c_in = c_in
46 | self.num_heads = num_heads
47 |
48 | self.c_head = c_in // self.num_heads
49 | self.c_qkv = self.c_head * num_heads
50 |
51 | self.use_rot_emb = use_rot_emb
52 | if self.use_rot_emb:
53 | self.rotary_emb = RotaryPositionEmbedding(self.c_head)
54 |
55 | self.struct_bias = nn.Linear(pair_dim, self.num_heads, bias=bias)
56 | self.outer_linear = nn.Linear(c_squeeze * c_squeeze, pair_dim)
57 | self.outer_squeeze = nn.Linear(self.c_in, c_squeeze)
58 | self.to_q = nn.Linear(self.c_in, self.c_qkv, bias=bias)
59 | self.to_k = nn.Linear(self.c_in, self.c_qkv, bias=bias)
60 | self.to_v = nn.Linear(self.c_in, self.c_qkv, bias=bias)
61 |
62 | self.attention_dropout = nn.Dropout(p=attention_dropout)
63 |
64 | self.out_proj = nn.Linear(c_in, c_in, bias=bias)
65 | self.struct_out_proj = nn.Linear(pair_dim, pair_dim, bias=bias)
66 |
67 | def forward(self, q, k, v, struct_embed, attn_mask=None, key_pad_mask=None):
68 | bs = q.shape[0]
69 |
70 | q = self.to_q(q).view(bs, -1, self.num_heads, self.c_head).transpose(-2, -3)
71 | k = self.to_k(k).view(bs, -1, self.num_heads, self.c_head).transpose(-2, -3)
72 | v = self.to_v(v).view(bs, -1, self.num_heads, self.c_head).transpose(-2, -3)
73 |
74 | struct_attr = self.struct_bias(struct_embed).permute(0, 3, 1, 2) # (N, H, L, L)
75 | # print("Shape of Structure attribute:", struct_attr.shape)
76 | if self.use_rot_emb:
77 | q, k = self.rotary_emb(q, k)
78 |
79 | output, attn = dot_product_attention(q, k, v, struct_attr, attn_mask, key_pad_mask, self.attention_dropout)
80 |
81 | output = output.transpose(-2, -3).contiguous().view(bs, -1, self.num_heads * self.c_head)
82 |
83 | a = self.outer_squeeze(output)
84 | outer_product = torch.einsum('...bc, ...de -> ...bdce', a, a) # N, L, L, C, C
85 | struct_output = struct_embed + self.outer_linear(outer_product.reshape(list(outer_product.shape[:-2]) + [-1]))
86 |
87 | output = self.out_proj(output)
88 | struct_output = self.struct_out_proj(struct_output)
89 |
90 | return output, struct_output, attn
91 |
92 | class MultiHeadSelfAttention(nn.Module):
93 | def __init__(self, c_in, pair_dim, num_heads, attention_dropout=0.0, use_rot_emb=True, bias=False):
94 | super().__init__()
95 |
96 | self.mh_attn = MultiHeadAttention(c_in, pair_dim, num_heads, attention_dropout, use_rot_emb, bias)
97 |
98 | def forward(self, x, struct_embed, attn_mask=None, key_pad_mask=None):
99 | return self.mh_attn(x, x, x, struct_embed, attn_mask, key_pad_mask)
100 |
101 | class FlashAttention(nn.Module):
102 | """
103 | Implement the scaled dot product attention with softmax.
104 | """
105 | def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
106 | """
107 | Args:
108 | causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
109 | softmax_scale: float. The scaling of QK^T before applying softmax.
110 | Default to 1 / sqrt(headdim).
111 | attention_dropout: float. The dropout rate to apply to the attention
112 | (default: 0.0)
113 | """
114 | super().__init__()
115 | self.softmax_scale = softmax_scale
116 | self.attention_dropout = attention_dropout
117 | self.causal = causal
118 |
119 | def forward(self, qkv, cu_seqlens=None, max_seqlen=None, return_attn_probs=False):
120 | """
121 | Arguments
122 | ---------
123 | qkv: The tensor containing the query, key, and value.
124 | If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
125 | If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
126 | (total, 3, H, D), where total is the sum of the sequence lengths in the batch.
127 | cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
128 | of the sequences in the batch, used to index into qkv.
129 | max_seqlen: int. Maximum sequence length in the batch.
130 | return_attn_probs: bool. Whether to return the attention probabilities. This option is for
131 | testing only. The returned probabilities are not guaranteed to be correct
132 | (they might not have the right scaling).
133 | Returns:
134 | --------
135 | out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
136 | else (B, S, H, D).
137 | """
138 | assert qkv.dtype in [torch.float16, torch.bfloat16]
139 | assert qkv.is_cuda
140 |
141 | unpadded = cu_seqlens is not None
142 |
143 | if unpadded:
144 | assert cu_seqlens.dtype == torch.int32
145 | assert max_seqlen is not None
146 | assert isinstance(max_seqlen, int)
147 | return flash_attn_varlen_qkvpacked_func(
148 | qkv,
149 | cu_seqlens,
150 | max_seqlen,
151 | self.attention_dropout if self.training else 0.0,
152 | softmax_scale=self.softmax_scale,
153 | causal=self.causal,
154 | return_attn_probs=return_attn_probs
155 | )
156 | else:
157 | return flash_attn_qkvpacked_func(
158 | qkv,
159 | self.attention_dropout if self.training else 0.0,
160 | softmax_scale=self.softmax_scale,
161 | causal=self.causal,
162 | return_attn_probs=return_attn_probs
163 | )
164 |
165 | class FlashMultiHeadSelfAttention(nn.Module):
166 | """
167 | Multi-head self-attention implemented using FlashAttention.
168 | """
169 | def __init__(self, embed_dim, num_heads, attention_dropout=0.0, causal=False, use_rot_emb=True, bias=False):
170 | super().__init__()
171 | assert embed_dim % num_heads == 0, "Embedding dimensionality must be divisible with number of attention heads!"
172 |
173 | self.causal = causal
174 |
175 | self.embed_dim = embed_dim
176 | self.num_heads = num_heads
177 |
178 | self.head_dim = self.embed_dim // self.num_heads
179 | self.qkv_dim = self.head_dim * num_heads * 3
180 |
181 | self.rotary_emb_dim = self.head_dim
182 | self.use_rot_emb = use_rot_emb
183 | if self.use_rot_emb:
184 | self.rotary_emb = RotaryEmbedding(
185 | dim=self.rotary_emb_dim,
186 | base=10000.0,
187 | interleaved=False,
188 | scale_base=None,
189 | pos_idx_in_fp32=True, # fp32 RoPE precision
190 | device=None
191 | )
192 | self.flash_self_attn = FlashAttention(causal=self.causal, softmax_scale=None, attention_dropout=attention_dropout)
193 |
194 | self.Wqkv = nn.Linear(self.embed_dim, self.qkv_dim, bias=bias)
195 |
196 | self.attention_dropout = nn.Dropout(p=attention_dropout)
197 |
198 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
199 |
200 | def forward(self, x, key_padding_mask=None, return_attn_probs=False):
201 | """
202 | Arguments:
203 | x: (batch, seqlen, hidden_dim) (where hidden_dim = num_heads * head_dim)
204 | key_pad_mask: boolean mask, True means to keep, False means to mask out.
205 | (batch, seqlen)
206 | return_attn_probs: whether to return attention masks (False by default)
207 | """
208 |
209 | qkv = self.Wqkv(x)
210 | qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
211 |
212 | if self.use_rot_emb:
213 | qkv = self.rotary_emb(qkv, seqlen_offset=0)
214 |
215 | if return_attn_probs:
216 | bs = qkv.shape[0]
217 | qkv = torch.permute(qkv, (0, 3, 2, 1, 4))
218 | q = qkv[:, :, 0, :, :]
219 | k = qkv[:, :, 1, :, :]
220 | v = qkv[:, :, 2, :, :]
221 | out, attn = dot_product_attention(q, k, v, key_pad_mask=torch.logical_not(key_padding_mask) if key_padding_mask is not None else None, dropout=self.attention_dropout)
222 | output = out.transpose(-2, -3).contiguous().view(bs, -1, self.num_heads * self.head_dim)
223 | output = self.out_proj(output)
224 | return output, attn
225 |
226 | if key_padding_mask is not None:
227 | batch_size = qkv.shape[0]
228 | seqlen = qkv.shape[1]
229 | x_unpad, indices, cu_seqlens, max_s = unpad_input(qkv, key_padding_mask)
230 | output_unpad = self.flash_self_attn(
231 | x_unpad,
232 | cu_seqlens=cu_seqlens,
233 | max_seqlen=max_s,
234 | return_attn_probs=return_attn_probs
235 | )
236 | out = pad_input(rearrange(output_unpad, '... h d -> ... (h d)'), indices, batch_size, seqlen)
237 | else:
238 | output = self.flash_self_attn(
239 | qkv,
240 | cu_seqlens=None,
241 | max_seqlen=None,
242 | return_attn_probs=return_attn_probs
243 | )
244 | out = rearrange(output, '... h d -> ... (h d)')
245 |
246 | out = self.out_proj(out)
247 | return out, None
248 |
--------------------------------------------------------------------------------
/data/transforms/patch.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 |
4 | from ._base import _index_select_data, register_transform, _get_CB_positions, _index_select_complex
5 | import math
6 |
7 | @register_transform('focused_random_patch')
8 | class FocusedRandomPatch(object):
9 |
10 | def __init__(self, focus_attr, seed_nbh_size=32, patch_size=128):
11 | super().__init__()
12 | self.focus_attr = focus_attr
13 | self.seed_nbh_size = seed_nbh_size
14 | self.patch_size = patch_size
15 |
16 | def __call__(self, data):
17 | focus_flag = (data[self.focus_attr] > 0) # (L, )
18 | if focus_flag.sum() == 0:
19 | # If there is no active residues, randomly pick one.
20 | focus_flag[random.randint(0, focus_flag.size(0) - 1)] = True
21 | seed_idx = torch.multinomial(focus_flag.float(), num_samples=1).item()
22 |
23 | pos_CB = _get_CB_positions(data['pos_atoms'], data['mask_atoms']) # (L, )
24 | pos_seed = pos_CB[seed_idx:seed_idx + 1] # (1, )
25 | dist_from_seed = torch.cdist(pos_CB, pos_seed)[:, 0] # (L, 1) -> (L, )
26 | nbh_seed_idx = dist_from_seed.argsort()[:self.seed_nbh_size] # (Nb, )
27 |
28 | core_idx = nbh_seed_idx[focus_flag[nbh_seed_idx]] # (Ac, ), the core-set must be a subset of the focus-set
29 | dist_from_core = torch.cdist(pos_CB, pos_CB[core_idx]).min(dim=1)[0] # (L, )
30 | patch_idx = dist_from_core.argsort()[:self.patch_size] # (P, ) # The distance to the itself is zero, thus the item must be chose.
31 | patch_idx = patch_idx.sort()[0]
32 |
33 | core_flag = torch.zeros([data['aa'].size(0), ], dtype=torch.bool)
34 | core_flag[core_idx] = True
35 | data['core_flag'] = core_flag
36 |
37 | data_patch = _index_select_data(data, patch_idx)
38 | return data_patch
39 |
40 |
41 | @register_transform('random_patch')
42 | class RandomPatch(object):
43 |
44 | def __init__(self, seed_nbh_size=32, patch_size=128):
45 | super().__init__()
46 | self.seed_nbh_size = seed_nbh_size
47 | self.patch_size = patch_size
48 |
49 | def __call__(self, data):
50 | seed_idx = random.randint(0, data['aa'].size(0) - 1)
51 |
52 | pos_CB = _get_CB_positions(data['pos_atoms'], data['mask_atoms']) # (L, )
53 | pos_seed = pos_CB[seed_idx:seed_idx + 1] # (1, )
54 | dist_from_seed = torch.cdist(pos_CB, pos_seed)[:, 0] # (L, 1) -> (L, )
55 | core_idx = dist_from_seed.argsort()[:self.seed_nbh_size] # (Nb, )
56 |
57 | dist_from_core = torch.cdist(pos_CB, pos_CB[core_idx]).min(dim=1)[0] # (L, )
58 | patch_idx = dist_from_core.argsort()[:self.patch_size] # (P, )
59 | patch_idx = patch_idx.sort()[0]
60 |
61 | core_flag = torch.zeros([data['aa'].size(0), ], dtype=torch.bool)
62 | core_flag[core_idx] = True
63 | data['core_flag'] = core_flag
64 |
65 | data_patch = _index_select_data(data, patch_idx)
66 | return data_patch
67 |
68 |
69 | @register_transform('selected_region_with_padding_patch')
70 | class SelectedRegionWithPaddingPatch(object):
71 |
72 | def __init__(self, select_attr, each_residue_nbh_size, patch_size_limit):
73 | super().__init__()
74 | self.select_attr = select_attr
75 | self.each_residue_nbh_size = each_residue_nbh_size
76 | self.patch_size_limit = patch_size_limit
77 |
78 | def __call__(self, data):
79 | select_flag = (data[self.select_attr] > 0)
80 |
81 | pos_CB = _get_CB_positions(data['pos_atoms'], data['mask_atoms']) # (L, 3)
82 | pos_sel = pos_CB[select_flag] # (S, 3)
83 | dist_from_sel = torch.cdist(pos_CB, pos_sel) # (L, S)
84 | nbh_sel_idx = torch.argsort(dist_from_sel, dim=0)[:self.each_residue_nbh_size, :] # (nbh, S)
85 | patch_idx = nbh_sel_idx.view(-1).unique() # (patchsize,)
86 |
87 | data_patch = _index_select_data(data, patch_idx)
88 | return data_patch
89 |
90 |
91 | @register_transform('selected_region_fixed_size_patch')
92 | class SelectedRegionFixedSizePatch(object):
93 |
94 | def __init__(self, select_attr, patch_size):
95 | super().__init__()
96 | self.select_attr = select_attr
97 | self.patch_size = patch_size
98 |
99 | def __call__(self, data):
100 | select_flag = (data[self.select_attr] > 0)
101 |
102 | pos_CB = _get_CB_positions(data['pos_atoms'], data['mask_atoms']) # (L, 3)
103 | pos_sel = pos_CB[select_flag] # (S, 3)
104 | # print("Pos CB and sel:", pos_sel.shape, pos_CB.shape, select_flag.shape)
105 | dist_from_sel = torch.cdist(pos_CB, pos_sel).min(dim=1)[0] # (L, )
106 | # print(self.patch_size)
107 | patch_idx = torch.argsort(dist_from_sel)[:self.patch_size]
108 |
109 | data_patch = _index_select_data(data, patch_idx)
110 | return data_patch
111 |
112 | @register_transform('selected_region_with_distmap')
113 | class SelectedRegionWithDistmap(object):
114 |
115 | def __init__(self, patch_size):
116 | super().__init__()
117 | self.patch_size = patch_size
118 |
119 | def __call__(self, data):
120 | atoms_dist_min = data['atom_min_dist']
121 |
122 | identifier = data['identifier']
123 | tmp = atoms_dist_min[identifier==0]
124 | interface_distance = tmp[:, identifier==1]
125 | prot_min_dist = interface_distance.min(dim=1)[0]
126 | rna_min_dist = interface_distance.transpose(0, 1).min(dim=1)[0]
127 | total_min = torch.cat([prot_min_dist, rna_min_dist], dim=0)
128 | patch_idx = torch.argsort(total_min)[:self.patch_size]
129 | patch_idx, _ = torch.sort(patch_idx)
130 | # print(self.patch_size)
131 | data_patch = _index_select_complex(data, patch_idx)
132 | data_patch['patch_idx'] = patch_idx
133 | return data_patch
134 |
135 | @register_transform('interface_max_size_patch')
136 | class InterfaceFixedMaxSizePatch(object):
137 | def __init__(self, max_size, interface_dist=7.5, kernel_size=100):
138 | super().__init__()
139 | self.max_size = max_size
140 | self.interface_dist = interface_dist
141 | self.kernel_size = kernel_size
142 |
143 | def segment_fill_zero(self, mask, chain_nb):
144 | unique_segs = torch.unique(chain_nb)
145 | result = mask.clone()
146 | can_fills = []
147 | for seg in unique_segs:
148 | seg_mask = chain_nb == seg
149 | seg_x = mask[seg_mask]
150 |
151 | ones_mask = seg_x == 1
152 | seg_result = ones_mask.clone()
153 | # cumsum = torch.cumsum(ones_mask, dim=0)
154 | indices = torch.nonzero(ones_mask).flatten()
155 | start = indices.min()
156 | end = indices.max() + 1
157 | seg_result[start: end] = True
158 | if len(seg_result) <= 30:
159 | seg_result[:] = True
160 | # seg_result = (cumsum > 0).bool()
161 | can_fill = (~seg_result).sum()
162 | can_fills.append(can_fill.item())
163 | result[seg_mask] = seg_result
164 | return result, can_fills
165 |
166 | def segment_fill_to_max(self, mask, chain_nb, length_to_fill, can_fill, prot_seqs, na_seqs):
167 | unique_segs = torch.unique(chain_nb)
168 | result = mask.clone()
169 | # assert sum(can_fill) >= length_to_fill
170 | remainders = length_to_fill
171 | prot_seqs_new = []
172 | na_seqs_new = []
173 | for i, seg in enumerate(unique_segs):
174 | to_fill = math.ceil(int((can_fill[i] / (sum(can_fill) + 1e-5)) * length_to_fill))
175 | if remainders < to_fill:
176 | to_fill = remainders
177 | remainders -= to_fill
178 | seg_mask = chain_nb == seg
179 | seg_x = mask[seg_mask]
180 |
181 | indices = torch.nonzero(seg_x).flatten()
182 | try:
183 | left_len = indices.min()
184 | right_len = len(seg_x) - indices.max() - 1
185 | except:
186 | print("???")
187 |
188 | if to_fill >= left_len + right_len:
189 | seg_x[:] = 1
190 | start = 0
191 | end = indices.max() + 1
192 | else:
193 | left_fill = random.randint(max(to_fill-right_len, 0), min(left_len,to_fill))
194 | right_fill = to_fill - left_fill
195 | assert left_fill + right_fill == to_fill
196 | start = indices.min() - left_fill
197 | end = indices.max() + right_fill + 1
198 | seg_x[start: end] = 1
199 |
200 | result[seg_mask] = seg_x
201 |
202 | if i < len(prot_seqs):
203 | prot_seq_new = prot_seqs[i][start: end]
204 | prot_seqs_new.append(prot_seq_new)
205 | else:
206 | na_seq_new = na_seqs[i-len(prot_seqs)][start: end]
207 | na_seqs_new.append(na_seq_new)
208 | return result, prot_seqs_new, na_seqs_new
209 |
210 | def __call__(self, data):
211 | if len(data['pos_atoms']) <= self.max_size:
212 | # print("Short data, no need to process!")
213 | return data
214 | pos_CB = _get_CB_positions(data['pos_atoms'], data['mask_atoms'])
215 | identifier = data['identifier']
216 | chain_nb = data['chain_nb']
217 | dist_map = torch.cdist(pos_CB, pos_CB)
218 | dist_map[identifier[:, None] == identifier[None, :]] = 10000
219 | # print("Original data:", data['id'], len(''.join(data['prot_seqs'])), len(''.join(data['rna_seqs'])))
220 | contact_dis_min = torch.min(dist_map, dim=-1)[0]
221 | kernel_area = torch.zeros_like(contact_dis_min).bool()
222 | # kernel_area = contact_dis_min <= self.interface_dist
223 | contact_indices = torch.argsort(contact_dis_min)
224 | kernel_indices = contact_indices[: self.kernel_size]
225 | kernel_area.index_fill_(dim=0, index=kernel_indices, value=True)
226 | continuous_kernel_area, can_fill = self.segment_fill_zero(kernel_area, chain_nb)
227 | # if torch.sum(continuous_kernel_area) > self.max_size:
228 | # print("Max size:", torch.sum(continuous_kernel_area))
229 | to_fill = max((self.max_size - torch.sum(continuous_kernel_area)).item(), 0)
230 | continuous_kernel_area, prot_seq_new, na_seq_new = self.segment_fill_to_max(continuous_kernel_area, chain_nb,
231 | to_fill, can_fill,
232 | data['prot_seqs'], data['rna_seqs'])
233 | select_idx = torch.nonzero(continuous_kernel_area).flatten()
234 | # print("New selected data:", len(select_idx))
235 | data = _index_select_complex(data, select_idx)
236 | data['prot_seqs'] = prot_seq_new
237 | data['rna_seqs'] = na_seq_new
238 | prot_lengths = [len(item) for item in prot_seq_new]
239 | na_lengths = [len(item) for item in na_seq_new]
240 | data['max_prot_length'] = max(prot_lengths)
241 | data['max_na_length'] = max(na_lengths)
242 | return data
--------------------------------------------------------------------------------
/data/pri30k_dataset.py:
--------------------------------------------------------------------------------
1 | from data.register import DataRegister
2 | from torch.utils.data import Dataset
3 | import pandas as pd
4 | import esm
5 | import torch
6 | from rinalmo.data.constants import *
7 | from rinalmo.data.alphabet import Alphabet
8 | from tqdm import tqdm
9 | import diskcache
10 | import os
11 | import math
12 | import time
13 | from data.transforms import get_transform
14 | from torch.utils.data._utils.collate import default_collate
15 | from typing import Optional, Dict
16 | from easydict import EasyDict
17 | from data.structure_dataset import _process_structure
18 |
19 | na_alphabet_config = {
20 | "standard_tkns": RNA_TOKENS,
21 | "special_tkns": [CLS_TKN, PAD_TKN, EOS_TKN, UNK_TKN, MASK_TKN],
22 | }
23 |
24 | R = DataRegister()
25 | # ATOM_N, ATOM_CA, ATOM_C, ATOM_O, ATOM_CB = 0, 1, 2, 3, 4
26 | # ATOM_P, ATOM_C4, ATOM_NB = 37, 38,
27 | @R.register('pri30k_dataset')
28 | class PRI30kDataset(Dataset):
29 | '''
30 | The implementation of Protein-RNA structure Dataset
31 | '''
32 | def __init__(self,
33 | dataframe,
34 | data_root,
35 | col_prot_name='PDB',
36 | col_prot_chain='Protein chains',
37 | col_na_chain='RNA chains',
38 | col_binding_site='Binding site renumbered merged',
39 | col_ligand='Binding ligands',
40 | diskcache=None,
41 | transform=None,
42 | **kwargs
43 | ):
44 | self.data_root = data_root
45 | self.df: pd.DataFrame = dataframe.copy()
46 | self.df.reset_index(drop=True, inplace=True)
47 | self.col_prot_name = col_prot_name
48 | self.col_prot_chain = col_prot_chain
49 | self.col_na_chain = col_na_chain
50 | self.col_binding_site = col_binding_site
51 | self.col_ligand = col_ligand
52 | self.diskcache = diskcache
53 | self.prot_alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
54 | self.na_alphabet = Alphabet(**na_alphabet_config)
55 |
56 | self.transform = get_transform(transform)
57 |
58 | # self.load_data()
59 |
60 | def load_data(self, idx):
61 | row = self.df.loc[idx]
62 | structure_id = row[self.col_prot_name]
63 | prot_chains = [row[self.col_prot_chain]]
64 | na_chains = [row[self.col_na_chain]]
65 | structure_id = structure_id + '_' + prot_chains[0] + '_' + na_chains[0]
66 | if self.diskcache is None or structure_id not in self.diskcache:
67 | structure_name = structure_id + '.cif'
68 | pdb_path = os.path.join(self.data_root, structure_name)
69 |
70 | cplx = _process_structure(pdb_path, structure_id, prot_chains, na_chains, gpu='cuda:0')
71 |
72 | ligand_id = row[self.col_prot_name] + '_' + row[self.col_na_chain]
73 | L = len(cplx['seq'])
74 | gpu_atoms = cplx['pos_heavyatom']
75 | gpu_masks = cplx['mask_heavyatom']
76 | distance_map = torch.linalg.norm(gpu_atoms[:, None, :, None, :]- gpu_atoms[None, :, None, :, :], dim=-1, ord=2).reshape(L, L, -1)
77 | mask = (gpu_masks[:, None, :, None] * gpu_masks[None, :, None, :]).reshape(L, L, -1)
78 | distance_map[~mask] = torch.inf
79 | atom_min_dist = torch.min(distance_map, dim=-1)[0]
80 |
81 | max_prot_length = 0
82 | max_na_length = 0
83 | for prot_seq in cplx.prot_seqs:
84 | if len(prot_seq) > max_prot_length:
85 | max_prot_length = len(prot_seq)
86 | for na_seq in cplx.rna_seqs:
87 | if len(na_seq) > max_na_length:
88 | max_na_length = len(na_seq)
89 |
90 | item = {
91 | 'ligand_id': ligand_id, # no need to pad
92 | 'atom_min_dist': atom_min_dist, # needs 2D padding
93 | 'max_prot_length': max_prot_length, # will be ignored in batching
94 | 'max_na_length': max_na_length, # will be ignored in batching
95 | 'can_bind': row[self.col_ligand] # no need to pad
96 | }
97 |
98 | cplx.update(item)
99 | if self.diskcache is not None:
100 | for key in cplx:
101 | if isinstance(cplx[key], torch.Tensor):
102 | cplx[key] = cplx[key].detach().cpu()
103 | self.diskcache[structure_id] = cplx
104 | return cplx
105 | else:
106 | return self.diskcache[structure_id]
107 |
108 | def __len__(self):
109 | return len(self.df)
110 |
111 | def __getitem__(self, idx):
112 | data = self.load_data(idx)
113 | if self.transform is not None:
114 | data = self.transform(data)
115 | return data
116 |
117 | EXCLUDE_KEYS = []
118 | DEFAULT_PAD_VALUES = {
119 | 'restype': 26,
120 | 'mask_atoms': 0,
121 | 'chain_nb': -1,
122 | 'identifier': -1
123 | }
124 |
125 | class PRI30kStructCollate(object):
126 | def __init__(self, strategy='separate', length_ref_key='restype', pad_values=DEFAULT_PAD_VALUES, exclude_keys=EXCLUDE_KEYS, eight=True):
127 | super().__init__()
128 | self.strategy = strategy
129 | self.length_ref_key = length_ref_key
130 | self.pad_values = pad_values
131 | self.exclude_keys = exclude_keys
132 | self.eight = eight
133 |
134 | @staticmethod
135 | def _pad_last(x, n, value=0):
136 | if isinstance(x, torch.Tensor):
137 | assert x.size(0) <= n
138 | if x.size(0) == n:
139 | return x
140 | pad_size = [n - x.size(0)] + list(x.shape[1:])
141 | pad = torch.full(pad_size, fill_value=value).to(x)
142 | return torch.cat([x, pad], dim=0)
143 | elif isinstance(x, list):
144 | pad = [value] * (n - len(x))
145 | return x + pad
146 | else:
147 | return x
148 |
149 | @staticmethod
150 | def _get_pad_mask(l, n):
151 | return torch.cat([
152 | torch.ones([l], dtype=torch.bool),
153 | torch.zeros([n - l], dtype=torch.bool)
154 | ], dim=0)
155 |
156 | @staticmethod
157 | def _get_common_keys(list_of_dict):
158 | keys = set(list_of_dict[0].keys())
159 | for d in list_of_dict[1:]:
160 | keys = keys.intersection(d.keys())
161 | return keys
162 |
163 | def _get_pad_value(self, key):
164 | if key not in self.pad_values:
165 | return 0
166 | return self.pad_values[key]
167 |
168 | def pad_2d(self, x, n, value=10000):
169 | assert isinstance(x, torch.Tensor)
170 | assert x.shape[0] == x.shape[1]
171 | if x.size(0) == n and x.size(1) == n:
172 | return x
173 | pad_size_1 = [n - x.size(0)] + list(x.shape[1:])
174 | pad = torch.full(pad_size_1, fill_value=value).to(x)
175 | x_padded_1 = torch.cat([x, pad], dim=0)
176 | pad_size_2 = [x_padded_1.shape[0]] + [n - x_padded_1.size(1)] + list(x_padded_1.shape[2:])
177 | pad = torch.full(pad_size_2, fill_value=value).to(x)
178 | x_padded = torch.cat([x_padded_1, pad], dim=1)
179 | return x_padded
180 |
181 |
182 |
183 | def collate_complex(self, data_list):
184 | max_length = max([data[self.length_ref_key].size(0) for data in data_list])
185 | keys_inter = self._get_common_keys(data_list)
186 | keys = []
187 | keys_not_pad = []
188 | keys_ignore = ['prot_seqs', 'rna_seqs', 'max_prot_length', 'max_na_length', 'can_bind', 'ligand_id', 'atom_min_dist']
189 | pad_2d = ['atom_min_dist']
190 | for key in keys_inter:
191 | if key in keys_ignore:
192 | continue
193 | elif key not in self.exclude_keys:
194 | keys.append(key)
195 | else:
196 | keys_not_pad.append(key)
197 |
198 | if self.eight:
199 | max_length = math.ceil(max_length / 8) * 8
200 | data_list_padded = []
201 |
202 | for data in data_list:
203 | data_padded = {
204 | k: self._pad_last(v, max_length, value=self._get_pad_value(k))
205 | for k, v in data.items()
206 | if k in keys
207 | }
208 |
209 | for k in pad_2d:
210 | data_padded[k] = self.pad_2d(data[k], max_length)
211 |
212 | for k in keys_not_pad:
213 | data_padded[k] = data[k]
214 |
215 | data_padded['mask'] = self._get_pad_mask(data[self.length_ref_key].size(0), max_length)
216 | data_list_padded.append(data_padded)
217 | return data_list_padded
218 |
219 | def pad_for_berts(self, batch):
220 | prot_alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
221 | na_alphabet = Alphabet(**na_alphabet_config)
222 | prot_chains = [len(item['prot_seqs']) for item in batch]
223 | na_chains = [len(item['rna_seqs']) for item in batch]
224 | max_item_prot_length = [item['max_prot_length'] for item in batch]
225 | max_item_na_length = [item['max_na_length'] for item in batch]
226 | max_prot_length = max(max_item_prot_length)
227 | max_na_length = max(max_item_na_length)
228 | total_prot_chains = sum(prot_chains)
229 | total_na_chains = sum(na_chains)
230 | if self.eight:
231 | max_prot_length = math.ceil((max_prot_length + 2) / 8) * 8
232 | max_na_length = math.ceil((max_na_length + 2) / 8) * 8
233 | else:
234 | max_prot_length = max_prot_length + 2
235 | max_na_length = max_na_length + 2
236 | prot_batch = torch.empty([total_prot_chains, max_prot_length])
237 | prot_batch.fill_(prot_alphabet.padding_idx)
238 | na_batch = torch.empty([total_na_chains, max_na_length])
239 | na_batch.fill_(na_alphabet.pad_idx)
240 | curr_prot_idx = 0
241 | curr_na_idx = 0
242 | for item in batch:
243 | prot_seqs = item['prot_seqs']
244 | na_seqs = item['rna_seqs']
245 | for prot_seq in prot_seqs:
246 | prot_batch[curr_prot_idx, 0] = prot_alphabet.cls_idx
247 | prot_seq_encode = prot_alphabet.encode(prot_seq)
248 | seq = torch.tensor(prot_seq_encode, dtype=torch.int64)
249 | prot_batch[curr_prot_idx, 1: len(prot_seq_encode)+1] = seq
250 | prot_batch[curr_prot_idx, len(prot_seq_encode)+1] = prot_alphabet.eos_idx
251 | curr_prot_idx += 1
252 | for na_seq in na_seqs:
253 | na_seq_encode = na_alphabet.encode(na_seq)
254 | seq = torch.tensor(na_seq_encode, dtype=torch.int64)
255 | na_batch[curr_na_idx, :len(seq)] = seq
256 | curr_na_idx += 1
257 |
258 | prot_mask = torch.zeros_like(prot_batch)
259 | na_mask = torch.zeros_like(na_batch)
260 | prot_mask[(prot_batch!=prot_alphabet.padding_idx) & (prot_batch!=prot_alphabet.eos_idx) & (prot_batch!=prot_alphabet.cls_idx)] = 1
261 | na_mask[(na_batch!=na_alphabet.pad_idx) & (na_batch!=na_alphabet.eos_idx) & (na_batch!=na_alphabet.cls_idx)] = 1
262 | return prot_batch.long(), prot_chains, prot_mask, na_batch.long(), na_chains, na_mask
263 |
264 | def gen_clip_label(self, ligand_ids, can_bind_info):
265 | clip_label = torch.eye(len(ligand_ids))
266 | for i, can_bind in enumerate(can_bind_info):
267 | candidates = can_bind.split(',')
268 | for j, ligand_id in enumerate(ligand_ids):
269 | if ligand_id in candidates:
270 | clip_label[i, j] = 1
271 | return clip_label.long()
272 |
273 |
274 | def __call__(self, data_list):
275 | data_list_padded = self.collate_complex(data_list)
276 | batch = default_collate(data_list_padded)
277 | batch['size'] = len(data_list_padded)
278 | prot_batch, prot_chains, prot_mask, na_batch, na_chains, na_mask = self.pad_for_berts(data_list)
279 | batch['prot'] = prot_batch
280 | batch['prot_chains'] = prot_chains
281 | batch['protein_mask'] = prot_mask
282 | batch['na'] = na_batch
283 | batch['na_chains'] = na_chains
284 | batch['na_mask'] = na_mask
285 | batch['strategy'] = self.strategy
286 | ligand_ids = [item['ligand_id'] for item in data_list]
287 | can_bind_info = [item['can_bind'] for item in data_list]
288 | batch['clip_label'] = self.gen_clip_label(ligand_ids, can_bind_info)
289 | return batch
--------------------------------------------------------------------------------
/pl_modules/pretune_module.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 | import pandas as pd
8 | import pytorch_lightning as pl
9 | from models import ModelRegister
10 | from utils.metrics import ScalarMetricAccumulator
11 |
12 | def get_model(model_args:dict=None):
13 | register = ModelRegister()
14 | model_args_ori = {}
15 | model_args_ori.update(model_args)
16 | model_cls = register[model_args['model_type']]
17 | model = model_cls(**model_args_ori)
18 | return model
19 |
20 | class PretuneModule(pl.LightningModule):
21 | def __init__(self, output_dir=None, model_args=None, data_args=None, run_args=None):
22 | super().__init__()
23 | self.save_hyperparameters()
24 | if model_args is None:
25 | model_args = {}
26 | if data_args is None:
27 | data_args = {}
28 | self.output_dir = output_dir
29 | if self.output_dir is not None:
30 | self.output_dir = Path(self.output_dir) / 'pred'
31 | self.output_dir.mkdir(parents=True, exist_ok=True)
32 | self.l_type = data_args.loss_type
33 | self.model = get_model(model_args=model_args.model)
34 | self.model_args = model_args
35 | self.data_args = data_args
36 | self.run_args = run_args
37 | self.optimizers_cfg = self.model_args.train.optimizer
38 | self.scheduler_cfg = self.model_args.train.scheduler
39 | self.valid_it = 0
40 | self.temperature = model_args.train.temperature
41 | self.batch_size = data_args.batch_size
42 |
43 | self.train_loss = None
44 |
45 | def get_progress_bar_dict(self):
46 | tqdm_dict = super().get_progress_bar_dict()
47 | tqdm_dict.pop('v_num', None)
48 | return tqdm_dict
49 |
50 | def configure_optimizers(self):
51 | if self.optimizers_cfg.type == 'adam':
52 | optimizer = torch.optim.Adam(self.parameters(),
53 | lr=self.optimizers_cfg.lr,
54 | betas=(self.optimizers_cfg.beta1, self.optimizers_cfg.beta2, ))
55 | elif self.optimizers_cfg.type == 'sgd':
56 | optimizer = torch.optim.SGD(self.parameters(), lr=self.optimizers_cfg.lr)
57 | elif self.optimizers_cfg.type == 'rmsprop':
58 | optimizer = torch.optim.RMSprop(self.parameters(), lr=self.optimizers_cfg.lr)
59 | else:
60 | raise NotImplementedError('Optimizer not supported: %s' % self.optimizers_cfg.type)
61 |
62 | if self.scheduler_cfg.type == 'plateau':
63 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
64 | factor=self.scheduler_cfg.factor,
65 | patience=self.scheduler_cfg.patience,
66 | min_lr=self.scheduler_cfg.min_lr)
67 | elif self.scheduler_cfg.type == 'multistep':
68 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
69 | milestones=self.scheduler_cfg.milestones,
70 | gamma=self.scheduler_cfg.gamma)
71 | elif self.scheduler_cfg.type == 'exp':
72 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
73 | gamma=self.scheduler_cfg.gamma)
74 | else:
75 | raise NotImplementedError('Scheduler not supported: %s' % self.scheduler_cfg.type)
76 |
77 | if self.model_args.resume is not None:
78 | print("Resuming from checkloint: %s" % self.model_args.resume)
79 | ckpt = torch.load(self.model_args.resume, map_location=self.model_args.device)
80 | it_first = ckpt['iteration']
81 | lsd_result = self.model.load_state_dict(ckpt['state_dict'], strict=False)
82 | print('Missing keys (%d): %s' % (len(lsd_result.missing_keys), ', '.join(lsd_result.missing_keys)))
83 | print(
84 | 'Unexpected keys (%d): %s' % (len(lsd_result.unexpected_keys), ', '.join(lsd_result.unexpected_keys)))
85 |
86 | print('Resuming optimizer states...')
87 | optimizer.load_state_dict(ckpt['optimizer'])
88 | print('Resuming scheduler states...')
89 | scheduler.load_state_dict(ckpt['scheduler'])
90 |
91 | if self.scheduler_cfg.type == 'plateau':
92 | optim_dict = {
93 | "optimizer": optimizer,
94 | "lr_scheduler": {
95 | "scheduler": scheduler,
96 | "monitor": 'val_loss'
97 | }
98 | }
99 | else:
100 | optim_dict = {
101 | "optimizer": optimizer,
102 | "lr_scheduler": {
103 | "scheduler": scheduler,
104 | }
105 | }
106 | return optim_dict
107 |
108 | def on_train_start(self):
109 | log_hyperparams = {'model_args':self.model_args, 'data_args': self.data_args, 'run_args': self.run_args}
110 | self.logger.log_hyperparams(log_hyperparams)
111 |
112 | def on_before_optimizer_step(self, optimizer) -> None:
113 | pass
114 | # for name, param in self.named_parameters():
115 | # if param.grad is None:
116 | # print(name)
117 | # print("Found Unused Parameters")
118 |
119 | def continuous_to_discrete_tensor(self, values):
120 | discrete_values = torch.zeros_like(values, dtype=torch.long, device=values.device)
121 |
122 | mask_0_8 = values < 8
123 | discrete_values[mask_0_8] = (values[mask_0_8] * 2 + 0.5).long()
124 |
125 | mask_8_32 = (values >= 8) & (values < 32)
126 | discrete_values[mask_8_32] = (16 + (values[mask_8_32] - 8)).long()
127 |
128 | mask_ge_32 = (values >= 32)
129 | discrete_values[mask_ge_32] = 39
130 | discrete_values = torch.clamp(discrete_values, 0, 39).long()
131 | return discrete_values
132 |
133 | def cal_loss(self, pred_clip, y_clip, pred_dist, y_dist, identifier):
134 | # y_dist = torch.clamp(y_dist.long(), 0, 31).long()
135 | y_dist = self.continuous_to_discrete_tensor(y_dist)
136 | total_y_dists = []
137 | total_pred_dists = []
138 | total_y_dists_inv = []
139 | total_pred_dists_inv = []
140 | for i in range(identifier.shape[0]):
141 | item_identifier = identifier[i].squeeze()
142 | y_tmp = y_dist[i, item_identifier==0]
143 | y_interface_dist = y_tmp[:, item_identifier==1].flatten()
144 | total_y_dists.append(y_interface_dist)
145 |
146 | y_pred_tmp = pred_dist[i, item_identifier==0]
147 | y_interface_pred = y_pred_tmp[:, item_identifier==1].reshape([-1, pred_dist.shape[-1]])
148 | total_pred_dists.append(y_interface_pred)
149 |
150 | y_tmp_inv = y_dist[i, item_identifier==1]
151 | y_interface_dist_inv = y_tmp_inv[:, item_identifier==0].flatten()
152 | total_y_dists_inv.append(y_interface_dist_inv)
153 |
154 | y_pred_tmp_inv = pred_dist[i, item_identifier==1]
155 | y_interface_pred_inv = y_pred_tmp_inv[:, item_identifier==0].reshape([-1, pred_dist.shape[-1]])
156 | total_pred_dists_inv.append(y_interface_pred_inv)
157 |
158 | total_pred_dists = torch.cat(total_pred_dists, dim=0)
159 | total_pred_dists_inv = torch.cat(total_pred_dists_inv, dim=0)
160 | total_y_dists = torch.cat(total_y_dists)
161 | total_y_dists_inv = torch.cat(total_y_dists_inv)
162 | # y_dist_one_hot = F.one_hot(indices, num_classes=32).float()
163 | loss_clip_forward = F.cross_entropy(pred_clip / self.temperature, y_clip.long())
164 | loss_clip_inverse = F.cross_entropy(pred_clip.transpose(0, 1) / self.temperature, y_clip.long())
165 | loss_clip = 0.5 * (loss_clip_forward + loss_clip_inverse)
166 | # print("Loss CLIP:", loss_clip)
167 | # loss_dist_forward = F.cross_entropy(y_interface_pred.permute(0, 3, 1, 2) / self.temperature , y_interface_dist.long())
168 | # loss_dist_inv = F.cross_entropy(y_interface_pred_inv.permute(0, 3, 1, 2) / self.temperature , y_interface_dist_inv.long())
169 | loss_dist_forward = F.cross_entropy(total_pred_dists / self.temperature , total_y_dists.long())
170 | loss_dist_inverse = F.cross_entropy(total_pred_dists_inv / self.temperature, total_y_dists_inv.long())
171 |
172 | loss_dist = 0.5 * (loss_dist_forward + loss_dist_inverse)
173 | # print("Loss Dist:", loss_dist)
174 | return loss_clip, loss_dist
175 |
176 | def training_step(self, batch, batch_idx):
177 | y_dist = batch['atom_min_dist']
178 | res_identifier = batch['identifier']
179 | # y_clip = batch['clip_label']
180 | y_clip = torch.arange(batch['size'], dtype=torch.long, device=y_dist.device)
181 | pred_dist, pred_clip = self.model(batch, self.data_args.strategy, stage='pretune', need_mask=True)
182 | loss_clip, loss_dist = self.cal_loss(pred_clip, y_clip, pred_dist, y_dist, res_identifier)
183 | loss = loss_clip + loss_dist
184 | if torch.isnan(loss).any():
185 | print("Found nan in loss!", input)
186 | exit()
187 | self.train_loss = loss.detach()
188 | self.log("train_loss", float(self.train_loss), batch_size=self.batch_size, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True)
189 | self.log("train_clip_loss", float(loss_clip.detach()), batch_size=self.batch_size, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True)
190 | self.log("train_dist_loss", float(loss_dist.detach()), batch_size=self.batch_size, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True)
191 | return loss
192 |
193 | def on_validation_epoch_start(self):
194 | self.scalar_accum = ScalarMetricAccumulator()
195 | self.results = []
196 |
197 | def validation_step(self, batch, batch_idx):
198 | y_dist = batch['atom_min_dist']
199 | res_identifier = batch['identifier']
200 | y_clip = torch.arange(batch['size'], dtype=torch.long, device=y_dist.device)
201 | pred_dist, pred_clip = self.model(batch, self.data_args.strategy, stage='pretune', need_mask=True)
202 |
203 | loss_clip, loss_dist = self.cal_loss(pred_clip, y_clip, pred_dist, y_dist, res_identifier)
204 | val_loss = loss_clip + loss_dist
205 | self.scalar_accum.add(name='val_loss', value=val_loss, batchsize=batch['size'], mode='mean')
206 | self.log("val_loss_step", val_loss, batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
207 | self.log("val_clip_loss", float(loss_clip.detach()), batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
208 | self.log("val_dist_loss", float(loss_dist.detach()), batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
209 | return val_loss
210 |
211 | def on_validation_epoch_end(self):
212 | val_loss = self.scalar_accum.get_average('val_loss')
213 | self.log('val_loss', val_loss, batch_size=self.batch_size, on_epoch=True, sync_dist=True)
214 | # Trigger scheduler
215 | self.valid_it += 1
216 | return val_loss
217 |
218 | def on_test_epoch_start(self) -> None:
219 | self.results = []
220 | self.scalar_accum = ScalarMetricAccumulator()
221 |
222 | def test_step(self, batch, batch_idx):
223 | y_dist = batch['atom_min_dist']
224 | res_identifier = batch['identifier']
225 | y_clip = torch.arange(batch['size'], dtype=torch.long, device=y_dist.device)
226 | pred_dist, pred_clip = self.model(batch, self.data_args.strategy, stage='pretune', need_mask=True)
227 | loss_clip, loss_dist = self.cal_loss(pred_clip, y_clip, pred_dist, y_dist, res_identifier)
228 | test_loss = loss_clip + loss_dist
229 | self.scalar_accum.add(name='loss', value = test_loss, batchsize=batch['size'], mode='mean')
230 | self.log("test_loss_step", test_loss, batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
231 | self.log("test_clip_loss", float(loss_clip.detach()), batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
232 | self.log("test_dist_loss", float(loss_dist.detach()), batch_size=self.batch_size, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
233 | self.res = {"test_loss_step": float(test_loss.detach()),"test_clip_loss": float(loss_clip.detach()), "test_dist_loss": float(loss_dist.detach())}
234 | return test_loss
235 |
236 | def on_test_epoch_end(self):
237 | test_loss = self.scalar_accum.get_average('loss')
238 | self.log('test_loss', test_loss, batch_size=self.batch_size, on_epoch=True, sync_dist=True)
239 |
240 | return test_loss
--------------------------------------------------------------------------------