├── src ├── interface │ ├── __init__.py │ ├── data_interface.py │ └── model_interface.py ├── datasets │ ├── __init__.py │ ├── afdb_dataset.py │ ├── inference_dataset.py │ ├── ts_dataset.py │ ├── cath_dataset.py │ ├── cathafdb_dataset.py │ └── featurizer.py ├── models │ ├── __init__.py │ ├── configs │ │ └── UBC2Model.yaml │ └── UBC2_model.py ├── tools │ ├── __init__.py │ ├── logger.py │ └── config_utils.py ├── version.py └── __init__.py ├── TMscore ├── assets ├── BC-Design.png └── BC-Design-overview.png ├── environment.yml ├── .gitignore ├── CONTRIBUTING.md ├── train ├── main_eval.py ├── data_interface.py └── main.py ├── LICENSE ├── README.md └── pdb2jsonpkl.py /src/interface/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /TMscore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gersteinlab/BC-Design/HEAD/TMscore -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) CAIRI AI Lab. All rights reserved 2 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) CAIRI AI Lab. All rights reserved 2 | -------------------------------------------------------------------------------- /assets/BC-Design.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gersteinlab/BC-Design/HEAD/assets/BC-Design.png -------------------------------------------------------------------------------- /assets/BC-Design-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gersteinlab/BC-Design/HEAD/assets/BC-Design-overview.png -------------------------------------------------------------------------------- /src/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) CAIRI AI Lab. All rights reserved 2 | 3 | from .affine_tools import Rigid, Rotation, get_interact_feats 4 | -------------------------------------------------------------------------------- /src/interface/data_interface.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | 3 | 4 | class DInterface_base(pl.LightningDataModule): 5 | def __init__(self, **kwargs): 6 | super().__init__() 7 | self.save_hyperparameters() 8 | self.batch_size = self.hparams.batch_size 9 | print("batch_size", self.batch_size) 10 | self.load_data_module() 11 | -------------------------------------------------------------------------------- /src/models/configs/UBC2Model.yaml: -------------------------------------------------------------------------------- 1 | res_dir: ./train/results 2 | # ex_name: UBC2Model 3 | dataset: CATH4.2 4 | model_name: UBC2Model 5 | # lr: 0.0002 6 | # lr_scheduler: onecycle 7 | # lr_scheduler: cosine 8 | offline: 1 9 | seed: 112 10 | pretrained_path: '' 11 | batch_size: 2 12 | accumulate_grad_batches: 1 13 | num_workers: 0 14 | min_length: 40 15 | data_root: ./data/ 16 | # epoch: 50 17 | augment_eps: 0.0 18 | geo_layer: 3 19 | attn_layer: 3 20 | node_layer: 3 21 | edge_layer: 3 22 | encoder_layer: 12 23 | hidden_dim: 128 24 | dropout: 0.0 25 | k_neighbors: 30 26 | virtual_atom_num: 24 27 | bc_encoder_layer: 2 28 | 29 | mask_rate: 0.1 30 | bc_mask_how: token 31 | modal_mask_ratio: 0. 32 | 33 | contrastive_pretrain_both: false 34 | contrastive_loss_global_alpha: 0.01 35 | contrastive_loss_local_alpha: 1. 36 | 37 | if_struc_only: false 38 | # checkpoint_path: "./UBC2Model.ckpt" 39 | -------------------------------------------------------------------------------- /src/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) CAIRI AI Lab. All rights reserved 2 | 3 | __version__ = '0.1.0' 4 | 5 | 6 | def parse_version_info(version_str): 7 | """Parse a version string into a tuple. 8 | 9 | Args: 10 | version_str (str): The version string. 11 | Returns: 12 | tuple[int | str]: The version info, e.g., "0.1.0" is parsed into 13 | (0, 1, 0), and "2.0.0rcx" is parsed into (2, 0, 0, 'rcx'). 14 | """ 15 | version_info = [] 16 | for x in version_str.split('.'): 17 | if x.isdigit(): 18 | version_info.append(int(x)) 19 | elif x.find('rc') != -1: 20 | patch_version = x.split('rc') 21 | version_info.append(int(patch_version[0])) 22 | version_info.append(f'rc{patch_version[1]}') 23 | return tuple(version_info) 24 | 25 | 26 | version_info = parse_version_info(__version__) 27 | 28 | __all__ = ['__version__', 'version_info', 'parse_version_info'] 29 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: bcdesign 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - nvidia # Kept for nvidia drivers/libs if not covered by pytorch channel for all needs 6 | - bioconda # Kept for biopython and potentially other bio-related tools 7 | - defaults 8 | dependencies: 9 | # Core Python 10 | - python=3.12.4 11 | 12 | - pytorch=2.3.0 13 | - pytorch-cuda=12.1 # This will pull in CUDA toolkit libs for PyTorch 14 | - lightning=2.3.3 # or pytorch-lightning, 'lightning' is the newer name 15 | - scikit-learn=1.5.1 16 | - numpy=2.0.1 17 | - pyyaml=6.0.1 18 | - tqdm=4.66.4 # Progress bars 19 | 20 | - pip=24.0 21 | 22 | - pip: 23 | - torch-scatter==2.1.2 24 | - torch-cluster==1.6.3 25 | - torch-geometric==2.5.3 26 | - "-f https://data.pyg.org/whl/torch-2.3.0+cu121.html" 27 | 28 | - transformers==4.43.3 # pip 29 | - huggingface-hub==0.24.2 # pip 30 | 31 | - scipy==1.14.0 # pip 32 | - pandas==2.2.2 # pip 33 | 34 | # Bioinformatics 35 | - biopython==1.84 # pip 36 | 37 | # Experiment tracking 38 | - wandb>=0.12.10 39 | 40 | # Configuration & Utilities 41 | - omegaconf==2.3.0 # pip 42 | - requests==2.32.3 # pip 43 | 44 | # Core utilities often managed by pip or with specific versions 45 | - torcheval==0.0.7 46 | - setuptools==71.1.0 # Build system 47 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) CAIRI AI Lab. All rights reserved 2 | 3 | import warnings 4 | from packaging.version import parse 5 | 6 | from .version import __version__ 7 | 8 | 9 | def digit_version(version_str: str, length: int = 4): 10 | """Convert a version string into a tuple of integers. 11 | 12 | This method is usually used for comparing two versions. For pre-release 13 | versions: alpha < beta < rc. 14 | 15 | Args: 16 | version_str (str): The version string. 17 | length (int): The maximum number of version levels. Default: 4. 18 | 19 | Returns: 20 | tuple[int]: The version info in digits (integers). 21 | """ 22 | version = parse(version_str) 23 | assert version.release, f'failed to parse version {version_str}' 24 | release = list(version.release) 25 | release = release[:length] 26 | if len(release) < length: 27 | release = release + [0] * (length - len(release)) 28 | if version.is_prerelease: 29 | mapping = {'a': -3, 'b': -2, 'rc': -1} 30 | val = -4 31 | # version.pre can be None 32 | if version.pre: 33 | if version.pre[0] not in mapping: 34 | warnings.warn(f'unknown prerelease version {version.pre[0]}, ' 35 | 'version checking may go wrong') 36 | else: 37 | val = mapping[version.pre[0]] 38 | release.extend([val, version.pre[-1]]) 39 | else: 40 | release.extend([val, 0]) 41 | 42 | elif version.is_postrelease: 43 | release.extend([1, version.post]) 44 | else: 45 | release.extend([0, 0]) 46 | return tuple(release) 47 | 48 | 49 | __all__ = ['__version__', 'digit_version'] 50 | -------------------------------------------------------------------------------- /src/tools/logger.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import Callback 2 | import os 3 | import shutil 4 | from omegaconf import OmegaConf 5 | 6 | class SetupCallback(Callback): 7 | def __init__(self, now, logdir, ckptdir, cfgdir, config, argv_content=None): 8 | super().__init__() 9 | self.now = now 10 | self.logdir = logdir 11 | self.ckptdir = ckptdir 12 | self.cfgdir = cfgdir 13 | self.config = config 14 | 15 | self.argv_content = argv_content 16 | 17 | # 在pretrain例程开始时调用。 18 | def on_fit_start(self, trainer, pl_module): 19 | # Create logdirs and save configs 20 | os.makedirs(self.logdir, exist_ok=True) 21 | os.makedirs(self.ckptdir, exist_ok=True) 22 | os.makedirs(self.cfgdir, exist_ok=True) 23 | 24 | print("Project config") 25 | print(OmegaConf.to_yaml(self.config)) 26 | OmegaConf.save(self.config, 27 | os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) 28 | 29 | with open(os.path.join(self.logdir, "argv_content.txt"), "w") as f: 30 | f.write(str(self.argv_content)) 31 | 32 | class BackupCodeCallback(Callback): 33 | def __init__(self, source_dir, backup_dir, ignore_patterns=None): 34 | super().__init__() 35 | self.source_dir = source_dir 36 | self.backup_dir = backup_dir 37 | self.ignore_patterns = ignore_patterns 38 | 39 | def on_train_start(self, trainer, pl_module): 40 | try: 41 | os.makedirs(self.backup_dir, exist_ok=True) 42 | if os.path.exists(self.backup_dir+'/code'): 43 | shutil.rmtree(self.backup_dir+'/code') 44 | shutil.copytree(self.source_dir, self.backup_dir+'/code', ignore=self.ignore_patterns) 45 | 46 | print(f"Code file backed up to {self.backup_dir}") 47 | except: 48 | print(f"Fail in copying file backed up to {self.backup_dir}") -------------------------------------------------------------------------------- /src/datasets/afdb_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import torch.utils.data as data 5 | import pickle 6 | 7 | 8 | def normalize_coordinates(surface): 9 | """ 10 | Normalize the coordinates of the surface. 11 | """ 12 | surface = np.array(surface) 13 | center = np.mean(surface, axis=0) 14 | max_ = np.max(surface, axis=0) 15 | min_ = np.min(surface, axis=0) 16 | length = np.max(max_ - min_) 17 | normalized_surface = (surface - center) / length 18 | return normalized_surface 19 | 20 | 21 | class AFDB2000Dataset(data.Dataset): 22 | def __init__(self, path = './', split='test'): 23 | self.path = path 24 | if not os.path.exists(path): 25 | raise "no such file:{} !!!".format(path) 26 | else: 27 | afdb2000_data = json.load(open(path+'/afdb2000.json')) 28 | 29 | self.data_dict = self._load_data_dict() 30 | 31 | self.data = [] 32 | for temp in afdb2000_data: 33 | title = temp['name'] 34 | data = self.data_dict[title] 35 | seq_length = len(temp['seq']) 36 | coords = np.array(temp['coords']) 37 | self.data.append({'title':title, 38 | 'seq':temp['seq'], 39 | 'CA':coords[:,1,:], 40 | 'C':coords[:,2,:], 41 | 'O':coords[:,3,:], 42 | 'N':coords[:,0,:], 43 | 'category': 'afdb2000', 44 | 'chain_mask': np.ones(seq_length), 45 | 'chain_encoding': np.ones(seq_length), 46 | 'orig_surface': data['surface'], 47 | 'surface': normalize_coordinates(data['surface']), 48 | 'features': data['features'][:, :2], 49 | }) 50 | 51 | def _load_data_dict(self): 52 | with open(self.path + f'/afdb2000.pkl', 'rb') as f: 53 | return pickle.load(f) 54 | 55 | def __len__(self): 56 | return len(self.data) 57 | 58 | def __getitem__(self, index): 59 | return self.data[index] -------------------------------------------------------------------------------- /src/interface/model_interface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn as nn 4 | import os 5 | import torch.optim.lr_scheduler as lrs 6 | import inspect 7 | 8 | class MInterface_base(pl.LightningModule): 9 | def __init__(self, model_name=None, loss=None, lr=None, **kargs): 10 | super().__init__() 11 | self.save_hyperparameters() 12 | self.load_model() 13 | self.configure_loss() 14 | os.makedirs(os.path.join(self.hparams.res_dir, self.hparams.ex_name), exist_ok=True) 15 | 16 | def on_validation_epoch_end(self): 17 | # Make the Progress Bar leave there 18 | self.print('') 19 | 20 | def get_schedular(self, optimizer, lr_scheduler='onecycle'): 21 | if lr_scheduler == 'step': 22 | scheduler = lrs.StepLR(optimizer, 23 | step_size=self.hparams.lr_decay_steps, 24 | gamma=self.hparams.lr_decay_rate) 25 | elif lr_scheduler == 'cosine': 26 | scheduler = lrs.CosineAnnealingLR(optimizer, 27 | T_max=self.hparams.steps_per_epoch*self.hparams.epoch, 28 | eta_min=self.hparams.lr / 100) 29 | elif lr_scheduler == 'onecycle': 30 | scheduler = lrs.OneCycleLR(optimizer, max_lr=self.hparams.lr, steps_per_epoch=self.hparams.steps_per_epoch, epochs=self.hparams.epoch, three_phase=False, final_div_factor=1., 31 | ) 32 | else: 33 | raise ValueError('Invalid lr_scheduler type!') 34 | 35 | return scheduler 36 | 37 | def configure_optimizers(self): 38 | if hasattr(self.hparams, 'weight_decay'): 39 | weight_decay = self.hparams.weight_decay 40 | else: 41 | weight_decay = 0 42 | 43 | optimizer_g = torch.optim.AdamW(self.model.parameters(), lr=self.hparams.lr, weight_decay=weight_decay, betas=(0.9, 0.98), eps=1e-8) 44 | 45 | schecular_g = self.get_schedular(optimizer_g, self.hparams.lr_scheduler) 46 | 47 | return [optimizer_g], [{"scheduler": schecular_g, "interval": "step"}] 48 | 49 | def lr_scheduler_step(self, *args, **kwargs): 50 | scheduler = self.lr_schedulers() 51 | scheduler.step() 52 | 53 | -------------------------------------------------------------------------------- /src/datasets/inference_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import torch.utils.data as data 5 | import pickle 6 | 7 | 8 | def normalize_coordinates(surface): 9 | """ 10 | Normalize the coordinates of the surface. 11 | """ 12 | surface = np.array(surface) 13 | center = np.mean(surface, axis=0) 14 | max_ = np.max(surface, axis=0) 15 | min_ = np.min(surface, axis=0) 16 | length = np.max(max_ - min_) 17 | normalized_surface = (surface - center) / length 18 | return normalized_surface 19 | 20 | 21 | class InferenceDataset(data.Dataset): 22 | def __init__(self, path = './', split='test'): 23 | self.path = path 24 | dataset_name = os.path.basename(path) 25 | self.json_path = os.path.join(path, dataset_name + '.json') 26 | self.pkl_path = os.path.join(path, dataset_name + '.pkl') 27 | if not os.path.exists(path): 28 | raise "no such file:{} !!!".format(path) 29 | else: 30 | data = json.load(open(self.json_path)) 31 | 32 | self.data_dict = self._load_data_dict() 33 | 34 | self.data = [] 35 | for temp in data: 36 | title = temp['name'] 37 | data = self.data_dict[title] 38 | seq_length = len(temp['seq']) 39 | coords = np.array(temp['coords']) 40 | self.data.append({'title':title, 41 | 'seq':temp['seq'], 42 | 'CA':coords[:,1,:], 43 | 'C':coords[:,2,:], 44 | 'O':coords[:,3,:], 45 | 'N':coords[:,0,:], 46 | 'category': 'inference', 47 | 'chain_mask': np.ones(seq_length), 48 | 'chain_encoding': np.ones(seq_length), 49 | 'orig_surface': data['surface'], 50 | 'surface': normalize_coordinates(data['surface']), 51 | 'features': data['features'][:, :2], 52 | }) 53 | 54 | def _load_data_dict(self): 55 | with open(self.pkl_path, 'rb') as f: 56 | return pickle.load(f) 57 | 58 | def __len__(self): 59 | return len(self.data) 60 | 61 | def __getitem__(self, index): 62 | return self.data[index] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | apex/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | .vscode 108 | .idea 109 | 110 | # custom 111 | /data.tar.gz 112 | *.pkl 113 | *.pkl.json 114 | *.log.json 115 | *.ckpt 116 | *.zip 117 | *.ipynb 118 | *.jpg 119 | bash 120 | data 121 | data-before0601 122 | /configs 123 | data_set 124 | results/ 125 | gaozhangyang/ 126 | src/models/SurfPro/ 127 | src/modules/ 128 | src/models/__pycache__/ 129 | src/models/*.py 130 | !src/models/__init__.py 131 | !src/models/UBC2_model.py 132 | 133 | src/datasets/dataloader.py 134 | src/datasets/mpnn_dataset.py 135 | src/datasets/casp_dataset.py 136 | src/datasets/alphafold_dataset.py 137 | src/datasets/fast_dataloader.py 138 | 139 | src/interface/pretrain_interface.py 140 | 141 | src/models/configs/*.yaml 142 | !src/models/configs/UBC2Model.yaml 143 | 144 | src/tools/main_utils.py 145 | src/tools/config_utils.py 146 | src/tools/metrics.py 147 | src/tools/parser.py 148 | src/tools/utils.py 149 | 150 | train/results 151 | train/ig* 152 | train/train.sh 153 | train/try_multi_train.py 154 | train/main_train.py 155 | 156 | run.sh 157 | output/ 158 | work_dirs/ 159 | workspace/ 160 | tools/exp_bash/ 161 | pretrains 162 | cache/ 163 | cath_classes/ 164 | gt_pdb/ 165 | ig_results/ 166 | ig_biochem_results_steps50/ 167 | ig_results_steps50/ 168 | requirements/ 169 | tools/prepare_data/ 170 | lightning_logs/ 171 | predicted_pdb/ 172 | /logits 173 | raw_test_data/ 174 | test_results/ 175 | alphafolddb/ 176 | cath_test_82/ 177 | figures* 178 | 179 | TMscore.cpp 180 | TMscore.f 181 | 182 | antonia* 183 | calc_diversity.py 184 | create_cath42test_json.py 185 | json2pkl.py 186 | 187 | environment-deprecated.yml 188 | environment0601.yml 189 | environment-full-deprecated.yml 190 | test.py 191 | visualization* 192 | create_cath42test_json.py 193 | pred_pdb2jsonpkl.py 194 | pred_pdb2jsonpkl-thr.py 195 | 196 | 197 | # Pytorch 198 | *.pth 199 | 200 | *.swp 201 | .DS_Store 202 | *.json 203 | run/wandb/ 204 | wandb/ 205 | esm/ 206 | figs/ 207 | esm/* 208 | sampling/* 209 | model_zoom/* 210 | run/ 211 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to *PLACEHOLDER* 2 | 3 | We welcome contributions from everyone to help improve and expand *PLACEHOLDER*. This document outlines the process for contributing to the project. 4 | 5 | ## Table of Contents 6 | 1. [Environment Setup](#environment-setup) 7 | 2. [Coding Standards](#coding-standards) 8 | 3. [Pull Request Process](#pull-request-process) 9 | 4. [Pull Request Template](#pull-request-template) 10 | 11 | ## Environment Setup 12 | 13 | To contribute to *PLACEHOLDER*, follow these steps to set up your development environment: 14 | 15 | 1. Clone the repository: 16 | ``` 17 | git clone https://github.com/gersteinlab/placeholder.git 18 | cd placeholder 19 | ``` 20 | 2. Create a Conda environment: 21 | ``` 22 | conda create -n placeholder python=3.10 23 | conda activate placeholder 24 | ``` 25 | 3. Install the project in editable mode with development dependencies: 26 | ``` 27 | python3 -m pip install --upgrade pip 28 | pip install -e . 29 | ``` 30 | 31 | ## Coding Standards 32 | 33 | We strive to maintain clean and consistent code throughout the project. Please adhere to the following guidelines: 34 | 35 | 1. Follow PEP 8 guidelines for Python code. 36 | 2. Use meaningful variable and function names. 37 | 3. Write docstrings for functions and classes. 38 | 4. Keep functions small and focused on a single task. 39 | 5. Use type hints where appropriate. 40 | 41 | ### Code Formatting 42 | 43 | We use `black` for code formatting. To ensure your code is properly formatted: 44 | 45 | 1. Install black: 46 | ``` 47 | pip install black 48 | ``` 49 | 2. Run black on the codebase: 50 | ``` 51 | black . 52 | ``` 53 | 54 | ## Pull Request Process 55 | 56 | 1. Create a new branch for your feature or bugfix; feature is for new function; bugfix is for fixing a bug: 57 | ``` 58 | git checkout -b feature/your-feature-name 59 | ``` 60 | 2. Make your changes and commit them with clear, concise commit messages. 61 | 1. Monitor the current conditions and check which files are modified or untracked 62 | ``` 63 | git status 64 | ``` 65 | 2. Git add your file 66 | ``` 67 | git add schema.py 68 | ``` 69 | 3. Submit your change and commit 70 | ``` 71 | git commit -m "message" 72 | ``` 73 | 4. Push your branch to the repository: 74 | ``` 75 | git push origin feature/your-feature-name 76 | ``` 77 | 5. Open a pull request against the `main` branch on the website. 78 | 6. Fill out the pull request template (see below). 79 | 7. Address any feedback or comments from reviewers. 80 | 81 | ## Pull Request Template 82 | 83 | When you open a new pull request, please use the following template: 84 | 85 | ```markdown 86 | ## Description 87 | 88 | ### Changes 89 | [Provide a detailed list of the changes made in this PR] 90 | 91 | ### Design 92 | [Explain the design decisions and architectural changes, if any] 93 | 94 | ### Example Code 95 | [If applicable, provide example code demonstrating the usage of new features or fixes] 96 | 97 | ## Related Issue 98 | [Link to the issue this PR addresses, if applicable] 99 | 100 | ## Type of Change 101 | - [ ] Bug fix (non-breaking change which fixes an issue) 102 | - [ ] New feature (non-breaking change which adds functionality) 103 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 104 | - [ ] This change requires a documentation update 105 | 106 | ## How Has This Been Tested? 107 | [Describe the tests you ran to verify your changes] 108 | 109 | ## Additional Notes 110 | [Add any additional information or context about the PR here] 111 | ``` 112 | 113 | Thank you for contributing to *PLACEHOLDER*! 114 | -------------------------------------------------------------------------------- /src/datasets/ts_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import torch.utils.data as data 5 | import pickle 6 | 7 | 8 | def normalize_coordinates(surface): 9 | """ 10 | Normalize the coordinates of the surface. 11 | """ 12 | surface = np.array(surface) 13 | center = np.mean(surface, axis=0) 14 | max_ = np.max(surface, axis=0) 15 | min_ = np.min(surface, axis=0) 16 | length = np.max(max_ - min_) 17 | normalized_surface = (surface - center) / length 18 | return normalized_surface 19 | 20 | 21 | class TS50Dataset(data.Dataset): 22 | def __init__(self, path = './', split='test'): 23 | self.path = path 24 | if not os.path.exists(path): 25 | raise "no such file:{} !!!".format(path) 26 | else: 27 | ts50_data = json.load(open(path+'/ts50.json')) 28 | 29 | self.data_dict = self._load_data_dict() 30 | 31 | # TS500 has proteins with lengths of 500+ 32 | # TS50 only contains proteins with lengths less than 500 33 | self.data = [] 34 | for temp in ts50_data: 35 | title = temp['name'] 36 | data = self.data_dict[title] 37 | seq_length = len(temp['seq']) 38 | coords = np.array(temp['coords']) 39 | self.data.append({'title':title, 40 | 'seq':temp['seq'], 41 | 'CA':coords[:,1,:], 42 | 'C':coords[:,2,:], 43 | 'O':coords[:,3,:], 44 | 'N':coords[:,0,:], 45 | 'category': 'ts50', 46 | 'chain_mask': np.ones(seq_length), 47 | 'chain_encoding': np.ones(seq_length), 48 | 'orig_surface': data['surface'], 49 | 'surface': normalize_coordinates(data['surface']), 50 | 'features': data['features'][:, :2], 51 | }) 52 | 53 | def _load_data_dict(self): 54 | with open(self.path + f'/ts50.pkl', 'rb') as f: 55 | return pickle.load(f) 56 | 57 | def __len__(self): 58 | return len(self.data) 59 | 60 | def __getitem__(self, index): 61 | return self.data[index] 62 | 63 | 64 | class TS500Dataset(data.Dataset): 65 | def __init__(self, path = './', split='test'): 66 | self.path = path 67 | if not os.path.exists(path): 68 | raise "no such file:{} !!!".format(path) 69 | else: 70 | ts500_data = json.load(open(path+'/ts500.json')) 71 | 72 | self.data_dict = self._load_data_dict() 73 | 74 | # TS500 has proteins with lengths of 500+ 75 | # TS50 only contains proteins with lengths less than 500 76 | self.data = [] 77 | for temp in ts500_data: 78 | title = temp['name'] 79 | data = self.data_dict[title] 80 | seq_length = len(temp['seq']) 81 | coords = np.array(temp['coords']) 82 | self.data.append({'title':title, 83 | 'seq':temp['seq'], 84 | 'CA':coords[:,1,:], 85 | 'C':coords[:,2,:], 86 | 'O':coords[:,3,:], 87 | 'N':coords[:,0,:], 88 | 'category': 'ts500', 89 | 'chain_mask': np.ones(seq_length), 90 | 'chain_encoding': np.ones(seq_length), 91 | 'orig_surface': data['surface'], 92 | 'surface': normalize_coordinates(data['surface']), 93 | 'features': data['features'][:, :2], 94 | }) 95 | 96 | def _load_data_dict(self): 97 | with open(self.path + f'/ts500.pkl', 'rb') as f: 98 | return pickle.load(f) 99 | 100 | def __len__(self): 101 | return len(self.data) 102 | 103 | def __getitem__(self, index): 104 | return self.data[index] -------------------------------------------------------------------------------- /src/tools/config_utils.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import tempfile 3 | import re 4 | import shutil 5 | import sys 6 | import ast 7 | from importlib import import_module 8 | 9 | ''' 10 | Thanks the code from https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py wrote by Open-MMLab. 11 | The `Config` class here uses some parts of this reference. 12 | ''' 13 | 14 | def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): 15 | if not osp.isfile(filename): 16 | raise FileNotFoundError(msg_tmpl.format(filename)) 17 | 18 | 19 | class Config: 20 | def __init__(self, cfg_dict=None, filename=None): 21 | if cfg_dict is None: 22 | cfg_dict = dict() 23 | elif not isinstance(cfg_dict, dict): 24 | raise TypeError('cfg_dict must be a dict, but ' 25 | f'got {type(cfg_dict)}') 26 | 27 | if filename is not None: 28 | cfg_dict = self._file2dict(filename, True) 29 | filename = filename 30 | 31 | super(Config, self).__setattr__('_cfg_dict', cfg_dict) 32 | super(Config, self).__setattr__('_filename', filename) 33 | 34 | @staticmethod 35 | def _validate_py_syntax(filename): 36 | with open(filename, 'r') as f: 37 | content = f.read() 38 | try: 39 | ast.parse(content) 40 | except SyntaxError as e: 41 | raise SyntaxError('There are syntax errors in config ' 42 | f'file {filename}: {e}') 43 | 44 | @staticmethod 45 | def _substitute_predefined_vars(filename, temp_config_name): 46 | file_dirname = osp.dirname(filename) 47 | file_basename = osp.basename(filename) 48 | file_basename_no_extension = osp.splitext(file_basename)[0] 49 | file_extname = osp.splitext(filename)[1] 50 | support_templates = dict( 51 | fileDirname=file_dirname, 52 | fileBasename=file_basename, 53 | fileBasenameNoExtension=file_basename_no_extension, 54 | fileExtname=file_extname) 55 | with open(filename, 'r') as f: 56 | config_file = f.read() 57 | for key, value in support_templates.items(): 58 | regexp = r'\{\{\s*' + str(key) + r'\s*\}\}' 59 | value = value.replace('\\', '/') 60 | config_file = re.sub(regexp, value, config_file) 61 | with open(temp_config_name, 'w') as tmp_config_file: 62 | tmp_config_file.write(config_file) 63 | 64 | @staticmethod 65 | def _file2dict(filename, use_predefined_variables=True): 66 | filename = osp.abspath(osp.expanduser(filename)) 67 | check_file_exist(filename) 68 | fileExtname = osp.splitext(filename)[1] 69 | if fileExtname not in ['.py']: 70 | raise IOError('Only py type are supported now!') 71 | 72 | with tempfile.TemporaryDirectory() as temp_config_dir: 73 | temp_config_file = tempfile.NamedTemporaryFile( 74 | dir=temp_config_dir, suffix=fileExtname) 75 | temp_config_name = osp.basename(temp_config_file.name) 76 | 77 | # Substitute predefined variables 78 | if use_predefined_variables: 79 | Config._substitute_predefined_vars(filename, 80 | temp_config_file.name) 81 | else: 82 | shutil.copyfile(filename, temp_config_file.name) 83 | 84 | if filename.endswith('.py'): 85 | temp_module_name = osp.splitext(temp_config_name)[0] 86 | sys.path.insert(0, temp_config_dir) 87 | Config._validate_py_syntax(filename) 88 | mod = import_module(temp_module_name) 89 | sys.path.pop(0) 90 | cfg_dict = { 91 | name: value 92 | for name, value in mod.__dict__.items() 93 | if not name.startswith('__') 94 | } 95 | # delete imported module 96 | del sys.modules[temp_module_name] 97 | # close temp file 98 | temp_config_file.close() 99 | return cfg_dict 100 | 101 | @staticmethod 102 | def fromfile(filename, use_predefined_variables=True): 103 | cfg_dict = Config._file2dict(filename, use_predefined_variables) 104 | return Config(cfg_dict, filename=filename) 105 | -------------------------------------------------------------------------------- /train/main_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.getcwd()) 4 | 5 | import warnings 6 | warnings.filterwarnings("ignore") 7 | 8 | import argparse 9 | import torch 10 | from model_interface import MInterface 11 | from data_interface import DInterface 12 | 13 | import pytorch_lightning as pl 14 | from pytorch_lightning.trainer import Trainer 15 | torch.autograd.set_detect_anomaly(True) 16 | 17 | def create_parser(): 18 | checkpoint_path = './UBC2Model.ckpt' 19 | ex_name = 'UBC2Model' 20 | batch_size = 2 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--res_dir', default='./train/results', type=str) 24 | parser.add_argument('--ex_name', default=ex_name, type=str) 25 | parser.add_argument('--check_val_every_n_epoch', default=1, type=int) 26 | parser.add_argument('--dataset', default='CATH4.2') 27 | parser.add_argument('--model_name', default='UBC2Model') 28 | parser.add_argument('--lr', default=0.0002, type=float, help='Learning rate') 29 | parser.add_argument('--lr_scheduler', default='onecycle') 30 | parser.add_argument('--offline', default=0, type=int) 31 | parser.add_argument('--seed', default=111, type=int) 32 | 33 | # dataset parameters 34 | parser.add_argument('--batch_size', default=batch_size, type=int) 35 | parser.add_argument('--num_workers', default=0, type=int) 36 | parser.add_argument('--pad', default=1024, type=int) 37 | parser.add_argument('--min_length', default=40, type=int) 38 | parser.add_argument('--data_root', default='./data/') 39 | 40 | # Testing specific parameters 41 | parser.add_argument('--epoch', default=50, type=int, help='end epoch') 42 | parser.add_argument('--augment_eps', default=0.0, type=float, help='noise level') 43 | 44 | # Model parameters 45 | parser.add_argument('--use_dist', default=1, type=int) 46 | parser.add_argument('--use_product', default=0, type=int) 47 | 48 | # Checkpoint parameter 49 | parser.add_argument('--checkpoint_path', default=checkpoint_path, type=str, help='Path to a checkpoint to resume testing') 50 | 51 | parser.add_argument('--contrastive_pretrain', default=False, type=bool) 52 | parser.add_argument('--contrastive_learning', default=False, type=bool) 53 | parser.add_argument('--if_strucenc_only', default=False, type=bool) 54 | parser.add_argument('--if_warmup_train', default=False, type=bool) 55 | 56 | parser.add_argument('--if_struc_only', default=False, type=bool) 57 | parser.add_argument('--exp_bc_mask_rate', default=0., type=float) 58 | parser.add_argument('--bc_mask_max_rate', default=0., type=float) 59 | parser.add_argument('--exp_hydro_mask_rate', default=0., type=float) 60 | parser.add_argument('--exp_charge_mask_rate', default=0., type=float) 61 | parser.add_argument('--exp_v_mask_rate', default=0., type=float) 62 | parser.add_argument('--exp_e_mask_rate', default=0., type=float) 63 | parser.add_argument('--exp_backbone_noise_sd', default=0., type=float) 64 | parser.add_argument('--exp_wo_bcgraph', default=False, type=bool) 65 | 66 | parser.add_argument('--partial_design', default=False, type=bool) 67 | parser.add_argument('--design_region_path', default='') 68 | 69 | parser.add_argument('--bc_indices', nargs='+', type=int, default=[0, 1]) 70 | 71 | args = parser.parse_args() 72 | return args 73 | 74 | def load_callbacks(args): 75 | callbacks = [] 76 | return callbacks 77 | 78 | if __name__ == "__main__": 79 | args = create_parser() 80 | pl.seed_everything(args.seed) 81 | 82 | # Initialize data module and setup test data 83 | data_module = DInterface(**vars(args)) 84 | data_module.setup() # Ensure the test dataset is loaded 85 | 86 | gpu_count = 1 87 | print(f"Using {gpu_count} GPUs for testing") 88 | 89 | # Initialize the model 90 | model = MInterface(**vars(args)) 91 | 92 | # Trainer configuration 93 | trainer_config = { 94 | 'devices': gpu_count, 95 | 'num_nodes': 1, # Number of nodes to use for distributed training 96 | 'precision': 32, 97 | 'accelerator': 'gpu', 98 | 'callbacks': load_callbacks(args), 99 | } 100 | 101 | trainer_opt = argparse.Namespace(**trainer_config) 102 | trainer_dict = vars(trainer_opt) 103 | trainer = Trainer(**trainer_dict) 104 | # Perform testing 105 | if args.checkpoint_path: 106 | print(f"Resuming from checkpoint: {args.checkpoint_path}") 107 | trainer.test(model, datamodule=data_module, ckpt_path=args.checkpoint_path) 108 | else: 109 | print("No checkpoint provided, testing with current model state") 110 | 111 | print(trainer_config) 112 | -------------------------------------------------------------------------------- /train/data_interface.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from torch.utils.data import DataLoader 3 | from src.interface.data_interface import DInterface_base 4 | import torch 5 | import os.path as osp 6 | 7 | class MyDataLoader(DataLoader): 8 | def __init__(self, dataset, model_name, batch_size=64, num_workers=8, *args, **kwargs): 9 | super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, *args, **kwargs) 10 | self.pretrain_device = 'cuda:0' 11 | self.model_name = model_name 12 | 13 | def __iter__(self): 14 | for batch in super().__iter__(): 15 | # 在这里对batch进行处理 16 | # ... 17 | try: 18 | self.pretrain_device = f'cuda:{torch.distributed.get_rank()}' 19 | except: 20 | self.pretrain_device = 'cuda:0' 21 | 22 | stream = torch.cuda.Stream( 23 | self.pretrain_device 24 | ) 25 | with torch.cuda.stream(stream): 26 | if self.model_name=='GVP': 27 | batch = batch.cuda(non_blocking=True, device=self.pretrain_device) 28 | yield batch 29 | else: 30 | for key, val in batch.items(): 31 | if type(val) == torch.Tensor: 32 | batch[key] = batch[key].cuda(non_blocking=True, device=self.pretrain_device) 33 | 34 | yield batch 35 | 36 | 37 | class DInterface(DInterface_base): 38 | def __init__(self,**kwargs): 39 | super().__init__(**kwargs) 40 | self.save_hyperparameters() 41 | self.load_data_module() 42 | self.exp_backbone_noise_sd = kwargs.get('exp_backbone_noise_sd', 0.0) 43 | self.partial_design = kwargs.get('partial_design', False) 44 | self.design_region_path = kwargs.get('design_region_path', '') 45 | self.ig_baseline_data = kwargs.get('ig_baseline_data', False) 46 | 47 | def setup(self, stage=None): 48 | from src.datasets.featurizer import (featurize_UBC2Model) 49 | if self.hparams.model_name == 'UBC2Model' or self.hparams.model_name == 'UBC2Large' or self.hparams.model_name == 'UBC01234': 50 | self.collate_fn = featurize_UBC2Model( 51 | exp_backbone_noise_sd=self.exp_backbone_noise_sd, 52 | partial_design=self.partial_design, 53 | design_region_path=self.design_region_path, 54 | ig_baseline_data=self.ig_baseline_data 55 | ).featurize 56 | 57 | # Assign train/val datasets for use in dataloaders 58 | if stage == 'fit' or stage is None: 59 | self.trainset = self.instancialize(split = 'train') 60 | self.valset = self.instancialize(split='valid') 61 | 62 | # Assign test dataset for use in dataloader(s) 63 | if stage == 'test' or stage is None: 64 | self.testset = self.instancialize(split='test') 65 | 66 | def train_dataloader(self): 67 | return MyDataLoader(self.trainset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=True, prefetch_factor=None, pin_memory=True, collate_fn=self.collate_fn) 68 | # return MyDataLoader(self.trainset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=True, prefetch_factor=8, pin_memory=True, collate_fn=self.collate_fn) 69 | 70 | def val_dataloader(self): 71 | return MyDataLoader(self.valset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, prefetch_factor=None, pin_memory=True, collate_fn=self.collate_fn) 72 | # return MyDataLoader(self.valset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn) 73 | 74 | def test_dataloader(self): 75 | return MyDataLoader(self.testset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, prefetch_factor=None, pin_memory=True, collate_fn=self.collate_fn) 76 | # return MyDataLoader(self.testset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn) 77 | 78 | def load_data_module(self): 79 | name = self.hparams.dataset 80 | if name == 'CATH4.2': 81 | from src.datasets.cath_dataset import CATHDatasetSurfProPiFoldDenseLarge 82 | self.data_module = CATHDatasetSurfProPiFoldDenseLarge 83 | self.hparams['version'] = 4.2 84 | self.hparams['path'] = osp.join(self.hparams.data_root, 'cath4.2') 85 | 86 | elif name == 'CATHAFDB': 87 | from src.datasets.cathafdb_dataset import CATHAFDBDataset 88 | self.data_module = CATHAFDBDataset 89 | self.hparams['version'] = 4.2 90 | self.hparams['path_cath'] = osp.join(self.hparams.data_root, 'cath4.2') 91 | self.hparams['path_afdb'] = osp.join(self.hparams.data_root, 'afdb-large4000') 92 | 93 | elif name == 'TS50': 94 | from src.datasets.ts_dataset import TS50Dataset 95 | self.data_module = TS50Dataset 96 | self.hparams['path'] = osp.join(self.hparams.data_root, 'ts50') 97 | 98 | elif name == 'TS500': 99 | from src.datasets.ts_dataset import TS500Dataset 100 | self.data_module = TS500Dataset 101 | self.hparams['path'] = osp.join(self.hparams.data_root, 'ts500') 102 | 103 | elif name == 'AFDB2000': 104 | from src.datasets.afdb_dataset import AFDB2000Dataset 105 | self.data_module = AFDB2000Dataset 106 | self.hparams['path'] = osp.join(self.hparams.data_root, 'afdb2000') 107 | 108 | else: 109 | from src.datasets.inference_dataset import InferenceDataset 110 | self.data_module = InferenceDataset 111 | self.hparams['path'] = osp.join(self.hparams.data_root, name) 112 | 113 | def instancialize(self, **other_args): 114 | """ Instancialize a model using the corresponding parameters 115 | from self.hparams dictionary. You can also input any args 116 | to overwrite the corresponding value in self.kwargs. 117 | """ 118 | 119 | class_args = list(inspect.signature(self.data_module.__init__).parameters)[1:] 120 | inkeys = self.hparams.keys() 121 | args1 = {} 122 | for arg in class_args: 123 | if arg in inkeys: 124 | args1[arg] = self.hparams[arg] 125 | args1.update(other_args) 126 | # print('finish instancialize') 127 | return self.data_module(**args1) -------------------------------------------------------------------------------- /train/main.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import sys 4 | sys.path.append(os.getcwd()) 5 | 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | 9 | import argparse 10 | import yaml 11 | import torch 12 | from model_interface import MInterface 13 | from data_interface import DInterface 14 | from src.tools.logger import SetupCallback,BackupCodeCallback 15 | import math 16 | from shutil import ignore_patterns 17 | 18 | import pytorch_lightning as pl 19 | from pytorch_lightning.trainer import Trainer 20 | import pytorch_lightning.callbacks as plc 21 | import pytorch_lightning.loggers as plog 22 | torch.autograd.set_detect_anomaly(True) 23 | 24 | def create_parser(): 25 | parser = argparse.ArgumentParser() 26 | # Set-up parameters 27 | parser.add_argument('--res_dir', default='./train/results', type=str) 28 | parser.add_argument('--ex_name', default='BC-Design-reproduce', type=str) 29 | parser.add_argument('--check_val_every_n_epoch', default=1, type=int) 30 | 31 | parser.add_argument('--dataset', default='CATH4.2') 32 | parser.add_argument('--model_name', default='UBC2Model') 33 | parser.add_argument('--lr', default=0.0002, type=float, help='Learning rate') 34 | parser.add_argument('--lr_scheduler', default='onecycle') 35 | parser.add_argument('--offline', default=1, type=int) 36 | parser.add_argument('--seed', default=111, type=int) 37 | 38 | # dataset parameters 39 | parser.add_argument('--batch_size', default=2, type=int) 40 | parser.add_argument('--num_workers', default=0, type=int) 41 | parser.add_argument('--pad', default=1024, type=int) 42 | parser.add_argument('--min_length', default=40, type=int) 43 | parser.add_argument('--data_root', default='./data/') 44 | 45 | # Training parameters 46 | parser.add_argument('--epoch', default=50, type=int, help='end epoch') 47 | parser.add_argument('--augment_eps', default=0.0, type=float, help='noise level') 48 | 49 | # Model parameters 50 | parser.add_argument('--use_dist', default=1, type=int) 51 | parser.add_argument('--use_product', default=0, type=int) 52 | 53 | # Checkpoint parameter 54 | parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to a checkpoint to resume training') 55 | 56 | parser.add_argument('--contrastive_pretrain', default=False, type=bool) 57 | parser.add_argument('--contrastive_learning', default=False, type=bool) 58 | parser.add_argument('--if_strucenc_only', default=False, type=bool) 59 | parser.add_argument('--if_warmup_train', default=False, type=bool) 60 | 61 | parser.add_argument('--if_struc_only', default=False, type=bool) 62 | parser.add_argument('--exp_bc_mask_rate', default=0., type=float) 63 | parser.add_argument('--bc_mask_max_rate', default=0., type=float) 64 | parser.add_argument('--exp_backbone_noise_sd', default=0., type=float) 65 | 66 | parser.add_argument('--partial_design', default=False, type=bool) 67 | parser.add_argument('--design_region_path', default='') 68 | 69 | args = parser.parse_args() 70 | return args 71 | 72 | 73 | def load_yaml_config_simple(args): 74 | """Load YAML config and update args for existing parameters only.""" 75 | yaml_path = f"./src/models/configs/{args.model_name}.yaml" 76 | 77 | if os.path.exists(yaml_path): 78 | print(f"Loading config from {yaml_path}") 79 | with open(yaml_path, 'r') as f: 80 | yaml_config = yaml.safe_load(f) 81 | 82 | # Only update args that already exist 83 | updated_params = [] 84 | for key, value in yaml_config.items(): 85 | if hasattr(args, key): 86 | setattr(args, key, value) 87 | updated_params.append(key) 88 | 89 | if updated_params: 90 | print(f"Updated parameters from YAML: {updated_params}") 91 | else: 92 | print(f"Config file {yaml_path} not found, using defaults") 93 | 94 | return args 95 | 96 | 97 | def load_callbacks(args): 98 | callbacks = [] 99 | 100 | logdir = str(os.path.join(args.res_dir, args.ex_name)) 101 | 102 | ckptdir = os.path.join(logdir, "checkpoints") 103 | 104 | callbacks.append(BackupCodeCallback(os.path.dirname(args.res_dir),logdir, ignore_patterns=ignore_patterns('results*', 'pdb*', 'metadata*', 'vq_dataset*'))) 105 | 106 | metric = "recovery" 107 | sv_filename = 'best-{epoch:02d}-{recovery:.3f}' 108 | callbacks.append(plc.ModelCheckpoint( 109 | monitor=metric, 110 | filename=sv_filename, 111 | save_top_k=15, 112 | mode='max', 113 | save_last=True, 114 | dirpath = ckptdir, 115 | verbose = True, 116 | every_n_epochs = args.check_val_every_n_epoch, 117 | )) 118 | 119 | now = datetime.datetime.now().strftime("%m-%dT%H-%M-%S") 120 | cfgdir = os.path.join(logdir, "configs") 121 | callbacks.append( 122 | SetupCallback( 123 | now = now, 124 | logdir = logdir, 125 | ckptdir = ckptdir, 126 | cfgdir = cfgdir, 127 | config = args.__dict__, 128 | argv_content = sys.argv + ["gpus: {}".format(torch.cuda.device_count())],) 129 | ) 130 | 131 | if args.lr_scheduler: 132 | callbacks.append(plc.LearningRateMonitor( 133 | logging_interval=None)) 134 | return callbacks 135 | 136 | 137 | if __name__ == "__main__": 138 | args = create_parser() 139 | 140 | # Load YAML config and update existing parameters 141 | args = load_yaml_config_simple(args) 142 | 143 | pl.seed_everything(args.seed) 144 | 145 | data_module = DInterface(**vars(args)) 146 | data_module.setup() 147 | 148 | gpu_count = torch.cuda.device_count() 149 | args.steps_per_epoch = math.ceil(len(data_module.trainset)/args.batch_size/gpu_count) 150 | print(f"steps_per_epoch {args.steps_per_epoch}, gpu_count {gpu_count}, batch_size{args.batch_size}") 151 | 152 | model = MInterface(**vars(args)) 153 | 154 | trainer_config = { 155 | 'devices': gpu_count, 156 | 'max_epochs': args.epoch, # Maximum number of epochs to train for 157 | 'num_nodes': 1, # Number of nodes to use for distributed training 158 | "strategy": 'ddp_find_unused_parameters_true', 159 | 'precision': 32, 160 | 'accelerator': 'gpu', # Use distributed data parallel 161 | 'callbacks': load_callbacks(args), 162 | 'logger': plog.WandbLogger( 163 | project = 'BC-Design', 164 | name=args.ex_name, 165 | save_dir=str(os.path.join(args.res_dir, args.ex_name)), 166 | offline = args.offline, 167 | id = "_".join(args.ex_name.split("/")),), 168 | 'gradient_clip_val':1.0 169 | } 170 | 171 | trainer_opt = argparse.Namespace(**trainer_config) 172 | trainer_dict = vars(trainer_opt) 173 | trainer = Trainer(**trainer_dict) 174 | 175 | trainer.fit(model, data_module) 176 | 177 | print(trainer_config) 178 | -------------------------------------------------------------------------------- /src/datasets/cath_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import numpy as np 5 | from tqdm import tqdm 6 | import random 7 | import torch.utils.data as data 8 | from transformers import AutoTokenizer 9 | 10 | def normalize_coordinates(surface): 11 | """ 12 | Normalize the coordinates of the surface. 13 | """ 14 | surface = np.array(surface) 15 | center = np.mean(surface, axis=0) 16 | max_ = np.max(surface, axis=0) 17 | min_ = np.min(surface, axis=0) 18 | length = np.max(max_ - min_) 19 | normalized_surface = (surface - center) / length 20 | return normalized_surface 21 | 22 | 23 | class CATHDatasetSurfProPiFoldDenseLarge(data.Dataset): 24 | def __init__(self, path='./', split='train', max_length=500, test_name='All', data=None, removeTS=0, version=4.2, bc_indices=None): 25 | self.version = version 26 | self.path = path 27 | self.mode = split 28 | self.max_length = max_length 29 | self.test_name = test_name 30 | self.removeTS = removeTS 31 | if bc_indices is None: 32 | self.bc_indices = [0, 1] 33 | else: 34 | self.bc_indices = bc_indices 35 | 36 | if self.removeTS: 37 | self.remove = json.load(open(self.path + '/remove.json', 'r'))['remove'] 38 | 39 | if data is None: 40 | self.metadata = self._load_metadata() 41 | else: 42 | self.metadata = data 43 | 44 | self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="gaozhangyang/model_zoom/transformers") 45 | 46 | # Load the entire dictionary corresponding to the current mode 47 | self.data_dict = self._load_data_dict() 48 | def _load_metadata(self): 49 | alphabet = 'ACDEFGHIKLMNPQRSTVWY' 50 | alphabet_set = set([a for a in alphabet]) 51 | metadata = [] 52 | 53 | # Load the split JSON files 54 | with open(self.path + '/chain_set_splits.json') as f: 55 | dataset_splits = json.load(f) 56 | 57 | # Handle specific test splits if needed 58 | if self.test_name == 'L100': 59 | with open(self.path + '/test_split_L100.json') as f: 60 | test_splits = json.load(f) 61 | dataset_splits['test'] = test_splits['test'] 62 | 63 | if self.test_name == 'sc': 64 | with open(self.path + '/test_split_sc.json') as f: 65 | test_splits = json.load(f) 66 | dataset_splits['test'] = test_splits['test'] 67 | 68 | # Select the appropriate split 69 | if self.mode == 'valid': 70 | valid_titles = set(dataset_splits['validation']) 71 | else: 72 | valid_titles = set(dataset_splits[self.mode]) 73 | 74 | if not os.path.exists(self.path): 75 | raise FileNotFoundError("No such file: {} !!!".format(self.path)) 76 | else: 77 | with open(self.path + '/chain_set.jsonl') as f: 78 | lines = f.readlines() 79 | for line in tqdm(lines): 80 | entry = json.loads(line) 81 | if self.removeTS and entry['name'] in self.remove: 82 | continue 83 | 84 | bad_chars = set([s for s in entry['seq']]).difference(alphabet_set) 85 | if len(bad_chars) == 0 and len(entry['seq']) <= self.max_length and entry['name'] in valid_titles: 86 | entry['coords']['CA'] = np.array(entry['coords']['CA']) 87 | entry['coords']['C'] = np.array(entry['coords']['C']) 88 | entry['coords']['O'] = np.array(entry['coords']['O']) 89 | entry['coords']['N'] = np.array(entry['coords']['N']) 90 | # create a mask representing whether the position of any value of entry['coords']['CA'] or entry['coords']['C'] or entry['coords']['O'] or entry['coords']['N'] is nan or infinite 91 | # sum them up and check if the values are inf or nan 92 | coords = np.stack([ 93 | entry['coords']['CA'], 94 | entry['coords']['C'], 95 | entry['coords']['O'], 96 | entry['coords']['N'] 97 | ], axis=1) # shape: (L, 4, 3) 98 | mask = np.isnan(coords).sum(axis=(1,2)) > 0 99 | mask = mask | (np.isinf(coords).sum(axis=(1,2)) > 0) 100 | # remove the positions where the mask is True 101 | entry['coords']['CA'] = entry['coords']['CA'][~mask] 102 | entry['coords']['C'] = entry['coords']['C'][~mask] 103 | entry['coords']['O'] = entry['coords']['O'][~mask] 104 | entry['coords']['N'] = entry['coords']['N'][~mask] 105 | idx = np.where(~mask)[0] 106 | entry['seq'] = ''.join([entry['seq'][i] for i in idx]) 107 | metadata.append({ 108 | 'title': entry['name'], 109 | 'seq_length': len(entry['seq']), 110 | 'seq': entry['seq'], 111 | 'coords': entry['coords'], 112 | }) 113 | return metadata 114 | 115 | def _load_data_dict(self): 116 | # Load the appropriate pickle file based on the mode and keep it in memory 117 | if self.mode == 'train': 118 | with open(self.path + f'/cath42_pc_train_sorted.pkl', 'rb') as f: 119 | return pickle.load(f) 120 | elif self.mode == 'valid': 121 | with open(self.path + f'/cath42_pc_validation_sorted.pkl', 'rb') as f: 122 | return pickle.load(f) 123 | elif self.mode == 'test': 124 | with open(self.path + f'/cath42_pc_test.pkl', 'rb') as f: 125 | return pickle.load(f) 126 | 127 | def __len__(self): 128 | return len(self.metadata) 129 | 130 | def _load_data_on_the_fly(self, index): 131 | entry = self.metadata[index] 132 | title = entry['title'] 133 | seq_length = entry['seq_length'] 134 | 135 | if title in self.data_dict: 136 | data = self.data_dict[title] 137 | data_entry = { 138 | 'title': title, 139 | 'seq': entry['seq'], 140 | 'CA': entry['coords']['CA'], 141 | 'C': entry['coords']['C'], 142 | 'O': entry['coords']['O'], 143 | 'N': entry['coords']['N'], 144 | 'chain_mask': np.ones(seq_length), 145 | 'chain_encoding': np.ones(seq_length), 146 | 'orig_surface': data['surface'], 147 | 'surface': normalize_coordinates(data['surface']), 148 | 'features': data['features'][:, self.bc_indices], 149 | } 150 | 151 | if self.mode == 'test': 152 | data_entry['category'] = 'Unknown' 153 | data_entry['score'] = 100.0 154 | 155 | return data_entry 156 | else: 157 | raise ValueError(f"Data for title {title} not found in the {self.mode} dictionary") 158 | 159 | def __getitem__(self, index): 160 | item = self._load_data_on_the_fly(index) 161 | L = len(item['seq']) 162 | if L > self.max_length: 163 | max_index = L - self.max_length 164 | truncate_index = random.randint(0, max_index) 165 | item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length] 166 | item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length] 167 | item['C'] = item['C'][truncate_index:truncate_index+self.max_length] 168 | item['O'] = item['O'][truncate_index:truncate_index+self.max_length] 169 | item['N'] = item['N'][truncate_index:truncate_index+self.max_length] 170 | item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length] 171 | item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length] 172 | return item 173 | 174 | 175 | -------------------------------------------------------------------------------- /src/datasets/cathafdb_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import numpy as np 5 | from tqdm import tqdm 6 | import random 7 | import torch.utils.data as data 8 | from transformers import AutoTokenizer 9 | 10 | def normalize_coordinates(surface): 11 | """ 12 | Normalize the coordinates of the surface. 13 | """ 14 | surface = np.array(surface) 15 | center = np.mean(surface, axis=0) 16 | max_ = np.max(surface, axis=0) 17 | min_ = np.min(surface, axis=0) 18 | length = np.max(max_ - min_) 19 | normalized_surface = (surface - center) / length 20 | return normalized_surface 21 | 22 | 23 | class CATHAFDBDataset(data.Dataset): 24 | def __init__(self, path_cath='./', path_afdb='./', split='train', max_length=1000, test_name='All', data=None, removeTS=0, version=4.2): 25 | self.version = version 26 | self.path_cath = path_cath 27 | self.path_afdb = path_afdb 28 | self.mode = split 29 | self.max_length = max_length 30 | self.test_name = test_name 31 | self.removeTS = removeTS 32 | 33 | if self.removeTS: 34 | self.remove = json.load(open(self.path_cath + '/remove.json', 'r'))['remove'] 35 | 36 | if data is None: 37 | self.metadata = self._load_metadata() 38 | else: 39 | self.metadata = data 40 | 41 | self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="gaozhangyang/model_zoom/transformers") 42 | 43 | # Load the entire dictionary corresponding to the current mode 44 | self.data_dict = self._load_data_dict() 45 | def _load_metadata(self): 46 | alphabet = 'ACDEFGHIKLMNPQRSTVWY' 47 | alphabet_set = set([a for a in alphabet]) 48 | metadata = [] 49 | 50 | # Load the split JSON files 51 | with open(self.path_cath + '/chain_set_splits.json') as f: 52 | dataset_splits = json.load(f) 53 | 54 | # Handle specific test splits if needed 55 | if self.test_name == 'L100': 56 | with open(self.path_cath + '/test_split_L100.json') as f: 57 | test_splits = json.load(f) 58 | dataset_splits['test'] = test_splits['test'] 59 | 60 | if self.test_name == 'sc': 61 | with open(self.path_cath + '/test_split_sc.json') as f: 62 | test_splits = json.load(f) 63 | dataset_splits['test'] = test_splits['test'] 64 | 65 | # Select the appropriate split 66 | if self.mode == 'valid': 67 | valid_titles = set(dataset_splits['validation']) 68 | else: 69 | valid_titles = set(dataset_splits[self.mode]) 70 | 71 | if not os.path.exists(self.path_cath): 72 | raise FileNotFoundError("No such file: {} !!!".format(self.path_cath)) 73 | else: 74 | with open(self.path_cath + '/chain_set.jsonl') as f: 75 | lines = f.readlines() 76 | for line in tqdm(lines): 77 | entry = json.loads(line) 78 | if self.removeTS and entry['name'] in self.remove: 79 | continue 80 | 81 | bad_chars = set([s for s in entry['seq']]).difference(alphabet_set) 82 | if len(bad_chars) == 0 and len(entry['seq']) <= self.max_length and entry['name'] in valid_titles: 83 | entry['coords']['CA'] = np.array(entry['coords']['CA']) 84 | entry['coords']['C'] = np.array(entry['coords']['C']) 85 | entry['coords']['O'] = np.array(entry['coords']['O']) 86 | entry['coords']['N'] = np.array(entry['coords']['N']) 87 | # create a mask representing whether the position of any value of entry['coords']['CA'] or entry['coords']['C'] or entry['coords']['O'] or entry['coords']['N'] is nan or infinite 88 | # sum them up and check if the values are inf or nan 89 | coords = np.stack([ 90 | entry['coords']['CA'], 91 | entry['coords']['C'], 92 | entry['coords']['O'], 93 | entry['coords']['N'] 94 | ], axis=1) # shape: (L, 4, 3) 95 | mask = np.isnan(coords).sum(axis=(1,2)) > 0 96 | mask = mask | (np.isinf(coords).sum(axis=(1,2)) > 0) 97 | # remove the positions where the mask is True 98 | entry['coords']['CA'] = entry['coords']['CA'][~mask] 99 | entry['coords']['C'] = entry['coords']['C'][~mask] 100 | entry['coords']['O'] = entry['coords']['O'][~mask] 101 | entry['coords']['N'] = entry['coords']['N'][~mask] 102 | idx = np.where(~mask)[0] 103 | entry['seq'] = ''.join([entry['seq'][i] for i in idx]) 104 | metadata.append({ 105 | 'title': entry['name'], 106 | 'seq_length': len(entry['seq']), 107 | 'seq': entry['seq'], 108 | 'coords': entry['coords'], 109 | }) 110 | 111 | if self.mode == 'train': 112 | if not os.path.exists(self.path_afdb): 113 | raise "no such file:{} !!!".format(self.path_afdb) 114 | else: 115 | afdb_data = json.load(open(self.path_afdb+'/afdb-large4000.json')) 116 | 117 | for temp in tqdm(afdb_data): 118 | title = temp['name'] 119 | seq_length = len(temp['seq']) 120 | coords = np.array(temp['coords']) 121 | coords_dict = { 122 | 'CA': coords[:,1,:], 123 | 'C': coords[:,2,:], 124 | 'O': coords[:,3,:], 125 | 'N': coords[:,0,:], 126 | } 127 | metadata.append({'title':title, 128 | 'seq':temp['seq'], 129 | 'seq_length': seq_length, 130 | 'coords': coords_dict, 131 | }) 132 | return metadata 133 | 134 | def _load_data_dict(self): 135 | # Load the appropriate pickle file based on the mode and keep it in memory 136 | def _downcast_float32_inplace(d): 137 | # Convert numeric arrays to float32 to reduce memory 138 | for _title, item in d.items(): 139 | if not isinstance(item, dict): 140 | continue 141 | if 'surface' in item: 142 | item['surface'] = np.asarray(item['surface'], dtype=np.float32) 143 | if 'features' in item: 144 | item['features'] = np.asarray(item['features'], dtype=np.float32) 145 | 146 | if self.mode == 'train': 147 | with open(self.path_cath + f'/cath42_pc_train_sorted.pkl', 'rb') as f: 148 | data_dict_cath = pickle.load(f) 149 | _downcast_float32_inplace(data_dict_cath) 150 | with open(self.path_afdb + f'/afdb-large4000.pkl', 'rb') as f: 151 | data_dict_afdb = pickle.load(f) 152 | _downcast_float32_inplace(data_dict_afdb) 153 | data_dict = {**data_dict_cath, **data_dict_afdb} 154 | return data_dict 155 | elif self.mode == 'valid': 156 | with open(self.path_cath + f'/cath42_pc_validation_sorted.pkl', 'rb') as f: 157 | data_dict = pickle.load(f) 158 | _downcast_float32_inplace(data_dict) 159 | return data_dict 160 | elif self.mode == 'test': 161 | with open(self.path_cath + f'/cath42_pc_test.pkl', 'rb') as f: 162 | data_dict = pickle.load(f) 163 | _downcast_float32_inplace(data_dict) 164 | return data_dict 165 | 166 | def __len__(self): 167 | return len(self.metadata) 168 | 169 | def _load_data_on_the_fly(self, index): 170 | entry = self.metadata[index] 171 | title = entry['title'] 172 | seq_length = entry['seq_length'] 173 | 174 | if title in self.data_dict: 175 | data = self.data_dict[title] 176 | data_entry = { 177 | 'title': title, 178 | 'seq': entry['seq'], 179 | 'CA': entry['coords']['CA'], 180 | 'C': entry['coords']['C'], 181 | 'O': entry['coords']['O'], 182 | 'N': entry['coords']['N'], 183 | 'chain_mask': np.ones(seq_length), 184 | 'chain_encoding': np.ones(seq_length), 185 | 'orig_surface': data['surface'], 186 | 'surface': normalize_coordinates(data['surface']), 187 | 'features': data['features'][:, :2], 188 | } 189 | 190 | if self.mode == 'test': 191 | data_entry['category'] = 'Unknown' 192 | data_entry['score'] = 100.0 193 | 194 | return data_entry 195 | else: 196 | raise ValueError(f"Data for title {title} not found in the {self.mode} dictionary") 197 | 198 | def __getitem__(self, index): 199 | item = self._load_data_on_the_fly(index) 200 | L = len(item['seq']) 201 | if L > self.max_length: 202 | max_index = L - self.max_length 203 | truncate_index = random.randint(0, max_index) 204 | item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length] 205 | item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length] 206 | item['C'] = item['C'][truncate_index:truncate_index+self.max_length] 207 | item['O'] = item['O'][truncate_index:truncate_index+self.max_length] 208 | item['N'] = item['N'][truncate_index:truncate_index+self.max_length] 209 | item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length] 210 | item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length] 211 | return item -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BC-Design: A Biochemistry-Aware Framework for Inverse Protein Design 2 |

3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |

11 | 12 | 16 | 17 | This repository contains the implementation code for the paper: 18 | 19 | [**BC-Design: A Biochemistry-Aware Framework for Inverse Protein Design**] 20 | 21 | Xiangru Tang, Xinwu Ye, Fang Wu, Yimeng Liu, Anna Su, Antonia Panescu, Guanlue Li, Daniel Shao, Dong Xu, and Mark Gerstein*. 22 | 23 | Equal contribution 24 | 25 | 26 | 27 | 28 | ![image](./assets/BC-Design.png) 29 | 30 | 42 | 43 | ## Overview 44 | 45 |
46 | Code Structures 47 | 48 | ![image](./assets/BC-Design-overview.png) 49 | 50 | - `src/datasets` contains datasets, featurizer, and utils 51 | - `src/interface` contains customized Pytorch-lightning data modules and modules. 52 | - `src/models/` contains the main BC-Design model architecture. 53 | - `src/tools` contains some script files of some tools. 54 | - `train` contains the training and inference script. 55 | 56 |
57 | 58 | ## News and Updates 59 | 60 | - [🆕 2025-11-23] Major updates: 61 | - Implemented a complete **backbone-only inference pipeline**. 62 | - Added **partial-information testing** script with controllable biochemical-feature masking (0–100% masking), enabling tunable recovery–diversity trade-offs. 63 | - Added full **PDB preprocessing utilities** (`pdb2jsonpkl.py`) to convert arbitrary protein structures into the BC-Design input format. 64 | - Cleaned and consolidated training/evaluation code, environment files, and documentation. 65 | - [🚀 2024-10-30] The official code is released. 66 | 67 | 68 | ## ⚙️ Installation 69 | 70 | This section guides you through setting up the necessary environment and dependencies to run BC-Design. 71 | 72 | ### Step 1: Prerequisites - CUDA and GCC 73 | 74 | Before creating the Conda environment, please ensure your system meets the following requirements. While other versions might also work, our code was developed and tested using the specific versions listed below: 75 | 76 | 1. **CUDA Version:** This codebase has been validated on **CUDA 12.8 with NVIDIA driver 570.133.20**, so running on that (or an equivalent, compatible setup) is recommended. 77 | 2. **GCC Compiler:** A C/C++ compiler is needed, specifically **GCC version 12.2.0** or a compatible version. This codebase has been validated on **GCC version 12.2.0**. 78 | * **Linux:** You can typically install GCC using your system's package manager. For example, on Debian/Ubuntu-based systems, you might use: 79 | ```shell 80 | sudo apt update 81 | sudo apt install gcc-12 g++-12 82 | ``` 83 | On other distributions, use the appropriate package manager (e.g., `yum`, `dnf`). You may need to configure your system to use this specific version if multiple GCC versions are installed. 84 | * **HPC Environments:** If you are using a High-Performance Computing (HPC) cluster, GCC is often managed via environment modules. You might load it using a command like: 85 | ```shell 86 | module load GCC/12.2.0 87 | ``` 88 | (The exact command may vary based on your HPC's module system.) 89 | * **Other Systems (macOS, Windows via WSL2):** Ensure you have a compatible C/C++ compiler. For macOS, Xcode Command Line Tools provide Clang, which is often compatible. For Windows, WSL2 with a Linux distribution is recommended. 90 | 4. **Reference OS:** Development and testing took place on **Red Hat Enterprise Linux 8.10 (Ootpa)**. Other modern Linux distributions should work fine as long as the CUDA/GCC requirements above are satisfied. 91 | 92 | ### Step 2: Create Conda Environment 93 | 94 | This project has provided an environment setting file for **Miniconda3**. Users can easily reproduce the Python environment by following these commands: 95 | 96 | ```shell 97 | git clone https://github.com/gersteinlab/BC-Design.git 98 | cd BC-Design 99 | conda env create -f environment.yml -n [your-env-name] 100 | conda activate [your-env-name] 101 | ```` 102 | 103 | Replace `[your-env-name]` with your preferred name for the Conda environment (e.g., `bcdn`). 104 | 105 | ### Step 3: Download Data and Model Checkpoint 106 | 107 | To train the model, you need to download the preprocessed data. 108 | To test with the released model weights, you should also download the checkpoint. 109 | 110 | 1. Navigate to the Hugging Face project page: [https://huggingface.co/datasets/XinwuYe/BC-Design/tree/main](https://huggingface.co/datasets/XinwuYe/BC-Design/tree/main) 111 | 2. Download the following files into the `BC-Design` folder (the main directory cloned from GitHub): 112 | * `data.zip` (contains data for training and testing) 113 | * `UBC2Model.ckpt` (the checkpoint for testing, download it only when you want to test with the releases model weights) 114 | 3. Once downloaded, unzip the data file: 115 | ```shell 116 | unzip data.zip 117 | ``` 118 | This should create a `data/` directory inside your `BC-Design` folder. 119 | 120 | As an alternative, you can also run the following commands: 121 | ```shell 122 | wget https://huggingface.co/datasets/XinwuYe/BC-Design/resolve/main/data.zip?download=true -O data.zip 123 | unzip data.zip 124 | wget "https://huggingface.co/datasets/XinwuYe/BC-Design/resolve/main/UBC2Model.ckpt?download=true" -O UBC2Model.ckpt 125 | ```` 126 | 127 | After completing these steps, your environment should be ready, and you'll have the necessary data (and model checkpoint) to proceed with using BC-Design. 128 | 129 | 130 | ## Getting Started 131 | 132 | ### Evaluate on CATH 4.2: 133 | 134 | The `train/main_eval.py` script is used to evaluate the trained BC-Design model on test datasets. It loads the specified dataset and the model checkpoint (`UBC2Model.ckpt` by default) to perform inference and report evaluation metrics. 135 | 136 | Note: `train/main_eval.py` computes structure-level metrics via ESMFold. For very large proteins, ESMFold may run out of GPU memory and fall back to CPU-based structure prediction, which significantly increases runtime. The commands below include rough runtime estimates; TS50 is the fastest dataset to reproduce the evaluation. 137 | 138 | To test on the test set of CATH4.2: 139 | 140 | ```shell 141 | python train/main_eval.py --dataset CATH4.2 # ~3.5 hours on 1 A100 GPU 142 | # Expected output: many metrics 143 | ``` 144 | 145 | 146 | To test on TS50, TS500, or AFDB2000: 147 | ```shell 148 | python train/main_eval.py --dataset TS50 # ~2 mins on 1 A100 GPU 149 | python train/main_eval.py --dataset TS500 # ~9 hours on 1 A100 GPU 150 | python train/main_eval.py --dataset AFDB2000 151 | ``` 152 | 153 | **Testing in backbone-only setting:** 154 | 155 | BC-Design now includes a complete **structure-only inference mode**, 156 | which uses *only* backbone coordinates as input and excludes all biochemical features. 157 | 158 | ```shell 159 | python train/main_eval.py --if_struc_only True --dataset [dataset-name] 160 | ``` 161 | 162 | **Testing in partial-information setting:** 163 | 164 | BC-Design supports biochemical-feature masking, enabling controlled removal of biochemical information at inference time. 165 | 166 | Example (mask 60% of biochemical feature points): 167 | ```shell 168 | python train/main_eval.py --exp_bc_mask_rate 0.6 --dataset [dataset-name] # mask 60% of biochemical features in the input 169 | ``` 170 | This mechanism allows users to reproduce intermediate recovery–diversity trade-offs. 171 | 172 | **Key functionalities of `main_eval.py`:** 173 | - **Dataset Selection:** You can specify the dataset for evaluation using the `--dataset` argument (e.g., `CATH4.2`, `TS50`, `TS500`, `AFDB2000`). 174 | - **Checkpoint Loading:** It loads a pre-trained model from the path specified by `--checkpoint_path` (defaults to `./UBC2Model.ckpt`). 175 | - **Evaluation Metrics:** The script calculates and displays various performance metrics such as test loss, sequence recovery, perplexity, pLDDT, and TM-score. 176 | - **Configurable Parameters:** Several aspects of the evaluation can be configured through command-line arguments, including: 177 | * `--res_dir`: Directory to store results. 178 | * `--batch_size`: Batch size for evaluation. 179 | * `--data_root`: Root directory of the dataset. 180 | * `--num_workers`: Number of workers for data loading. 181 | * For a full list of arguments and their default values, you can refer to the `create_parser()` function within the `train/main_eval.py` script. 182 | 183 | The predicted protein sequences will be saved under `predicted_pdb/[ex_name]/[dataset]`. 184 | 185 | ### Training Model 186 | 187 | Run the following commamds to reproduce training BC-Design on the CATH 4.2 training set. The model checkpoint will be saved as `./train/results/UBC2ModelReproduced/checkpoints/last.ckpt`. 188 | 189 | ```shell 190 | python train/main.py \ 191 | --lr 0.001 \ 192 | --if_strucenc_only True \ 193 | --ex_name UBC2ModelStage1 # stage 1 194 | 195 | python train/main.py \ 196 | --lr 0.0005 \ 197 | --contrastive_learning True \ 198 | --contrastive_pretrain True \ 199 | --checkpoint_path "./train/results/UBC2ModelStage1/checkpoints/last.ckpt" \ 200 | --ex_name UBC2ModelStage2 # stage 2 201 | 202 | python train/main.py \ 203 | --lr 0.0005 \ 204 | --if_warmup_train True \ 205 | --checkpoint_path "./train/results/UBC2ModelStage2/checkpoints/last.ckpt" \ 206 | --ex_name UBC2ModelStage3 # stage 3 207 | 208 | python train/main.py \ 209 | --lr 0.00002 \ 210 | --lr_scheduler cosine \ 211 | --bc_mask_max_rate 3.0 \ 212 | --checkpoint_path "./train/results/UBC2ModelStage3/checkpoints/last.ckpt" \ 213 | --ex_name UBC2ModelReproduced # stage 4 214 | ``` 215 | 216 | ### Data Preparation 217 | 218 | If you’d like to use BC-Design on your own data, run this command to convert your .pdb files into the format BC-Design expects: 219 | ```shell 220 | python pdb2jsonpkl.py --pdb_folder [dir-of-pdb-files] --dataset_name [dataset-name] 221 | ``` 222 | After running it, the processed data will be saved in `.data/[dataset-name]`, and the `[dataset-name]` can be used directly as the `dataset` argument for `train/main_eval.py`. 223 | 224 |

(back to top)

225 | 226 | ## License 227 | 228 | This project is released under the [Apache 2.0 license](LICENSE). See `LICENSE` for more information. 229 | 230 | 231 | 232 | 233 | ## Contribution and Contact 234 | 235 | For adding new features, looking for helps, or reporting bugs associated with `BC-Design`, please open a [GitHub issue](https://github.com/gersteinlab/BC-Design/issues) and [pull request](https://github.com/gersteinlab/BC-Design/pulls) with the tag "new features", "help wanted", or "enhancement". Please ensure that all pull requests meet the requirements outlined in our [contribution guidelines](https://github.com/gersteinlab/BC-Design/blob/public-release/CONTRIBUTING.md). Following these guidelines helps streamline the review process and maintain code quality across the project. 236 | Feel free to contact us through email if you have any questions. 237 | 238 | 239 |

(back to top)

240 | -------------------------------------------------------------------------------- /pdb2jsonpkl.py: -------------------------------------------------------------------------------- 1 | from Bio import PDB 2 | import os 3 | import json 4 | import numpy as np 5 | import pickle 6 | import torch 7 | from tqdm import tqdm 8 | from Bio.PDB import PDBParser, Structure, Model, Chain, Residue, Atom 9 | from Bio.PDB.ResidueDepth import get_surface 10 | from scipy.spatial import cKDTree, Delaunay 11 | from Bio.SeqUtils import seq1 12 | from Bio.PDB.Polypeptide import is_aa 13 | from Bio.PDB.PDBExceptions import PDBConstructionWarning 14 | import warnings 15 | import argparse 16 | 17 | # Suppress PDBConstructionWarning 18 | warnings.simplefilter('ignore', PDBConstructionWarning) 19 | 20 | # Define MSMS executable path 21 | msms_exec = '/gpfs/gibbs/pi/gerstein/xt86/surface/msms/msms.x86_64Linux2.2.6.1' # replace with your own path 22 | os.chmod(msms_exec, 0o755) 23 | 24 | # Define the biochemical features dictionary 25 | bio_feat_dict = { 26 | "hydrophobicity": { 27 | "I": 4.5, "V": 4.2, "L": 3.8, "F": 2.8, "C": 2.5, "M": 1.9, "A": 1.8, 28 | "W": -0.9, "G": -0.4, "T": -0.7, "S": -0.8, "Y": -1.3, "P": -1.6, "H": -3.2, 29 | "N": -3.5, "D": -3.5, "Q": -3.5, "E": -3.5, "K": -3.9, "R": -4.5 30 | }, 31 | "charge": { 32 | "R": 1, "K": 1, "D": -1, "E": -1, "H": 0.1, "A": 0, "C": 0, "F": 0, "G": 0, "I": 0, 33 | "L": 0, "M": 0, "N": 0, "P": 0, "Q": 0, "S": 0, "T": 0, "V": 0, "W": 0, "Y": 0 34 | }, 35 | "polarity": { 36 | "R": 1, "N": 1, "D": 1, "Q": 1, "E": 1, "H": 1, "K": 1, "S": 1, "T": 1, "Y": 1, 37 | "A": 0, "C": 0, "F": 0, "G": 0, "I": 0, "L": 0, "M": 0, "P": 0, "V": 0, "W": 0 38 | }, 39 | "acceptor": { 40 | "D": 1, "E": 1, "N": 1, "Q": 1, "H": 1, "S": 1, "T": 1, "Y": 1, 41 | "A": 0, "C": 0, "F": 0, "G": 0, "I": 0, "K": 0, "L": 0, "M": 0, "P": 0, "R": 0, "V": 0, "W": 0 42 | }, 43 | "donor": { 44 | "R": 1, "K": 1, "W": 1, "N": 1, "Q": 1, "H": 1, "S": 1, "T": 1, "Y": 1, 45 | "A": 0, "C": 0, "D": 0, "E": 0, "F": 0, "G": 0, "I": 0, "L": 0, "M": 0, "P": 0, "V": 0 46 | } 47 | } 48 | 49 | # Mapping from three-letter codes to one-letter codes 50 | three_to_one = { 51 | "ALA": "A", "CYS": "C", "ASP": "D", "GLU": "E", "PHE": "F", "GLY": "G", 52 | "HIS": "H", "ILE": "I", "LYS": "K", "LEU": "L", "MET": "M", "ASN": "N", 53 | "PRO": "P", "GLN": "Q", "ARG": "R", "SER": "S", "THR": "T", "VAL": "V", 54 | "TRP": "W", "TYR": "Y" 55 | } 56 | 57 | 58 | def parse_pdb(file_path): 59 | name = os.path.basename(file_path).replace('.pdb', '') 60 | parser = PDB.PDBParser(QUIET=True) 61 | structure = parser.get_structure(name, file_path) 62 | 63 | seq = '' 64 | coords = [] 65 | 66 | for model in structure: 67 | for chain in model: 68 | seq += ''.join([three_to_one[res.get_resname()] for res in chain if res.get_id()[0] == ' ']) 69 | atom_names = ['N', 'CA', 'C', 'O'] 70 | 71 | for res in chain: 72 | if res.get_resname() in three_to_one.keys(): 73 | coord_dict = {atom.get_name(): atom.get_coord().tolist() for atom in res if atom.get_name() in atom_names} 74 | if all(atom in coord_dict for atom in atom_names): # Ensure all atoms are present 75 | temp_coords = [coord_dict[atom] for atom in atom_names] 76 | if len(temp_coords) == 4: # Collect 4 sets of coordinates 77 | coords.append(temp_coords) 78 | 79 | return {'name': name, 'seq': seq, 'coords': coords} 80 | 81 | 82 | # Step 1: Create PDB structure from protein dict 83 | def create_pdb_structure(protein_data): 84 | structure_id = protein_data['name'] 85 | sequence = protein_data['seq'] 86 | coords = protein_data['coords'] 87 | 88 | structure = Structure.Structure(structure_id) 89 | model = Model.Model(0) 90 | chain = Chain.Chain('A') 91 | 92 | aa_map = {'A': 'ALA', 'C': 'CYS', 'D': 'ASP', 'E': 'GLU', 'F': 'PHE', 'G': 'GLY', 'H': 'HIS', 93 | 'I': 'ILE', 'K': 'LYS', 'L': 'LEU', 'M': 'MET', 'N': 'ASN', 'P': 'PRO', 'Q': 'GLN', 94 | 'R': 'ARG', 'S': 'SER', 'T': 'THR', 'V': 'VAL', 'W': 'TRP', 'Y': 'TYR'} 95 | 96 | atom_names = ['N', 'CA', 'C', 'O'] 97 | 98 | for res_index, (res, coord_set) in enumerate(zip(sequence, coords), start=1): 99 | residue = Residue.Residue((' ', res_index, ' '), aa_map[res], ' ') 100 | for atom_index, (atom_name, coord) in enumerate(zip(atom_names, coord_set)): 101 | atom = Atom.Atom(atom_name, coord, 1.0, 0.0, ' ', atom_name, atom_index, atom_name[0]) 102 | residue.add(atom) 103 | chain.add(residue) 104 | 105 | model.add(chain) 106 | structure.add(model) 107 | return structure 108 | 109 | # Step 2: Feature assignment 110 | # Function to get atom coordinates and residue types 111 | def get_atom_coords_and_residues(structure): 112 | coords = [] 113 | residues = [] 114 | for model in structure: 115 | for chain in model: 116 | for residue in chain: 117 | for atom in residue: 118 | coords.append(atom.coord) 119 | residues.append(three_to_one.get(residue.get_resname(), '')) 120 | return np.array(coords), residues 121 | 122 | 123 | # Process each PDB file in the input directory 124 | def assign_features(surface, structure): 125 | atom_coords, residue_types = get_atom_coords_and_residues(structure) 126 | 127 | # Build k-D tree for atom coordinates 128 | kdtree = cKDTree(atom_coords) 129 | 130 | # Assign biochemical features to each vertex in the surface 131 | features = [] 132 | for vertex in surface: 133 | dist, idx = kdtree.query(vertex) 134 | residue_type = residue_types[idx] 135 | residue_features = [bio_feat_dict[feat].get(residue_type, 0) for feat in bio_feat_dict] 136 | features.append(residue_features) 137 | 138 | # Convert features to a numpy array 139 | features_array = np.array(features) 140 | 141 | return features_array 142 | 143 | 144 | # Step 3: Smooth the surface 145 | # Function to perform Gaussian kernel smoothing on all points using PyTorch 146 | def gaussian_kernel_smoothing(coords, k=8, eta=None): 147 | # print(coords.shape) 148 | if len(coords) > 20000: 149 | # Generate random permutation of indices 150 | indices = torch.randperm(len(coords))[:20000] 151 | # Select the random indices along the 0-th axis 152 | coords = coords[indices] 153 | # Convert numpy array to PyTorch tensor and move to GPU 154 | coords = torch.tensor(coords, dtype=torch.float32).cuda(0) 155 | 156 | # Compute the full pairwise distance matrix 157 | dists = torch.cdist(coords, coords, p=2) 158 | 159 | if eta is None: 160 | eta = torch.max(dists).item() 161 | 162 | nearest_neighbors = torch.argsort(dists, dim=1)[:, 1:k+1] 163 | 164 | # Get the distances of the k-nearest neighbors 165 | nearest_dists = torch.gather(dists, 1, nearest_neighbors) 166 | 167 | # Compute weights using the Gaussian kernel 168 | weights = torch.exp(-nearest_dists**2 / eta) 169 | weights /= torch.sum(weights, dim=1, keepdim=True) 170 | 171 | # Compute the smoothed coordinates 172 | smoothed_coords = torch.sum(weights[:, :, None] * coords[nearest_neighbors], dim=1) 173 | 174 | return smoothed_coords.cpu().numpy() 175 | 176 | 177 | # Step 4: Compress the surface and features using octree-based compression 178 | class OctreeNode: 179 | def __init__(self, points, indices): 180 | self.points = points 181 | self.indices = indices 182 | self.children = [] 183 | 184 | def create_octree(points, indices, min_points_per_cube): 185 | """ 186 | Create an octree for the given points. 187 | """ 188 | def divide(points, indices): 189 | if len(points) <= min_points_per_cube: 190 | return OctreeNode(points, indices) 191 | centroid = np.mean(points, axis=0) 192 | partitions = [[] for _ in range(8)] 193 | partition_indices = [[] for _ in range(8)] 194 | for idx, point in enumerate(points): 195 | partition_index = 0 196 | if point[0] > centroid[0]: 197 | partition_index += 1 198 | if point[1] > centroid[1]: 199 | partition_index += 2 200 | if point[2] > centroid[2]: 201 | partition_index += 4 202 | partitions[partition_index].append(point) 203 | partition_indices[partition_index].append(indices[idx]) 204 | node = OctreeNode(None, None) 205 | node.children = [divide(part, part_idx) for part, part_idx in zip(partitions, partition_indices)] 206 | return node 207 | 208 | return divide(points, indices) 209 | 210 | def gather_points(node): 211 | """ 212 | Gather points and indices from the octree. 213 | """ 214 | if node.points is not None: 215 | return [(node.points, node.indices)] 216 | result = [] 217 | for child in node.children: 218 | result.extend(gather_points(child)) 219 | return result 220 | 221 | def compress_surface(points, features, down_sample_ratio, min_points_per_cube=32): 222 | """ 223 | Compress the surface and features using octree-based compression. 224 | """ 225 | indices = np.arange(points.shape[0]) 226 | octree = create_octree(points, indices, min_points_per_cube) 227 | compressed_points = [] 228 | compressed_features = [] 229 | for cube_points, cube_indices in gather_points(octree): 230 | local_density = len(cube_points) 231 | num_points = int(local_density * down_sample_ratio) 232 | if num_points > 0: 233 | sampled_indices = np.random.choice(local_density, num_points, replace=False) 234 | sampled_points = np.array(cube_points)[sampled_indices] 235 | sampled_features = features[cube_indices][sampled_indices] 236 | compressed_points.extend(sampled_points) 237 | compressed_features.extend(sampled_features) 238 | return np.array(compressed_points), np.array(compressed_features) 239 | 240 | 241 | # Step 5: Add interior points 242 | # Function to get biochemical features from a residue 243 | def get_biochem_features(residue): 244 | # Define hydrophobicity scale (Kyte-Doolittle) 245 | hydrophobicity_scale = { 246 | 'A': 1.8, 'C': 2.5, 'D': -3.5, 'E': -3.5, 'F': 2.8, 247 | 'G': -0.4, 'H': -3.2, 'I': 4.5, 'K': -3.9, 'L': 3.8, 248 | 'M': 1.9, 'N': -3.5, 'P': -1.6, 'Q': -3.5, 'R': -4.5, 249 | 'S': -0.8, 'T': -0.7, 'V': 4.2, 'W': -0.9, 'Y': -1.3 250 | } 251 | 252 | # Define charge scale 253 | charge_scale = { 254 | 'D': -1, 'E': -1, 'K': 1, 'R': 1, 'H': 0.1 # Histidine is partially charged 255 | } 256 | 257 | # Define polarity, acceptor, and donor features as shown in the image 258 | polarity_scale = {'R': 1, 'N': 1, 'D': 1, 'Q': 1, 'E': 1, 'H': 1, 'K': 1, 'S': 1, 'T': 1, 'Y': 1} 259 | acceptor_scale = {'D': 1, 'E': 1, 'N': 1, 'Q': 1, 'H': 1, 'S': 1, 'T': 1, 'Y': 1} 260 | donor_scale = {'R': 1, 'K': 1, 'W': 1, 'N': 1, 'Q': 1, 'H': 1, 'S': 1, 'T': 1, 'Y': 1} 261 | 262 | res_3letter = residue.get_resname() # Get the three-letter code 263 | res_1letter = seq1(res_3letter) # Convert to one-letter code 264 | hydrophobicity = hydrophobicity_scale.get(res_1letter, 0) 265 | charge = charge_scale.get(res_1letter, 0) 266 | polarity = polarity_scale.get(res_1letter, 0) 267 | acceptor = acceptor_scale.get(res_1letter, 0) 268 | donor = donor_scale.get(res_1letter, 0) 269 | return np.array([hydrophobicity, charge, polarity, acceptor, donor]) 270 | 271 | 272 | def add_interior_points(surface_points, surface_features, structure): 273 | # Extract residue info and calculate biochemical features 274 | coords = [] 275 | features = [] 276 | 277 | for model in structure: 278 | for chain in model: 279 | for residue in chain: 280 | if is_aa(residue) and 'CA' in residue: 281 | res_coord = residue['CA'].get_coord() # Get alpha carbon coordinates 282 | res_features = get_biochem_features(residue) 283 | 284 | coords.append(res_coord) 285 | features.append(res_features) 286 | 287 | coords = np.array(coords) 288 | features = np.array(features) 289 | 290 | # Build a KDTree for fast nearest-neighbor search 291 | kdtree = cKDTree(coords) 292 | 293 | # Generate random points inside the surface 294 | min_coords = surface_points.min(axis=0) 295 | max_coords = surface_points.max(axis=0) 296 | num_samples = 5000 297 | random_points = np.random.uniform(min_coords, max_coords, (num_samples, 3)) 298 | 299 | tri = Delaunay(surface_points) 300 | 301 | def is_inside(point, tri): 302 | return tri.find_simplex(point) >= 0 303 | 304 | inside_points = np.array([p for p in random_points if is_inside(p, tri)]) 305 | 306 | # Assign biochemical features to random points based on nearest residue 307 | _, idx = kdtree.query(inside_points) 308 | inside_features = features[idx] 309 | 310 | # Concatenate surface and inside points and features 311 | new_surface = np.concatenate([surface_points, inside_points], axis=0) 312 | new_features = np.concatenate([surface_features, inside_features], axis=0) 313 | 314 | return new_surface, new_features 315 | 316 | 317 | # Step 6: Sample 318 | def sample_if_needed(data_dict, max_length=5000): 319 | for key, value in data_dict.items(): 320 | surface = value['surface'] 321 | features = value['features'] 322 | 323 | if len(surface) > max_length: 324 | indices = np.random.choice(len(surface), max_length, replace=False) 325 | value['surface'] = surface[indices] 326 | value['features'] = features[indices] 327 | 328 | return data_dict 329 | 330 | 331 | # Main function to run the pipeline 332 | def main(dataset='afdb2000'): 333 | input_json_path = f'data/{dataset}/{dataset}.json' 334 | output_pkl_path = f'data/{dataset}/{dataset}.pkl' 335 | 336 | # Ensure the output directory exists 337 | output_dir = os.path.dirname(output_pkl_path) 338 | os.makedirs(output_dir, exist_ok=True) 339 | 340 | with open(input_json_path, 'r') as f: 341 | protein_dicts = json.load(f) 342 | 343 | combined_data = {} 344 | for protein_data in tqdm(protein_dicts, desc="Processing proteins"): 345 | structure = create_pdb_structure(protein_data) 346 | try: 347 | surface = get_surface(structure[0], MSMS=msms_exec) 348 | except Exception as e: 349 | print(f"Failed to generate surface for {protein_data['name']}: {e}") 350 | continue 351 | features = assign_features(surface, structure) 352 | # Step 3: Smooth the surface 353 | smoothed_surface = gaussian_kernel_smoothing(surface) 354 | # Step 4: Compress the surface and features using octree-based compression 355 | if len(smoothed_surface) > 5000: 356 | down_sample_ratio = 5000 / len(smoothed_surface) 357 | compressed_points, compressed_features = compress_surface(smoothed_surface, features, down_sample_ratio) 358 | else: 359 | compressed_points, compressed_features = smoothed_surface, features # No down-sampling 360 | # Step 5: Add interior points 361 | final_surface, final_features = add_interior_points(compressed_points, compressed_features, structure) 362 | 363 | combined_data[protein_data['name']] = { 364 | 'surface': final_surface, 365 | 'features': final_features, 366 | 'seq': protein_data['seq'] 367 | } 368 | 369 | combined_data = sample_if_needed(combined_data) 370 | 371 | # Save the final data into a .pkl file 372 | with open(output_pkl_path, 'wb') as f: 373 | pickle.dump(combined_data, f) 374 | return combined_data 375 | 376 | 377 | if __name__ == "__main__": 378 | parser = argparse.ArgumentParser(description="Convert PDB files to JSON/PKL datasets.") 379 | parser.add_argument( 380 | "--pdb_folder", 381 | type=str, 382 | default=None, 383 | help="Directory containing the PDB files to process." 384 | ) 385 | parser.add_argument( 386 | "--dataset_name", 387 | type=str, 388 | default=None, 389 | help="Name of the output dataset (defaults to the pdb_folder name)." 390 | ) 391 | args = parser.parse_args() 392 | 393 | if not args.pdb_folder: 394 | raise ValueError("Please provide --pdb_folder pointing to the directory with PDB files.") 395 | 396 | pdb_folder = args.pdb_folder 397 | dataset_name = args.dataset_name or os.path.basename(os.path.normpath(pdb_folder)) 398 | 399 | pdb_files = [f for f in os.listdir(pdb_folder) if f.endswith('.pdb')] 400 | data = [] 401 | 402 | print("--- Creating initial dataset ---") 403 | for pdb_file in tqdm(pdb_files, desc="Parsing PDBs"): 404 | predicted_path = os.path.join(pdb_folder, pdb_file) 405 | 406 | combined_data = parse_pdb(predicted_path) 407 | 408 | if combined_data: 409 | data.append(combined_data) 410 | 411 | output_data_dir = os.path.join('./data', dataset_name) 412 | os.makedirs(output_data_dir, exist_ok=True) 413 | json_output_path = os.path.join(output_data_dir, dataset_name + '.json') 414 | 415 | with open(json_output_path, 'w') as json_file: 416 | json.dump(data, json_file, indent=4) 417 | print(f"\nInitial JSON saved to: {json_output_path}") 418 | 419 | main(dataset_name) 420 | -------------------------------------------------------------------------------- /src/datasets/featurizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | from torch_geometric.nn.pool import knn_graph 5 | from torch_scatter import scatter_sum 6 | from transformers import AutoTokenizer 7 | from sklearn.neighbors import NearestNeighbors 8 | from src.tools import Rigid, Rotation, get_interact_feats 9 | import copy 10 | import json 11 | 12 | tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="gaozhangyang/model_zoom/transformers") # mask token: 32 13 | 14 | 15 | def pad_ss_connections(ss_connections, max_residues, max_surface_atoms): 16 | """ Pad ss_connections to the maximum number of residues and surface atoms in the batch """ 17 | B = len(ss_connections) 18 | ss_connections_padded = torch.ones((B, max_residues, max_surface_atoms), dtype=torch.float32) 19 | for i, ss_connection in enumerate(ss_connections): 20 | ss_connections_padded[i, :ss_connection.shape[0], :ss_connection.shape[1]] = ss_connection 21 | return ss_connections_padded 22 | 23 | 24 | def rbf(values, v_min, v_max, n_bins=16): 25 | """ 26 | Returns RBF encodings in a new dimension at the end. 27 | """ 28 | rbf_centers = torch.linspace(v_min, v_max, n_bins, device=values.device, dtype=values.dtype) 29 | rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1]) 30 | rbf_std = (v_max - v_min) / n_bins 31 | z = (values.unsqueeze(-1) - rbf_centers) / rbf_std 32 | return torch.exp(-z ** 2) 33 | 34 | 35 | class MyTokenizer: 36 | def __init__(self): 37 | self.alphabet_protein = 'ACDEFGHIKLMNPQRSTVWY' # [X] for unknown token 38 | self.alphabet_RNA = 'AUGC' 39 | 40 | def encode(self, seq, RNA=False): 41 | if RNA: 42 | return [self.alphabet_RNA.index(s) for s in seq] 43 | else: 44 | return [self.alphabet_protein.index(s) for s in seq] 45 | 46 | def decode(self, indices, RNA=False): 47 | if RNA: 48 | return ' '.join([self.alphabet_RNA[i] for i in indices]) 49 | else: 50 | return ' '.join([self.alphabet_protein[i] for i in indices]) 51 | 52 | 53 | class featurize_UBC2Model: 54 | def __init__(self, **kwargs) -> None: 55 | self.tokenizer = MyTokenizer() 56 | self.virtual_frame_num = 3 57 | self.exp_backbone_noise_sd = kwargs.get('exp_backbone_noise_sd', 0.0) 58 | self.partial_design = kwargs.get('partial_design', False) 59 | self.design_region_path = kwargs.get('design_region_path', '') 60 | self.design_regions = None # Initialize as None 61 | self.ig_baseline_data = kwargs.get('ig_baseline_data', False) 62 | 63 | if self.partial_design: 64 | print(f"Partial design is enabled. Loading design regions from: {self.design_region_path}") 65 | try: 66 | with open(self.design_region_path, 'r') as f: 67 | self.design_regions = json.load(f) 68 | print("Successfully loaded design regions.") 69 | except FileNotFoundError: 70 | print(f"⚠️ WARNING: Design region file not found at {self.design_region_path}. Partial design will be disabled.") 71 | self.partial_design = False 72 | except json.JSONDecodeError: 73 | print(f"⚠️ WARNING: Could not decode JSON from {self.design_region_path}. Partial design will be disabled.") 74 | self.partial_design = False 75 | 76 | def _get_features_persample(self, batch): 77 | # uniif struc featurizer 78 | for key in batch: 79 | try: 80 | batch[key] = batch[key][None,...] 81 | except: 82 | batch[key] = batch[key] 83 | S = [] 84 | for seq in batch['seq']: 85 | S.extend(self.tokenizer.encode(seq)) 86 | S = torch.tensor(S) 87 | 88 | X = torch.from_numpy(np.stack([np.concatenate(batch['N']), 89 | np.concatenate(batch['CA']), 90 | np.concatenate(batch['C']), 91 | np.concatenate(batch['O'])], axis=1)).float() 92 | 93 | chain_mask = torch.from_numpy(np.concatenate(batch['chain_mask'])).float() 94 | chain_encoding = torch.from_numpy(np.concatenate(batch['chain_encoding'])).float() 95 | 96 | X, S = X.unsqueeze(0), S.unsqueeze(0) 97 | mask = torch.isfinite(torch.sum(X,(2,3))).float() # atom mask 98 | numbers = torch.sum(mask, axis=1).int() 99 | S_new = torch.zeros_like(S) 100 | X_new = torch.zeros_like(X)+torch.nan 101 | for i, n in enumerate(numbers): 102 | X_new[i,:n,::] = X[i][mask[i]==1] 103 | S_new[i,:n] = S[i][mask[i]==1] 104 | 105 | X = X_new 106 | S = S_new 107 | isnan = torch.isnan(X) 108 | mask = torch.isfinite(torch.sum(X,(2,3))).float() 109 | X[isnan] = 0. 110 | 111 | mask_bool = (mask==1) 112 | def node_mask_select(x): 113 | shape = x.shape 114 | x = x.reshape(shape[0], shape[1],-1) 115 | out = torch.masked_select(x, mask_bool.unsqueeze(-1)).reshape(-1, x.shape[-1]) 116 | out = out.reshape(-1,*shape[2:]) 117 | return out 118 | 119 | batch_id = torch.arange(mask_bool.shape[0], device=mask_bool.device)[:,None].expand_as(mask_bool) 120 | seq = node_mask_select(S) 121 | X = node_mask_select(X) 122 | batch_id = node_mask_select(batch_id) 123 | C_a = X[:,1,:] 124 | 125 | edge_idx = knn_graph(C_a, k=30, batch=batch_id, loop=True, flow='target_to_source') 126 | 127 | N, CA, C = X[:,0], X[:,1], X[:,2] 128 | 129 | T = Rigid.make_transform_from_reference(N.float(), CA.float(), C.float()) 130 | src_idx, dst_idx = edge_idx[0], edge_idx[1] 131 | T_ts = T[dst_idx,None].invert().compose(T[src_idx,None]) 132 | 133 | # global virtual frames 134 | num_global = self.virtual_frame_num 135 | 136 | ''' 137 | U的每一列,为原始空间中的坐标基向量 138 | R = U 139 | U2, S2, V2 = torch.svd((R@X_c.T)@(X_c@R.T)) 140 | R@U == U2 141 | ''' 142 | 143 | X_c = T._trans 144 | X_m = X_c.mean(dim=0, keepdim=True) 145 | X_c = X_c-X_m 146 | U,S,V = torch.svd(X_c.T@X_c) 147 | d = (torch.det(U) * torch.det(V)) < 0.0 148 | D = torch.zeros_like(V) 149 | D[ [0,1], [0,1]] = 1 150 | D[2,2] = -1*d+1*(~d) 151 | V = D@V 152 | R = torch.matmul(U, V.permute(0,1)) 153 | 154 | rot_g = [R]*num_global 155 | trans_g = [X_m]*num_global 156 | 157 | feat = get_interact_feats(T, T_ts, X.float(), edge_idx, batch_id) 158 | _V, _E = feat['_V'], feat['_E'] 159 | 160 | ''' 161 | global_src: N+1,N+1,N+2,N+2,..N+B, N+B+1,N+B+1,N+B+2,N+B+2,..N+B+B 162 | global_dst: 0, 1, 2, 3, ..N, 0, 1, 2, 3, ..N 163 | batch_id_g: 1, 1, 2, 2, ..B, 1, 1, 2, 2, ..B 164 | ''' 165 | T_g = Rigid(Rotation(torch.stack(rot_g)), torch.cat(trans_g,dim=0)) 166 | num_nodes = scatter_sum(torch.ones_like(batch_id), batch_id) 167 | global_src = torch.cat([batch_id +k*num_nodes.shape[0] for k in range(num_global)]) + num_nodes 168 | global_dst = torch.arange(batch_id.shape[0], device=batch_id.device).repeat(num_global) 169 | edge_idx_g = torch.stack([global_dst, global_src]) 170 | edge_idx_g_inv = torch.stack([global_src, global_dst]) 171 | edge_idx_g = torch.cat([edge_idx_g, edge_idx_g_inv], dim=1) 172 | 173 | batch_id_g = torch.zeros(num_global,dtype=batch_id.dtype) 174 | T_all = Rigid.cat([T, T_g], dim=0) 175 | 176 | idx, _ = edge_idx_g.min(dim=0) 177 | T_gs = T_all[idx,None].invert().compose(T_all[idx,None]) 178 | 179 | rbf_ts = rbf(T_ts._trans.norm(dim=-1), 0, 50, 16)[:,0].view(_E.shape[0],-1) 180 | rbf_gs = rbf(T_gs._trans.norm(dim=-1), 0, 50, 16)[:,0].view(edge_idx_g.shape[1],-1) 181 | 182 | _V_g = torch.arange(num_global) 183 | _E_g = torch.zeros([edge_idx_g.shape[1], 128]) 184 | 185 | mask = torch.masked_select(mask, mask_bool) 186 | chain_features = (chain_encoding[edge_idx[0]] == chain_encoding[edge_idx[1]]).int() 187 | 188 | batch={ 189 | 'T':T, 190 | 'T_g': T_g, 191 | 'T_ts': T_ts, 192 | 'T_gs': T_gs, 193 | 'rbf_ts': rbf_ts, 194 | 'rbf_gs': rbf_gs, 195 | 'X':X, 196 | 'chain_features': chain_features, 197 | '_V': _V, 198 | '_E': _E, 199 | '_V_g': _V_g, 200 | '_E_g': _E_g, 201 | 'S':seq, 202 | 'edge_idx':edge_idx, 203 | 'edge_idx_g': edge_idx_g, 204 | 'batch_id': batch_id, 205 | 'batch_id_g': batch_id_g, 206 | 'num_nodes': num_nodes, 207 | 'mask': mask, 208 | 'chain_mask': chain_mask, 209 | 'chain_encoding': chain_encoding, 210 | 'K_g': num_global} 211 | 212 | return batch 213 | 214 | def featurize(self,batch): 215 | if self.exp_backbone_noise_sd != 0: 216 | # Iterate over each protein sample in the batch list 217 | for protein_sample in batch: 218 | # List of keys corresponding to backbone atom coordinates 219 | coord_keys = ['N', 'CA', 'C', 'O'] 220 | for key in coord_keys: 221 | # Get the original coordinates (e.g., shape [num_residues, 3]) 222 | coords = protein_sample[key] 223 | # Generate Gaussian noise with the same shape as the coordinates. 224 | # The noise is centered at 0.0 with the specified standard deviation. 225 | noise = np.random.normal(loc=0.0, scale=self.exp_backbone_noise_sd, size=coords.shape) 226 | # Add the noise to the original coordinates and update the sample in place 227 | protein_sample[key] = coords + noise 228 | 229 | if self.ig_baseline_data: 230 | for protein_sample in batch: 231 | # List of keys corresponding to backbone atom coordinates 232 | coord_keys = ['N', 'CA', 'C', 'O'] 233 | for key in coord_keys: 234 | coords = protein_sample[key] 235 | protein_sample[key] = np.random.normal(loc=0.0, scale=15., size=coords.shape) 236 | 237 | # deepcopy batch 238 | batch_copy = copy.deepcopy(batch) 239 | res = [] 240 | for one in batch: 241 | temp = self._get_features_persample(one) 242 | res.append(temp) 243 | res = self.custom_collate_fn(res) 244 | # sbc2 featurizer 245 | bc_batch = self.featurize_SBC2Model(batch_copy) 246 | # update bc_batch into res 247 | for key in bc_batch.keys(): 248 | res[key] = bc_batch[key] 249 | return res 250 | 251 | def custom_collate_fn(self, batch): 252 | batch = [one for one in batch if one is not None] 253 | num_nodes = torch.cat([one['num_nodes'] for one in batch]) 254 | shift = num_nodes.cumsum(dim=0) 255 | shift = torch.cat([torch.tensor([0], device=shift.device), shift], dim=0) 256 | def shift_node_idx(idx, num_node, shift_real, shift_virtual): 257 | mask = idx>=num_node 258 | shift_combine = (~mask)*(shift_real) + (mask)*(shift_virtual) 259 | return idx+shift_combine 260 | 261 | ret = {} 262 | for key in batch[0].keys(): 263 | if batch[0][key] is None: 264 | continue 265 | 266 | if key in ['T', 'T_g', 'T_ts', 'T_gs']: 267 | T = Rigid.cat([one[key] for one in batch], dim=0) 268 | ret[key+'_rot'] = T._rots._rot_mats 269 | ret[key+'_trans'] = T._trans 270 | elif key in ['edge_idx']: 271 | ret[key] = torch.cat([one[key] + shift[idx] for idx, one in enumerate(batch)], dim=1) 272 | elif key in ['edge_idx_g']: 273 | edge_idx_g = [] 274 | for idx, one in enumerate(batch): 275 | shift_virtual = shift[-1] + idx*one['K_g']-num_nodes[idx] 276 | src = shift_node_idx(one['edge_idx_g'][0], num_nodes[idx], shift[idx], shift_virtual) 277 | dst_g = shift_node_idx(one['edge_idx_g'][1], num_nodes[idx], shift[idx], shift_virtual) 278 | edge_idx_g.append(torch.stack([src, dst_g])) 279 | ret[key] = torch.cat(edge_idx_g, dim=1) 280 | elif key in ['batch_id', 'batch_id_g']: 281 | ret[key] = torch.cat([one[key] + idx for idx, one in enumerate(batch)]) 282 | elif key in ['K_g']: 283 | pass 284 | else: 285 | ret[key] = torch.cat([one[key] for one in batch], dim=0) 286 | 287 | return ret 288 | 289 | def featurize_SBC2Model(self, batch): 290 | """ Pack and pad batch into torch tensors with surface and orig_surface downsampling to the minimum size """ 291 | # batch = [one for one in batch if one is not None] 292 | B = len(batch) 293 | if B == 0: 294 | return None 295 | lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) 296 | L_max = max(lengths) 297 | 298 | X = np.zeros([B, L_max, 4, 3]) 299 | S = np.zeros([B, L_max], dtype=np.int32) 300 | score = np.ones([B, L_max]) * 100.0 301 | chain_mask = np.zeros([B, L_max]) - 1 # 1:需要被预测的掩码部分 0:可见部分 302 | chain_encoding = np.zeros([B, L_max]) - 1 303 | 304 | # Build the batch 305 | surfaces = [] 306 | features = [] 307 | orig_surfaces = [] 308 | surface_lengths = [] 309 | ss_connections = [] 310 | correspondences = [] 311 | 312 | for i, b in enumerate(batch): 313 | x = np.stack([b[c] for c in ['N', 'CA', 'C', 'O']], 1) # [#atom, 4, 3] 314 | 315 | l = len(b['seq']) 316 | x_pad = np.pad(x, [[0, L_max - l], [0, 0], [0, 0]], 'constant', constant_values=(np.nan,)) # [#atom, 4, 3] 317 | X[i, :, :, :] = x_pad 318 | 319 | # Convert to labels 320 | indices = np.array(tokenizer.encode(b['seq'], add_special_tokens=False)) 321 | S[i, :l] = indices 322 | chain_mask[i, :l] = b['chain_mask'] 323 | chain_encoding[i, :l] = b['chain_encoding'] 324 | 325 | # Add surface, features, orig_surface 326 | surfaces.append(torch.tensor(b['surface'], dtype=torch.float32)) 327 | features.append(torch.tensor(b['features'], dtype=torch.float32)) 328 | orig_surfaces.append(torch.tensor(b['orig_surface'], dtype=torch.float32)) 329 | surface_lengths.append(b['surface'].shape[0]) 330 | 331 | if self.partial_design and self.design_regions: 332 | protein_name = b['title'] 333 | if protein_name in self.design_regions: 334 | # 1. Get the necessary data 335 | design_mask = torch.tensor(self.design_regions[protein_name], dtype=torch.bool) 336 | 337 | # Convert tensors to NumPy arrays for Scikit-learn (this is fast on CPU) 338 | ca_coords = torch.tensor(b['CA'], dtype=torch.float32).numpy() 339 | surface_coords = orig_surfaces[i].numpy() 340 | 341 | if len(design_mask) != len(ca_coords): 342 | print(f"⚠️ WARNING: Mismatch for '{protein_name}'. Mask length {len(design_mask)} != Residue count {len(ca_coords)}. Skipping masking.") 343 | else: 344 | # 2. Find the closest residue for each surface point (using NearestNeighbors) 345 | # Build the tree from the residue coordinates 346 | nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(ca_coords) 347 | 348 | # Find the index of the single nearest neighbor for each surface point 349 | distances, indices = nbrs.kneighbors(surface_coords) 350 | 351 | # `indices` has shape [num_surface_points, 1], so flatten it 352 | closest_residue_indices = indices.flatten() 353 | 354 | # 3. Create a mask for the surface points 355 | # Use the NumPy array of indices to look up values in the PyTorch design_mask 356 | surface_mask = design_mask[closest_residue_indices] 357 | 358 | # 4. Apply the mask to the features tensor 359 | features[i][surface_mask] = float('nan') 360 | 361 | else: 362 | print(f"⚠️ WARNING: Protein '{protein_name}' not found in design region file. Skipping masking for this sample.") 363 | 364 | if self.ig_baseline_data: 365 | features[i][:] = float('nan') 366 | 367 | # Find the minimum surface length in the batch 368 | min_surface_length = min(surface_lengths) 369 | 370 | # Downsample all surfaces, features, and orig_surfaces to the minimum surface length 371 | surfaces_downsampled = [] 372 | features_downsampled = [] 373 | orig_surfaces_downsampled = [] 374 | 375 | for i, surface in enumerate(surfaces): 376 | surface_len = surface.shape[0] 377 | if surface_len > min_surface_length: 378 | # Randomly sample indices without replacement 379 | sampled_indices = random.sample(range(surface_len), min_surface_length) 380 | surfaces_downsampled.append(surface[sampled_indices]) 381 | features_downsampled.append(features[i][sampled_indices]) 382 | orig_surfaces_downsampled.append(orig_surfaces[i][sampled_indices]) 383 | else: 384 | surfaces_downsampled.append(surface) 385 | features_downsampled.append(features[i]) 386 | orig_surfaces_downsampled.append(orig_surfaces[i]) 387 | 388 | # Stack the downsampled surfaces, features, and orig_surfaces 389 | surfaces_stacked = torch.stack(surfaces_downsampled, dim=0) 390 | features_stacked = torch.stack(features_downsampled, dim=0) 391 | orig_surfaces_stacked = torch.stack(orig_surfaces_downsampled, dim=0) 392 | 393 | mask = np.isfinite(np.sum(X, (2, 3))).astype(np.float32) # atom mask 394 | numbers = np.sum(mask, axis=1).astype(np.int32) 395 | S_new = np.zeros_like(S) 396 | X_new = np.zeros_like(X) + np.nan 397 | 398 | for i, n in enumerate(numbers): 399 | X_new[i, :n, ::] = X[i][mask[i] == 1] 400 | S_new[i, :n] = S[i][mask[i] == 1] 401 | 402 | X = X_new 403 | S = S_new 404 | isnan = np.isnan(X) 405 | mask = np.isfinite(np.sum(X, (2, 3))).astype(np.float32) 406 | X[isnan] = 0. 407 | 408 | # Calculate ss_connection based on X_new and downsampled orig_surface 409 | for i in range(B): 410 | ca_coords = X[i, :, 1, :] # Extract CA coordinates from X_new (1 is for CA atom) 411 | surface_coords = orig_surfaces_stacked[i] 412 | 413 | # Use the mask to identify valid indices 414 | valid_indices = mask[i].astype(bool) # mask[i] is 1 for valid indices, 0 otherwise 415 | valid_ca_coords = ca_coords[valid_indices] 416 | 417 | # Nearest neighbors search 418 | n_neighbors = max(1, int(8 * 175 / lengths[i])) 419 | nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm='ball_tree').fit(surface_coords) 420 | # nbrs = NearestNeighbors(n_neighbors=8, algorithm='ball_tree').fit(surface_coords) 421 | distances, indices = nbrs.kneighbors(valid_ca_coords) 422 | 423 | ss_connection = np.zeros((ca_coords.shape[0], surface_coords.shape[0])) 424 | 425 | # Fill ss_connection for valid CA coordinates 426 | for j, neighbors in zip(np.where(valid_indices)[0], indices): 427 | ss_connection[j, neighbors] = 1 428 | 429 | # Fill ss_connection for invalid CA coordinates 430 | ss_connection[~valid_indices, :] = 1 431 | 432 | ss_connections.append(torch.tensor(ss_connection, dtype=torch.float32)) 433 | 434 | # 1. Calculate the distance matrix for valid_ca_coords 435 | ca_dist_matrix = np.linalg.norm(valid_ca_coords[:, None, :] - valid_ca_coords[None, :, :], axis=-1) 436 | max_dist = np.max(ca_dist_matrix) 437 | r = max_dist / 3 # 1/3 of max distance as radius 438 | 439 | # 2. Randomly sample 8 coords from valid_ca_coords 440 | sampled_indices = random.sample(range(valid_ca_coords.shape[0]), min(8, valid_ca_coords.shape[0])) 441 | 442 | batch_correspondences = [] 443 | for sampled_idx in sampled_indices: 444 | # Get indices of CA atoms within radius r 445 | ca_neighbors = np.where(ca_dist_matrix[sampled_idx] < r)[0] 446 | 447 | # Get distances between the sampled CA atom and surface points 448 | ca_surface_dist_matrix = np.linalg.norm(valid_ca_coords[sampled_idx] - surface_coords.numpy(), axis=-1) 449 | 450 | # Get indices of surface points within radius r 451 | surface_neighbors = np.where(ca_surface_dist_matrix < r)[0] 452 | 453 | # Store the two sets of indices as tensors 454 | batch_correspondences.append([ 455 | torch.tensor(ca_neighbors, dtype=torch.long), 456 | torch.tensor(surface_neighbors, dtype=torch.long) 457 | ]) 458 | 459 | correspondences.append(batch_correspondences) 460 | 461 | # Pad ss_connections 462 | ss_connections_padded = pad_ss_connections(ss_connections, L_max, min_surface_length) 463 | 464 | # Conversion 465 | S = torch.from_numpy(S).to(dtype=torch.long) 466 | score = torch.from_numpy(score).float() 467 | X = torch.from_numpy(X).to(dtype=torch.float32) 468 | mask = torch.from_numpy(mask).to(dtype=torch.float32) 469 | X_flattened = X[mask==1] 470 | lengths = torch.from_numpy(lengths) 471 | chain_mask = torch.from_numpy(chain_mask) 472 | chain_encoding = torch.from_numpy(chain_encoding) 473 | 474 | mask_bool = (mask==1) 475 | S = torch.masked_select(S, mask_bool) 476 | mask = torch.masked_select(mask, mask_bool) 477 | return { 478 | "title": [b['title'] for b in batch], 479 | "X": X, 480 | "X_flattened": X_flattened, 481 | "S": S, 482 | "score": score, 483 | "mask": mask, 484 | "lengths": lengths, 485 | "chain_mask": chain_mask, 486 | "chain_encoding": chain_encoding, 487 | "surface": surfaces_stacked, 488 | "features": features_stacked, 489 | 'ss_connection': ss_connections_padded, 490 | 'correspondences': correspondences, 491 | } 492 | -------------------------------------------------------------------------------- /src/models/UBC2_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import TransformerDecoder, TransformerDecoderLayer 5 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 6 | from torch_scatter import scatter_sum, scatter_softmax 7 | from src.tools import Rigid, Rotation 8 | from src.datasets.featurizer import rbf 9 | import numpy as np 10 | import math 11 | 12 | 13 | def build_MLP(n_layers,dim_in, dim_hid, dim_out, dropout = 0.0, activation=nn.ReLU, normalize=True): 14 | if normalize: 15 | layers = [nn.Linear(dim_in, dim_hid), 16 | nn.BatchNorm1d(dim_hid), 17 | nn.Dropout(dropout), 18 | activation()] 19 | else: 20 | layers = [nn.Linear(dim_in, dim_hid), 21 | nn.Dropout(dropout), 22 | activation()] 23 | for _ in range(n_layers - 2): 24 | layers.append(nn.Linear(dim_hid, dim_hid)) 25 | if normalize: 26 | layers.append(nn.BatchNorm1d(dim_hid)) 27 | layers.append(nn.Dropout(dropout)) 28 | layers.append(activation()) 29 | layers.append(nn.Linear(dim_hid, dim_out)) 30 | return nn.Sequential(*layers) 31 | 32 | 33 | class PointCloudMessagePassing(nn.Module): 34 | def __init__(self, args, feat_dim, edge_dim, l_max, num_scales, hidden_dim, aggregation='concat', num_heads=4, num_mha_layers=1, bc_dropout=0.0): 35 | super(PointCloudMessagePassing, self).__init__() 36 | self.l_max = l_max 37 | self.num_scales = num_scales 38 | self.aggregation = aggregation 39 | self.num_heads = num_heads 40 | 41 | self.per_layer_dim = hidden_dim // 4 42 | 43 | self.bc_mask_max_rate = args.bc_mask_max_rate 44 | self.bc_mask_how = args.bc_mask_how 45 | self.if_struc_only = args.if_struc_only 46 | self.exp_bc_mask_rate = args.exp_bc_mask_rate 47 | self.exp_hydro_mask_rate = getattr(args, 'exp_hydro_mask_rate', 0.) 48 | self.exp_charge_mask_rate = getattr(args, 'exp_charge_mask_rate', 0.) 49 | 50 | # CLS token for biochemical features initialized with per_layer_dim 51 | self.biochem_cls_token = nn.Parameter(torch.randn(1 + 8, self.per_layer_dim)) # Adjusted dimension 52 | 53 | # bc mask token 54 | self.bc_mask_token = nn.Parameter(torch.randn(1, self.per_layer_dim)) 55 | 56 | # Linear layer for feature dimension adjustment 57 | self.input_fc = nn.Linear(feat_dim, self.per_layer_dim) 58 | 59 | encoder_layer = TransformerEncoderLayer( 60 | d_model=self.per_layer_dim, # 输入特征维度 61 | nhead=num_heads, # 多头注意力的头数 62 | dim_feedforward=self.per_layer_dim * 4, # FFN的隐藏层维度 63 | dropout=bc_dropout, 64 | batch_first=True 65 | ) 66 | self.attention_layers = TransformerEncoder(encoder_layer, num_layers=args.bc_encoder_layer) 67 | 68 | # fc for residue connection 69 | self.res_conn_mlp = nn.Sequential( 70 | nn.ReLU(), 71 | nn.Linear(self.per_layer_dim, hidden_dim) 72 | ) 73 | 74 | # Feature aggregation after MHA 75 | self.fc = nn.Linear(self.per_layer_dim * num_scales, hidden_dim) 76 | 77 | def forward(self, surfaces, biochem_feats, correspondences): 78 | B, N, _ = surfaces.shape 79 | 80 | ###### for inference with only backbone structure, bc input will be all nan 81 | # Find rows (over N) where any feature is nan, for each batch 82 | nan_rows = torch.any(torch.isnan(biochem_feats), dim=-1) # shape: (B, N) 83 | 84 | hydro_mask_indices = torch.rand(B, N, device=biochem_feats.device) < self.exp_hydro_mask_rate 85 | charge_mask_indices = torch.rand(B, N, device=biochem_feats.device) < self.exp_charge_mask_rate 86 | biochem_feats[..., 0][hydro_mask_indices] = biochem_feats[..., 0][hydro_mask_indices].mean() 87 | biochem_feats[..., 1][charge_mask_indices] = biochem_feats[..., 1][charge_mask_indices].mean() 88 | 89 | # Elevate the biochemical features 90 | biochem_feats = self.input_fc(biochem_feats) # BxNx(per_layer_dim) 91 | 92 | biochem_feats[nan_rows] = self.bc_mask_token 93 | 94 | if self.if_struc_only: 95 | if self.bc_mask_how == 'token': 96 | biochem_feats[:] = self.bc_mask_token 97 | elif self.bc_mask_how == 'gauss': 98 | biochem_feats[:] = torch.randn_like(biochem_feats) 99 | 100 | # select the indices of the biochemical features to be masked 101 | bc_mask_indices = torch.rand(B, N, device=biochem_feats.device) < self.exp_bc_mask_rate 102 | # mask the biochemical features 103 | if self.bc_mask_how == 'token': 104 | biochem_feats[bc_mask_indices] = self.bc_mask_token 105 | elif self.bc_mask_how == 'gauss': 106 | biochem_feats[bc_mask_indices] = torch.randn_like(biochem_feats[bc_mask_indices]) 107 | 108 | if self.training: 109 | # randomly select a probability between 0 and self.bc_mask_max_rate 110 | bc_mask_rate = torch.rand(B, device=biochem_feats.device) * self.bc_mask_max_rate 111 | # select the indices of the biochemical features to be masked 112 | bc_mask_indices = torch.rand(B, N, device=biochem_feats.device) < bc_mask_rate[:, None] 113 | # mask the biochemical features 114 | if self.bc_mask_how == 'token': 115 | biochem_feats[bc_mask_indices] = self.bc_mask_token 116 | elif self.bc_mask_how == 'gauss': 117 | biochem_feats[bc_mask_indices] = torch.randn_like(biochem_feats[bc_mask_indices]) 118 | 119 | # Add CLS token at the end of biochem_feats (Bx(N+1)x(per_layer_dim)) 120 | cls_tokens = self.biochem_cls_token.expand(B, -1, -1) # Expand CLS token for the batch 121 | biochem_feats = torch.cat([biochem_feats, cls_tokens], dim=1) # Concatenated CLS token 122 | 123 | # Add a last row of infs and a last column of 0s to distances 124 | distances = torch.cdist(surfaces, surfaces) # BxNxN 125 | 126 | # Compute the maximum distance to set dynamic radii 127 | max_distance = distances.max().item() 128 | thr_rs = [max_distance / 20 * i / 4 for i in range(1, 5)] # Different scales of radii 129 | 130 | # Add 9 rows to the bottom of distances, all set to inf 131 | inf_rows = torch.full((B, 9, N), float('inf'), device=surfaces.device) # (Bx9xN) 132 | distances = torch.cat([distances, inf_rows], dim=1) # Bx(N+9)xN 133 | 134 | # Add 9 columns to the right of distances, with special handling 135 | inf_cols = torch.full((B, N + 9, 9), float('inf'), device=surfaces.device) # Bx(N+9)x9 136 | 137 | # First column (corresponding to global CLS token) is all 0s 138 | inf_cols[:, :, 0] = 0 139 | 140 | # Vectorized filling of subarea CLS distances based on correspondences 141 | for i in range(B): 142 | # Get the neighbors from correspondences and subarea indices 143 | corr = correspondences[i] 144 | surface_neighbors = torch.cat([surf for _, surf in corr], dim=0) # Concatenate all surface neighbors 145 | 146 | # Create indices for the subareas corresponding to surface neighbors 147 | subarea_idxs = torch.cat([torch.full_like(surf, j+1) for j, (_, surf) in enumerate(corr)], dim=0) 148 | 149 | # Assign distances for subarea CLS tokens to 0 where correspondences exist 150 | inf_cols[i, surface_neighbors, subarea_idxs] = 0 151 | 152 | # Concatenate the inf_cols to distances 153 | distances = torch.cat([distances, inf_cols], dim=2) # Bx(N+9)x(N+9) 154 | 155 | # Set the diagonal of the last 9x9 block to 0 156 | distances[:, -9:, -9:] = float('inf') # Set the entire 9x9 block to inf first 157 | distances[:, -9:, -9:].diagonal(dim1=-2, dim2=-1).fill_(0) # Set only the diagonal values to 0 158 | 159 | N += 9 # Adjust N to N+9 since CLS tokens are added 160 | 161 | features_list = [] 162 | 163 | for thr_r in thr_rs: 164 | # 1. Create a mask for points within the spherical region 165 | region_mask = distances < thr_r # Bx(N+1)x(N+1) boolean mask 166 | 167 | # 2. Compute the number of neighbors for each point in the region (Bx(N+1)) 168 | num_neighbors = region_mask.sum(dim=-1) # Bx(N+1) 169 | 170 | # 3. Find the maximum number of neighbors to pad all regions to the same size 171 | max_neighbors = num_neighbors.max().item() # The largest region size in this batch 172 | 173 | # 4. Downsample neighbors to 100 if max_neighbors > 100 174 | if max_neighbors > 100: 175 | # Step 1: Get the indices of the True values in region_mask (all neighbors) 176 | batch_idx, center_idx, neighbor_idx = torch.nonzero(region_mask, as_tuple=True) 177 | 178 | # Step 2: Create a mask for the center points (rows) that have more than 100 neighbors 179 | over_limit_mask = num_neighbors > 100 # Bx(N+1) boolean mask where num_neighbors > 100 180 | 181 | # Step 3: Find the batch and center indices that have more than 100 neighbors 182 | over_limit_batch_idx, over_limit_center_idx = torch.nonzero(over_limit_mask, as_tuple=True) 183 | 184 | # Step 4: For these rows, get the neighbor indices and randomly sample 100 neighbors for each row 185 | downsampled_mask = region_mask.clone() 186 | 187 | for b_idx, c_idx in zip(over_limit_batch_idx, over_limit_center_idx): 188 | # Find all neighbors for this center point 189 | neighbor_indices = torch.nonzero(region_mask[b_idx, c_idx], as_tuple=False).squeeze() # Get all neighbors 190 | 191 | # Randomly sample 100 neighbors 192 | random_indices = torch.randperm(neighbor_indices.size(0), device=biochem_feats.device)[:100] # Randomly select 100 193 | selected_neighbors = neighbor_indices[random_indices] # Select 100 neighbors 194 | 195 | # Reset region_mask for this point and update it with only the selected 100 neighbors 196 | downsampled_mask[b_idx, c_idx] = False 197 | downsampled_mask[b_idx, c_idx, selected_neighbors] = True 198 | 199 | # Update region_mask with the downsampled mask 200 | region_mask = downsampled_mask 201 | 202 | # Recompute num_neighbors and max_neighbors after downsampling 203 | num_neighbors = region_mask.sum(dim=-1) # Bx(N+1) 204 | max_neighbors = num_neighbors.max().item() # Limit max_neighbors to 100 205 | 206 | # 5. Get the indices of True values in region_mask 207 | batch_idx, center_idx, neighbor_idx = torch.nonzero(region_mask, as_tuple=True) # Extract indices of neighbors in the region 208 | 209 | # 6. Gather the biochemical features for these indices 210 | gathered_feats = biochem_feats[batch_idx, neighbor_idx] # Gather the corresponding features from biochem_feats 211 | 212 | # 7. Generate sequential indices for each neighbor 213 | neighbor_offsets = torch.arange(num_neighbors.sum()).to(num_neighbors.device) - torch.repeat_interleave(torch.cumsum(num_neighbors.view(-1), dim=0) - num_neighbors.view(-1), num_neighbors.view(-1)).to(num_neighbors.device) 214 | 215 | # 8. Create a tensor to hold padded features for each region 216 | padded_feats = torch.zeros(B, N, max_neighbors, biochem_feats.shape[-1], device=biochem_feats.device) 217 | 218 | # Create a mask to indicate which points are real and which are padding 219 | padding_mask = torch.zeros(B, N, max_neighbors, device=biochem_feats.device, dtype=torch.bool) 220 | 221 | # 9. Scatter the gathered features into the padded_feats tensor using the generated sequential indices 222 | padded_feats[batch_idx, center_idx, neighbor_offsets] = gathered_feats 223 | 224 | # Update padding mask where neighbors exist 225 | padding_mask[batch_idx, center_idx, neighbor_offsets] = 1 # Mark valid neighbors 226 | 227 | # 10. Perform Multi-Head Attention (MHA) 228 | padded_feats_flat = padded_feats.view(B * N, max_neighbors, -1) # (B*(N+1))xMaxNeighborsxFeatDim 229 | padding_mask_flat = ~padding_mask.view(B * N, max_neighbors) # (B*(N+1))xMaxNeighbors, invert mask for MHA 230 | 231 | # # Apply MHA over the padded regions 232 | attn_output = self.attention_layers(padded_feats_flat, src_key_padding_mask=padding_mask_flat) 233 | 234 | # 11. Perform pooling over the region (e.g., mean pooling over valid points) 235 | attn_output = attn_output.view(B, N, max_neighbors, -1) # Bx(N+1)xMaxNeighborsxFeatDim 236 | pooled_feats = attn_output.masked_fill(~padding_mask.unsqueeze(-1), 0).sum(dim=2) / num_neighbors.unsqueeze(-1) # Bx(N+1)xFeatDim 237 | 238 | features_list.append(pooled_feats) 239 | 240 | # 12. Concatenate features from different scales 241 | combined_feats = torch.cat(features_list, dim=-1) # Bx(N+1)x(num_scales * per_layer_dim) 242 | 243 | # Add the residual connection and final projection to hidden_dim 244 | combined_feats = combined_feats + self.res_conn_mlp(biochem_feats) 245 | output_feats = self.fc(combined_feats) # Bx(N+1)xhidden_dim 246 | 247 | return output_feats 248 | 249 | 250 | class GeoFeat(nn.Module): 251 | def __init__(self, geo_layer, num_hidden, virtual_atom_num, dropout=0.0): 252 | super(GeoFeat, self).__init__() 253 | self.__dict__.update(locals()) 254 | self.virtual_atom = nn.Linear(num_hidden, virtual_atom_num*3) 255 | self.virtual_direct = nn.Linear(num_hidden, virtual_atom_num*3) 256 | self.we_condition = build_MLP(geo_layer, 4*virtual_atom_num*3+9+16+32, num_hidden, num_hidden, dropout) 257 | self.MergeEG = nn.Linear(num_hidden+num_hidden, num_hidden) 258 | 259 | def forward(self, h_V, h_E, T_ts, edge_idx, h_E_0): 260 | src_idx = edge_idx[0] 261 | dst_idx = edge_idx[1] 262 | num_edge = src_idx.shape[0] 263 | num_atom = h_V.shape[0] 264 | 265 | # ==================== point cross attention ===================== 266 | V_local = self.virtual_atom(h_V).view(num_atom,-1,3) 267 | V_edge = self.virtual_direct(h_E).view(num_edge,-1,3) 268 | Ks = torch.cat([V_edge,V_local[src_idx].view(num_edge,-1,3)], dim=1) 269 | Qt = T_ts.apply(Ks) 270 | Ks = Ks.view(num_edge,-1) 271 | Qt = Qt.reshape(num_edge,-1) 272 | V_edge = V_edge.reshape(num_edge,-1) 273 | quat_st = T_ts._rots._rot_mats[:, 0].reshape(num_edge, -1) 274 | 275 | RKs = torch.einsum('eij,enj->eni', T_ts._rots._rot_mats[:,0], V_local[src_idx].view(num_edge,-1,3)) 276 | QRK = torch.einsum('enj,enj->en', V_local[dst_idx].view(num_edge,-1,3), RKs) 277 | 278 | H = torch.cat([Ks, Qt, quat_st, T_ts.rbf, QRK], dim=1) 279 | G_e = self.we_condition(H) 280 | h_E = self.MergeEG(torch.cat([h_E, G_e], dim=-1)) 281 | return h_E 282 | 283 | 284 | class PiFoldAttn(nn.Module): 285 | def __init__(self, attn_layer, num_hidden, num_V, num_E, dropout=0.0): 286 | super(PiFoldAttn, self).__init__() 287 | self.__dict__.update(locals()) 288 | self.num_heads = 4 289 | self.W_V = nn.Sequential(nn.Linear(num_E, num_hidden), 290 | nn.GELU()) 291 | 292 | self.Bias = nn.Sequential( 293 | nn.Linear(2*num_V+num_E, num_hidden), 294 | nn.ReLU(), 295 | nn.Linear(num_hidden,num_hidden), 296 | nn.ReLU(), 297 | nn.Linear(num_hidden,self.num_heads)) 298 | self.W_O = nn.Linear(num_hidden, num_V, bias=False) 299 | self.gate = nn.Linear(num_hidden, num_V) 300 | 301 | def forward(self, h_V, h_E, edge_idx): 302 | src_idx = edge_idx[0] 303 | dst_idx = edge_idx[1] 304 | h_V_skip = h_V 305 | 306 | E = h_E.shape[0] 307 | n_heads = self.num_heads 308 | d = int(self.num_hidden / n_heads) 309 | num_nodes = h_V.shape[0] 310 | 311 | w = self.Bias(torch.cat([h_V[src_idx], h_E, h_V[dst_idx]],dim=-1)).view(E, n_heads, 1) 312 | attend_logits = w/np.sqrt(d) 313 | 314 | V = self.W_V(h_E).view(-1,n_heads, d) 315 | attend = scatter_softmax(attend_logits, index=src_idx, dim=0) 316 | h_V = scatter_sum(attend*V, src_idx, dim=0).view([num_nodes, -1]) 317 | 318 | h_V_gate = F.sigmoid(self.gate(h_V)) 319 | dh = self.W_O(h_V)*h_V_gate 320 | 321 | h_V = h_V_skip + dh 322 | return h_V 323 | 324 | 325 | class UpdateNode(nn.Module): 326 | def __init__(self, num_hidden): 327 | super().__init__() 328 | self.dense = nn.Sequential( 329 | nn.BatchNorm1d(num_hidden), 330 | nn.Linear(num_hidden, num_hidden*4), 331 | nn.ReLU(), 332 | nn.Linear(num_hidden*4, num_hidden), 333 | nn.BatchNorm1d(num_hidden) 334 | ) 335 | self.V_MLP_g = nn.Sequential( 336 | nn.Linear(num_hidden, num_hidden), 337 | nn.ReLU(), 338 | nn.Linear(num_hidden,num_hidden), 339 | nn.ReLU(), 340 | nn.Linear(num_hidden,num_hidden)) 341 | 342 | def forward(self, h_V, batch_id): 343 | dh = self.dense(h_V) 344 | h_V = h_V + dh 345 | 346 | # # ============== global attn - virtual frame 347 | uni = batch_id.unique() 348 | mat = (uni[:,None] == batch_id[None]).to(h_V.dtype) 349 | mat = mat/mat.sum(dim=1, keepdim=True) 350 | c_V = mat@h_V 351 | 352 | h_V = h_V * F.sigmoid(self.V_MLP_g(c_V))[batch_id] 353 | return h_V 354 | 355 | 356 | class UpdateEdge(nn.Module): 357 | def __init__(self, edge_layer, num_hidden, dropout=0.1): 358 | super(UpdateEdge, self).__init__() 359 | self.W = build_MLP(edge_layer, num_hidden*3, num_hidden, num_hidden, dropout, activation=nn.GELU, normalize=False) 360 | self.norm = nn.BatchNorm1d(num_hidden) 361 | self.pred_quat = nn.Linear(num_hidden,8) 362 | 363 | def forward(self, h_V, h_E, T_ts, edge_idx, batch_id): 364 | src_idx = edge_idx[0] 365 | dst_idx = edge_idx[1] 366 | 367 | h_EV = torch.cat([h_V[src_idx], h_E, h_V[dst_idx]], dim=-1) 368 | h_E = self.norm(h_E + self.W(h_EV)) 369 | 370 | return h_E 371 | 372 | 373 | class GeneralGNN(nn.Module): 374 | def __init__(self, 375 | geo_layer, 376 | attn_layer, 377 | ffn_layer, 378 | edge_layer, 379 | num_hidden, 380 | virtual_atom_num=32, 381 | dropout=0.1, 382 | mask_rate=0.15, 383 | exp_v_mask_rate=0., 384 | exp_e_mask_rate=0.): 385 | super(GeneralGNN, self).__init__() 386 | self.__dict__.update(locals()) 387 | self.geofeat = GeoFeat(geo_layer, num_hidden, virtual_atom_num, dropout) 388 | self.attention = PiFoldAttn(attn_layer, num_hidden, num_hidden, num_hidden, dropout) 389 | self.update_node = UpdateNode(num_hidden) 390 | self.update_edge = UpdateEdge(edge_layer, num_hidden, dropout) 391 | self.mask_token = nn.Embedding(2, num_hidden) 392 | 393 | def get_rand_idx(self, h_V, mask_rate): 394 | num_N = int(h_V.shape[0] * mask_rate) # 要选择的样本数量,即15% 395 | indices = torch.randperm(h_V.shape[0], device=h_V.device) 396 | selected_indices = indices[:num_N] 397 | return selected_indices 398 | 399 | def forward(self, h_V, h_E, T_ts, edge_idx, batch_id, h_E_0): 400 | if self.training: 401 | selected_indices = self.get_rand_idx(h_V, self.mask_rate) 402 | h_V[selected_indices] = self.mask_token.weight[0] 403 | 404 | selected_indices = self.get_rand_idx(h_E, self.mask_rate) 405 | h_E[selected_indices] = self.mask_token.weight[1] 406 | 407 | if not self.training: # for ablation study 408 | selected_indices = self.get_rand_idx(h_V, self.exp_v_mask_rate) 409 | h_V[selected_indices] = self.mask_token.weight[0] 410 | 411 | selected_indices = self.get_rand_idx(h_E, self.exp_e_mask_rate) 412 | h_E[selected_indices] = self.mask_token.weight[1] 413 | 414 | h_E = self.geofeat(h_V, h_E, T_ts, edge_idx, h_E_0) 415 | h_V = self.attention(h_V, h_E, edge_idx) 416 | h_V = self.update_node(h_V, batch_id) 417 | h_E = self.update_edge( h_V, h_E, T_ts, edge_idx, batch_id ) 418 | return h_V, h_E 419 | 420 | 421 | class StructureEncoder(nn.Module): 422 | def __init__(self, 423 | geo_layer, 424 | attn_layer, 425 | ffn_layer, 426 | edge_layer, 427 | encoder_layer, 428 | hidden_dim, 429 | dropout=0, 430 | mask_rate=0.15, 431 | exp_v_mask_rate=0., 432 | exp_e_mask_rate=0.): 433 | """ Graph labeling network """ 434 | super(StructureEncoder, self).__init__() 435 | self.__dict__.update(locals()) 436 | self.encoder_layers = nn.ModuleList([GeneralGNN(geo_layer, 437 | attn_layer, 438 | ffn_layer, 439 | edge_layer, 440 | hidden_dim, 441 | dropout=dropout, 442 | mask_rate=mask_rate, 443 | exp_v_mask_rate=exp_v_mask_rate, 444 | exp_e_mask_rate=exp_e_mask_rate) for i in range(encoder_layer)]) 445 | self.s = nn.Linear(hidden_dim, 1) 446 | 447 | def forward(self, h_S, 448 | T, 449 | h_V, 450 | h_E, 451 | T_ts, 452 | edge_idx, 453 | batch_id, h_E_0): 454 | # No global frame handling needed - work only with local components 455 | outputs = [] 456 | for layer in self.encoder_layers: 457 | h_V, h_E = layer(h_V, h_E, T_ts, edge_idx, batch_id, h_E_0) 458 | outputs.append(h_V.unsqueeze(1)) 459 | 460 | outputs = torch.cat(outputs, dim=1) 461 | S = F.sigmoid(self.s(outputs)) 462 | output = torch.einsum('nkc, nkb -> nbc', outputs, S).squeeze(1) 463 | return output 464 | 465 | 466 | class UniIFEncoder(nn.Module): 467 | def __init__(self, args, **kwargs): 468 | """ Graph labeling network """ 469 | super(UniIFEncoder, self).__init__() 470 | self.__dict__.update(locals()) 471 | self.hidden_dim = args.hidden_dim 472 | geo_layer, attn_layer, node_layer, edge_layer, encoder_layer, hidden_dim, dropout, mask_rate = args.geo_layer, args.attn_layer, args.node_layer, args.edge_layer, args.encoder_layer, args.hidden_dim, args.dropout, args.mask_rate 473 | 474 | exp_v_mask_rate = getattr(args, 'exp_v_mask_rate', 0.) 475 | exp_e_mask_rate = getattr(args, 'exp_e_mask_rate', 0.) 476 | 477 | self.node_embedding = build_MLP(2, 76, hidden_dim, hidden_dim) 478 | self.edge_embedding = build_MLP(2, 196+16, hidden_dim, hidden_dim) 479 | self.encoder = StructureEncoder(geo_layer, attn_layer, node_layer, edge_layer, encoder_layer, hidden_dim, dropout, mask_rate, 480 | exp_v_mask_rate, exp_e_mask_rate) 481 | self.chain_embeddings = nn.Embedding(2, 16) 482 | 483 | # CLS token for structural features 484 | self.struct_cls_token = nn.Parameter(torch.randn(1 + 8, hidden_dim)) 485 | 486 | self._init_params() 487 | 488 | def _init_params(self): 489 | for name, p in self.named_parameters(): 490 | if p.dim() > 1: 491 | nn.init.xavier_uniform_(p) 492 | 493 | def forward(self, batch, num_global=3): 494 | h_V, h_E, edge_idx, batch_id, chain_features = batch['_V'], batch['_E'], batch['edge_idx'], batch['batch_id'], batch['chain_features'] 495 | correspondences = batch['correspondences'] 496 | # Remove global virtual frame variables 497 | T = Rigid(Rotation(batch['T_rot']), batch['T_trans']) 498 | T_ts = Rigid(Rotation(batch['T_ts_rot']), batch['T_ts_trans']) 499 | h_E = torch.cat([h_E, self.chain_embeddings(chain_features)], dim=-1) 500 | 501 | h_E_0 = h_E 502 | 503 | node_embeds = self.node_embedding(h_V) 504 | 505 | # Prepare for adding CLS tokens 506 | B = len(batch_id.unique()) # Batch size 507 | max_nodes = 9 + max([(batch_id == i).sum().item() for i in range(B)]) # 9 CLS + max residues per batch 508 | # Add CLS token embeddings to node embeddings 509 | # Directly add CLS tokens to the beginning of each batch in the original flattened node_embeds 510 | cls_tokens = self.struct_cls_token.expand(B, -1, -1) # (B, 9, hidden_dim) 511 | # For each batch, prepend 9 CLS tokens to the corresponding node embeddings 512 | node_embeds_with_cls = [] 513 | for i in range(B): 514 | node_indices = (batch_id == i).nonzero(as_tuple=True)[0] 515 | this_node_embeds = node_embeds[node_indices] # (num_nodes_i, hidden_dim) 516 | this_cls_tokens = cls_tokens[i] # (9, hidden_dim) 517 | node_embeds_with_cls.append(torch.cat([this_cls_tokens, this_node_embeds], dim=0)) # (9 + num_nodes_i, hidden_dim) 518 | 519 | h_V = torch.cat(node_embeds_with_cls, dim=0) # (sum_i (9 + num_nodes_i), hidden_dim) 520 | 521 | batch_id_with_cls = [] 522 | # batch_id identifies the which batch each node belongs to 523 | # get unique batch_id 524 | unique_batch_id = batch_id.unique() 525 | for i in unique_batch_id: 526 | # 9 i's before the original i's, use tensor 527 | batch_id_with_cls.append(torch.full((9 + (batch_id == i).sum().item(),), i, device=batch_id.device)) 528 | batch_id_with_cls = torch.cat(batch_id_with_cls, dim=0) 529 | 530 | h_E = self.edge_embedding(h_E) 531 | h_E, edge_idx, T_ts = self._pad_and_stack_edges(h_E, batch_id, edge_idx, T_ts, max_nodes, correspondences) # Shape: (total_edges, hidden_dim), (2, total_edges) 532 | 533 | h_S = None 534 | 535 | # Get structural node embeddings from encoder (without global frames) 536 | node_embeds = self.encoder(h_S, 537 | T, 538 | h_V, 539 | h_E, 540 | T_ts, 541 | edge_idx, 542 | batch_id_with_cls, h_E_0) 543 | 544 | unflattened_node_embeds = self._pad_and_stack(node_embeds, batch_id_with_cls, max_nodes) 545 | 546 | return unflattened_node_embeds 547 | 548 | 549 | def _get_features(self, batch): 550 | return batch 551 | 552 | def _pad_and_stack(self, features, batch_id, max_nodes): 553 | """Pad and stack node features.""" 554 | B = batch_id.max().item() + 1 # Batch size 555 | padded = torch.zeros((B, max_nodes, self.hidden_dim), device=features.device) 556 | 557 | for i in range(B): 558 | node_indices = (batch_id == i).nonzero(as_tuple=True)[0] 559 | padded[i, :len(node_indices), :] = features[node_indices] 560 | 561 | return padded 562 | 563 | def _pad_and_stack_edges(self, edge_weights, batch_id, E_idx, T_ts, max_nodes, correspondences): 564 | """Convert edge features to include CLS tokens and return in 2D format with updated edge indices.""" 565 | B = batch_id.max().item() + 1 # Batch size 566 | 567 | is_vector = True 568 | hidden_dim = edge_weights.shape[-1] 569 | 570 | # Collect all new edge features and indices 571 | new_edge_features = [] 572 | new_edge_indices = [] 573 | new_rot_mats = [] 574 | new_trans = [] 575 | # shape of T_ts._rots._rot_mats torch.Size([#edges, 1, 3, 3]) 576 | # shape of T_ts._trans torch.Size([#edges, 1, 3]) 577 | 578 | # Calculate batch offsets for global node indexing 579 | batch_sizes = [(batch_id == i).sum().item() for i in range(B)] 580 | batch_offsets = [0] 581 | for i in range(B): 582 | batch_offsets.append(batch_offsets[-1] + batch_sizes[i] + 9) # +9 for CLS tokens per batch 583 | 584 | for i in range(B): 585 | node_indices = (batch_id == i).nonzero(as_tuple=True)[0] 586 | min_node_id = node_indices.min().item() 587 | num_nodes = len(node_indices) 588 | 589 | src, dst = E_idx[0, :], E_idx[1, :] 590 | local_edges_mask = (src >= min_node_id) & (src < min_node_id + num_nodes) 591 | 592 | batch_offset = batch_offsets[i] 593 | 594 | # 1. Process original edges (shifted by +9 for CLS tokens) 595 | if local_edges_mask.any(): 596 | src_local = src[local_edges_mask] - min_node_id + 9 # +9 for CLS tokens 597 | dst_local = dst[local_edges_mask] - min_node_id + 9 # +9 for CLS tokens 598 | 599 | # Add to global indices 600 | global_src = src_local + batch_offset 601 | global_dst = dst_local + batch_offset 602 | 603 | # Original edge features 604 | original_edge_features = edge_weights[local_edges_mask] 605 | new_edge_features.append(original_edge_features) 606 | 607 | new_edge_indices.append(torch.stack([global_src, global_dst])) 608 | # T_ts is a list, so we need to use nonzero indices to select elements 609 | local_edge_indices = local_edges_mask.nonzero(as_tuple=True)[0] 610 | new_rot_mats.append(T_ts._rots._rot_mats[local_edge_indices]) 611 | new_trans.append(T_ts._trans[local_edge_indices]) 612 | 613 | # 2. Determine CLS edge feature 614 | cls_edge_feature = torch.zeros(hidden_dim, device=edge_weights.device) 615 | # determine the T_ts for the CLS edge 616 | cls_rot_mats = torch.zeros_like(T_ts._rots._rot_mats[0]) 617 | cls_trans = torch.zeros_like(T_ts._trans[0]) 618 | 619 | # 3. Add global CLS token connections (index 0) 620 | global_cls_idx = batch_offset + 0 # Global CLS index for this batch 621 | global_node_indices = torch.arange(9, 9 + num_nodes, device=edge_weights.device) + batch_offset 622 | 623 | # CLS -> Nodes 624 | cls_to_nodes_src = torch.full((num_nodes,), global_cls_idx, device=edge_weights.device) 625 | cls_to_nodes_dst = global_node_indices 626 | cls_to_nodes_features = cls_edge_feature.unsqueeze(0).repeat(num_nodes, 1) 627 | 628 | new_edge_indices.append(torch.stack([cls_to_nodes_src, cls_to_nodes_dst])) 629 | new_edge_features.append(cls_to_nodes_features) 630 | new_rot_mats.append(cls_rot_mats.unsqueeze(0).repeat(num_nodes, 1, 1, 1)) 631 | new_trans.append(cls_trans.unsqueeze(0).repeat(num_nodes, 1, 1)) 632 | 633 | # Nodes -> CLS 634 | new_edge_indices.append(torch.stack([cls_to_nodes_dst, cls_to_nodes_src])) 635 | new_edge_features.append(cls_to_nodes_features) 636 | new_rot_mats.append(cls_rot_mats.unsqueeze(0).repeat(num_nodes, 1, 1, 1)) 637 | new_trans.append(cls_trans.unsqueeze(0).repeat(num_nodes, 1, 1)) 638 | 639 | # 4. Add subarea CLS token connections (indices 1-8) 640 | if len(correspondences) > i and len(correspondences[i]) > 0: 641 | for sub_idx, (ca_neighbors, _) in enumerate(correspondences[i], start=1): # Limit to 8 subareas 642 | if len(ca_neighbors) > 0: 643 | global_subarea_idx = batch_offset + sub_idx # Global subarea CLS index 644 | global_ca_neighbors = torch.tensor(ca_neighbors, device=edge_weights.device) + 9 + batch_offset 645 | 646 | # Subarea CLS -> CA neighbors 647 | subarea_to_ca_src = torch.full((len(ca_neighbors),), global_subarea_idx, device=edge_weights.device) 648 | subarea_to_ca_dst = global_ca_neighbors 649 | subarea_to_ca_features = cls_edge_feature.unsqueeze(0).repeat(len(ca_neighbors), 1) 650 | 651 | new_edge_indices.append(torch.stack([subarea_to_ca_src, subarea_to_ca_dst])) 652 | new_edge_features.append(subarea_to_ca_features) 653 | new_rot_mats.append(cls_rot_mats.unsqueeze(0).repeat(len(ca_neighbors), 1, 1, 1)) 654 | new_trans.append(cls_trans.unsqueeze(0).repeat(len(ca_neighbors), 1, 1)) 655 | 656 | # CA neighbors -> Subarea CLS 657 | new_edge_indices.append(torch.stack([subarea_to_ca_dst, subarea_to_ca_src])) 658 | new_edge_features.append(subarea_to_ca_features) 659 | new_rot_mats.append(cls_rot_mats.unsqueeze(0).repeat(len(ca_neighbors), 1, 1, 1)) 660 | new_trans.append(cls_trans.unsqueeze(0).repeat(len(ca_neighbors), 1, 1)) 661 | 662 | # Concatenate all edge features and indices 663 | if new_edge_features: 664 | final_edge_features = torch.cat(new_edge_features, dim=0) # (total_edges, hidden_dim) 665 | final_edge_indices = torch.cat(new_edge_indices, dim=1) # (2, total_edges) 666 | final_rot_mats = torch.cat(new_rot_mats, dim=0) 667 | final_trans = torch.cat(new_trans, dim=0) 668 | T_ts._rots._rot_mats = final_rot_mats 669 | T_ts._trans = final_trans 670 | rbf_ts = rbf(T_ts._trans.norm(dim=-1), 0, 50, 16)[:,0].view(final_edge_features.shape[0],-1) 671 | T_ts.rbf = rbf_ts 672 | else: 673 | # Handle empty case 674 | final_edge_features = torch.zeros((0, hidden_dim), device=edge_weights.device) 675 | final_edge_indices = torch.zeros((2, 0), device=edge_weights.device, dtype=torch.long) 676 | 677 | return final_edge_features, final_edge_indices, T_ts 678 | 679 | 680 | class PositionalEncoding(nn.Module): 681 | def __init__(self, d_model, dropout=0.1, max_len=5000): 682 | super(PositionalEncoding, self).__init__() 683 | self.dropout = nn.Dropout(p=dropout) 684 | 685 | pe = torch.zeros(max_len, d_model) 686 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 687 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 688 | pe[:, 0::2] = torch.sin(position * div_term) 689 | pe[:, 1::2] = torch.cos(position * div_term) 690 | pe = pe.unsqueeze(0).transpose(0, 1) 691 | self.register_buffer('pe', pe) 692 | 693 | def forward(self, x): 694 | x = x + self.pe[:x.size(0), :] 695 | return self.dropout(x) 696 | 697 | 698 | class UBC2Model(nn.Module): 699 | def __init__(self, args, queue_size=64, **kwargs): 700 | """ Graph labeling network """ 701 | super(UBC2Model, self).__init__() 702 | self.args = args 703 | hidden_dim = args.hidden_dim 704 | dropout = args.dropout 705 | self.modal_mask_ratio = args.modal_mask_ratio 706 | self.contrastive_pretrain = args.contrastive_pretrain 707 | self.contrastive_pretrain_both = args.contrastive_pretrain_both 708 | self.contrastive_loss_global_alpha = args.contrastive_loss_global_alpha 709 | self.contrastive_loss_local_alpha = args.contrastive_loss_local_alpha 710 | 711 | self.if_strucenc_only = args.if_strucenc_only 712 | 713 | self.if_warmup_train = args.if_warmup_train 714 | 715 | self.bc_indices = getattr(args, 'bc_indices', [0, 1]) 716 | self.exp_wo_bcgraph = getattr(args, 'exp_wo_bcgraph', False) 717 | 718 | self.encoder = UniIFEncoder(args) 719 | 720 | l_max = 2 721 | num_scales = 4 722 | self.surface_encoder = PointCloudMessagePassing(args, len(self.bc_indices), 1, l_max, num_scales, hidden_dim) 723 | 724 | # New Transformer decoder and MLP for final prediction 725 | decoder_layer = TransformerDecoderLayer(d_model=hidden_dim, nhead=8, dropout=dropout, batch_first=True) 726 | self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=3) 727 | self.mlp = nn.Sequential( 728 | nn.Linear(hidden_dim, hidden_dim), 729 | nn.ReLU(), 730 | nn.Linear(hidden_dim, 33) 731 | ) 732 | 733 | # Positional encoding 734 | self.positional_encoding = PositionalEncoding(hidden_dim, dropout) 735 | 736 | self.contrastive_learning = args.contrastive_learning 737 | 738 | # Temperature for contrastive learning 739 | self.temperature = 0.1 740 | self.queue_size = queue_size 741 | 742 | # Initialize queues for structural and biochemical CLS tokens 743 | self.struct_queue = nn.Parameter(torch.zeros(queue_size, hidden_dim), requires_grad=False) 744 | self.biochem_queue = nn.Parameter(torch.zeros(queue_size, hidden_dim), requires_grad=False) 745 | self.queue_ptr = nn.Parameter(torch.zeros(1, dtype=torch.long), requires_grad=False) 746 | 747 | self._init_params() 748 | 749 | if self.contrastive_pretrain: 750 | modules_to_freeze = { 751 | "encoder": self.encoder, 752 | "transformer_decoder": self.transformer_decoder, 753 | "mlp": self.mlp 754 | } 755 | 756 | for name, module in modules_to_freeze.items(): 757 | print(f"--- Freezing module: '{name}'") 758 | for param in module.parameters(): 759 | param.requires_grad = False 760 | module.eval() 761 | elif self.contrastive_pretrain_both: 762 | modules_to_freeze = { 763 | "transformer_decoder": self.transformer_decoder, 764 | "mlp": self.mlp 765 | } 766 | 767 | for name, module in modules_to_freeze.items(): 768 | print(f"--- Freezing module: '{name}'") 769 | for param in module.parameters(): 770 | param.requires_grad = False 771 | module.eval() 772 | elif self.if_strucenc_only: 773 | modules_to_freeze = { 774 | "surface_encoder": self.surface_encoder, 775 | "transformer_decoder": self.transformer_decoder, 776 | } 777 | 778 | for name, module in modules_to_freeze.items(): 779 | print(f"--- Freezing module: '{name}'") 780 | for param in module.parameters(): 781 | param.requires_grad = False 782 | module.eval() 783 | elif self.if_warmup_train: 784 | modules_to_freeze = { 785 | "encoder": self.encoder, 786 | "surface_encoder": self.surface_encoder 787 | } 788 | for name, module in modules_to_freeze.items(): 789 | print(f"--- Freezing module: '{name}'") 790 | for param in module.parameters(): 791 | param.requires_grad = False 792 | module.eval() 793 | 794 | def forward(self, batch): 795 | batch_id = batch['batch_id'] 796 | h_V_unflattened = self.encoder(batch) 797 | 798 | # Manually extract the CLS tokens (global + subarea) from the structure encoder 799 | struct_cls_tokens = h_V_unflattened[:, :9, :] # First 9 tokens: global (0th) + subarea (1st to 8th) 800 | h_V_unflattened = h_V_unflattened[:, 9:, :] # The rest of the node embeddings 801 | 802 | # Unflatten h_V and mask to have batch dimension 803 | max_length = batch['lengths'].max().item() 804 | batch_size = len(batch['lengths']) 805 | mask_unflattened = torch.zeros(batch_size, max_length, device=h_V_unflattened.device) 806 | 807 | # Efficiently assign values to h_V_unflattened and mask_unflattened 808 | for idx in torch.unique(batch_id): 809 | mask = (batch_id == idx) 810 | mask_unflattened[idx, :mask.sum()] = 1 811 | # Create padding masks 812 | target_padding_mask = ~mask_unflattened.bool() 813 | 814 | ### surface encoder 815 | surfaces, biochem_feats, correspondences = batch['surface'], batch['features'], batch['correspondences'] 816 | 817 | if self.training: 818 | # generate a random number between 0 and 1 819 | random_number = torch.rand(1).item() 820 | if random_number < self.modal_mask_ratio: 821 | biochem_feats = torch.randn_like(biochem_feats) 822 | 823 | if self.if_strucenc_only: 824 | decoder_output = h_V_unflattened 825 | # Flatten decoder_output and remove padding 826 | mask = mask_unflattened.bool() 827 | decoder_output = decoder_output[mask] 828 | 829 | # Predict labels using MLP 830 | logits = self.mlp(decoder_output) 831 | log_probs = F.log_softmax(logits, dim=-1) 832 | 833 | elif self.contrastive_pretrain or self.contrastive_pretrain_both: 834 | h_surface = self.surface_encoder(surfaces, biochem_feats, correspondences) 835 | 836 | # Manually extract the CLS tokens (global + subarea) from the biochemical encoder 837 | biochem_cls_tokens = h_surface[:, -9:, :] # Last 9 tokens: global (0th) + subarea (1st to 8th) 838 | h_surface = h_surface[:, :-9, :] # The rest of the biochemical node embeddings 839 | 840 | logits = 0 841 | log_probs = 0 842 | else: 843 | h_surface = self.surface_encoder(surfaces, biochem_feats, correspondences) 844 | 845 | # Manually extract the CLS tokens (global + subarea) from the biochemical encoder 846 | biochem_cls_tokens = h_surface[:, -9:, :] # Last 9 tokens: global (0th) + subarea (1st to 8th) 847 | h_surface = h_surface[:, :-9, :] # The rest of the biochemical node embeddings 848 | 849 | ss_connection_mask = batch['ss_connection'] 850 | ss_connection_mask = ~ss_connection_mask.bool().repeat(8, 1, 1) 851 | 852 | # Transformer decoder to fuse h_V_unflattened and h_surface 853 | # Add positional encoding to the inputs of the Transformer decoder 854 | h_V_unflattened = self.positional_encoding(h_V_unflattened) 855 | 856 | if self.exp_wo_bcgraph: 857 | decoder_output = self.transformer_decoder( 858 | h_V_unflattened, h_surface, 859 | tgt_key_padding_mask=target_padding_mask, 860 | ) 861 | else: 862 | decoder_output = self.transformer_decoder( 863 | h_V_unflattened, h_surface, 864 | tgt_key_padding_mask=target_padding_mask, 865 | memory_mask=ss_connection_mask 866 | ) 867 | 868 | # Flatten decoder_output and remove padding 869 | mask = mask_unflattened.bool() 870 | decoder_output = decoder_output[mask] 871 | 872 | # Predict labels using MLP 873 | logits = self.mlp(decoder_output) 874 | log_probs = F.log_softmax(logits, dim=-1) 875 | 876 | # Contrastive learning 877 | if (self.training and random_number < self.modal_mask_ratio) or not self.contrastive_learning: 878 | contrastive_loss = 0 879 | else: 880 | contrastive_loss_global = self._contrastive_loss(struct_cls_tokens[:, 0, :], biochem_cls_tokens[:, 0, :]) # Global CLS 881 | contrastive_loss_subarea = self._contrastive_loss_subarea(struct_cls_tokens[:, 1:, :], biochem_cls_tokens[:, 1:, :]) # Subarea CLS 882 | contrastive_loss = self.contrastive_loss_global_alpha * contrastive_loss_global + self.contrastive_loss_local_alpha * contrastive_loss_subarea 883 | 884 | # Update queues with current batch global CLS tokens 885 | self._dequeue_and_enqueue(struct_cls_tokens[:, 0, :], biochem_cls_tokens[:, 0, :]) 886 | 887 | return {'log_probs': log_probs, 'contrastive_loss': contrastive_loss, 'logits': logits} 888 | 889 | @torch.no_grad() 890 | def _dequeue_and_enqueue(self, struct_cls_token, biochem_cls_token): 891 | """Append new CLS tokens to the queue and dequeue older ones.""" 892 | batch_size = struct_cls_token.size(0) 893 | 894 | # Get current position in the queue 895 | ptr = int(self.queue_ptr) 896 | 897 | # Replace oldest entries with the new ones 898 | if ptr + batch_size > self.queue_size: 899 | ptr = 0 900 | self.struct_queue[ptr:ptr + batch_size, :] = struct_cls_token 901 | self.biochem_queue[ptr:ptr + batch_size, :] = biochem_cls_token 902 | 903 | # Move pointer and wrap-around if necessary 904 | ptr = (ptr + batch_size) % self.queue_size 905 | self.queue_ptr[0] = ptr 906 | 907 | def _contrastive_loss(self, struct_cls_token, biochem_cls_token): 908 | """Compute NT-Xent contrastive loss using queue-based negative sampling.""" 909 | batch_size = struct_cls_token.size(0) 910 | 911 | # Normalize CLS tokens 912 | z_i = F.normalize(struct_cls_token, dim=-1) 913 | z_j = F.normalize(biochem_cls_token, dim=-1) 914 | 915 | # Normalize queue embeddings 916 | struct_queue_norm = F.normalize(self.struct_queue.clone().detach(), dim=-1) 917 | biochem_queue_norm = F.normalize(self.biochem_queue.clone().detach(), dim=-1) 918 | 919 | # Cosine similarity between current CLS tokens 920 | sim_ij = torch.matmul(z_i, z_j.T) / self.temperature # (batch_size, batch_size) 921 | 922 | # Cosine similarity with negative samples from the queue 923 | sim_i_struct_queue = torch.matmul(z_i, biochem_queue_norm.T) / self.temperature # (batch_size, queue_size) 924 | sim_j_biochem_queue = torch.matmul(z_j, struct_queue_norm.T) / self.temperature # (batch_size, queue_size) 925 | 926 | # Combine positive and negative samples 927 | sim_matrix_i = torch.cat([sim_ij, sim_i_struct_queue], dim=1) # (batch_size, batch_size + queue_size) 928 | sim_matrix_j = torch.cat([sim_ij.T, sim_j_biochem_queue], dim=1) # (batch_size, batch_size + queue_size) 929 | 930 | # Create labels (positive samples on the diagonal) 931 | labels = torch.arange(batch_size).long().to(sim_matrix_i.device) 932 | 933 | # Contrastive loss for both modalities 934 | loss_i = F.cross_entropy(sim_matrix_i, labels) 935 | loss_j = F.cross_entropy(sim_matrix_j, labels) 936 | 937 | loss = (loss_i + loss_j) / 2.0 938 | return loss 939 | 940 | def _contrastive_loss_subarea(self, struct_subarea_cls_tokens, biochem_subarea_cls_tokens): 941 | """Compute contrastive loss for the subarea CLS tokens without using a queue, using only the current batch.""" 942 | batch_size, num_subareas, hidden_dim = struct_subarea_cls_tokens.size() 943 | 944 | # Normalize CLS tokens 945 | z_i = F.normalize(struct_subarea_cls_tokens, dim=-1) 946 | z_j = F.normalize(biochem_subarea_cls_tokens, dim=-1) 947 | 948 | # Cosine similarity within the batch for subarea CLS tokens 949 | sim_ij = torch.matmul(z_i, z_j.transpose(1, 2)) / self.temperature # (batch_size, num_subareas, num_subareas) 950 | 951 | # Create labels (positive samples on the diagonal) 952 | labels = torch.arange(num_subareas).long().to(sim_ij.device).unsqueeze(0).expand(batch_size, -1) 953 | 954 | # Reshape sim_ij and labels for efficient cross-entropy calculation 955 | sim_ij = sim_ij.view(batch_size * num_subareas, num_subareas) # (batch_size * num_subareas, num_subareas) 956 | labels = labels.reshape(batch_size * num_subareas) # (batch_size * num_subareas,) 957 | 958 | # Compute contrastive loss in one step 959 | loss = F.cross_entropy(sim_ij, labels) 960 | 961 | return loss 962 | 963 | def _init_params(self): 964 | for name, p in self.named_parameters(): 965 | if p.dim() > 1: 966 | nn.init.xavier_uniform_(p) 967 | 968 | def _get_features(self, batch): 969 | return batch 970 | 971 | 972 | --------------------------------------------------------------------------------