├── src
├── interface
│ ├── __init__.py
│ ├── data_interface.py
│ └── model_interface.py
├── datasets
│ ├── __init__.py
│ ├── afdb_dataset.py
│ ├── inference_dataset.py
│ ├── ts_dataset.py
│ ├── cath_dataset.py
│ ├── cathafdb_dataset.py
│ └── featurizer.py
├── models
│ ├── __init__.py
│ ├── configs
│ │ └── UBC2Model.yaml
│ └── UBC2_model.py
├── tools
│ ├── __init__.py
│ ├── logger.py
│ └── config_utils.py
├── version.py
└── __init__.py
├── TMscore
├── assets
├── BC-Design.png
└── BC-Design-overview.png
├── environment.yml
├── .gitignore
├── CONTRIBUTING.md
├── train
├── main_eval.py
├── data_interface.py
└── main.py
├── LICENSE
├── README.md
└── pdb2jsonpkl.py
/src/interface/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/TMscore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gersteinlab/BC-Design/HEAD/TMscore
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) CAIRI AI Lab. All rights reserved
2 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) CAIRI AI Lab. All rights reserved
2 |
--------------------------------------------------------------------------------
/assets/BC-Design.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gersteinlab/BC-Design/HEAD/assets/BC-Design.png
--------------------------------------------------------------------------------
/assets/BC-Design-overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gersteinlab/BC-Design/HEAD/assets/BC-Design-overview.png
--------------------------------------------------------------------------------
/src/tools/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) CAIRI AI Lab. All rights reserved
2 |
3 | from .affine_tools import Rigid, Rotation, get_interact_feats
4 |
--------------------------------------------------------------------------------
/src/interface/data_interface.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 |
3 |
4 | class DInterface_base(pl.LightningDataModule):
5 | def __init__(self, **kwargs):
6 | super().__init__()
7 | self.save_hyperparameters()
8 | self.batch_size = self.hparams.batch_size
9 | print("batch_size", self.batch_size)
10 | self.load_data_module()
11 |
--------------------------------------------------------------------------------
/src/models/configs/UBC2Model.yaml:
--------------------------------------------------------------------------------
1 | res_dir: ./train/results
2 | # ex_name: UBC2Model
3 | dataset: CATH4.2
4 | model_name: UBC2Model
5 | # lr: 0.0002
6 | # lr_scheduler: onecycle
7 | # lr_scheduler: cosine
8 | offline: 1
9 | seed: 112
10 | pretrained_path: ''
11 | batch_size: 2
12 | accumulate_grad_batches: 1
13 | num_workers: 0
14 | min_length: 40
15 | data_root: ./data/
16 | # epoch: 50
17 | augment_eps: 0.0
18 | geo_layer: 3
19 | attn_layer: 3
20 | node_layer: 3
21 | edge_layer: 3
22 | encoder_layer: 12
23 | hidden_dim: 128
24 | dropout: 0.0
25 | k_neighbors: 30
26 | virtual_atom_num: 24
27 | bc_encoder_layer: 2
28 |
29 | mask_rate: 0.1
30 | bc_mask_how: token
31 | modal_mask_ratio: 0.
32 |
33 | contrastive_pretrain_both: false
34 | contrastive_loss_global_alpha: 0.01
35 | contrastive_loss_local_alpha: 1.
36 |
37 | if_struc_only: false
38 | # checkpoint_path: "./UBC2Model.ckpt"
39 |
--------------------------------------------------------------------------------
/src/version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) CAIRI AI Lab. All rights reserved
2 |
3 | __version__ = '0.1.0'
4 |
5 |
6 | def parse_version_info(version_str):
7 | """Parse a version string into a tuple.
8 |
9 | Args:
10 | version_str (str): The version string.
11 | Returns:
12 | tuple[int | str]: The version info, e.g., "0.1.0" is parsed into
13 | (0, 1, 0), and "2.0.0rcx" is parsed into (2, 0, 0, 'rcx').
14 | """
15 | version_info = []
16 | for x in version_str.split('.'):
17 | if x.isdigit():
18 | version_info.append(int(x))
19 | elif x.find('rc') != -1:
20 | patch_version = x.split('rc')
21 | version_info.append(int(patch_version[0]))
22 | version_info.append(f'rc{patch_version[1]}')
23 | return tuple(version_info)
24 |
25 |
26 | version_info = parse_version_info(__version__)
27 |
28 | __all__ = ['__version__', 'version_info', 'parse_version_info']
29 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: bcdesign
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - nvidia # Kept for nvidia drivers/libs if not covered by pytorch channel for all needs
6 | - bioconda # Kept for biopython and potentially other bio-related tools
7 | - defaults
8 | dependencies:
9 | # Core Python
10 | - python=3.12.4
11 |
12 | - pytorch=2.3.0
13 | - pytorch-cuda=12.1 # This will pull in CUDA toolkit libs for PyTorch
14 | - lightning=2.3.3 # or pytorch-lightning, 'lightning' is the newer name
15 | - scikit-learn=1.5.1
16 | - numpy=2.0.1
17 | - pyyaml=6.0.1
18 | - tqdm=4.66.4 # Progress bars
19 |
20 | - pip=24.0
21 |
22 | - pip:
23 | - torch-scatter==2.1.2
24 | - torch-cluster==1.6.3
25 | - torch-geometric==2.5.3
26 | - "-f https://data.pyg.org/whl/torch-2.3.0+cu121.html"
27 |
28 | - transformers==4.43.3 # pip
29 | - huggingface-hub==0.24.2 # pip
30 |
31 | - scipy==1.14.0 # pip
32 | - pandas==2.2.2 # pip
33 |
34 | # Bioinformatics
35 | - biopython==1.84 # pip
36 |
37 | # Experiment tracking
38 | - wandb>=0.12.10
39 |
40 | # Configuration & Utilities
41 | - omegaconf==2.3.0 # pip
42 | - requests==2.32.3 # pip
43 |
44 | # Core utilities often managed by pip or with specific versions
45 | - torcheval==0.0.7
46 | - setuptools==71.1.0 # Build system
47 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) CAIRI AI Lab. All rights reserved
2 |
3 | import warnings
4 | from packaging.version import parse
5 |
6 | from .version import __version__
7 |
8 |
9 | def digit_version(version_str: str, length: int = 4):
10 | """Convert a version string into a tuple of integers.
11 |
12 | This method is usually used for comparing two versions. For pre-release
13 | versions: alpha < beta < rc.
14 |
15 | Args:
16 | version_str (str): The version string.
17 | length (int): The maximum number of version levels. Default: 4.
18 |
19 | Returns:
20 | tuple[int]: The version info in digits (integers).
21 | """
22 | version = parse(version_str)
23 | assert version.release, f'failed to parse version {version_str}'
24 | release = list(version.release)
25 | release = release[:length]
26 | if len(release) < length:
27 | release = release + [0] * (length - len(release))
28 | if version.is_prerelease:
29 | mapping = {'a': -3, 'b': -2, 'rc': -1}
30 | val = -4
31 | # version.pre can be None
32 | if version.pre:
33 | if version.pre[0] not in mapping:
34 | warnings.warn(f'unknown prerelease version {version.pre[0]}, '
35 | 'version checking may go wrong')
36 | else:
37 | val = mapping[version.pre[0]]
38 | release.extend([val, version.pre[-1]])
39 | else:
40 | release.extend([val, 0])
41 |
42 | elif version.is_postrelease:
43 | release.extend([1, version.post])
44 | else:
45 | release.extend([0, 0])
46 | return tuple(release)
47 |
48 |
49 | __all__ = ['__version__', 'digit_version']
50 |
--------------------------------------------------------------------------------
/src/tools/logger.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning.callbacks import Callback
2 | import os
3 | import shutil
4 | from omegaconf import OmegaConf
5 |
6 | class SetupCallback(Callback):
7 | def __init__(self, now, logdir, ckptdir, cfgdir, config, argv_content=None):
8 | super().__init__()
9 | self.now = now
10 | self.logdir = logdir
11 | self.ckptdir = ckptdir
12 | self.cfgdir = cfgdir
13 | self.config = config
14 |
15 | self.argv_content = argv_content
16 |
17 | # 在pretrain例程开始时调用。
18 | def on_fit_start(self, trainer, pl_module):
19 | # Create logdirs and save configs
20 | os.makedirs(self.logdir, exist_ok=True)
21 | os.makedirs(self.ckptdir, exist_ok=True)
22 | os.makedirs(self.cfgdir, exist_ok=True)
23 |
24 | print("Project config")
25 | print(OmegaConf.to_yaml(self.config))
26 | OmegaConf.save(self.config,
27 | os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
28 |
29 | with open(os.path.join(self.logdir, "argv_content.txt"), "w") as f:
30 | f.write(str(self.argv_content))
31 |
32 | class BackupCodeCallback(Callback):
33 | def __init__(self, source_dir, backup_dir, ignore_patterns=None):
34 | super().__init__()
35 | self.source_dir = source_dir
36 | self.backup_dir = backup_dir
37 | self.ignore_patterns = ignore_patterns
38 |
39 | def on_train_start(self, trainer, pl_module):
40 | try:
41 | os.makedirs(self.backup_dir, exist_ok=True)
42 | if os.path.exists(self.backup_dir+'/code'):
43 | shutil.rmtree(self.backup_dir+'/code')
44 | shutil.copytree(self.source_dir, self.backup_dir+'/code', ignore=self.ignore_patterns)
45 |
46 | print(f"Code file backed up to {self.backup_dir}")
47 | except:
48 | print(f"Fail in copying file backed up to {self.backup_dir}")
--------------------------------------------------------------------------------
/src/datasets/afdb_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | import torch.utils.data as data
5 | import pickle
6 |
7 |
8 | def normalize_coordinates(surface):
9 | """
10 | Normalize the coordinates of the surface.
11 | """
12 | surface = np.array(surface)
13 | center = np.mean(surface, axis=0)
14 | max_ = np.max(surface, axis=0)
15 | min_ = np.min(surface, axis=0)
16 | length = np.max(max_ - min_)
17 | normalized_surface = (surface - center) / length
18 | return normalized_surface
19 |
20 |
21 | class AFDB2000Dataset(data.Dataset):
22 | def __init__(self, path = './', split='test'):
23 | self.path = path
24 | if not os.path.exists(path):
25 | raise "no such file:{} !!!".format(path)
26 | else:
27 | afdb2000_data = json.load(open(path+'/afdb2000.json'))
28 |
29 | self.data_dict = self._load_data_dict()
30 |
31 | self.data = []
32 | for temp in afdb2000_data:
33 | title = temp['name']
34 | data = self.data_dict[title]
35 | seq_length = len(temp['seq'])
36 | coords = np.array(temp['coords'])
37 | self.data.append({'title':title,
38 | 'seq':temp['seq'],
39 | 'CA':coords[:,1,:],
40 | 'C':coords[:,2,:],
41 | 'O':coords[:,3,:],
42 | 'N':coords[:,0,:],
43 | 'category': 'afdb2000',
44 | 'chain_mask': np.ones(seq_length),
45 | 'chain_encoding': np.ones(seq_length),
46 | 'orig_surface': data['surface'],
47 | 'surface': normalize_coordinates(data['surface']),
48 | 'features': data['features'][:, :2],
49 | })
50 |
51 | def _load_data_dict(self):
52 | with open(self.path + f'/afdb2000.pkl', 'rb') as f:
53 | return pickle.load(f)
54 |
55 | def __len__(self):
56 | return len(self.data)
57 |
58 | def __getitem__(self, index):
59 | return self.data[index]
--------------------------------------------------------------------------------
/src/interface/model_interface.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 | import torch.nn as nn
4 | import os
5 | import torch.optim.lr_scheduler as lrs
6 | import inspect
7 |
8 | class MInterface_base(pl.LightningModule):
9 | def __init__(self, model_name=None, loss=None, lr=None, **kargs):
10 | super().__init__()
11 | self.save_hyperparameters()
12 | self.load_model()
13 | self.configure_loss()
14 | os.makedirs(os.path.join(self.hparams.res_dir, self.hparams.ex_name), exist_ok=True)
15 |
16 | def on_validation_epoch_end(self):
17 | # Make the Progress Bar leave there
18 | self.print('')
19 |
20 | def get_schedular(self, optimizer, lr_scheduler='onecycle'):
21 | if lr_scheduler == 'step':
22 | scheduler = lrs.StepLR(optimizer,
23 | step_size=self.hparams.lr_decay_steps,
24 | gamma=self.hparams.lr_decay_rate)
25 | elif lr_scheduler == 'cosine':
26 | scheduler = lrs.CosineAnnealingLR(optimizer,
27 | T_max=self.hparams.steps_per_epoch*self.hparams.epoch,
28 | eta_min=self.hparams.lr / 100)
29 | elif lr_scheduler == 'onecycle':
30 | scheduler = lrs.OneCycleLR(optimizer, max_lr=self.hparams.lr, steps_per_epoch=self.hparams.steps_per_epoch, epochs=self.hparams.epoch, three_phase=False, final_div_factor=1.,
31 | )
32 | else:
33 | raise ValueError('Invalid lr_scheduler type!')
34 |
35 | return scheduler
36 |
37 | def configure_optimizers(self):
38 | if hasattr(self.hparams, 'weight_decay'):
39 | weight_decay = self.hparams.weight_decay
40 | else:
41 | weight_decay = 0
42 |
43 | optimizer_g = torch.optim.AdamW(self.model.parameters(), lr=self.hparams.lr, weight_decay=weight_decay, betas=(0.9, 0.98), eps=1e-8)
44 |
45 | schecular_g = self.get_schedular(optimizer_g, self.hparams.lr_scheduler)
46 |
47 | return [optimizer_g], [{"scheduler": schecular_g, "interval": "step"}]
48 |
49 | def lr_scheduler_step(self, *args, **kwargs):
50 | scheduler = self.lr_schedulers()
51 | scheduler.step()
52 |
53 |
--------------------------------------------------------------------------------
/src/datasets/inference_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | import torch.utils.data as data
5 | import pickle
6 |
7 |
8 | def normalize_coordinates(surface):
9 | """
10 | Normalize the coordinates of the surface.
11 | """
12 | surface = np.array(surface)
13 | center = np.mean(surface, axis=0)
14 | max_ = np.max(surface, axis=0)
15 | min_ = np.min(surface, axis=0)
16 | length = np.max(max_ - min_)
17 | normalized_surface = (surface - center) / length
18 | return normalized_surface
19 |
20 |
21 | class InferenceDataset(data.Dataset):
22 | def __init__(self, path = './', split='test'):
23 | self.path = path
24 | dataset_name = os.path.basename(path)
25 | self.json_path = os.path.join(path, dataset_name + '.json')
26 | self.pkl_path = os.path.join(path, dataset_name + '.pkl')
27 | if not os.path.exists(path):
28 | raise "no such file:{} !!!".format(path)
29 | else:
30 | data = json.load(open(self.json_path))
31 |
32 | self.data_dict = self._load_data_dict()
33 |
34 | self.data = []
35 | for temp in data:
36 | title = temp['name']
37 | data = self.data_dict[title]
38 | seq_length = len(temp['seq'])
39 | coords = np.array(temp['coords'])
40 | self.data.append({'title':title,
41 | 'seq':temp['seq'],
42 | 'CA':coords[:,1,:],
43 | 'C':coords[:,2,:],
44 | 'O':coords[:,3,:],
45 | 'N':coords[:,0,:],
46 | 'category': 'inference',
47 | 'chain_mask': np.ones(seq_length),
48 | 'chain_encoding': np.ones(seq_length),
49 | 'orig_surface': data['surface'],
50 | 'surface': normalize_coordinates(data['surface']),
51 | 'features': data['features'][:, :2],
52 | })
53 |
54 | def _load_data_dict(self):
55 | with open(self.pkl_path, 'rb') as f:
56 | return pickle.load(f)
57 |
58 | def __len__(self):
59 | return len(self.data)
60 |
61 | def __getitem__(self, index):
62 | return self.data[index]
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | apex/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *.cover
48 | .hypothesis/
49 | .pytest_cache/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 | db.sqlite3
59 |
60 | # Flask stuff:
61 | instance/
62 | .webassets-cache
63 |
64 | # Scrapy stuff:
65 | .scrapy
66 |
67 | # Sphinx documentation
68 | docs/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
107 | .vscode
108 | .idea
109 |
110 | # custom
111 | /data.tar.gz
112 | *.pkl
113 | *.pkl.json
114 | *.log.json
115 | *.ckpt
116 | *.zip
117 | *.ipynb
118 | *.jpg
119 | bash
120 | data
121 | data-before0601
122 | /configs
123 | data_set
124 | results/
125 | gaozhangyang/
126 | src/models/SurfPro/
127 | src/modules/
128 | src/models/__pycache__/
129 | src/models/*.py
130 | !src/models/__init__.py
131 | !src/models/UBC2_model.py
132 |
133 | src/datasets/dataloader.py
134 | src/datasets/mpnn_dataset.py
135 | src/datasets/casp_dataset.py
136 | src/datasets/alphafold_dataset.py
137 | src/datasets/fast_dataloader.py
138 |
139 | src/interface/pretrain_interface.py
140 |
141 | src/models/configs/*.yaml
142 | !src/models/configs/UBC2Model.yaml
143 |
144 | src/tools/main_utils.py
145 | src/tools/config_utils.py
146 | src/tools/metrics.py
147 | src/tools/parser.py
148 | src/tools/utils.py
149 |
150 | train/results
151 | train/ig*
152 | train/train.sh
153 | train/try_multi_train.py
154 | train/main_train.py
155 |
156 | run.sh
157 | output/
158 | work_dirs/
159 | workspace/
160 | tools/exp_bash/
161 | pretrains
162 | cache/
163 | cath_classes/
164 | gt_pdb/
165 | ig_results/
166 | ig_biochem_results_steps50/
167 | ig_results_steps50/
168 | requirements/
169 | tools/prepare_data/
170 | lightning_logs/
171 | predicted_pdb/
172 | /logits
173 | raw_test_data/
174 | test_results/
175 | alphafolddb/
176 | cath_test_82/
177 | figures*
178 |
179 | TMscore.cpp
180 | TMscore.f
181 |
182 | antonia*
183 | calc_diversity.py
184 | create_cath42test_json.py
185 | json2pkl.py
186 |
187 | environment-deprecated.yml
188 | environment0601.yml
189 | environment-full-deprecated.yml
190 | test.py
191 | visualization*
192 | create_cath42test_json.py
193 | pred_pdb2jsonpkl.py
194 | pred_pdb2jsonpkl-thr.py
195 |
196 |
197 | # Pytorch
198 | *.pth
199 |
200 | *.swp
201 | .DS_Store
202 | *.json
203 | run/wandb/
204 | wandb/
205 | esm/
206 | figs/
207 | esm/*
208 | sampling/*
209 | model_zoom/*
210 | run/
211 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to *PLACEHOLDER*
2 |
3 | We welcome contributions from everyone to help improve and expand *PLACEHOLDER*. This document outlines the process for contributing to the project.
4 |
5 | ## Table of Contents
6 | 1. [Environment Setup](#environment-setup)
7 | 2. [Coding Standards](#coding-standards)
8 | 3. [Pull Request Process](#pull-request-process)
9 | 4. [Pull Request Template](#pull-request-template)
10 |
11 | ## Environment Setup
12 |
13 | To contribute to *PLACEHOLDER*, follow these steps to set up your development environment:
14 |
15 | 1. Clone the repository:
16 | ```
17 | git clone https://github.com/gersteinlab/placeholder.git
18 | cd placeholder
19 | ```
20 | 2. Create a Conda environment:
21 | ```
22 | conda create -n placeholder python=3.10
23 | conda activate placeholder
24 | ```
25 | 3. Install the project in editable mode with development dependencies:
26 | ```
27 | python3 -m pip install --upgrade pip
28 | pip install -e .
29 | ```
30 |
31 | ## Coding Standards
32 |
33 | We strive to maintain clean and consistent code throughout the project. Please adhere to the following guidelines:
34 |
35 | 1. Follow PEP 8 guidelines for Python code.
36 | 2. Use meaningful variable and function names.
37 | 3. Write docstrings for functions and classes.
38 | 4. Keep functions small and focused on a single task.
39 | 5. Use type hints where appropriate.
40 |
41 | ### Code Formatting
42 |
43 | We use `black` for code formatting. To ensure your code is properly formatted:
44 |
45 | 1. Install black:
46 | ```
47 | pip install black
48 | ```
49 | 2. Run black on the codebase:
50 | ```
51 | black .
52 | ```
53 |
54 | ## Pull Request Process
55 |
56 | 1. Create a new branch for your feature or bugfix; feature is for new function; bugfix is for fixing a bug:
57 | ```
58 | git checkout -b feature/your-feature-name
59 | ```
60 | 2. Make your changes and commit them with clear, concise commit messages.
61 | 1. Monitor the current conditions and check which files are modified or untracked
62 | ```
63 | git status
64 | ```
65 | 2. Git add your file
66 | ```
67 | git add schema.py
68 | ```
69 | 3. Submit your change and commit
70 | ```
71 | git commit -m "message"
72 | ```
73 | 4. Push your branch to the repository:
74 | ```
75 | git push origin feature/your-feature-name
76 | ```
77 | 5. Open a pull request against the `main` branch on the website.
78 | 6. Fill out the pull request template (see below).
79 | 7. Address any feedback or comments from reviewers.
80 |
81 | ## Pull Request Template
82 |
83 | When you open a new pull request, please use the following template:
84 |
85 | ```markdown
86 | ## Description
87 |
88 | ### Changes
89 | [Provide a detailed list of the changes made in this PR]
90 |
91 | ### Design
92 | [Explain the design decisions and architectural changes, if any]
93 |
94 | ### Example Code
95 | [If applicable, provide example code demonstrating the usage of new features or fixes]
96 |
97 | ## Related Issue
98 | [Link to the issue this PR addresses, if applicable]
99 |
100 | ## Type of Change
101 | - [ ] Bug fix (non-breaking change which fixes an issue)
102 | - [ ] New feature (non-breaking change which adds functionality)
103 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
104 | - [ ] This change requires a documentation update
105 |
106 | ## How Has This Been Tested?
107 | [Describe the tests you ran to verify your changes]
108 |
109 | ## Additional Notes
110 | [Add any additional information or context about the PR here]
111 | ```
112 |
113 | Thank you for contributing to *PLACEHOLDER*!
114 |
--------------------------------------------------------------------------------
/src/datasets/ts_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | import torch.utils.data as data
5 | import pickle
6 |
7 |
8 | def normalize_coordinates(surface):
9 | """
10 | Normalize the coordinates of the surface.
11 | """
12 | surface = np.array(surface)
13 | center = np.mean(surface, axis=0)
14 | max_ = np.max(surface, axis=0)
15 | min_ = np.min(surface, axis=0)
16 | length = np.max(max_ - min_)
17 | normalized_surface = (surface - center) / length
18 | return normalized_surface
19 |
20 |
21 | class TS50Dataset(data.Dataset):
22 | def __init__(self, path = './', split='test'):
23 | self.path = path
24 | if not os.path.exists(path):
25 | raise "no such file:{} !!!".format(path)
26 | else:
27 | ts50_data = json.load(open(path+'/ts50.json'))
28 |
29 | self.data_dict = self._load_data_dict()
30 |
31 | # TS500 has proteins with lengths of 500+
32 | # TS50 only contains proteins with lengths less than 500
33 | self.data = []
34 | for temp in ts50_data:
35 | title = temp['name']
36 | data = self.data_dict[title]
37 | seq_length = len(temp['seq'])
38 | coords = np.array(temp['coords'])
39 | self.data.append({'title':title,
40 | 'seq':temp['seq'],
41 | 'CA':coords[:,1,:],
42 | 'C':coords[:,2,:],
43 | 'O':coords[:,3,:],
44 | 'N':coords[:,0,:],
45 | 'category': 'ts50',
46 | 'chain_mask': np.ones(seq_length),
47 | 'chain_encoding': np.ones(seq_length),
48 | 'orig_surface': data['surface'],
49 | 'surface': normalize_coordinates(data['surface']),
50 | 'features': data['features'][:, :2],
51 | })
52 |
53 | def _load_data_dict(self):
54 | with open(self.path + f'/ts50.pkl', 'rb') as f:
55 | return pickle.load(f)
56 |
57 | def __len__(self):
58 | return len(self.data)
59 |
60 | def __getitem__(self, index):
61 | return self.data[index]
62 |
63 |
64 | class TS500Dataset(data.Dataset):
65 | def __init__(self, path = './', split='test'):
66 | self.path = path
67 | if not os.path.exists(path):
68 | raise "no such file:{} !!!".format(path)
69 | else:
70 | ts500_data = json.load(open(path+'/ts500.json'))
71 |
72 | self.data_dict = self._load_data_dict()
73 |
74 | # TS500 has proteins with lengths of 500+
75 | # TS50 only contains proteins with lengths less than 500
76 | self.data = []
77 | for temp in ts500_data:
78 | title = temp['name']
79 | data = self.data_dict[title]
80 | seq_length = len(temp['seq'])
81 | coords = np.array(temp['coords'])
82 | self.data.append({'title':title,
83 | 'seq':temp['seq'],
84 | 'CA':coords[:,1,:],
85 | 'C':coords[:,2,:],
86 | 'O':coords[:,3,:],
87 | 'N':coords[:,0,:],
88 | 'category': 'ts500',
89 | 'chain_mask': np.ones(seq_length),
90 | 'chain_encoding': np.ones(seq_length),
91 | 'orig_surface': data['surface'],
92 | 'surface': normalize_coordinates(data['surface']),
93 | 'features': data['features'][:, :2],
94 | })
95 |
96 | def _load_data_dict(self):
97 | with open(self.path + f'/ts500.pkl', 'rb') as f:
98 | return pickle.load(f)
99 |
100 | def __len__(self):
101 | return len(self.data)
102 |
103 | def __getitem__(self, index):
104 | return self.data[index]
--------------------------------------------------------------------------------
/src/tools/config_utils.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import tempfile
3 | import re
4 | import shutil
5 | import sys
6 | import ast
7 | from importlib import import_module
8 |
9 | '''
10 | Thanks the code from https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py wrote by Open-MMLab.
11 | The `Config` class here uses some parts of this reference.
12 | '''
13 |
14 | def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
15 | if not osp.isfile(filename):
16 | raise FileNotFoundError(msg_tmpl.format(filename))
17 |
18 |
19 | class Config:
20 | def __init__(self, cfg_dict=None, filename=None):
21 | if cfg_dict is None:
22 | cfg_dict = dict()
23 | elif not isinstance(cfg_dict, dict):
24 | raise TypeError('cfg_dict must be a dict, but '
25 | f'got {type(cfg_dict)}')
26 |
27 | if filename is not None:
28 | cfg_dict = self._file2dict(filename, True)
29 | filename = filename
30 |
31 | super(Config, self).__setattr__('_cfg_dict', cfg_dict)
32 | super(Config, self).__setattr__('_filename', filename)
33 |
34 | @staticmethod
35 | def _validate_py_syntax(filename):
36 | with open(filename, 'r') as f:
37 | content = f.read()
38 | try:
39 | ast.parse(content)
40 | except SyntaxError as e:
41 | raise SyntaxError('There are syntax errors in config '
42 | f'file {filename}: {e}')
43 |
44 | @staticmethod
45 | def _substitute_predefined_vars(filename, temp_config_name):
46 | file_dirname = osp.dirname(filename)
47 | file_basename = osp.basename(filename)
48 | file_basename_no_extension = osp.splitext(file_basename)[0]
49 | file_extname = osp.splitext(filename)[1]
50 | support_templates = dict(
51 | fileDirname=file_dirname,
52 | fileBasename=file_basename,
53 | fileBasenameNoExtension=file_basename_no_extension,
54 | fileExtname=file_extname)
55 | with open(filename, 'r') as f:
56 | config_file = f.read()
57 | for key, value in support_templates.items():
58 | regexp = r'\{\{\s*' + str(key) + r'\s*\}\}'
59 | value = value.replace('\\', '/')
60 | config_file = re.sub(regexp, value, config_file)
61 | with open(temp_config_name, 'w') as tmp_config_file:
62 | tmp_config_file.write(config_file)
63 |
64 | @staticmethod
65 | def _file2dict(filename, use_predefined_variables=True):
66 | filename = osp.abspath(osp.expanduser(filename))
67 | check_file_exist(filename)
68 | fileExtname = osp.splitext(filename)[1]
69 | if fileExtname not in ['.py']:
70 | raise IOError('Only py type are supported now!')
71 |
72 | with tempfile.TemporaryDirectory() as temp_config_dir:
73 | temp_config_file = tempfile.NamedTemporaryFile(
74 | dir=temp_config_dir, suffix=fileExtname)
75 | temp_config_name = osp.basename(temp_config_file.name)
76 |
77 | # Substitute predefined variables
78 | if use_predefined_variables:
79 | Config._substitute_predefined_vars(filename,
80 | temp_config_file.name)
81 | else:
82 | shutil.copyfile(filename, temp_config_file.name)
83 |
84 | if filename.endswith('.py'):
85 | temp_module_name = osp.splitext(temp_config_name)[0]
86 | sys.path.insert(0, temp_config_dir)
87 | Config._validate_py_syntax(filename)
88 | mod = import_module(temp_module_name)
89 | sys.path.pop(0)
90 | cfg_dict = {
91 | name: value
92 | for name, value in mod.__dict__.items()
93 | if not name.startswith('__')
94 | }
95 | # delete imported module
96 | del sys.modules[temp_module_name]
97 | # close temp file
98 | temp_config_file.close()
99 | return cfg_dict
100 |
101 | @staticmethod
102 | def fromfile(filename, use_predefined_variables=True):
103 | cfg_dict = Config._file2dict(filename, use_predefined_variables)
104 | return Config(cfg_dict, filename=filename)
105 |
--------------------------------------------------------------------------------
/train/main_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.getcwd())
4 |
5 | import warnings
6 | warnings.filterwarnings("ignore")
7 |
8 | import argparse
9 | import torch
10 | from model_interface import MInterface
11 | from data_interface import DInterface
12 |
13 | import pytorch_lightning as pl
14 | from pytorch_lightning.trainer import Trainer
15 | torch.autograd.set_detect_anomaly(True)
16 |
17 | def create_parser():
18 | checkpoint_path = './UBC2Model.ckpt'
19 | ex_name = 'UBC2Model'
20 | batch_size = 2
21 |
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--res_dir', default='./train/results', type=str)
24 | parser.add_argument('--ex_name', default=ex_name, type=str)
25 | parser.add_argument('--check_val_every_n_epoch', default=1, type=int)
26 | parser.add_argument('--dataset', default='CATH4.2')
27 | parser.add_argument('--model_name', default='UBC2Model')
28 | parser.add_argument('--lr', default=0.0002, type=float, help='Learning rate')
29 | parser.add_argument('--lr_scheduler', default='onecycle')
30 | parser.add_argument('--offline', default=0, type=int)
31 | parser.add_argument('--seed', default=111, type=int)
32 |
33 | # dataset parameters
34 | parser.add_argument('--batch_size', default=batch_size, type=int)
35 | parser.add_argument('--num_workers', default=0, type=int)
36 | parser.add_argument('--pad', default=1024, type=int)
37 | parser.add_argument('--min_length', default=40, type=int)
38 | parser.add_argument('--data_root', default='./data/')
39 |
40 | # Testing specific parameters
41 | parser.add_argument('--epoch', default=50, type=int, help='end epoch')
42 | parser.add_argument('--augment_eps', default=0.0, type=float, help='noise level')
43 |
44 | # Model parameters
45 | parser.add_argument('--use_dist', default=1, type=int)
46 | parser.add_argument('--use_product', default=0, type=int)
47 |
48 | # Checkpoint parameter
49 | parser.add_argument('--checkpoint_path', default=checkpoint_path, type=str, help='Path to a checkpoint to resume testing')
50 |
51 | parser.add_argument('--contrastive_pretrain', default=False, type=bool)
52 | parser.add_argument('--contrastive_learning', default=False, type=bool)
53 | parser.add_argument('--if_strucenc_only', default=False, type=bool)
54 | parser.add_argument('--if_warmup_train', default=False, type=bool)
55 |
56 | parser.add_argument('--if_struc_only', default=False, type=bool)
57 | parser.add_argument('--exp_bc_mask_rate', default=0., type=float)
58 | parser.add_argument('--bc_mask_max_rate', default=0., type=float)
59 | parser.add_argument('--exp_hydro_mask_rate', default=0., type=float)
60 | parser.add_argument('--exp_charge_mask_rate', default=0., type=float)
61 | parser.add_argument('--exp_v_mask_rate', default=0., type=float)
62 | parser.add_argument('--exp_e_mask_rate', default=0., type=float)
63 | parser.add_argument('--exp_backbone_noise_sd', default=0., type=float)
64 | parser.add_argument('--exp_wo_bcgraph', default=False, type=bool)
65 |
66 | parser.add_argument('--partial_design', default=False, type=bool)
67 | parser.add_argument('--design_region_path', default='')
68 |
69 | parser.add_argument('--bc_indices', nargs='+', type=int, default=[0, 1])
70 |
71 | args = parser.parse_args()
72 | return args
73 |
74 | def load_callbacks(args):
75 | callbacks = []
76 | return callbacks
77 |
78 | if __name__ == "__main__":
79 | args = create_parser()
80 | pl.seed_everything(args.seed)
81 |
82 | # Initialize data module and setup test data
83 | data_module = DInterface(**vars(args))
84 | data_module.setup() # Ensure the test dataset is loaded
85 |
86 | gpu_count = 1
87 | print(f"Using {gpu_count} GPUs for testing")
88 |
89 | # Initialize the model
90 | model = MInterface(**vars(args))
91 |
92 | # Trainer configuration
93 | trainer_config = {
94 | 'devices': gpu_count,
95 | 'num_nodes': 1, # Number of nodes to use for distributed training
96 | 'precision': 32,
97 | 'accelerator': 'gpu',
98 | 'callbacks': load_callbacks(args),
99 | }
100 |
101 | trainer_opt = argparse.Namespace(**trainer_config)
102 | trainer_dict = vars(trainer_opt)
103 | trainer = Trainer(**trainer_dict)
104 | # Perform testing
105 | if args.checkpoint_path:
106 | print(f"Resuming from checkpoint: {args.checkpoint_path}")
107 | trainer.test(model, datamodule=data_module, ckpt_path=args.checkpoint_path)
108 | else:
109 | print("No checkpoint provided, testing with current model state")
110 |
111 | print(trainer_config)
112 |
--------------------------------------------------------------------------------
/train/data_interface.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from torch.utils.data import DataLoader
3 | from src.interface.data_interface import DInterface_base
4 | import torch
5 | import os.path as osp
6 |
7 | class MyDataLoader(DataLoader):
8 | def __init__(self, dataset, model_name, batch_size=64, num_workers=8, *args, **kwargs):
9 | super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, *args, **kwargs)
10 | self.pretrain_device = 'cuda:0'
11 | self.model_name = model_name
12 |
13 | def __iter__(self):
14 | for batch in super().__iter__():
15 | # 在这里对batch进行处理
16 | # ...
17 | try:
18 | self.pretrain_device = f'cuda:{torch.distributed.get_rank()}'
19 | except:
20 | self.pretrain_device = 'cuda:0'
21 |
22 | stream = torch.cuda.Stream(
23 | self.pretrain_device
24 | )
25 | with torch.cuda.stream(stream):
26 | if self.model_name=='GVP':
27 | batch = batch.cuda(non_blocking=True, device=self.pretrain_device)
28 | yield batch
29 | else:
30 | for key, val in batch.items():
31 | if type(val) == torch.Tensor:
32 | batch[key] = batch[key].cuda(non_blocking=True, device=self.pretrain_device)
33 |
34 | yield batch
35 |
36 |
37 | class DInterface(DInterface_base):
38 | def __init__(self,**kwargs):
39 | super().__init__(**kwargs)
40 | self.save_hyperparameters()
41 | self.load_data_module()
42 | self.exp_backbone_noise_sd = kwargs.get('exp_backbone_noise_sd', 0.0)
43 | self.partial_design = kwargs.get('partial_design', False)
44 | self.design_region_path = kwargs.get('design_region_path', '')
45 | self.ig_baseline_data = kwargs.get('ig_baseline_data', False)
46 |
47 | def setup(self, stage=None):
48 | from src.datasets.featurizer import (featurize_UBC2Model)
49 | if self.hparams.model_name == 'UBC2Model' or self.hparams.model_name == 'UBC2Large' or self.hparams.model_name == 'UBC01234':
50 | self.collate_fn = featurize_UBC2Model(
51 | exp_backbone_noise_sd=self.exp_backbone_noise_sd,
52 | partial_design=self.partial_design,
53 | design_region_path=self.design_region_path,
54 | ig_baseline_data=self.ig_baseline_data
55 | ).featurize
56 |
57 | # Assign train/val datasets for use in dataloaders
58 | if stage == 'fit' or stage is None:
59 | self.trainset = self.instancialize(split = 'train')
60 | self.valset = self.instancialize(split='valid')
61 |
62 | # Assign test dataset for use in dataloader(s)
63 | if stage == 'test' or stage is None:
64 | self.testset = self.instancialize(split='test')
65 |
66 | def train_dataloader(self):
67 | return MyDataLoader(self.trainset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=True, prefetch_factor=None, pin_memory=True, collate_fn=self.collate_fn)
68 | # return MyDataLoader(self.trainset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=True, prefetch_factor=8, pin_memory=True, collate_fn=self.collate_fn)
69 |
70 | def val_dataloader(self):
71 | return MyDataLoader(self.valset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, prefetch_factor=None, pin_memory=True, collate_fn=self.collate_fn)
72 | # return MyDataLoader(self.valset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn)
73 |
74 | def test_dataloader(self):
75 | return MyDataLoader(self.testset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, prefetch_factor=None, pin_memory=True, collate_fn=self.collate_fn)
76 | # return MyDataLoader(self.testset, model_name=self.hparams.model_name, batch_size=self.batch_size, num_workers=self.hparams.num_workers, shuffle=False, pin_memory=True, collate_fn=self.collate_fn)
77 |
78 | def load_data_module(self):
79 | name = self.hparams.dataset
80 | if name == 'CATH4.2':
81 | from src.datasets.cath_dataset import CATHDatasetSurfProPiFoldDenseLarge
82 | self.data_module = CATHDatasetSurfProPiFoldDenseLarge
83 | self.hparams['version'] = 4.2
84 | self.hparams['path'] = osp.join(self.hparams.data_root, 'cath4.2')
85 |
86 | elif name == 'CATHAFDB':
87 | from src.datasets.cathafdb_dataset import CATHAFDBDataset
88 | self.data_module = CATHAFDBDataset
89 | self.hparams['version'] = 4.2
90 | self.hparams['path_cath'] = osp.join(self.hparams.data_root, 'cath4.2')
91 | self.hparams['path_afdb'] = osp.join(self.hparams.data_root, 'afdb-large4000')
92 |
93 | elif name == 'TS50':
94 | from src.datasets.ts_dataset import TS50Dataset
95 | self.data_module = TS50Dataset
96 | self.hparams['path'] = osp.join(self.hparams.data_root, 'ts50')
97 |
98 | elif name == 'TS500':
99 | from src.datasets.ts_dataset import TS500Dataset
100 | self.data_module = TS500Dataset
101 | self.hparams['path'] = osp.join(self.hparams.data_root, 'ts500')
102 |
103 | elif name == 'AFDB2000':
104 | from src.datasets.afdb_dataset import AFDB2000Dataset
105 | self.data_module = AFDB2000Dataset
106 | self.hparams['path'] = osp.join(self.hparams.data_root, 'afdb2000')
107 |
108 | else:
109 | from src.datasets.inference_dataset import InferenceDataset
110 | self.data_module = InferenceDataset
111 | self.hparams['path'] = osp.join(self.hparams.data_root, name)
112 |
113 | def instancialize(self, **other_args):
114 | """ Instancialize a model using the corresponding parameters
115 | from self.hparams dictionary. You can also input any args
116 | to overwrite the corresponding value in self.kwargs.
117 | """
118 |
119 | class_args = list(inspect.signature(self.data_module.__init__).parameters)[1:]
120 | inkeys = self.hparams.keys()
121 | args1 = {}
122 | for arg in class_args:
123 | if arg in inkeys:
124 | args1[arg] = self.hparams[arg]
125 | args1.update(other_args)
126 | # print('finish instancialize')
127 | return self.data_module(**args1)
--------------------------------------------------------------------------------
/train/main.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import sys
4 | sys.path.append(os.getcwd())
5 |
6 | import warnings
7 | warnings.filterwarnings("ignore")
8 |
9 | import argparse
10 | import yaml
11 | import torch
12 | from model_interface import MInterface
13 | from data_interface import DInterface
14 | from src.tools.logger import SetupCallback,BackupCodeCallback
15 | import math
16 | from shutil import ignore_patterns
17 |
18 | import pytorch_lightning as pl
19 | from pytorch_lightning.trainer import Trainer
20 | import pytorch_lightning.callbacks as plc
21 | import pytorch_lightning.loggers as plog
22 | torch.autograd.set_detect_anomaly(True)
23 |
24 | def create_parser():
25 | parser = argparse.ArgumentParser()
26 | # Set-up parameters
27 | parser.add_argument('--res_dir', default='./train/results', type=str)
28 | parser.add_argument('--ex_name', default='BC-Design-reproduce', type=str)
29 | parser.add_argument('--check_val_every_n_epoch', default=1, type=int)
30 |
31 | parser.add_argument('--dataset', default='CATH4.2')
32 | parser.add_argument('--model_name', default='UBC2Model')
33 | parser.add_argument('--lr', default=0.0002, type=float, help='Learning rate')
34 | parser.add_argument('--lr_scheduler', default='onecycle')
35 | parser.add_argument('--offline', default=1, type=int)
36 | parser.add_argument('--seed', default=111, type=int)
37 |
38 | # dataset parameters
39 | parser.add_argument('--batch_size', default=2, type=int)
40 | parser.add_argument('--num_workers', default=0, type=int)
41 | parser.add_argument('--pad', default=1024, type=int)
42 | parser.add_argument('--min_length', default=40, type=int)
43 | parser.add_argument('--data_root', default='./data/')
44 |
45 | # Training parameters
46 | parser.add_argument('--epoch', default=50, type=int, help='end epoch')
47 | parser.add_argument('--augment_eps', default=0.0, type=float, help='noise level')
48 |
49 | # Model parameters
50 | parser.add_argument('--use_dist', default=1, type=int)
51 | parser.add_argument('--use_product', default=0, type=int)
52 |
53 | # Checkpoint parameter
54 | parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to a checkpoint to resume training')
55 |
56 | parser.add_argument('--contrastive_pretrain', default=False, type=bool)
57 | parser.add_argument('--contrastive_learning', default=False, type=bool)
58 | parser.add_argument('--if_strucenc_only', default=False, type=bool)
59 | parser.add_argument('--if_warmup_train', default=False, type=bool)
60 |
61 | parser.add_argument('--if_struc_only', default=False, type=bool)
62 | parser.add_argument('--exp_bc_mask_rate', default=0., type=float)
63 | parser.add_argument('--bc_mask_max_rate', default=0., type=float)
64 | parser.add_argument('--exp_backbone_noise_sd', default=0., type=float)
65 |
66 | parser.add_argument('--partial_design', default=False, type=bool)
67 | parser.add_argument('--design_region_path', default='')
68 |
69 | args = parser.parse_args()
70 | return args
71 |
72 |
73 | def load_yaml_config_simple(args):
74 | """Load YAML config and update args for existing parameters only."""
75 | yaml_path = f"./src/models/configs/{args.model_name}.yaml"
76 |
77 | if os.path.exists(yaml_path):
78 | print(f"Loading config from {yaml_path}")
79 | with open(yaml_path, 'r') as f:
80 | yaml_config = yaml.safe_load(f)
81 |
82 | # Only update args that already exist
83 | updated_params = []
84 | for key, value in yaml_config.items():
85 | if hasattr(args, key):
86 | setattr(args, key, value)
87 | updated_params.append(key)
88 |
89 | if updated_params:
90 | print(f"Updated parameters from YAML: {updated_params}")
91 | else:
92 | print(f"Config file {yaml_path} not found, using defaults")
93 |
94 | return args
95 |
96 |
97 | def load_callbacks(args):
98 | callbacks = []
99 |
100 | logdir = str(os.path.join(args.res_dir, args.ex_name))
101 |
102 | ckptdir = os.path.join(logdir, "checkpoints")
103 |
104 | callbacks.append(BackupCodeCallback(os.path.dirname(args.res_dir),logdir, ignore_patterns=ignore_patterns('results*', 'pdb*', 'metadata*', 'vq_dataset*')))
105 |
106 | metric = "recovery"
107 | sv_filename = 'best-{epoch:02d}-{recovery:.3f}'
108 | callbacks.append(plc.ModelCheckpoint(
109 | monitor=metric,
110 | filename=sv_filename,
111 | save_top_k=15,
112 | mode='max',
113 | save_last=True,
114 | dirpath = ckptdir,
115 | verbose = True,
116 | every_n_epochs = args.check_val_every_n_epoch,
117 | ))
118 |
119 | now = datetime.datetime.now().strftime("%m-%dT%H-%M-%S")
120 | cfgdir = os.path.join(logdir, "configs")
121 | callbacks.append(
122 | SetupCallback(
123 | now = now,
124 | logdir = logdir,
125 | ckptdir = ckptdir,
126 | cfgdir = cfgdir,
127 | config = args.__dict__,
128 | argv_content = sys.argv + ["gpus: {}".format(torch.cuda.device_count())],)
129 | )
130 |
131 | if args.lr_scheduler:
132 | callbacks.append(plc.LearningRateMonitor(
133 | logging_interval=None))
134 | return callbacks
135 |
136 |
137 | if __name__ == "__main__":
138 | args = create_parser()
139 |
140 | # Load YAML config and update existing parameters
141 | args = load_yaml_config_simple(args)
142 |
143 | pl.seed_everything(args.seed)
144 |
145 | data_module = DInterface(**vars(args))
146 | data_module.setup()
147 |
148 | gpu_count = torch.cuda.device_count()
149 | args.steps_per_epoch = math.ceil(len(data_module.trainset)/args.batch_size/gpu_count)
150 | print(f"steps_per_epoch {args.steps_per_epoch}, gpu_count {gpu_count}, batch_size{args.batch_size}")
151 |
152 | model = MInterface(**vars(args))
153 |
154 | trainer_config = {
155 | 'devices': gpu_count,
156 | 'max_epochs': args.epoch, # Maximum number of epochs to train for
157 | 'num_nodes': 1, # Number of nodes to use for distributed training
158 | "strategy": 'ddp_find_unused_parameters_true',
159 | 'precision': 32,
160 | 'accelerator': 'gpu', # Use distributed data parallel
161 | 'callbacks': load_callbacks(args),
162 | 'logger': plog.WandbLogger(
163 | project = 'BC-Design',
164 | name=args.ex_name,
165 | save_dir=str(os.path.join(args.res_dir, args.ex_name)),
166 | offline = args.offline,
167 | id = "_".join(args.ex_name.split("/")),),
168 | 'gradient_clip_val':1.0
169 | }
170 |
171 | trainer_opt = argparse.Namespace(**trainer_config)
172 | trainer_dict = vars(trainer_opt)
173 | trainer = Trainer(**trainer_dict)
174 |
175 | trainer.fit(model, data_module)
176 |
177 | print(trainer_config)
178 |
--------------------------------------------------------------------------------
/src/datasets/cath_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import numpy as np
5 | from tqdm import tqdm
6 | import random
7 | import torch.utils.data as data
8 | from transformers import AutoTokenizer
9 |
10 | def normalize_coordinates(surface):
11 | """
12 | Normalize the coordinates of the surface.
13 | """
14 | surface = np.array(surface)
15 | center = np.mean(surface, axis=0)
16 | max_ = np.max(surface, axis=0)
17 | min_ = np.min(surface, axis=0)
18 | length = np.max(max_ - min_)
19 | normalized_surface = (surface - center) / length
20 | return normalized_surface
21 |
22 |
23 | class CATHDatasetSurfProPiFoldDenseLarge(data.Dataset):
24 | def __init__(self, path='./', split='train', max_length=500, test_name='All', data=None, removeTS=0, version=4.2, bc_indices=None):
25 | self.version = version
26 | self.path = path
27 | self.mode = split
28 | self.max_length = max_length
29 | self.test_name = test_name
30 | self.removeTS = removeTS
31 | if bc_indices is None:
32 | self.bc_indices = [0, 1]
33 | else:
34 | self.bc_indices = bc_indices
35 |
36 | if self.removeTS:
37 | self.remove = json.load(open(self.path + '/remove.json', 'r'))['remove']
38 |
39 | if data is None:
40 | self.metadata = self._load_metadata()
41 | else:
42 | self.metadata = data
43 |
44 | self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="gaozhangyang/model_zoom/transformers")
45 |
46 | # Load the entire dictionary corresponding to the current mode
47 | self.data_dict = self._load_data_dict()
48 | def _load_metadata(self):
49 | alphabet = 'ACDEFGHIKLMNPQRSTVWY'
50 | alphabet_set = set([a for a in alphabet])
51 | metadata = []
52 |
53 | # Load the split JSON files
54 | with open(self.path + '/chain_set_splits.json') as f:
55 | dataset_splits = json.load(f)
56 |
57 | # Handle specific test splits if needed
58 | if self.test_name == 'L100':
59 | with open(self.path + '/test_split_L100.json') as f:
60 | test_splits = json.load(f)
61 | dataset_splits['test'] = test_splits['test']
62 |
63 | if self.test_name == 'sc':
64 | with open(self.path + '/test_split_sc.json') as f:
65 | test_splits = json.load(f)
66 | dataset_splits['test'] = test_splits['test']
67 |
68 | # Select the appropriate split
69 | if self.mode == 'valid':
70 | valid_titles = set(dataset_splits['validation'])
71 | else:
72 | valid_titles = set(dataset_splits[self.mode])
73 |
74 | if not os.path.exists(self.path):
75 | raise FileNotFoundError("No such file: {} !!!".format(self.path))
76 | else:
77 | with open(self.path + '/chain_set.jsonl') as f:
78 | lines = f.readlines()
79 | for line in tqdm(lines):
80 | entry = json.loads(line)
81 | if self.removeTS and entry['name'] in self.remove:
82 | continue
83 |
84 | bad_chars = set([s for s in entry['seq']]).difference(alphabet_set)
85 | if len(bad_chars) == 0 and len(entry['seq']) <= self.max_length and entry['name'] in valid_titles:
86 | entry['coords']['CA'] = np.array(entry['coords']['CA'])
87 | entry['coords']['C'] = np.array(entry['coords']['C'])
88 | entry['coords']['O'] = np.array(entry['coords']['O'])
89 | entry['coords']['N'] = np.array(entry['coords']['N'])
90 | # create a mask representing whether the position of any value of entry['coords']['CA'] or entry['coords']['C'] or entry['coords']['O'] or entry['coords']['N'] is nan or infinite
91 | # sum them up and check if the values are inf or nan
92 | coords = np.stack([
93 | entry['coords']['CA'],
94 | entry['coords']['C'],
95 | entry['coords']['O'],
96 | entry['coords']['N']
97 | ], axis=1) # shape: (L, 4, 3)
98 | mask = np.isnan(coords).sum(axis=(1,2)) > 0
99 | mask = mask | (np.isinf(coords).sum(axis=(1,2)) > 0)
100 | # remove the positions where the mask is True
101 | entry['coords']['CA'] = entry['coords']['CA'][~mask]
102 | entry['coords']['C'] = entry['coords']['C'][~mask]
103 | entry['coords']['O'] = entry['coords']['O'][~mask]
104 | entry['coords']['N'] = entry['coords']['N'][~mask]
105 | idx = np.where(~mask)[0]
106 | entry['seq'] = ''.join([entry['seq'][i] for i in idx])
107 | metadata.append({
108 | 'title': entry['name'],
109 | 'seq_length': len(entry['seq']),
110 | 'seq': entry['seq'],
111 | 'coords': entry['coords'],
112 | })
113 | return metadata
114 |
115 | def _load_data_dict(self):
116 | # Load the appropriate pickle file based on the mode and keep it in memory
117 | if self.mode == 'train':
118 | with open(self.path + f'/cath42_pc_train_sorted.pkl', 'rb') as f:
119 | return pickle.load(f)
120 | elif self.mode == 'valid':
121 | with open(self.path + f'/cath42_pc_validation_sorted.pkl', 'rb') as f:
122 | return pickle.load(f)
123 | elif self.mode == 'test':
124 | with open(self.path + f'/cath42_pc_test.pkl', 'rb') as f:
125 | return pickle.load(f)
126 |
127 | def __len__(self):
128 | return len(self.metadata)
129 |
130 | def _load_data_on_the_fly(self, index):
131 | entry = self.metadata[index]
132 | title = entry['title']
133 | seq_length = entry['seq_length']
134 |
135 | if title in self.data_dict:
136 | data = self.data_dict[title]
137 | data_entry = {
138 | 'title': title,
139 | 'seq': entry['seq'],
140 | 'CA': entry['coords']['CA'],
141 | 'C': entry['coords']['C'],
142 | 'O': entry['coords']['O'],
143 | 'N': entry['coords']['N'],
144 | 'chain_mask': np.ones(seq_length),
145 | 'chain_encoding': np.ones(seq_length),
146 | 'orig_surface': data['surface'],
147 | 'surface': normalize_coordinates(data['surface']),
148 | 'features': data['features'][:, self.bc_indices],
149 | }
150 |
151 | if self.mode == 'test':
152 | data_entry['category'] = 'Unknown'
153 | data_entry['score'] = 100.0
154 |
155 | return data_entry
156 | else:
157 | raise ValueError(f"Data for title {title} not found in the {self.mode} dictionary")
158 |
159 | def __getitem__(self, index):
160 | item = self._load_data_on_the_fly(index)
161 | L = len(item['seq'])
162 | if L > self.max_length:
163 | max_index = L - self.max_length
164 | truncate_index = random.randint(0, max_index)
165 | item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
166 | item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length]
167 | item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
168 | item['O'] = item['O'][truncate_index:truncate_index+self.max_length]
169 | item['N'] = item['N'][truncate_index:truncate_index+self.max_length]
170 | item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length]
171 | item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
172 | return item
173 |
174 |
175 |
--------------------------------------------------------------------------------
/src/datasets/cathafdb_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import numpy as np
5 | from tqdm import tqdm
6 | import random
7 | import torch.utils.data as data
8 | from transformers import AutoTokenizer
9 |
10 | def normalize_coordinates(surface):
11 | """
12 | Normalize the coordinates of the surface.
13 | """
14 | surface = np.array(surface)
15 | center = np.mean(surface, axis=0)
16 | max_ = np.max(surface, axis=0)
17 | min_ = np.min(surface, axis=0)
18 | length = np.max(max_ - min_)
19 | normalized_surface = (surface - center) / length
20 | return normalized_surface
21 |
22 |
23 | class CATHAFDBDataset(data.Dataset):
24 | def __init__(self, path_cath='./', path_afdb='./', split='train', max_length=1000, test_name='All', data=None, removeTS=0, version=4.2):
25 | self.version = version
26 | self.path_cath = path_cath
27 | self.path_afdb = path_afdb
28 | self.mode = split
29 | self.max_length = max_length
30 | self.test_name = test_name
31 | self.removeTS = removeTS
32 |
33 | if self.removeTS:
34 | self.remove = json.load(open(self.path_cath + '/remove.json', 'r'))['remove']
35 |
36 | if data is None:
37 | self.metadata = self._load_metadata()
38 | else:
39 | self.metadata = data
40 |
41 | self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="gaozhangyang/model_zoom/transformers")
42 |
43 | # Load the entire dictionary corresponding to the current mode
44 | self.data_dict = self._load_data_dict()
45 | def _load_metadata(self):
46 | alphabet = 'ACDEFGHIKLMNPQRSTVWY'
47 | alphabet_set = set([a for a in alphabet])
48 | metadata = []
49 |
50 | # Load the split JSON files
51 | with open(self.path_cath + '/chain_set_splits.json') as f:
52 | dataset_splits = json.load(f)
53 |
54 | # Handle specific test splits if needed
55 | if self.test_name == 'L100':
56 | with open(self.path_cath + '/test_split_L100.json') as f:
57 | test_splits = json.load(f)
58 | dataset_splits['test'] = test_splits['test']
59 |
60 | if self.test_name == 'sc':
61 | with open(self.path_cath + '/test_split_sc.json') as f:
62 | test_splits = json.load(f)
63 | dataset_splits['test'] = test_splits['test']
64 |
65 | # Select the appropriate split
66 | if self.mode == 'valid':
67 | valid_titles = set(dataset_splits['validation'])
68 | else:
69 | valid_titles = set(dataset_splits[self.mode])
70 |
71 | if not os.path.exists(self.path_cath):
72 | raise FileNotFoundError("No such file: {} !!!".format(self.path_cath))
73 | else:
74 | with open(self.path_cath + '/chain_set.jsonl') as f:
75 | lines = f.readlines()
76 | for line in tqdm(lines):
77 | entry = json.loads(line)
78 | if self.removeTS and entry['name'] in self.remove:
79 | continue
80 |
81 | bad_chars = set([s for s in entry['seq']]).difference(alphabet_set)
82 | if len(bad_chars) == 0 and len(entry['seq']) <= self.max_length and entry['name'] in valid_titles:
83 | entry['coords']['CA'] = np.array(entry['coords']['CA'])
84 | entry['coords']['C'] = np.array(entry['coords']['C'])
85 | entry['coords']['O'] = np.array(entry['coords']['O'])
86 | entry['coords']['N'] = np.array(entry['coords']['N'])
87 | # create a mask representing whether the position of any value of entry['coords']['CA'] or entry['coords']['C'] or entry['coords']['O'] or entry['coords']['N'] is nan or infinite
88 | # sum them up and check if the values are inf or nan
89 | coords = np.stack([
90 | entry['coords']['CA'],
91 | entry['coords']['C'],
92 | entry['coords']['O'],
93 | entry['coords']['N']
94 | ], axis=1) # shape: (L, 4, 3)
95 | mask = np.isnan(coords).sum(axis=(1,2)) > 0
96 | mask = mask | (np.isinf(coords).sum(axis=(1,2)) > 0)
97 | # remove the positions where the mask is True
98 | entry['coords']['CA'] = entry['coords']['CA'][~mask]
99 | entry['coords']['C'] = entry['coords']['C'][~mask]
100 | entry['coords']['O'] = entry['coords']['O'][~mask]
101 | entry['coords']['N'] = entry['coords']['N'][~mask]
102 | idx = np.where(~mask)[0]
103 | entry['seq'] = ''.join([entry['seq'][i] for i in idx])
104 | metadata.append({
105 | 'title': entry['name'],
106 | 'seq_length': len(entry['seq']),
107 | 'seq': entry['seq'],
108 | 'coords': entry['coords'],
109 | })
110 |
111 | if self.mode == 'train':
112 | if not os.path.exists(self.path_afdb):
113 | raise "no such file:{} !!!".format(self.path_afdb)
114 | else:
115 | afdb_data = json.load(open(self.path_afdb+'/afdb-large4000.json'))
116 |
117 | for temp in tqdm(afdb_data):
118 | title = temp['name']
119 | seq_length = len(temp['seq'])
120 | coords = np.array(temp['coords'])
121 | coords_dict = {
122 | 'CA': coords[:,1,:],
123 | 'C': coords[:,2,:],
124 | 'O': coords[:,3,:],
125 | 'N': coords[:,0,:],
126 | }
127 | metadata.append({'title':title,
128 | 'seq':temp['seq'],
129 | 'seq_length': seq_length,
130 | 'coords': coords_dict,
131 | })
132 | return metadata
133 |
134 | def _load_data_dict(self):
135 | # Load the appropriate pickle file based on the mode and keep it in memory
136 | def _downcast_float32_inplace(d):
137 | # Convert numeric arrays to float32 to reduce memory
138 | for _title, item in d.items():
139 | if not isinstance(item, dict):
140 | continue
141 | if 'surface' in item:
142 | item['surface'] = np.asarray(item['surface'], dtype=np.float32)
143 | if 'features' in item:
144 | item['features'] = np.asarray(item['features'], dtype=np.float32)
145 |
146 | if self.mode == 'train':
147 | with open(self.path_cath + f'/cath42_pc_train_sorted.pkl', 'rb') as f:
148 | data_dict_cath = pickle.load(f)
149 | _downcast_float32_inplace(data_dict_cath)
150 | with open(self.path_afdb + f'/afdb-large4000.pkl', 'rb') as f:
151 | data_dict_afdb = pickle.load(f)
152 | _downcast_float32_inplace(data_dict_afdb)
153 | data_dict = {**data_dict_cath, **data_dict_afdb}
154 | return data_dict
155 | elif self.mode == 'valid':
156 | with open(self.path_cath + f'/cath42_pc_validation_sorted.pkl', 'rb') as f:
157 | data_dict = pickle.load(f)
158 | _downcast_float32_inplace(data_dict)
159 | return data_dict
160 | elif self.mode == 'test':
161 | with open(self.path_cath + f'/cath42_pc_test.pkl', 'rb') as f:
162 | data_dict = pickle.load(f)
163 | _downcast_float32_inplace(data_dict)
164 | return data_dict
165 |
166 | def __len__(self):
167 | return len(self.metadata)
168 |
169 | def _load_data_on_the_fly(self, index):
170 | entry = self.metadata[index]
171 | title = entry['title']
172 | seq_length = entry['seq_length']
173 |
174 | if title in self.data_dict:
175 | data = self.data_dict[title]
176 | data_entry = {
177 | 'title': title,
178 | 'seq': entry['seq'],
179 | 'CA': entry['coords']['CA'],
180 | 'C': entry['coords']['C'],
181 | 'O': entry['coords']['O'],
182 | 'N': entry['coords']['N'],
183 | 'chain_mask': np.ones(seq_length),
184 | 'chain_encoding': np.ones(seq_length),
185 | 'orig_surface': data['surface'],
186 | 'surface': normalize_coordinates(data['surface']),
187 | 'features': data['features'][:, :2],
188 | }
189 |
190 | if self.mode == 'test':
191 | data_entry['category'] = 'Unknown'
192 | data_entry['score'] = 100.0
193 |
194 | return data_entry
195 | else:
196 | raise ValueError(f"Data for title {title} not found in the {self.mode} dictionary")
197 |
198 | def __getitem__(self, index):
199 | item = self._load_data_on_the_fly(index)
200 | L = len(item['seq'])
201 | if L > self.max_length:
202 | max_index = L - self.max_length
203 | truncate_index = random.randint(0, max_index)
204 | item['seq'] = item['seq'][truncate_index:truncate_index+self.max_length]
205 | item['CA'] = item['CA'][truncate_index:truncate_index+self.max_length]
206 | item['C'] = item['C'][truncate_index:truncate_index+self.max_length]
207 | item['O'] = item['O'][truncate_index:truncate_index+self.max_length]
208 | item['N'] = item['N'][truncate_index:truncate_index+self.max_length]
209 | item['chain_mask'] = item['chain_mask'][truncate_index:truncate_index+self.max_length]
210 | item['chain_encoding'] = item['chain_encoding'][truncate_index:truncate_index+self.max_length]
211 | return item
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BC-Design: A Biochemistry-Aware Framework for Inverse Protein Design
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
16 |
17 | This repository contains the implementation code for the paper:
18 |
19 | [**BC-Design: A Biochemistry-Aware Framework for Inverse Protein Design**]
20 |
21 | Xiangru Tang†, Xinwu Ye†, Fang Wu†, Yimeng Liu, Anna Su, Antonia Panescu, Guanlue Li, Daniel Shao, Dong Xu, and Mark Gerstein*.
22 |
23 | † Equal contribution
24 |
25 |
26 |
27 |
28 | 
29 |
30 |
42 |
43 | ## Overview
44 |
45 |
46 | Code Structures
47 |
48 | 
49 |
50 | - `src/datasets` contains datasets, featurizer, and utils
51 | - `src/interface` contains customized Pytorch-lightning data modules and modules.
52 | - `src/models/` contains the main BC-Design model architecture.
53 | - `src/tools` contains some script files of some tools.
54 | - `train` contains the training and inference script.
55 |
56 |
57 |
58 | ## News and Updates
59 |
60 | - [🆕 2025-11-23] Major updates:
61 | - Implemented a complete **backbone-only inference pipeline**.
62 | - Added **partial-information testing** script with controllable biochemical-feature masking (0–100% masking), enabling tunable recovery–diversity trade-offs.
63 | - Added full **PDB preprocessing utilities** (`pdb2jsonpkl.py`) to convert arbitrary protein structures into the BC-Design input format.
64 | - Cleaned and consolidated training/evaluation code, environment files, and documentation.
65 | - [🚀 2024-10-30] The official code is released.
66 |
67 |
68 | ## ⚙️ Installation
69 |
70 | This section guides you through setting up the necessary environment and dependencies to run BC-Design.
71 |
72 | ### Step 1: Prerequisites - CUDA and GCC
73 |
74 | Before creating the Conda environment, please ensure your system meets the following requirements. While other versions might also work, our code was developed and tested using the specific versions listed below:
75 |
76 | 1. **CUDA Version:** This codebase has been validated on **CUDA 12.8 with NVIDIA driver 570.133.20**, so running on that (or an equivalent, compatible setup) is recommended.
77 | 2. **GCC Compiler:** A C/C++ compiler is needed, specifically **GCC version 12.2.0** or a compatible version. This codebase has been validated on **GCC version 12.2.0**.
78 | * **Linux:** You can typically install GCC using your system's package manager. For example, on Debian/Ubuntu-based systems, you might use:
79 | ```shell
80 | sudo apt update
81 | sudo apt install gcc-12 g++-12
82 | ```
83 | On other distributions, use the appropriate package manager (e.g., `yum`, `dnf`). You may need to configure your system to use this specific version if multiple GCC versions are installed.
84 | * **HPC Environments:** If you are using a High-Performance Computing (HPC) cluster, GCC is often managed via environment modules. You might load it using a command like:
85 | ```shell
86 | module load GCC/12.2.0
87 | ```
88 | (The exact command may vary based on your HPC's module system.)
89 | * **Other Systems (macOS, Windows via WSL2):** Ensure you have a compatible C/C++ compiler. For macOS, Xcode Command Line Tools provide Clang, which is often compatible. For Windows, WSL2 with a Linux distribution is recommended.
90 | 4. **Reference OS:** Development and testing took place on **Red Hat Enterprise Linux 8.10 (Ootpa)**. Other modern Linux distributions should work fine as long as the CUDA/GCC requirements above are satisfied.
91 |
92 | ### Step 2: Create Conda Environment
93 |
94 | This project has provided an environment setting file for **Miniconda3**. Users can easily reproduce the Python environment by following these commands:
95 |
96 | ```shell
97 | git clone https://github.com/gersteinlab/BC-Design.git
98 | cd BC-Design
99 | conda env create -f environment.yml -n [your-env-name]
100 | conda activate [your-env-name]
101 | ````
102 |
103 | Replace `[your-env-name]` with your preferred name for the Conda environment (e.g., `bcdn`).
104 |
105 | ### Step 3: Download Data and Model Checkpoint
106 |
107 | To train the model, you need to download the preprocessed data.
108 | To test with the released model weights, you should also download the checkpoint.
109 |
110 | 1. Navigate to the Hugging Face project page: [https://huggingface.co/datasets/XinwuYe/BC-Design/tree/main](https://huggingface.co/datasets/XinwuYe/BC-Design/tree/main)
111 | 2. Download the following files into the `BC-Design` folder (the main directory cloned from GitHub):
112 | * `data.zip` (contains data for training and testing)
113 | * `UBC2Model.ckpt` (the checkpoint for testing, download it only when you want to test with the releases model weights)
114 | 3. Once downloaded, unzip the data file:
115 | ```shell
116 | unzip data.zip
117 | ```
118 | This should create a `data/` directory inside your `BC-Design` folder.
119 |
120 | As an alternative, you can also run the following commands:
121 | ```shell
122 | wget https://huggingface.co/datasets/XinwuYe/BC-Design/resolve/main/data.zip?download=true -O data.zip
123 | unzip data.zip
124 | wget "https://huggingface.co/datasets/XinwuYe/BC-Design/resolve/main/UBC2Model.ckpt?download=true" -O UBC2Model.ckpt
125 | ````
126 |
127 | After completing these steps, your environment should be ready, and you'll have the necessary data (and model checkpoint) to proceed with using BC-Design.
128 |
129 |
130 | ## Getting Started
131 |
132 | ### Evaluate on CATH 4.2:
133 |
134 | The `train/main_eval.py` script is used to evaluate the trained BC-Design model on test datasets. It loads the specified dataset and the model checkpoint (`UBC2Model.ckpt` by default) to perform inference and report evaluation metrics.
135 |
136 | Note: `train/main_eval.py` computes structure-level metrics via ESMFold. For very large proteins, ESMFold may run out of GPU memory and fall back to CPU-based structure prediction, which significantly increases runtime. The commands below include rough runtime estimates; TS50 is the fastest dataset to reproduce the evaluation.
137 |
138 | To test on the test set of CATH4.2:
139 |
140 | ```shell
141 | python train/main_eval.py --dataset CATH4.2 # ~3.5 hours on 1 A100 GPU
142 | # Expected output: many metrics
143 | ```
144 |
145 |
146 | To test on TS50, TS500, or AFDB2000:
147 | ```shell
148 | python train/main_eval.py --dataset TS50 # ~2 mins on 1 A100 GPU
149 | python train/main_eval.py --dataset TS500 # ~9 hours on 1 A100 GPU
150 | python train/main_eval.py --dataset AFDB2000
151 | ```
152 |
153 | **Testing in backbone-only setting:**
154 |
155 | BC-Design now includes a complete **structure-only inference mode**,
156 | which uses *only* backbone coordinates as input and excludes all biochemical features.
157 |
158 | ```shell
159 | python train/main_eval.py --if_struc_only True --dataset [dataset-name]
160 | ```
161 |
162 | **Testing in partial-information setting:**
163 |
164 | BC-Design supports biochemical-feature masking, enabling controlled removal of biochemical information at inference time.
165 |
166 | Example (mask 60% of biochemical feature points):
167 | ```shell
168 | python train/main_eval.py --exp_bc_mask_rate 0.6 --dataset [dataset-name] # mask 60% of biochemical features in the input
169 | ```
170 | This mechanism allows users to reproduce intermediate recovery–diversity trade-offs.
171 |
172 | **Key functionalities of `main_eval.py`:**
173 | - **Dataset Selection:** You can specify the dataset for evaluation using the `--dataset` argument (e.g., `CATH4.2`, `TS50`, `TS500`, `AFDB2000`).
174 | - **Checkpoint Loading:** It loads a pre-trained model from the path specified by `--checkpoint_path` (defaults to `./UBC2Model.ckpt`).
175 | - **Evaluation Metrics:** The script calculates and displays various performance metrics such as test loss, sequence recovery, perplexity, pLDDT, and TM-score.
176 | - **Configurable Parameters:** Several aspects of the evaluation can be configured through command-line arguments, including:
177 | * `--res_dir`: Directory to store results.
178 | * `--batch_size`: Batch size for evaluation.
179 | * `--data_root`: Root directory of the dataset.
180 | * `--num_workers`: Number of workers for data loading.
181 | * For a full list of arguments and their default values, you can refer to the `create_parser()` function within the `train/main_eval.py` script.
182 |
183 | The predicted protein sequences will be saved under `predicted_pdb/[ex_name]/[dataset]`.
184 |
185 | ### Training Model
186 |
187 | Run the following commamds to reproduce training BC-Design on the CATH 4.2 training set. The model checkpoint will be saved as `./train/results/UBC2ModelReproduced/checkpoints/last.ckpt`.
188 |
189 | ```shell
190 | python train/main.py \
191 | --lr 0.001 \
192 | --if_strucenc_only True \
193 | --ex_name UBC2ModelStage1 # stage 1
194 |
195 | python train/main.py \
196 | --lr 0.0005 \
197 | --contrastive_learning True \
198 | --contrastive_pretrain True \
199 | --checkpoint_path "./train/results/UBC2ModelStage1/checkpoints/last.ckpt" \
200 | --ex_name UBC2ModelStage2 # stage 2
201 |
202 | python train/main.py \
203 | --lr 0.0005 \
204 | --if_warmup_train True \
205 | --checkpoint_path "./train/results/UBC2ModelStage2/checkpoints/last.ckpt" \
206 | --ex_name UBC2ModelStage3 # stage 3
207 |
208 | python train/main.py \
209 | --lr 0.00002 \
210 | --lr_scheduler cosine \
211 | --bc_mask_max_rate 3.0 \
212 | --checkpoint_path "./train/results/UBC2ModelStage3/checkpoints/last.ckpt" \
213 | --ex_name UBC2ModelReproduced # stage 4
214 | ```
215 |
216 | ### Data Preparation
217 |
218 | If you’d like to use BC-Design on your own data, run this command to convert your .pdb files into the format BC-Design expects:
219 | ```shell
220 | python pdb2jsonpkl.py --pdb_folder [dir-of-pdb-files] --dataset_name [dataset-name]
221 | ```
222 | After running it, the processed data will be saved in `.data/[dataset-name]`, and the `[dataset-name]` can be used directly as the `dataset` argument for `train/main_eval.py`.
223 |
224 | (back to top)
225 |
226 | ## License
227 |
228 | This project is released under the [Apache 2.0 license](LICENSE). See `LICENSE` for more information.
229 |
230 |
231 |
232 |
233 | ## Contribution and Contact
234 |
235 | For adding new features, looking for helps, or reporting bugs associated with `BC-Design`, please open a [GitHub issue](https://github.com/gersteinlab/BC-Design/issues) and [pull request](https://github.com/gersteinlab/BC-Design/pulls) with the tag "new features", "help wanted", or "enhancement". Please ensure that all pull requests meet the requirements outlined in our [contribution guidelines](https://github.com/gersteinlab/BC-Design/blob/public-release/CONTRIBUTING.md). Following these guidelines helps streamline the review process and maintain code quality across the project.
236 | Feel free to contact us through email if you have any questions.
237 |
238 |
239 | (back to top)
240 |
--------------------------------------------------------------------------------
/pdb2jsonpkl.py:
--------------------------------------------------------------------------------
1 | from Bio import PDB
2 | import os
3 | import json
4 | import numpy as np
5 | import pickle
6 | import torch
7 | from tqdm import tqdm
8 | from Bio.PDB import PDBParser, Structure, Model, Chain, Residue, Atom
9 | from Bio.PDB.ResidueDepth import get_surface
10 | from scipy.spatial import cKDTree, Delaunay
11 | from Bio.SeqUtils import seq1
12 | from Bio.PDB.Polypeptide import is_aa
13 | from Bio.PDB.PDBExceptions import PDBConstructionWarning
14 | import warnings
15 | import argparse
16 |
17 | # Suppress PDBConstructionWarning
18 | warnings.simplefilter('ignore', PDBConstructionWarning)
19 |
20 | # Define MSMS executable path
21 | msms_exec = '/gpfs/gibbs/pi/gerstein/xt86/surface/msms/msms.x86_64Linux2.2.6.1' # replace with your own path
22 | os.chmod(msms_exec, 0o755)
23 |
24 | # Define the biochemical features dictionary
25 | bio_feat_dict = {
26 | "hydrophobicity": {
27 | "I": 4.5, "V": 4.2, "L": 3.8, "F": 2.8, "C": 2.5, "M": 1.9, "A": 1.8,
28 | "W": -0.9, "G": -0.4, "T": -0.7, "S": -0.8, "Y": -1.3, "P": -1.6, "H": -3.2,
29 | "N": -3.5, "D": -3.5, "Q": -3.5, "E": -3.5, "K": -3.9, "R": -4.5
30 | },
31 | "charge": {
32 | "R": 1, "K": 1, "D": -1, "E": -1, "H": 0.1, "A": 0, "C": 0, "F": 0, "G": 0, "I": 0,
33 | "L": 0, "M": 0, "N": 0, "P": 0, "Q": 0, "S": 0, "T": 0, "V": 0, "W": 0, "Y": 0
34 | },
35 | "polarity": {
36 | "R": 1, "N": 1, "D": 1, "Q": 1, "E": 1, "H": 1, "K": 1, "S": 1, "T": 1, "Y": 1,
37 | "A": 0, "C": 0, "F": 0, "G": 0, "I": 0, "L": 0, "M": 0, "P": 0, "V": 0, "W": 0
38 | },
39 | "acceptor": {
40 | "D": 1, "E": 1, "N": 1, "Q": 1, "H": 1, "S": 1, "T": 1, "Y": 1,
41 | "A": 0, "C": 0, "F": 0, "G": 0, "I": 0, "K": 0, "L": 0, "M": 0, "P": 0, "R": 0, "V": 0, "W": 0
42 | },
43 | "donor": {
44 | "R": 1, "K": 1, "W": 1, "N": 1, "Q": 1, "H": 1, "S": 1, "T": 1, "Y": 1,
45 | "A": 0, "C": 0, "D": 0, "E": 0, "F": 0, "G": 0, "I": 0, "L": 0, "M": 0, "P": 0, "V": 0
46 | }
47 | }
48 |
49 | # Mapping from three-letter codes to one-letter codes
50 | three_to_one = {
51 | "ALA": "A", "CYS": "C", "ASP": "D", "GLU": "E", "PHE": "F", "GLY": "G",
52 | "HIS": "H", "ILE": "I", "LYS": "K", "LEU": "L", "MET": "M", "ASN": "N",
53 | "PRO": "P", "GLN": "Q", "ARG": "R", "SER": "S", "THR": "T", "VAL": "V",
54 | "TRP": "W", "TYR": "Y"
55 | }
56 |
57 |
58 | def parse_pdb(file_path):
59 | name = os.path.basename(file_path).replace('.pdb', '')
60 | parser = PDB.PDBParser(QUIET=True)
61 | structure = parser.get_structure(name, file_path)
62 |
63 | seq = ''
64 | coords = []
65 |
66 | for model in structure:
67 | for chain in model:
68 | seq += ''.join([three_to_one[res.get_resname()] for res in chain if res.get_id()[0] == ' '])
69 | atom_names = ['N', 'CA', 'C', 'O']
70 |
71 | for res in chain:
72 | if res.get_resname() in three_to_one.keys():
73 | coord_dict = {atom.get_name(): atom.get_coord().tolist() for atom in res if atom.get_name() in atom_names}
74 | if all(atom in coord_dict for atom in atom_names): # Ensure all atoms are present
75 | temp_coords = [coord_dict[atom] for atom in atom_names]
76 | if len(temp_coords) == 4: # Collect 4 sets of coordinates
77 | coords.append(temp_coords)
78 |
79 | return {'name': name, 'seq': seq, 'coords': coords}
80 |
81 |
82 | # Step 1: Create PDB structure from protein dict
83 | def create_pdb_structure(protein_data):
84 | structure_id = protein_data['name']
85 | sequence = protein_data['seq']
86 | coords = protein_data['coords']
87 |
88 | structure = Structure.Structure(structure_id)
89 | model = Model.Model(0)
90 | chain = Chain.Chain('A')
91 |
92 | aa_map = {'A': 'ALA', 'C': 'CYS', 'D': 'ASP', 'E': 'GLU', 'F': 'PHE', 'G': 'GLY', 'H': 'HIS',
93 | 'I': 'ILE', 'K': 'LYS', 'L': 'LEU', 'M': 'MET', 'N': 'ASN', 'P': 'PRO', 'Q': 'GLN',
94 | 'R': 'ARG', 'S': 'SER', 'T': 'THR', 'V': 'VAL', 'W': 'TRP', 'Y': 'TYR'}
95 |
96 | atom_names = ['N', 'CA', 'C', 'O']
97 |
98 | for res_index, (res, coord_set) in enumerate(zip(sequence, coords), start=1):
99 | residue = Residue.Residue((' ', res_index, ' '), aa_map[res], ' ')
100 | for atom_index, (atom_name, coord) in enumerate(zip(atom_names, coord_set)):
101 | atom = Atom.Atom(atom_name, coord, 1.0, 0.0, ' ', atom_name, atom_index, atom_name[0])
102 | residue.add(atom)
103 | chain.add(residue)
104 |
105 | model.add(chain)
106 | structure.add(model)
107 | return structure
108 |
109 | # Step 2: Feature assignment
110 | # Function to get atom coordinates and residue types
111 | def get_atom_coords_and_residues(structure):
112 | coords = []
113 | residues = []
114 | for model in structure:
115 | for chain in model:
116 | for residue in chain:
117 | for atom in residue:
118 | coords.append(atom.coord)
119 | residues.append(three_to_one.get(residue.get_resname(), ''))
120 | return np.array(coords), residues
121 |
122 |
123 | # Process each PDB file in the input directory
124 | def assign_features(surface, structure):
125 | atom_coords, residue_types = get_atom_coords_and_residues(structure)
126 |
127 | # Build k-D tree for atom coordinates
128 | kdtree = cKDTree(atom_coords)
129 |
130 | # Assign biochemical features to each vertex in the surface
131 | features = []
132 | for vertex in surface:
133 | dist, idx = kdtree.query(vertex)
134 | residue_type = residue_types[idx]
135 | residue_features = [bio_feat_dict[feat].get(residue_type, 0) for feat in bio_feat_dict]
136 | features.append(residue_features)
137 |
138 | # Convert features to a numpy array
139 | features_array = np.array(features)
140 |
141 | return features_array
142 |
143 |
144 | # Step 3: Smooth the surface
145 | # Function to perform Gaussian kernel smoothing on all points using PyTorch
146 | def gaussian_kernel_smoothing(coords, k=8, eta=None):
147 | # print(coords.shape)
148 | if len(coords) > 20000:
149 | # Generate random permutation of indices
150 | indices = torch.randperm(len(coords))[:20000]
151 | # Select the random indices along the 0-th axis
152 | coords = coords[indices]
153 | # Convert numpy array to PyTorch tensor and move to GPU
154 | coords = torch.tensor(coords, dtype=torch.float32).cuda(0)
155 |
156 | # Compute the full pairwise distance matrix
157 | dists = torch.cdist(coords, coords, p=2)
158 |
159 | if eta is None:
160 | eta = torch.max(dists).item()
161 |
162 | nearest_neighbors = torch.argsort(dists, dim=1)[:, 1:k+1]
163 |
164 | # Get the distances of the k-nearest neighbors
165 | nearest_dists = torch.gather(dists, 1, nearest_neighbors)
166 |
167 | # Compute weights using the Gaussian kernel
168 | weights = torch.exp(-nearest_dists**2 / eta)
169 | weights /= torch.sum(weights, dim=1, keepdim=True)
170 |
171 | # Compute the smoothed coordinates
172 | smoothed_coords = torch.sum(weights[:, :, None] * coords[nearest_neighbors], dim=1)
173 |
174 | return smoothed_coords.cpu().numpy()
175 |
176 |
177 | # Step 4: Compress the surface and features using octree-based compression
178 | class OctreeNode:
179 | def __init__(self, points, indices):
180 | self.points = points
181 | self.indices = indices
182 | self.children = []
183 |
184 | def create_octree(points, indices, min_points_per_cube):
185 | """
186 | Create an octree for the given points.
187 | """
188 | def divide(points, indices):
189 | if len(points) <= min_points_per_cube:
190 | return OctreeNode(points, indices)
191 | centroid = np.mean(points, axis=0)
192 | partitions = [[] for _ in range(8)]
193 | partition_indices = [[] for _ in range(8)]
194 | for idx, point in enumerate(points):
195 | partition_index = 0
196 | if point[0] > centroid[0]:
197 | partition_index += 1
198 | if point[1] > centroid[1]:
199 | partition_index += 2
200 | if point[2] > centroid[2]:
201 | partition_index += 4
202 | partitions[partition_index].append(point)
203 | partition_indices[partition_index].append(indices[idx])
204 | node = OctreeNode(None, None)
205 | node.children = [divide(part, part_idx) for part, part_idx in zip(partitions, partition_indices)]
206 | return node
207 |
208 | return divide(points, indices)
209 |
210 | def gather_points(node):
211 | """
212 | Gather points and indices from the octree.
213 | """
214 | if node.points is not None:
215 | return [(node.points, node.indices)]
216 | result = []
217 | for child in node.children:
218 | result.extend(gather_points(child))
219 | return result
220 |
221 | def compress_surface(points, features, down_sample_ratio, min_points_per_cube=32):
222 | """
223 | Compress the surface and features using octree-based compression.
224 | """
225 | indices = np.arange(points.shape[0])
226 | octree = create_octree(points, indices, min_points_per_cube)
227 | compressed_points = []
228 | compressed_features = []
229 | for cube_points, cube_indices in gather_points(octree):
230 | local_density = len(cube_points)
231 | num_points = int(local_density * down_sample_ratio)
232 | if num_points > 0:
233 | sampled_indices = np.random.choice(local_density, num_points, replace=False)
234 | sampled_points = np.array(cube_points)[sampled_indices]
235 | sampled_features = features[cube_indices][sampled_indices]
236 | compressed_points.extend(sampled_points)
237 | compressed_features.extend(sampled_features)
238 | return np.array(compressed_points), np.array(compressed_features)
239 |
240 |
241 | # Step 5: Add interior points
242 | # Function to get biochemical features from a residue
243 | def get_biochem_features(residue):
244 | # Define hydrophobicity scale (Kyte-Doolittle)
245 | hydrophobicity_scale = {
246 | 'A': 1.8, 'C': 2.5, 'D': -3.5, 'E': -3.5, 'F': 2.8,
247 | 'G': -0.4, 'H': -3.2, 'I': 4.5, 'K': -3.9, 'L': 3.8,
248 | 'M': 1.9, 'N': -3.5, 'P': -1.6, 'Q': -3.5, 'R': -4.5,
249 | 'S': -0.8, 'T': -0.7, 'V': 4.2, 'W': -0.9, 'Y': -1.3
250 | }
251 |
252 | # Define charge scale
253 | charge_scale = {
254 | 'D': -1, 'E': -1, 'K': 1, 'R': 1, 'H': 0.1 # Histidine is partially charged
255 | }
256 |
257 | # Define polarity, acceptor, and donor features as shown in the image
258 | polarity_scale = {'R': 1, 'N': 1, 'D': 1, 'Q': 1, 'E': 1, 'H': 1, 'K': 1, 'S': 1, 'T': 1, 'Y': 1}
259 | acceptor_scale = {'D': 1, 'E': 1, 'N': 1, 'Q': 1, 'H': 1, 'S': 1, 'T': 1, 'Y': 1}
260 | donor_scale = {'R': 1, 'K': 1, 'W': 1, 'N': 1, 'Q': 1, 'H': 1, 'S': 1, 'T': 1, 'Y': 1}
261 |
262 | res_3letter = residue.get_resname() # Get the three-letter code
263 | res_1letter = seq1(res_3letter) # Convert to one-letter code
264 | hydrophobicity = hydrophobicity_scale.get(res_1letter, 0)
265 | charge = charge_scale.get(res_1letter, 0)
266 | polarity = polarity_scale.get(res_1letter, 0)
267 | acceptor = acceptor_scale.get(res_1letter, 0)
268 | donor = donor_scale.get(res_1letter, 0)
269 | return np.array([hydrophobicity, charge, polarity, acceptor, donor])
270 |
271 |
272 | def add_interior_points(surface_points, surface_features, structure):
273 | # Extract residue info and calculate biochemical features
274 | coords = []
275 | features = []
276 |
277 | for model in structure:
278 | for chain in model:
279 | for residue in chain:
280 | if is_aa(residue) and 'CA' in residue:
281 | res_coord = residue['CA'].get_coord() # Get alpha carbon coordinates
282 | res_features = get_biochem_features(residue)
283 |
284 | coords.append(res_coord)
285 | features.append(res_features)
286 |
287 | coords = np.array(coords)
288 | features = np.array(features)
289 |
290 | # Build a KDTree for fast nearest-neighbor search
291 | kdtree = cKDTree(coords)
292 |
293 | # Generate random points inside the surface
294 | min_coords = surface_points.min(axis=0)
295 | max_coords = surface_points.max(axis=0)
296 | num_samples = 5000
297 | random_points = np.random.uniform(min_coords, max_coords, (num_samples, 3))
298 |
299 | tri = Delaunay(surface_points)
300 |
301 | def is_inside(point, tri):
302 | return tri.find_simplex(point) >= 0
303 |
304 | inside_points = np.array([p for p in random_points if is_inside(p, tri)])
305 |
306 | # Assign biochemical features to random points based on nearest residue
307 | _, idx = kdtree.query(inside_points)
308 | inside_features = features[idx]
309 |
310 | # Concatenate surface and inside points and features
311 | new_surface = np.concatenate([surface_points, inside_points], axis=0)
312 | new_features = np.concatenate([surface_features, inside_features], axis=0)
313 |
314 | return new_surface, new_features
315 |
316 |
317 | # Step 6: Sample
318 | def sample_if_needed(data_dict, max_length=5000):
319 | for key, value in data_dict.items():
320 | surface = value['surface']
321 | features = value['features']
322 |
323 | if len(surface) > max_length:
324 | indices = np.random.choice(len(surface), max_length, replace=False)
325 | value['surface'] = surface[indices]
326 | value['features'] = features[indices]
327 |
328 | return data_dict
329 |
330 |
331 | # Main function to run the pipeline
332 | def main(dataset='afdb2000'):
333 | input_json_path = f'data/{dataset}/{dataset}.json'
334 | output_pkl_path = f'data/{dataset}/{dataset}.pkl'
335 |
336 | # Ensure the output directory exists
337 | output_dir = os.path.dirname(output_pkl_path)
338 | os.makedirs(output_dir, exist_ok=True)
339 |
340 | with open(input_json_path, 'r') as f:
341 | protein_dicts = json.load(f)
342 |
343 | combined_data = {}
344 | for protein_data in tqdm(protein_dicts, desc="Processing proteins"):
345 | structure = create_pdb_structure(protein_data)
346 | try:
347 | surface = get_surface(structure[0], MSMS=msms_exec)
348 | except Exception as e:
349 | print(f"Failed to generate surface for {protein_data['name']}: {e}")
350 | continue
351 | features = assign_features(surface, structure)
352 | # Step 3: Smooth the surface
353 | smoothed_surface = gaussian_kernel_smoothing(surface)
354 | # Step 4: Compress the surface and features using octree-based compression
355 | if len(smoothed_surface) > 5000:
356 | down_sample_ratio = 5000 / len(smoothed_surface)
357 | compressed_points, compressed_features = compress_surface(smoothed_surface, features, down_sample_ratio)
358 | else:
359 | compressed_points, compressed_features = smoothed_surface, features # No down-sampling
360 | # Step 5: Add interior points
361 | final_surface, final_features = add_interior_points(compressed_points, compressed_features, structure)
362 |
363 | combined_data[protein_data['name']] = {
364 | 'surface': final_surface,
365 | 'features': final_features,
366 | 'seq': protein_data['seq']
367 | }
368 |
369 | combined_data = sample_if_needed(combined_data)
370 |
371 | # Save the final data into a .pkl file
372 | with open(output_pkl_path, 'wb') as f:
373 | pickle.dump(combined_data, f)
374 | return combined_data
375 |
376 |
377 | if __name__ == "__main__":
378 | parser = argparse.ArgumentParser(description="Convert PDB files to JSON/PKL datasets.")
379 | parser.add_argument(
380 | "--pdb_folder",
381 | type=str,
382 | default=None,
383 | help="Directory containing the PDB files to process."
384 | )
385 | parser.add_argument(
386 | "--dataset_name",
387 | type=str,
388 | default=None,
389 | help="Name of the output dataset (defaults to the pdb_folder name)."
390 | )
391 | args = parser.parse_args()
392 |
393 | if not args.pdb_folder:
394 | raise ValueError("Please provide --pdb_folder pointing to the directory with PDB files.")
395 |
396 | pdb_folder = args.pdb_folder
397 | dataset_name = args.dataset_name or os.path.basename(os.path.normpath(pdb_folder))
398 |
399 | pdb_files = [f for f in os.listdir(pdb_folder) if f.endswith('.pdb')]
400 | data = []
401 |
402 | print("--- Creating initial dataset ---")
403 | for pdb_file in tqdm(pdb_files, desc="Parsing PDBs"):
404 | predicted_path = os.path.join(pdb_folder, pdb_file)
405 |
406 | combined_data = parse_pdb(predicted_path)
407 |
408 | if combined_data:
409 | data.append(combined_data)
410 |
411 | output_data_dir = os.path.join('./data', dataset_name)
412 | os.makedirs(output_data_dir, exist_ok=True)
413 | json_output_path = os.path.join(output_data_dir, dataset_name + '.json')
414 |
415 | with open(json_output_path, 'w') as json_file:
416 | json.dump(data, json_file, indent=4)
417 | print(f"\nInitial JSON saved to: {json_output_path}")
418 |
419 | main(dataset_name)
420 |
--------------------------------------------------------------------------------
/src/datasets/featurizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import random
4 | from torch_geometric.nn.pool import knn_graph
5 | from torch_scatter import scatter_sum
6 | from transformers import AutoTokenizer
7 | from sklearn.neighbors import NearestNeighbors
8 | from src.tools import Rigid, Rotation, get_interact_feats
9 | import copy
10 | import json
11 |
12 | tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D", cache_dir="gaozhangyang/model_zoom/transformers") # mask token: 32
13 |
14 |
15 | def pad_ss_connections(ss_connections, max_residues, max_surface_atoms):
16 | """ Pad ss_connections to the maximum number of residues and surface atoms in the batch """
17 | B = len(ss_connections)
18 | ss_connections_padded = torch.ones((B, max_residues, max_surface_atoms), dtype=torch.float32)
19 | for i, ss_connection in enumerate(ss_connections):
20 | ss_connections_padded[i, :ss_connection.shape[0], :ss_connection.shape[1]] = ss_connection
21 | return ss_connections_padded
22 |
23 |
24 | def rbf(values, v_min, v_max, n_bins=16):
25 | """
26 | Returns RBF encodings in a new dimension at the end.
27 | """
28 | rbf_centers = torch.linspace(v_min, v_max, n_bins, device=values.device, dtype=values.dtype)
29 | rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1])
30 | rbf_std = (v_max - v_min) / n_bins
31 | z = (values.unsqueeze(-1) - rbf_centers) / rbf_std
32 | return torch.exp(-z ** 2)
33 |
34 |
35 | class MyTokenizer:
36 | def __init__(self):
37 | self.alphabet_protein = 'ACDEFGHIKLMNPQRSTVWY' # [X] for unknown token
38 | self.alphabet_RNA = 'AUGC'
39 |
40 | def encode(self, seq, RNA=False):
41 | if RNA:
42 | return [self.alphabet_RNA.index(s) for s in seq]
43 | else:
44 | return [self.alphabet_protein.index(s) for s in seq]
45 |
46 | def decode(self, indices, RNA=False):
47 | if RNA:
48 | return ' '.join([self.alphabet_RNA[i] for i in indices])
49 | else:
50 | return ' '.join([self.alphabet_protein[i] for i in indices])
51 |
52 |
53 | class featurize_UBC2Model:
54 | def __init__(self, **kwargs) -> None:
55 | self.tokenizer = MyTokenizer()
56 | self.virtual_frame_num = 3
57 | self.exp_backbone_noise_sd = kwargs.get('exp_backbone_noise_sd', 0.0)
58 | self.partial_design = kwargs.get('partial_design', False)
59 | self.design_region_path = kwargs.get('design_region_path', '')
60 | self.design_regions = None # Initialize as None
61 | self.ig_baseline_data = kwargs.get('ig_baseline_data', False)
62 |
63 | if self.partial_design:
64 | print(f"Partial design is enabled. Loading design regions from: {self.design_region_path}")
65 | try:
66 | with open(self.design_region_path, 'r') as f:
67 | self.design_regions = json.load(f)
68 | print("Successfully loaded design regions.")
69 | except FileNotFoundError:
70 | print(f"⚠️ WARNING: Design region file not found at {self.design_region_path}. Partial design will be disabled.")
71 | self.partial_design = False
72 | except json.JSONDecodeError:
73 | print(f"⚠️ WARNING: Could not decode JSON from {self.design_region_path}. Partial design will be disabled.")
74 | self.partial_design = False
75 |
76 | def _get_features_persample(self, batch):
77 | # uniif struc featurizer
78 | for key in batch:
79 | try:
80 | batch[key] = batch[key][None,...]
81 | except:
82 | batch[key] = batch[key]
83 | S = []
84 | for seq in batch['seq']:
85 | S.extend(self.tokenizer.encode(seq))
86 | S = torch.tensor(S)
87 |
88 | X = torch.from_numpy(np.stack([np.concatenate(batch['N']),
89 | np.concatenate(batch['CA']),
90 | np.concatenate(batch['C']),
91 | np.concatenate(batch['O'])], axis=1)).float()
92 |
93 | chain_mask = torch.from_numpy(np.concatenate(batch['chain_mask'])).float()
94 | chain_encoding = torch.from_numpy(np.concatenate(batch['chain_encoding'])).float()
95 |
96 | X, S = X.unsqueeze(0), S.unsqueeze(0)
97 | mask = torch.isfinite(torch.sum(X,(2,3))).float() # atom mask
98 | numbers = torch.sum(mask, axis=1).int()
99 | S_new = torch.zeros_like(S)
100 | X_new = torch.zeros_like(X)+torch.nan
101 | for i, n in enumerate(numbers):
102 | X_new[i,:n,::] = X[i][mask[i]==1]
103 | S_new[i,:n] = S[i][mask[i]==1]
104 |
105 | X = X_new
106 | S = S_new
107 | isnan = torch.isnan(X)
108 | mask = torch.isfinite(torch.sum(X,(2,3))).float()
109 | X[isnan] = 0.
110 |
111 | mask_bool = (mask==1)
112 | def node_mask_select(x):
113 | shape = x.shape
114 | x = x.reshape(shape[0], shape[1],-1)
115 | out = torch.masked_select(x, mask_bool.unsqueeze(-1)).reshape(-1, x.shape[-1])
116 | out = out.reshape(-1,*shape[2:])
117 | return out
118 |
119 | batch_id = torch.arange(mask_bool.shape[0], device=mask_bool.device)[:,None].expand_as(mask_bool)
120 | seq = node_mask_select(S)
121 | X = node_mask_select(X)
122 | batch_id = node_mask_select(batch_id)
123 | C_a = X[:,1,:]
124 |
125 | edge_idx = knn_graph(C_a, k=30, batch=batch_id, loop=True, flow='target_to_source')
126 |
127 | N, CA, C = X[:,0], X[:,1], X[:,2]
128 |
129 | T = Rigid.make_transform_from_reference(N.float(), CA.float(), C.float())
130 | src_idx, dst_idx = edge_idx[0], edge_idx[1]
131 | T_ts = T[dst_idx,None].invert().compose(T[src_idx,None])
132 |
133 | # global virtual frames
134 | num_global = self.virtual_frame_num
135 |
136 | '''
137 | U的每一列,为原始空间中的坐标基向量
138 | R = U
139 | U2, S2, V2 = torch.svd((R@X_c.T)@(X_c@R.T))
140 | R@U == U2
141 | '''
142 |
143 | X_c = T._trans
144 | X_m = X_c.mean(dim=0, keepdim=True)
145 | X_c = X_c-X_m
146 | U,S,V = torch.svd(X_c.T@X_c)
147 | d = (torch.det(U) * torch.det(V)) < 0.0
148 | D = torch.zeros_like(V)
149 | D[ [0,1], [0,1]] = 1
150 | D[2,2] = -1*d+1*(~d)
151 | V = D@V
152 | R = torch.matmul(U, V.permute(0,1))
153 |
154 | rot_g = [R]*num_global
155 | trans_g = [X_m]*num_global
156 |
157 | feat = get_interact_feats(T, T_ts, X.float(), edge_idx, batch_id)
158 | _V, _E = feat['_V'], feat['_E']
159 |
160 | '''
161 | global_src: N+1,N+1,N+2,N+2,..N+B, N+B+1,N+B+1,N+B+2,N+B+2,..N+B+B
162 | global_dst: 0, 1, 2, 3, ..N, 0, 1, 2, 3, ..N
163 | batch_id_g: 1, 1, 2, 2, ..B, 1, 1, 2, 2, ..B
164 | '''
165 | T_g = Rigid(Rotation(torch.stack(rot_g)), torch.cat(trans_g,dim=0))
166 | num_nodes = scatter_sum(torch.ones_like(batch_id), batch_id)
167 | global_src = torch.cat([batch_id +k*num_nodes.shape[0] for k in range(num_global)]) + num_nodes
168 | global_dst = torch.arange(batch_id.shape[0], device=batch_id.device).repeat(num_global)
169 | edge_idx_g = torch.stack([global_dst, global_src])
170 | edge_idx_g_inv = torch.stack([global_src, global_dst])
171 | edge_idx_g = torch.cat([edge_idx_g, edge_idx_g_inv], dim=1)
172 |
173 | batch_id_g = torch.zeros(num_global,dtype=batch_id.dtype)
174 | T_all = Rigid.cat([T, T_g], dim=0)
175 |
176 | idx, _ = edge_idx_g.min(dim=0)
177 | T_gs = T_all[idx,None].invert().compose(T_all[idx,None])
178 |
179 | rbf_ts = rbf(T_ts._trans.norm(dim=-1), 0, 50, 16)[:,0].view(_E.shape[0],-1)
180 | rbf_gs = rbf(T_gs._trans.norm(dim=-1), 0, 50, 16)[:,0].view(edge_idx_g.shape[1],-1)
181 |
182 | _V_g = torch.arange(num_global)
183 | _E_g = torch.zeros([edge_idx_g.shape[1], 128])
184 |
185 | mask = torch.masked_select(mask, mask_bool)
186 | chain_features = (chain_encoding[edge_idx[0]] == chain_encoding[edge_idx[1]]).int()
187 |
188 | batch={
189 | 'T':T,
190 | 'T_g': T_g,
191 | 'T_ts': T_ts,
192 | 'T_gs': T_gs,
193 | 'rbf_ts': rbf_ts,
194 | 'rbf_gs': rbf_gs,
195 | 'X':X,
196 | 'chain_features': chain_features,
197 | '_V': _V,
198 | '_E': _E,
199 | '_V_g': _V_g,
200 | '_E_g': _E_g,
201 | 'S':seq,
202 | 'edge_idx':edge_idx,
203 | 'edge_idx_g': edge_idx_g,
204 | 'batch_id': batch_id,
205 | 'batch_id_g': batch_id_g,
206 | 'num_nodes': num_nodes,
207 | 'mask': mask,
208 | 'chain_mask': chain_mask,
209 | 'chain_encoding': chain_encoding,
210 | 'K_g': num_global}
211 |
212 | return batch
213 |
214 | def featurize(self,batch):
215 | if self.exp_backbone_noise_sd != 0:
216 | # Iterate over each protein sample in the batch list
217 | for protein_sample in batch:
218 | # List of keys corresponding to backbone atom coordinates
219 | coord_keys = ['N', 'CA', 'C', 'O']
220 | for key in coord_keys:
221 | # Get the original coordinates (e.g., shape [num_residues, 3])
222 | coords = protein_sample[key]
223 | # Generate Gaussian noise with the same shape as the coordinates.
224 | # The noise is centered at 0.0 with the specified standard deviation.
225 | noise = np.random.normal(loc=0.0, scale=self.exp_backbone_noise_sd, size=coords.shape)
226 | # Add the noise to the original coordinates and update the sample in place
227 | protein_sample[key] = coords + noise
228 |
229 | if self.ig_baseline_data:
230 | for protein_sample in batch:
231 | # List of keys corresponding to backbone atom coordinates
232 | coord_keys = ['N', 'CA', 'C', 'O']
233 | for key in coord_keys:
234 | coords = protein_sample[key]
235 | protein_sample[key] = np.random.normal(loc=0.0, scale=15., size=coords.shape)
236 |
237 | # deepcopy batch
238 | batch_copy = copy.deepcopy(batch)
239 | res = []
240 | for one in batch:
241 | temp = self._get_features_persample(one)
242 | res.append(temp)
243 | res = self.custom_collate_fn(res)
244 | # sbc2 featurizer
245 | bc_batch = self.featurize_SBC2Model(batch_copy)
246 | # update bc_batch into res
247 | for key in bc_batch.keys():
248 | res[key] = bc_batch[key]
249 | return res
250 |
251 | def custom_collate_fn(self, batch):
252 | batch = [one for one in batch if one is not None]
253 | num_nodes = torch.cat([one['num_nodes'] for one in batch])
254 | shift = num_nodes.cumsum(dim=0)
255 | shift = torch.cat([torch.tensor([0], device=shift.device), shift], dim=0)
256 | def shift_node_idx(idx, num_node, shift_real, shift_virtual):
257 | mask = idx>=num_node
258 | shift_combine = (~mask)*(shift_real) + (mask)*(shift_virtual)
259 | return idx+shift_combine
260 |
261 | ret = {}
262 | for key in batch[0].keys():
263 | if batch[0][key] is None:
264 | continue
265 |
266 | if key in ['T', 'T_g', 'T_ts', 'T_gs']:
267 | T = Rigid.cat([one[key] for one in batch], dim=0)
268 | ret[key+'_rot'] = T._rots._rot_mats
269 | ret[key+'_trans'] = T._trans
270 | elif key in ['edge_idx']:
271 | ret[key] = torch.cat([one[key] + shift[idx] for idx, one in enumerate(batch)], dim=1)
272 | elif key in ['edge_idx_g']:
273 | edge_idx_g = []
274 | for idx, one in enumerate(batch):
275 | shift_virtual = shift[-1] + idx*one['K_g']-num_nodes[idx]
276 | src = shift_node_idx(one['edge_idx_g'][0], num_nodes[idx], shift[idx], shift_virtual)
277 | dst_g = shift_node_idx(one['edge_idx_g'][1], num_nodes[idx], shift[idx], shift_virtual)
278 | edge_idx_g.append(torch.stack([src, dst_g]))
279 | ret[key] = torch.cat(edge_idx_g, dim=1)
280 | elif key in ['batch_id', 'batch_id_g']:
281 | ret[key] = torch.cat([one[key] + idx for idx, one in enumerate(batch)])
282 | elif key in ['K_g']:
283 | pass
284 | else:
285 | ret[key] = torch.cat([one[key] for one in batch], dim=0)
286 |
287 | return ret
288 |
289 | def featurize_SBC2Model(self, batch):
290 | """ Pack and pad batch into torch tensors with surface and orig_surface downsampling to the minimum size """
291 | # batch = [one for one in batch if one is not None]
292 | B = len(batch)
293 | if B == 0:
294 | return None
295 | lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32)
296 | L_max = max(lengths)
297 |
298 | X = np.zeros([B, L_max, 4, 3])
299 | S = np.zeros([B, L_max], dtype=np.int32)
300 | score = np.ones([B, L_max]) * 100.0
301 | chain_mask = np.zeros([B, L_max]) - 1 # 1:需要被预测的掩码部分 0:可见部分
302 | chain_encoding = np.zeros([B, L_max]) - 1
303 |
304 | # Build the batch
305 | surfaces = []
306 | features = []
307 | orig_surfaces = []
308 | surface_lengths = []
309 | ss_connections = []
310 | correspondences = []
311 |
312 | for i, b in enumerate(batch):
313 | x = np.stack([b[c] for c in ['N', 'CA', 'C', 'O']], 1) # [#atom, 4, 3]
314 |
315 | l = len(b['seq'])
316 | x_pad = np.pad(x, [[0, L_max - l], [0, 0], [0, 0]], 'constant', constant_values=(np.nan,)) # [#atom, 4, 3]
317 | X[i, :, :, :] = x_pad
318 |
319 | # Convert to labels
320 | indices = np.array(tokenizer.encode(b['seq'], add_special_tokens=False))
321 | S[i, :l] = indices
322 | chain_mask[i, :l] = b['chain_mask']
323 | chain_encoding[i, :l] = b['chain_encoding']
324 |
325 | # Add surface, features, orig_surface
326 | surfaces.append(torch.tensor(b['surface'], dtype=torch.float32))
327 | features.append(torch.tensor(b['features'], dtype=torch.float32))
328 | orig_surfaces.append(torch.tensor(b['orig_surface'], dtype=torch.float32))
329 | surface_lengths.append(b['surface'].shape[0])
330 |
331 | if self.partial_design and self.design_regions:
332 | protein_name = b['title']
333 | if protein_name in self.design_regions:
334 | # 1. Get the necessary data
335 | design_mask = torch.tensor(self.design_regions[protein_name], dtype=torch.bool)
336 |
337 | # Convert tensors to NumPy arrays for Scikit-learn (this is fast on CPU)
338 | ca_coords = torch.tensor(b['CA'], dtype=torch.float32).numpy()
339 | surface_coords = orig_surfaces[i].numpy()
340 |
341 | if len(design_mask) != len(ca_coords):
342 | print(f"⚠️ WARNING: Mismatch for '{protein_name}'. Mask length {len(design_mask)} != Residue count {len(ca_coords)}. Skipping masking.")
343 | else:
344 | # 2. Find the closest residue for each surface point (using NearestNeighbors)
345 | # Build the tree from the residue coordinates
346 | nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(ca_coords)
347 |
348 | # Find the index of the single nearest neighbor for each surface point
349 | distances, indices = nbrs.kneighbors(surface_coords)
350 |
351 | # `indices` has shape [num_surface_points, 1], so flatten it
352 | closest_residue_indices = indices.flatten()
353 |
354 | # 3. Create a mask for the surface points
355 | # Use the NumPy array of indices to look up values in the PyTorch design_mask
356 | surface_mask = design_mask[closest_residue_indices]
357 |
358 | # 4. Apply the mask to the features tensor
359 | features[i][surface_mask] = float('nan')
360 |
361 | else:
362 | print(f"⚠️ WARNING: Protein '{protein_name}' not found in design region file. Skipping masking for this sample.")
363 |
364 | if self.ig_baseline_data:
365 | features[i][:] = float('nan')
366 |
367 | # Find the minimum surface length in the batch
368 | min_surface_length = min(surface_lengths)
369 |
370 | # Downsample all surfaces, features, and orig_surfaces to the minimum surface length
371 | surfaces_downsampled = []
372 | features_downsampled = []
373 | orig_surfaces_downsampled = []
374 |
375 | for i, surface in enumerate(surfaces):
376 | surface_len = surface.shape[0]
377 | if surface_len > min_surface_length:
378 | # Randomly sample indices without replacement
379 | sampled_indices = random.sample(range(surface_len), min_surface_length)
380 | surfaces_downsampled.append(surface[sampled_indices])
381 | features_downsampled.append(features[i][sampled_indices])
382 | orig_surfaces_downsampled.append(orig_surfaces[i][sampled_indices])
383 | else:
384 | surfaces_downsampled.append(surface)
385 | features_downsampled.append(features[i])
386 | orig_surfaces_downsampled.append(orig_surfaces[i])
387 |
388 | # Stack the downsampled surfaces, features, and orig_surfaces
389 | surfaces_stacked = torch.stack(surfaces_downsampled, dim=0)
390 | features_stacked = torch.stack(features_downsampled, dim=0)
391 | orig_surfaces_stacked = torch.stack(orig_surfaces_downsampled, dim=0)
392 |
393 | mask = np.isfinite(np.sum(X, (2, 3))).astype(np.float32) # atom mask
394 | numbers = np.sum(mask, axis=1).astype(np.int32)
395 | S_new = np.zeros_like(S)
396 | X_new = np.zeros_like(X) + np.nan
397 |
398 | for i, n in enumerate(numbers):
399 | X_new[i, :n, ::] = X[i][mask[i] == 1]
400 | S_new[i, :n] = S[i][mask[i] == 1]
401 |
402 | X = X_new
403 | S = S_new
404 | isnan = np.isnan(X)
405 | mask = np.isfinite(np.sum(X, (2, 3))).astype(np.float32)
406 | X[isnan] = 0.
407 |
408 | # Calculate ss_connection based on X_new and downsampled orig_surface
409 | for i in range(B):
410 | ca_coords = X[i, :, 1, :] # Extract CA coordinates from X_new (1 is for CA atom)
411 | surface_coords = orig_surfaces_stacked[i]
412 |
413 | # Use the mask to identify valid indices
414 | valid_indices = mask[i].astype(bool) # mask[i] is 1 for valid indices, 0 otherwise
415 | valid_ca_coords = ca_coords[valid_indices]
416 |
417 | # Nearest neighbors search
418 | n_neighbors = max(1, int(8 * 175 / lengths[i]))
419 | nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm='ball_tree').fit(surface_coords)
420 | # nbrs = NearestNeighbors(n_neighbors=8, algorithm='ball_tree').fit(surface_coords)
421 | distances, indices = nbrs.kneighbors(valid_ca_coords)
422 |
423 | ss_connection = np.zeros((ca_coords.shape[0], surface_coords.shape[0]))
424 |
425 | # Fill ss_connection for valid CA coordinates
426 | for j, neighbors in zip(np.where(valid_indices)[0], indices):
427 | ss_connection[j, neighbors] = 1
428 |
429 | # Fill ss_connection for invalid CA coordinates
430 | ss_connection[~valid_indices, :] = 1
431 |
432 | ss_connections.append(torch.tensor(ss_connection, dtype=torch.float32))
433 |
434 | # 1. Calculate the distance matrix for valid_ca_coords
435 | ca_dist_matrix = np.linalg.norm(valid_ca_coords[:, None, :] - valid_ca_coords[None, :, :], axis=-1)
436 | max_dist = np.max(ca_dist_matrix)
437 | r = max_dist / 3 # 1/3 of max distance as radius
438 |
439 | # 2. Randomly sample 8 coords from valid_ca_coords
440 | sampled_indices = random.sample(range(valid_ca_coords.shape[0]), min(8, valid_ca_coords.shape[0]))
441 |
442 | batch_correspondences = []
443 | for sampled_idx in sampled_indices:
444 | # Get indices of CA atoms within radius r
445 | ca_neighbors = np.where(ca_dist_matrix[sampled_idx] < r)[0]
446 |
447 | # Get distances between the sampled CA atom and surface points
448 | ca_surface_dist_matrix = np.linalg.norm(valid_ca_coords[sampled_idx] - surface_coords.numpy(), axis=-1)
449 |
450 | # Get indices of surface points within radius r
451 | surface_neighbors = np.where(ca_surface_dist_matrix < r)[0]
452 |
453 | # Store the two sets of indices as tensors
454 | batch_correspondences.append([
455 | torch.tensor(ca_neighbors, dtype=torch.long),
456 | torch.tensor(surface_neighbors, dtype=torch.long)
457 | ])
458 |
459 | correspondences.append(batch_correspondences)
460 |
461 | # Pad ss_connections
462 | ss_connections_padded = pad_ss_connections(ss_connections, L_max, min_surface_length)
463 |
464 | # Conversion
465 | S = torch.from_numpy(S).to(dtype=torch.long)
466 | score = torch.from_numpy(score).float()
467 | X = torch.from_numpy(X).to(dtype=torch.float32)
468 | mask = torch.from_numpy(mask).to(dtype=torch.float32)
469 | X_flattened = X[mask==1]
470 | lengths = torch.from_numpy(lengths)
471 | chain_mask = torch.from_numpy(chain_mask)
472 | chain_encoding = torch.from_numpy(chain_encoding)
473 |
474 | mask_bool = (mask==1)
475 | S = torch.masked_select(S, mask_bool)
476 | mask = torch.masked_select(mask, mask_bool)
477 | return {
478 | "title": [b['title'] for b in batch],
479 | "X": X,
480 | "X_flattened": X_flattened,
481 | "S": S,
482 | "score": score,
483 | "mask": mask,
484 | "lengths": lengths,
485 | "chain_mask": chain_mask,
486 | "chain_encoding": chain_encoding,
487 | "surface": surfaces_stacked,
488 | "features": features_stacked,
489 | 'ss_connection': ss_connections_padded,
490 | 'correspondences': correspondences,
491 | }
492 |
--------------------------------------------------------------------------------
/src/models/UBC2_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import TransformerDecoder, TransformerDecoderLayer
5 | from torch.nn import TransformerEncoder, TransformerEncoderLayer
6 | from torch_scatter import scatter_sum, scatter_softmax
7 | from src.tools import Rigid, Rotation
8 | from src.datasets.featurizer import rbf
9 | import numpy as np
10 | import math
11 |
12 |
13 | def build_MLP(n_layers,dim_in, dim_hid, dim_out, dropout = 0.0, activation=nn.ReLU, normalize=True):
14 | if normalize:
15 | layers = [nn.Linear(dim_in, dim_hid),
16 | nn.BatchNorm1d(dim_hid),
17 | nn.Dropout(dropout),
18 | activation()]
19 | else:
20 | layers = [nn.Linear(dim_in, dim_hid),
21 | nn.Dropout(dropout),
22 | activation()]
23 | for _ in range(n_layers - 2):
24 | layers.append(nn.Linear(dim_hid, dim_hid))
25 | if normalize:
26 | layers.append(nn.BatchNorm1d(dim_hid))
27 | layers.append(nn.Dropout(dropout))
28 | layers.append(activation())
29 | layers.append(nn.Linear(dim_hid, dim_out))
30 | return nn.Sequential(*layers)
31 |
32 |
33 | class PointCloudMessagePassing(nn.Module):
34 | def __init__(self, args, feat_dim, edge_dim, l_max, num_scales, hidden_dim, aggregation='concat', num_heads=4, num_mha_layers=1, bc_dropout=0.0):
35 | super(PointCloudMessagePassing, self).__init__()
36 | self.l_max = l_max
37 | self.num_scales = num_scales
38 | self.aggregation = aggregation
39 | self.num_heads = num_heads
40 |
41 | self.per_layer_dim = hidden_dim // 4
42 |
43 | self.bc_mask_max_rate = args.bc_mask_max_rate
44 | self.bc_mask_how = args.bc_mask_how
45 | self.if_struc_only = args.if_struc_only
46 | self.exp_bc_mask_rate = args.exp_bc_mask_rate
47 | self.exp_hydro_mask_rate = getattr(args, 'exp_hydro_mask_rate', 0.)
48 | self.exp_charge_mask_rate = getattr(args, 'exp_charge_mask_rate', 0.)
49 |
50 | # CLS token for biochemical features initialized with per_layer_dim
51 | self.biochem_cls_token = nn.Parameter(torch.randn(1 + 8, self.per_layer_dim)) # Adjusted dimension
52 |
53 | # bc mask token
54 | self.bc_mask_token = nn.Parameter(torch.randn(1, self.per_layer_dim))
55 |
56 | # Linear layer for feature dimension adjustment
57 | self.input_fc = nn.Linear(feat_dim, self.per_layer_dim)
58 |
59 | encoder_layer = TransformerEncoderLayer(
60 | d_model=self.per_layer_dim, # 输入特征维度
61 | nhead=num_heads, # 多头注意力的头数
62 | dim_feedforward=self.per_layer_dim * 4, # FFN的隐藏层维度
63 | dropout=bc_dropout,
64 | batch_first=True
65 | )
66 | self.attention_layers = TransformerEncoder(encoder_layer, num_layers=args.bc_encoder_layer)
67 |
68 | # fc for residue connection
69 | self.res_conn_mlp = nn.Sequential(
70 | nn.ReLU(),
71 | nn.Linear(self.per_layer_dim, hidden_dim)
72 | )
73 |
74 | # Feature aggregation after MHA
75 | self.fc = nn.Linear(self.per_layer_dim * num_scales, hidden_dim)
76 |
77 | def forward(self, surfaces, biochem_feats, correspondences):
78 | B, N, _ = surfaces.shape
79 |
80 | ###### for inference with only backbone structure, bc input will be all nan
81 | # Find rows (over N) where any feature is nan, for each batch
82 | nan_rows = torch.any(torch.isnan(biochem_feats), dim=-1) # shape: (B, N)
83 |
84 | hydro_mask_indices = torch.rand(B, N, device=biochem_feats.device) < self.exp_hydro_mask_rate
85 | charge_mask_indices = torch.rand(B, N, device=biochem_feats.device) < self.exp_charge_mask_rate
86 | biochem_feats[..., 0][hydro_mask_indices] = biochem_feats[..., 0][hydro_mask_indices].mean()
87 | biochem_feats[..., 1][charge_mask_indices] = biochem_feats[..., 1][charge_mask_indices].mean()
88 |
89 | # Elevate the biochemical features
90 | biochem_feats = self.input_fc(biochem_feats) # BxNx(per_layer_dim)
91 |
92 | biochem_feats[nan_rows] = self.bc_mask_token
93 |
94 | if self.if_struc_only:
95 | if self.bc_mask_how == 'token':
96 | biochem_feats[:] = self.bc_mask_token
97 | elif self.bc_mask_how == 'gauss':
98 | biochem_feats[:] = torch.randn_like(biochem_feats)
99 |
100 | # select the indices of the biochemical features to be masked
101 | bc_mask_indices = torch.rand(B, N, device=biochem_feats.device) < self.exp_bc_mask_rate
102 | # mask the biochemical features
103 | if self.bc_mask_how == 'token':
104 | biochem_feats[bc_mask_indices] = self.bc_mask_token
105 | elif self.bc_mask_how == 'gauss':
106 | biochem_feats[bc_mask_indices] = torch.randn_like(biochem_feats[bc_mask_indices])
107 |
108 | if self.training:
109 | # randomly select a probability between 0 and self.bc_mask_max_rate
110 | bc_mask_rate = torch.rand(B, device=biochem_feats.device) * self.bc_mask_max_rate
111 | # select the indices of the biochemical features to be masked
112 | bc_mask_indices = torch.rand(B, N, device=biochem_feats.device) < bc_mask_rate[:, None]
113 | # mask the biochemical features
114 | if self.bc_mask_how == 'token':
115 | biochem_feats[bc_mask_indices] = self.bc_mask_token
116 | elif self.bc_mask_how == 'gauss':
117 | biochem_feats[bc_mask_indices] = torch.randn_like(biochem_feats[bc_mask_indices])
118 |
119 | # Add CLS token at the end of biochem_feats (Bx(N+1)x(per_layer_dim))
120 | cls_tokens = self.biochem_cls_token.expand(B, -1, -1) # Expand CLS token for the batch
121 | biochem_feats = torch.cat([biochem_feats, cls_tokens], dim=1) # Concatenated CLS token
122 |
123 | # Add a last row of infs and a last column of 0s to distances
124 | distances = torch.cdist(surfaces, surfaces) # BxNxN
125 |
126 | # Compute the maximum distance to set dynamic radii
127 | max_distance = distances.max().item()
128 | thr_rs = [max_distance / 20 * i / 4 for i in range(1, 5)] # Different scales of radii
129 |
130 | # Add 9 rows to the bottom of distances, all set to inf
131 | inf_rows = torch.full((B, 9, N), float('inf'), device=surfaces.device) # (Bx9xN)
132 | distances = torch.cat([distances, inf_rows], dim=1) # Bx(N+9)xN
133 |
134 | # Add 9 columns to the right of distances, with special handling
135 | inf_cols = torch.full((B, N + 9, 9), float('inf'), device=surfaces.device) # Bx(N+9)x9
136 |
137 | # First column (corresponding to global CLS token) is all 0s
138 | inf_cols[:, :, 0] = 0
139 |
140 | # Vectorized filling of subarea CLS distances based on correspondences
141 | for i in range(B):
142 | # Get the neighbors from correspondences and subarea indices
143 | corr = correspondences[i]
144 | surface_neighbors = torch.cat([surf for _, surf in corr], dim=0) # Concatenate all surface neighbors
145 |
146 | # Create indices for the subareas corresponding to surface neighbors
147 | subarea_idxs = torch.cat([torch.full_like(surf, j+1) for j, (_, surf) in enumerate(corr)], dim=0)
148 |
149 | # Assign distances for subarea CLS tokens to 0 where correspondences exist
150 | inf_cols[i, surface_neighbors, subarea_idxs] = 0
151 |
152 | # Concatenate the inf_cols to distances
153 | distances = torch.cat([distances, inf_cols], dim=2) # Bx(N+9)x(N+9)
154 |
155 | # Set the diagonal of the last 9x9 block to 0
156 | distances[:, -9:, -9:] = float('inf') # Set the entire 9x9 block to inf first
157 | distances[:, -9:, -9:].diagonal(dim1=-2, dim2=-1).fill_(0) # Set only the diagonal values to 0
158 |
159 | N += 9 # Adjust N to N+9 since CLS tokens are added
160 |
161 | features_list = []
162 |
163 | for thr_r in thr_rs:
164 | # 1. Create a mask for points within the spherical region
165 | region_mask = distances < thr_r # Bx(N+1)x(N+1) boolean mask
166 |
167 | # 2. Compute the number of neighbors for each point in the region (Bx(N+1))
168 | num_neighbors = region_mask.sum(dim=-1) # Bx(N+1)
169 |
170 | # 3. Find the maximum number of neighbors to pad all regions to the same size
171 | max_neighbors = num_neighbors.max().item() # The largest region size in this batch
172 |
173 | # 4. Downsample neighbors to 100 if max_neighbors > 100
174 | if max_neighbors > 100:
175 | # Step 1: Get the indices of the True values in region_mask (all neighbors)
176 | batch_idx, center_idx, neighbor_idx = torch.nonzero(region_mask, as_tuple=True)
177 |
178 | # Step 2: Create a mask for the center points (rows) that have more than 100 neighbors
179 | over_limit_mask = num_neighbors > 100 # Bx(N+1) boolean mask where num_neighbors > 100
180 |
181 | # Step 3: Find the batch and center indices that have more than 100 neighbors
182 | over_limit_batch_idx, over_limit_center_idx = torch.nonzero(over_limit_mask, as_tuple=True)
183 |
184 | # Step 4: For these rows, get the neighbor indices and randomly sample 100 neighbors for each row
185 | downsampled_mask = region_mask.clone()
186 |
187 | for b_idx, c_idx in zip(over_limit_batch_idx, over_limit_center_idx):
188 | # Find all neighbors for this center point
189 | neighbor_indices = torch.nonzero(region_mask[b_idx, c_idx], as_tuple=False).squeeze() # Get all neighbors
190 |
191 | # Randomly sample 100 neighbors
192 | random_indices = torch.randperm(neighbor_indices.size(0), device=biochem_feats.device)[:100] # Randomly select 100
193 | selected_neighbors = neighbor_indices[random_indices] # Select 100 neighbors
194 |
195 | # Reset region_mask for this point and update it with only the selected 100 neighbors
196 | downsampled_mask[b_idx, c_idx] = False
197 | downsampled_mask[b_idx, c_idx, selected_neighbors] = True
198 |
199 | # Update region_mask with the downsampled mask
200 | region_mask = downsampled_mask
201 |
202 | # Recompute num_neighbors and max_neighbors after downsampling
203 | num_neighbors = region_mask.sum(dim=-1) # Bx(N+1)
204 | max_neighbors = num_neighbors.max().item() # Limit max_neighbors to 100
205 |
206 | # 5. Get the indices of True values in region_mask
207 | batch_idx, center_idx, neighbor_idx = torch.nonzero(region_mask, as_tuple=True) # Extract indices of neighbors in the region
208 |
209 | # 6. Gather the biochemical features for these indices
210 | gathered_feats = biochem_feats[batch_idx, neighbor_idx] # Gather the corresponding features from biochem_feats
211 |
212 | # 7. Generate sequential indices for each neighbor
213 | neighbor_offsets = torch.arange(num_neighbors.sum()).to(num_neighbors.device) - torch.repeat_interleave(torch.cumsum(num_neighbors.view(-1), dim=0) - num_neighbors.view(-1), num_neighbors.view(-1)).to(num_neighbors.device)
214 |
215 | # 8. Create a tensor to hold padded features for each region
216 | padded_feats = torch.zeros(B, N, max_neighbors, biochem_feats.shape[-1], device=biochem_feats.device)
217 |
218 | # Create a mask to indicate which points are real and which are padding
219 | padding_mask = torch.zeros(B, N, max_neighbors, device=biochem_feats.device, dtype=torch.bool)
220 |
221 | # 9. Scatter the gathered features into the padded_feats tensor using the generated sequential indices
222 | padded_feats[batch_idx, center_idx, neighbor_offsets] = gathered_feats
223 |
224 | # Update padding mask where neighbors exist
225 | padding_mask[batch_idx, center_idx, neighbor_offsets] = 1 # Mark valid neighbors
226 |
227 | # 10. Perform Multi-Head Attention (MHA)
228 | padded_feats_flat = padded_feats.view(B * N, max_neighbors, -1) # (B*(N+1))xMaxNeighborsxFeatDim
229 | padding_mask_flat = ~padding_mask.view(B * N, max_neighbors) # (B*(N+1))xMaxNeighbors, invert mask for MHA
230 |
231 | # # Apply MHA over the padded regions
232 | attn_output = self.attention_layers(padded_feats_flat, src_key_padding_mask=padding_mask_flat)
233 |
234 | # 11. Perform pooling over the region (e.g., mean pooling over valid points)
235 | attn_output = attn_output.view(B, N, max_neighbors, -1) # Bx(N+1)xMaxNeighborsxFeatDim
236 | pooled_feats = attn_output.masked_fill(~padding_mask.unsqueeze(-1), 0).sum(dim=2) / num_neighbors.unsqueeze(-1) # Bx(N+1)xFeatDim
237 |
238 | features_list.append(pooled_feats)
239 |
240 | # 12. Concatenate features from different scales
241 | combined_feats = torch.cat(features_list, dim=-1) # Bx(N+1)x(num_scales * per_layer_dim)
242 |
243 | # Add the residual connection and final projection to hidden_dim
244 | combined_feats = combined_feats + self.res_conn_mlp(biochem_feats)
245 | output_feats = self.fc(combined_feats) # Bx(N+1)xhidden_dim
246 |
247 | return output_feats
248 |
249 |
250 | class GeoFeat(nn.Module):
251 | def __init__(self, geo_layer, num_hidden, virtual_atom_num, dropout=0.0):
252 | super(GeoFeat, self).__init__()
253 | self.__dict__.update(locals())
254 | self.virtual_atom = nn.Linear(num_hidden, virtual_atom_num*3)
255 | self.virtual_direct = nn.Linear(num_hidden, virtual_atom_num*3)
256 | self.we_condition = build_MLP(geo_layer, 4*virtual_atom_num*3+9+16+32, num_hidden, num_hidden, dropout)
257 | self.MergeEG = nn.Linear(num_hidden+num_hidden, num_hidden)
258 |
259 | def forward(self, h_V, h_E, T_ts, edge_idx, h_E_0):
260 | src_idx = edge_idx[0]
261 | dst_idx = edge_idx[1]
262 | num_edge = src_idx.shape[0]
263 | num_atom = h_V.shape[0]
264 |
265 | # ==================== point cross attention =====================
266 | V_local = self.virtual_atom(h_V).view(num_atom,-1,3)
267 | V_edge = self.virtual_direct(h_E).view(num_edge,-1,3)
268 | Ks = torch.cat([V_edge,V_local[src_idx].view(num_edge,-1,3)], dim=1)
269 | Qt = T_ts.apply(Ks)
270 | Ks = Ks.view(num_edge,-1)
271 | Qt = Qt.reshape(num_edge,-1)
272 | V_edge = V_edge.reshape(num_edge,-1)
273 | quat_st = T_ts._rots._rot_mats[:, 0].reshape(num_edge, -1)
274 |
275 | RKs = torch.einsum('eij,enj->eni', T_ts._rots._rot_mats[:,0], V_local[src_idx].view(num_edge,-1,3))
276 | QRK = torch.einsum('enj,enj->en', V_local[dst_idx].view(num_edge,-1,3), RKs)
277 |
278 | H = torch.cat([Ks, Qt, quat_st, T_ts.rbf, QRK], dim=1)
279 | G_e = self.we_condition(H)
280 | h_E = self.MergeEG(torch.cat([h_E, G_e], dim=-1))
281 | return h_E
282 |
283 |
284 | class PiFoldAttn(nn.Module):
285 | def __init__(self, attn_layer, num_hidden, num_V, num_E, dropout=0.0):
286 | super(PiFoldAttn, self).__init__()
287 | self.__dict__.update(locals())
288 | self.num_heads = 4
289 | self.W_V = nn.Sequential(nn.Linear(num_E, num_hidden),
290 | nn.GELU())
291 |
292 | self.Bias = nn.Sequential(
293 | nn.Linear(2*num_V+num_E, num_hidden),
294 | nn.ReLU(),
295 | nn.Linear(num_hidden,num_hidden),
296 | nn.ReLU(),
297 | nn.Linear(num_hidden,self.num_heads))
298 | self.W_O = nn.Linear(num_hidden, num_V, bias=False)
299 | self.gate = nn.Linear(num_hidden, num_V)
300 |
301 | def forward(self, h_V, h_E, edge_idx):
302 | src_idx = edge_idx[0]
303 | dst_idx = edge_idx[1]
304 | h_V_skip = h_V
305 |
306 | E = h_E.shape[0]
307 | n_heads = self.num_heads
308 | d = int(self.num_hidden / n_heads)
309 | num_nodes = h_V.shape[0]
310 |
311 | w = self.Bias(torch.cat([h_V[src_idx], h_E, h_V[dst_idx]],dim=-1)).view(E, n_heads, 1)
312 | attend_logits = w/np.sqrt(d)
313 |
314 | V = self.W_V(h_E).view(-1,n_heads, d)
315 | attend = scatter_softmax(attend_logits, index=src_idx, dim=0)
316 | h_V = scatter_sum(attend*V, src_idx, dim=0).view([num_nodes, -1])
317 |
318 | h_V_gate = F.sigmoid(self.gate(h_V))
319 | dh = self.W_O(h_V)*h_V_gate
320 |
321 | h_V = h_V_skip + dh
322 | return h_V
323 |
324 |
325 | class UpdateNode(nn.Module):
326 | def __init__(self, num_hidden):
327 | super().__init__()
328 | self.dense = nn.Sequential(
329 | nn.BatchNorm1d(num_hidden),
330 | nn.Linear(num_hidden, num_hidden*4),
331 | nn.ReLU(),
332 | nn.Linear(num_hidden*4, num_hidden),
333 | nn.BatchNorm1d(num_hidden)
334 | )
335 | self.V_MLP_g = nn.Sequential(
336 | nn.Linear(num_hidden, num_hidden),
337 | nn.ReLU(),
338 | nn.Linear(num_hidden,num_hidden),
339 | nn.ReLU(),
340 | nn.Linear(num_hidden,num_hidden))
341 |
342 | def forward(self, h_V, batch_id):
343 | dh = self.dense(h_V)
344 | h_V = h_V + dh
345 |
346 | # # ============== global attn - virtual frame
347 | uni = batch_id.unique()
348 | mat = (uni[:,None] == batch_id[None]).to(h_V.dtype)
349 | mat = mat/mat.sum(dim=1, keepdim=True)
350 | c_V = mat@h_V
351 |
352 | h_V = h_V * F.sigmoid(self.V_MLP_g(c_V))[batch_id]
353 | return h_V
354 |
355 |
356 | class UpdateEdge(nn.Module):
357 | def __init__(self, edge_layer, num_hidden, dropout=0.1):
358 | super(UpdateEdge, self).__init__()
359 | self.W = build_MLP(edge_layer, num_hidden*3, num_hidden, num_hidden, dropout, activation=nn.GELU, normalize=False)
360 | self.norm = nn.BatchNorm1d(num_hidden)
361 | self.pred_quat = nn.Linear(num_hidden,8)
362 |
363 | def forward(self, h_V, h_E, T_ts, edge_idx, batch_id):
364 | src_idx = edge_idx[0]
365 | dst_idx = edge_idx[1]
366 |
367 | h_EV = torch.cat([h_V[src_idx], h_E, h_V[dst_idx]], dim=-1)
368 | h_E = self.norm(h_E + self.W(h_EV))
369 |
370 | return h_E
371 |
372 |
373 | class GeneralGNN(nn.Module):
374 | def __init__(self,
375 | geo_layer,
376 | attn_layer,
377 | ffn_layer,
378 | edge_layer,
379 | num_hidden,
380 | virtual_atom_num=32,
381 | dropout=0.1,
382 | mask_rate=0.15,
383 | exp_v_mask_rate=0.,
384 | exp_e_mask_rate=0.):
385 | super(GeneralGNN, self).__init__()
386 | self.__dict__.update(locals())
387 | self.geofeat = GeoFeat(geo_layer, num_hidden, virtual_atom_num, dropout)
388 | self.attention = PiFoldAttn(attn_layer, num_hidden, num_hidden, num_hidden, dropout)
389 | self.update_node = UpdateNode(num_hidden)
390 | self.update_edge = UpdateEdge(edge_layer, num_hidden, dropout)
391 | self.mask_token = nn.Embedding(2, num_hidden)
392 |
393 | def get_rand_idx(self, h_V, mask_rate):
394 | num_N = int(h_V.shape[0] * mask_rate) # 要选择的样本数量,即15%
395 | indices = torch.randperm(h_V.shape[0], device=h_V.device)
396 | selected_indices = indices[:num_N]
397 | return selected_indices
398 |
399 | def forward(self, h_V, h_E, T_ts, edge_idx, batch_id, h_E_0):
400 | if self.training:
401 | selected_indices = self.get_rand_idx(h_V, self.mask_rate)
402 | h_V[selected_indices] = self.mask_token.weight[0]
403 |
404 | selected_indices = self.get_rand_idx(h_E, self.mask_rate)
405 | h_E[selected_indices] = self.mask_token.weight[1]
406 |
407 | if not self.training: # for ablation study
408 | selected_indices = self.get_rand_idx(h_V, self.exp_v_mask_rate)
409 | h_V[selected_indices] = self.mask_token.weight[0]
410 |
411 | selected_indices = self.get_rand_idx(h_E, self.exp_e_mask_rate)
412 | h_E[selected_indices] = self.mask_token.weight[1]
413 |
414 | h_E = self.geofeat(h_V, h_E, T_ts, edge_idx, h_E_0)
415 | h_V = self.attention(h_V, h_E, edge_idx)
416 | h_V = self.update_node(h_V, batch_id)
417 | h_E = self.update_edge( h_V, h_E, T_ts, edge_idx, batch_id )
418 | return h_V, h_E
419 |
420 |
421 | class StructureEncoder(nn.Module):
422 | def __init__(self,
423 | geo_layer,
424 | attn_layer,
425 | ffn_layer,
426 | edge_layer,
427 | encoder_layer,
428 | hidden_dim,
429 | dropout=0,
430 | mask_rate=0.15,
431 | exp_v_mask_rate=0.,
432 | exp_e_mask_rate=0.):
433 | """ Graph labeling network """
434 | super(StructureEncoder, self).__init__()
435 | self.__dict__.update(locals())
436 | self.encoder_layers = nn.ModuleList([GeneralGNN(geo_layer,
437 | attn_layer,
438 | ffn_layer,
439 | edge_layer,
440 | hidden_dim,
441 | dropout=dropout,
442 | mask_rate=mask_rate,
443 | exp_v_mask_rate=exp_v_mask_rate,
444 | exp_e_mask_rate=exp_e_mask_rate) for i in range(encoder_layer)])
445 | self.s = nn.Linear(hidden_dim, 1)
446 |
447 | def forward(self, h_S,
448 | T,
449 | h_V,
450 | h_E,
451 | T_ts,
452 | edge_idx,
453 | batch_id, h_E_0):
454 | # No global frame handling needed - work only with local components
455 | outputs = []
456 | for layer in self.encoder_layers:
457 | h_V, h_E = layer(h_V, h_E, T_ts, edge_idx, batch_id, h_E_0)
458 | outputs.append(h_V.unsqueeze(1))
459 |
460 | outputs = torch.cat(outputs, dim=1)
461 | S = F.sigmoid(self.s(outputs))
462 | output = torch.einsum('nkc, nkb -> nbc', outputs, S).squeeze(1)
463 | return output
464 |
465 |
466 | class UniIFEncoder(nn.Module):
467 | def __init__(self, args, **kwargs):
468 | """ Graph labeling network """
469 | super(UniIFEncoder, self).__init__()
470 | self.__dict__.update(locals())
471 | self.hidden_dim = args.hidden_dim
472 | geo_layer, attn_layer, node_layer, edge_layer, encoder_layer, hidden_dim, dropout, mask_rate = args.geo_layer, args.attn_layer, args.node_layer, args.edge_layer, args.encoder_layer, args.hidden_dim, args.dropout, args.mask_rate
473 |
474 | exp_v_mask_rate = getattr(args, 'exp_v_mask_rate', 0.)
475 | exp_e_mask_rate = getattr(args, 'exp_e_mask_rate', 0.)
476 |
477 | self.node_embedding = build_MLP(2, 76, hidden_dim, hidden_dim)
478 | self.edge_embedding = build_MLP(2, 196+16, hidden_dim, hidden_dim)
479 | self.encoder = StructureEncoder(geo_layer, attn_layer, node_layer, edge_layer, encoder_layer, hidden_dim, dropout, mask_rate,
480 | exp_v_mask_rate, exp_e_mask_rate)
481 | self.chain_embeddings = nn.Embedding(2, 16)
482 |
483 | # CLS token for structural features
484 | self.struct_cls_token = nn.Parameter(torch.randn(1 + 8, hidden_dim))
485 |
486 | self._init_params()
487 |
488 | def _init_params(self):
489 | for name, p in self.named_parameters():
490 | if p.dim() > 1:
491 | nn.init.xavier_uniform_(p)
492 |
493 | def forward(self, batch, num_global=3):
494 | h_V, h_E, edge_idx, batch_id, chain_features = batch['_V'], batch['_E'], batch['edge_idx'], batch['batch_id'], batch['chain_features']
495 | correspondences = batch['correspondences']
496 | # Remove global virtual frame variables
497 | T = Rigid(Rotation(batch['T_rot']), batch['T_trans'])
498 | T_ts = Rigid(Rotation(batch['T_ts_rot']), batch['T_ts_trans'])
499 | h_E = torch.cat([h_E, self.chain_embeddings(chain_features)], dim=-1)
500 |
501 | h_E_0 = h_E
502 |
503 | node_embeds = self.node_embedding(h_V)
504 |
505 | # Prepare for adding CLS tokens
506 | B = len(batch_id.unique()) # Batch size
507 | max_nodes = 9 + max([(batch_id == i).sum().item() for i in range(B)]) # 9 CLS + max residues per batch
508 | # Add CLS token embeddings to node embeddings
509 | # Directly add CLS tokens to the beginning of each batch in the original flattened node_embeds
510 | cls_tokens = self.struct_cls_token.expand(B, -1, -1) # (B, 9, hidden_dim)
511 | # For each batch, prepend 9 CLS tokens to the corresponding node embeddings
512 | node_embeds_with_cls = []
513 | for i in range(B):
514 | node_indices = (batch_id == i).nonzero(as_tuple=True)[0]
515 | this_node_embeds = node_embeds[node_indices] # (num_nodes_i, hidden_dim)
516 | this_cls_tokens = cls_tokens[i] # (9, hidden_dim)
517 | node_embeds_with_cls.append(torch.cat([this_cls_tokens, this_node_embeds], dim=0)) # (9 + num_nodes_i, hidden_dim)
518 |
519 | h_V = torch.cat(node_embeds_with_cls, dim=0) # (sum_i (9 + num_nodes_i), hidden_dim)
520 |
521 | batch_id_with_cls = []
522 | # batch_id identifies the which batch each node belongs to
523 | # get unique batch_id
524 | unique_batch_id = batch_id.unique()
525 | for i in unique_batch_id:
526 | # 9 i's before the original i's, use tensor
527 | batch_id_with_cls.append(torch.full((9 + (batch_id == i).sum().item(),), i, device=batch_id.device))
528 | batch_id_with_cls = torch.cat(batch_id_with_cls, dim=0)
529 |
530 | h_E = self.edge_embedding(h_E)
531 | h_E, edge_idx, T_ts = self._pad_and_stack_edges(h_E, batch_id, edge_idx, T_ts, max_nodes, correspondences) # Shape: (total_edges, hidden_dim), (2, total_edges)
532 |
533 | h_S = None
534 |
535 | # Get structural node embeddings from encoder (without global frames)
536 | node_embeds = self.encoder(h_S,
537 | T,
538 | h_V,
539 | h_E,
540 | T_ts,
541 | edge_idx,
542 | batch_id_with_cls, h_E_0)
543 |
544 | unflattened_node_embeds = self._pad_and_stack(node_embeds, batch_id_with_cls, max_nodes)
545 |
546 | return unflattened_node_embeds
547 |
548 |
549 | def _get_features(self, batch):
550 | return batch
551 |
552 | def _pad_and_stack(self, features, batch_id, max_nodes):
553 | """Pad and stack node features."""
554 | B = batch_id.max().item() + 1 # Batch size
555 | padded = torch.zeros((B, max_nodes, self.hidden_dim), device=features.device)
556 |
557 | for i in range(B):
558 | node_indices = (batch_id == i).nonzero(as_tuple=True)[0]
559 | padded[i, :len(node_indices), :] = features[node_indices]
560 |
561 | return padded
562 |
563 | def _pad_and_stack_edges(self, edge_weights, batch_id, E_idx, T_ts, max_nodes, correspondences):
564 | """Convert edge features to include CLS tokens and return in 2D format with updated edge indices."""
565 | B = batch_id.max().item() + 1 # Batch size
566 |
567 | is_vector = True
568 | hidden_dim = edge_weights.shape[-1]
569 |
570 | # Collect all new edge features and indices
571 | new_edge_features = []
572 | new_edge_indices = []
573 | new_rot_mats = []
574 | new_trans = []
575 | # shape of T_ts._rots._rot_mats torch.Size([#edges, 1, 3, 3])
576 | # shape of T_ts._trans torch.Size([#edges, 1, 3])
577 |
578 | # Calculate batch offsets for global node indexing
579 | batch_sizes = [(batch_id == i).sum().item() for i in range(B)]
580 | batch_offsets = [0]
581 | for i in range(B):
582 | batch_offsets.append(batch_offsets[-1] + batch_sizes[i] + 9) # +9 for CLS tokens per batch
583 |
584 | for i in range(B):
585 | node_indices = (batch_id == i).nonzero(as_tuple=True)[0]
586 | min_node_id = node_indices.min().item()
587 | num_nodes = len(node_indices)
588 |
589 | src, dst = E_idx[0, :], E_idx[1, :]
590 | local_edges_mask = (src >= min_node_id) & (src < min_node_id + num_nodes)
591 |
592 | batch_offset = batch_offsets[i]
593 |
594 | # 1. Process original edges (shifted by +9 for CLS tokens)
595 | if local_edges_mask.any():
596 | src_local = src[local_edges_mask] - min_node_id + 9 # +9 for CLS tokens
597 | dst_local = dst[local_edges_mask] - min_node_id + 9 # +9 for CLS tokens
598 |
599 | # Add to global indices
600 | global_src = src_local + batch_offset
601 | global_dst = dst_local + batch_offset
602 |
603 | # Original edge features
604 | original_edge_features = edge_weights[local_edges_mask]
605 | new_edge_features.append(original_edge_features)
606 |
607 | new_edge_indices.append(torch.stack([global_src, global_dst]))
608 | # T_ts is a list, so we need to use nonzero indices to select elements
609 | local_edge_indices = local_edges_mask.nonzero(as_tuple=True)[0]
610 | new_rot_mats.append(T_ts._rots._rot_mats[local_edge_indices])
611 | new_trans.append(T_ts._trans[local_edge_indices])
612 |
613 | # 2. Determine CLS edge feature
614 | cls_edge_feature = torch.zeros(hidden_dim, device=edge_weights.device)
615 | # determine the T_ts for the CLS edge
616 | cls_rot_mats = torch.zeros_like(T_ts._rots._rot_mats[0])
617 | cls_trans = torch.zeros_like(T_ts._trans[0])
618 |
619 | # 3. Add global CLS token connections (index 0)
620 | global_cls_idx = batch_offset + 0 # Global CLS index for this batch
621 | global_node_indices = torch.arange(9, 9 + num_nodes, device=edge_weights.device) + batch_offset
622 |
623 | # CLS -> Nodes
624 | cls_to_nodes_src = torch.full((num_nodes,), global_cls_idx, device=edge_weights.device)
625 | cls_to_nodes_dst = global_node_indices
626 | cls_to_nodes_features = cls_edge_feature.unsqueeze(0).repeat(num_nodes, 1)
627 |
628 | new_edge_indices.append(torch.stack([cls_to_nodes_src, cls_to_nodes_dst]))
629 | new_edge_features.append(cls_to_nodes_features)
630 | new_rot_mats.append(cls_rot_mats.unsqueeze(0).repeat(num_nodes, 1, 1, 1))
631 | new_trans.append(cls_trans.unsqueeze(0).repeat(num_nodes, 1, 1))
632 |
633 | # Nodes -> CLS
634 | new_edge_indices.append(torch.stack([cls_to_nodes_dst, cls_to_nodes_src]))
635 | new_edge_features.append(cls_to_nodes_features)
636 | new_rot_mats.append(cls_rot_mats.unsqueeze(0).repeat(num_nodes, 1, 1, 1))
637 | new_trans.append(cls_trans.unsqueeze(0).repeat(num_nodes, 1, 1))
638 |
639 | # 4. Add subarea CLS token connections (indices 1-8)
640 | if len(correspondences) > i and len(correspondences[i]) > 0:
641 | for sub_idx, (ca_neighbors, _) in enumerate(correspondences[i], start=1): # Limit to 8 subareas
642 | if len(ca_neighbors) > 0:
643 | global_subarea_idx = batch_offset + sub_idx # Global subarea CLS index
644 | global_ca_neighbors = torch.tensor(ca_neighbors, device=edge_weights.device) + 9 + batch_offset
645 |
646 | # Subarea CLS -> CA neighbors
647 | subarea_to_ca_src = torch.full((len(ca_neighbors),), global_subarea_idx, device=edge_weights.device)
648 | subarea_to_ca_dst = global_ca_neighbors
649 | subarea_to_ca_features = cls_edge_feature.unsqueeze(0).repeat(len(ca_neighbors), 1)
650 |
651 | new_edge_indices.append(torch.stack([subarea_to_ca_src, subarea_to_ca_dst]))
652 | new_edge_features.append(subarea_to_ca_features)
653 | new_rot_mats.append(cls_rot_mats.unsqueeze(0).repeat(len(ca_neighbors), 1, 1, 1))
654 | new_trans.append(cls_trans.unsqueeze(0).repeat(len(ca_neighbors), 1, 1))
655 |
656 | # CA neighbors -> Subarea CLS
657 | new_edge_indices.append(torch.stack([subarea_to_ca_dst, subarea_to_ca_src]))
658 | new_edge_features.append(subarea_to_ca_features)
659 | new_rot_mats.append(cls_rot_mats.unsqueeze(0).repeat(len(ca_neighbors), 1, 1, 1))
660 | new_trans.append(cls_trans.unsqueeze(0).repeat(len(ca_neighbors), 1, 1))
661 |
662 | # Concatenate all edge features and indices
663 | if new_edge_features:
664 | final_edge_features = torch.cat(new_edge_features, dim=0) # (total_edges, hidden_dim)
665 | final_edge_indices = torch.cat(new_edge_indices, dim=1) # (2, total_edges)
666 | final_rot_mats = torch.cat(new_rot_mats, dim=0)
667 | final_trans = torch.cat(new_trans, dim=0)
668 | T_ts._rots._rot_mats = final_rot_mats
669 | T_ts._trans = final_trans
670 | rbf_ts = rbf(T_ts._trans.norm(dim=-1), 0, 50, 16)[:,0].view(final_edge_features.shape[0],-1)
671 | T_ts.rbf = rbf_ts
672 | else:
673 | # Handle empty case
674 | final_edge_features = torch.zeros((0, hidden_dim), device=edge_weights.device)
675 | final_edge_indices = torch.zeros((2, 0), device=edge_weights.device, dtype=torch.long)
676 |
677 | return final_edge_features, final_edge_indices, T_ts
678 |
679 |
680 | class PositionalEncoding(nn.Module):
681 | def __init__(self, d_model, dropout=0.1, max_len=5000):
682 | super(PositionalEncoding, self).__init__()
683 | self.dropout = nn.Dropout(p=dropout)
684 |
685 | pe = torch.zeros(max_len, d_model)
686 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
687 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
688 | pe[:, 0::2] = torch.sin(position * div_term)
689 | pe[:, 1::2] = torch.cos(position * div_term)
690 | pe = pe.unsqueeze(0).transpose(0, 1)
691 | self.register_buffer('pe', pe)
692 |
693 | def forward(self, x):
694 | x = x + self.pe[:x.size(0), :]
695 | return self.dropout(x)
696 |
697 |
698 | class UBC2Model(nn.Module):
699 | def __init__(self, args, queue_size=64, **kwargs):
700 | """ Graph labeling network """
701 | super(UBC2Model, self).__init__()
702 | self.args = args
703 | hidden_dim = args.hidden_dim
704 | dropout = args.dropout
705 | self.modal_mask_ratio = args.modal_mask_ratio
706 | self.contrastive_pretrain = args.contrastive_pretrain
707 | self.contrastive_pretrain_both = args.contrastive_pretrain_both
708 | self.contrastive_loss_global_alpha = args.contrastive_loss_global_alpha
709 | self.contrastive_loss_local_alpha = args.contrastive_loss_local_alpha
710 |
711 | self.if_strucenc_only = args.if_strucenc_only
712 |
713 | self.if_warmup_train = args.if_warmup_train
714 |
715 | self.bc_indices = getattr(args, 'bc_indices', [0, 1])
716 | self.exp_wo_bcgraph = getattr(args, 'exp_wo_bcgraph', False)
717 |
718 | self.encoder = UniIFEncoder(args)
719 |
720 | l_max = 2
721 | num_scales = 4
722 | self.surface_encoder = PointCloudMessagePassing(args, len(self.bc_indices), 1, l_max, num_scales, hidden_dim)
723 |
724 | # New Transformer decoder and MLP for final prediction
725 | decoder_layer = TransformerDecoderLayer(d_model=hidden_dim, nhead=8, dropout=dropout, batch_first=True)
726 | self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=3)
727 | self.mlp = nn.Sequential(
728 | nn.Linear(hidden_dim, hidden_dim),
729 | nn.ReLU(),
730 | nn.Linear(hidden_dim, 33)
731 | )
732 |
733 | # Positional encoding
734 | self.positional_encoding = PositionalEncoding(hidden_dim, dropout)
735 |
736 | self.contrastive_learning = args.contrastive_learning
737 |
738 | # Temperature for contrastive learning
739 | self.temperature = 0.1
740 | self.queue_size = queue_size
741 |
742 | # Initialize queues for structural and biochemical CLS tokens
743 | self.struct_queue = nn.Parameter(torch.zeros(queue_size, hidden_dim), requires_grad=False)
744 | self.biochem_queue = nn.Parameter(torch.zeros(queue_size, hidden_dim), requires_grad=False)
745 | self.queue_ptr = nn.Parameter(torch.zeros(1, dtype=torch.long), requires_grad=False)
746 |
747 | self._init_params()
748 |
749 | if self.contrastive_pretrain:
750 | modules_to_freeze = {
751 | "encoder": self.encoder,
752 | "transformer_decoder": self.transformer_decoder,
753 | "mlp": self.mlp
754 | }
755 |
756 | for name, module in modules_to_freeze.items():
757 | print(f"--- Freezing module: '{name}'")
758 | for param in module.parameters():
759 | param.requires_grad = False
760 | module.eval()
761 | elif self.contrastive_pretrain_both:
762 | modules_to_freeze = {
763 | "transformer_decoder": self.transformer_decoder,
764 | "mlp": self.mlp
765 | }
766 |
767 | for name, module in modules_to_freeze.items():
768 | print(f"--- Freezing module: '{name}'")
769 | for param in module.parameters():
770 | param.requires_grad = False
771 | module.eval()
772 | elif self.if_strucenc_only:
773 | modules_to_freeze = {
774 | "surface_encoder": self.surface_encoder,
775 | "transformer_decoder": self.transformer_decoder,
776 | }
777 |
778 | for name, module in modules_to_freeze.items():
779 | print(f"--- Freezing module: '{name}'")
780 | for param in module.parameters():
781 | param.requires_grad = False
782 | module.eval()
783 | elif self.if_warmup_train:
784 | modules_to_freeze = {
785 | "encoder": self.encoder,
786 | "surface_encoder": self.surface_encoder
787 | }
788 | for name, module in modules_to_freeze.items():
789 | print(f"--- Freezing module: '{name}'")
790 | for param in module.parameters():
791 | param.requires_grad = False
792 | module.eval()
793 |
794 | def forward(self, batch):
795 | batch_id = batch['batch_id']
796 | h_V_unflattened = self.encoder(batch)
797 |
798 | # Manually extract the CLS tokens (global + subarea) from the structure encoder
799 | struct_cls_tokens = h_V_unflattened[:, :9, :] # First 9 tokens: global (0th) + subarea (1st to 8th)
800 | h_V_unflattened = h_V_unflattened[:, 9:, :] # The rest of the node embeddings
801 |
802 | # Unflatten h_V and mask to have batch dimension
803 | max_length = batch['lengths'].max().item()
804 | batch_size = len(batch['lengths'])
805 | mask_unflattened = torch.zeros(batch_size, max_length, device=h_V_unflattened.device)
806 |
807 | # Efficiently assign values to h_V_unflattened and mask_unflattened
808 | for idx in torch.unique(batch_id):
809 | mask = (batch_id == idx)
810 | mask_unflattened[idx, :mask.sum()] = 1
811 | # Create padding masks
812 | target_padding_mask = ~mask_unflattened.bool()
813 |
814 | ### surface encoder
815 | surfaces, biochem_feats, correspondences = batch['surface'], batch['features'], batch['correspondences']
816 |
817 | if self.training:
818 | # generate a random number between 0 and 1
819 | random_number = torch.rand(1).item()
820 | if random_number < self.modal_mask_ratio:
821 | biochem_feats = torch.randn_like(biochem_feats)
822 |
823 | if self.if_strucenc_only:
824 | decoder_output = h_V_unflattened
825 | # Flatten decoder_output and remove padding
826 | mask = mask_unflattened.bool()
827 | decoder_output = decoder_output[mask]
828 |
829 | # Predict labels using MLP
830 | logits = self.mlp(decoder_output)
831 | log_probs = F.log_softmax(logits, dim=-1)
832 |
833 | elif self.contrastive_pretrain or self.contrastive_pretrain_both:
834 | h_surface = self.surface_encoder(surfaces, biochem_feats, correspondences)
835 |
836 | # Manually extract the CLS tokens (global + subarea) from the biochemical encoder
837 | biochem_cls_tokens = h_surface[:, -9:, :] # Last 9 tokens: global (0th) + subarea (1st to 8th)
838 | h_surface = h_surface[:, :-9, :] # The rest of the biochemical node embeddings
839 |
840 | logits = 0
841 | log_probs = 0
842 | else:
843 | h_surface = self.surface_encoder(surfaces, biochem_feats, correspondences)
844 |
845 | # Manually extract the CLS tokens (global + subarea) from the biochemical encoder
846 | biochem_cls_tokens = h_surface[:, -9:, :] # Last 9 tokens: global (0th) + subarea (1st to 8th)
847 | h_surface = h_surface[:, :-9, :] # The rest of the biochemical node embeddings
848 |
849 | ss_connection_mask = batch['ss_connection']
850 | ss_connection_mask = ~ss_connection_mask.bool().repeat(8, 1, 1)
851 |
852 | # Transformer decoder to fuse h_V_unflattened and h_surface
853 | # Add positional encoding to the inputs of the Transformer decoder
854 | h_V_unflattened = self.positional_encoding(h_V_unflattened)
855 |
856 | if self.exp_wo_bcgraph:
857 | decoder_output = self.transformer_decoder(
858 | h_V_unflattened, h_surface,
859 | tgt_key_padding_mask=target_padding_mask,
860 | )
861 | else:
862 | decoder_output = self.transformer_decoder(
863 | h_V_unflattened, h_surface,
864 | tgt_key_padding_mask=target_padding_mask,
865 | memory_mask=ss_connection_mask
866 | )
867 |
868 | # Flatten decoder_output and remove padding
869 | mask = mask_unflattened.bool()
870 | decoder_output = decoder_output[mask]
871 |
872 | # Predict labels using MLP
873 | logits = self.mlp(decoder_output)
874 | log_probs = F.log_softmax(logits, dim=-1)
875 |
876 | # Contrastive learning
877 | if (self.training and random_number < self.modal_mask_ratio) or not self.contrastive_learning:
878 | contrastive_loss = 0
879 | else:
880 | contrastive_loss_global = self._contrastive_loss(struct_cls_tokens[:, 0, :], biochem_cls_tokens[:, 0, :]) # Global CLS
881 | contrastive_loss_subarea = self._contrastive_loss_subarea(struct_cls_tokens[:, 1:, :], biochem_cls_tokens[:, 1:, :]) # Subarea CLS
882 | contrastive_loss = self.contrastive_loss_global_alpha * contrastive_loss_global + self.contrastive_loss_local_alpha * contrastive_loss_subarea
883 |
884 | # Update queues with current batch global CLS tokens
885 | self._dequeue_and_enqueue(struct_cls_tokens[:, 0, :], biochem_cls_tokens[:, 0, :])
886 |
887 | return {'log_probs': log_probs, 'contrastive_loss': contrastive_loss, 'logits': logits}
888 |
889 | @torch.no_grad()
890 | def _dequeue_and_enqueue(self, struct_cls_token, biochem_cls_token):
891 | """Append new CLS tokens to the queue and dequeue older ones."""
892 | batch_size = struct_cls_token.size(0)
893 |
894 | # Get current position in the queue
895 | ptr = int(self.queue_ptr)
896 |
897 | # Replace oldest entries with the new ones
898 | if ptr + batch_size > self.queue_size:
899 | ptr = 0
900 | self.struct_queue[ptr:ptr + batch_size, :] = struct_cls_token
901 | self.biochem_queue[ptr:ptr + batch_size, :] = biochem_cls_token
902 |
903 | # Move pointer and wrap-around if necessary
904 | ptr = (ptr + batch_size) % self.queue_size
905 | self.queue_ptr[0] = ptr
906 |
907 | def _contrastive_loss(self, struct_cls_token, biochem_cls_token):
908 | """Compute NT-Xent contrastive loss using queue-based negative sampling."""
909 | batch_size = struct_cls_token.size(0)
910 |
911 | # Normalize CLS tokens
912 | z_i = F.normalize(struct_cls_token, dim=-1)
913 | z_j = F.normalize(biochem_cls_token, dim=-1)
914 |
915 | # Normalize queue embeddings
916 | struct_queue_norm = F.normalize(self.struct_queue.clone().detach(), dim=-1)
917 | biochem_queue_norm = F.normalize(self.biochem_queue.clone().detach(), dim=-1)
918 |
919 | # Cosine similarity between current CLS tokens
920 | sim_ij = torch.matmul(z_i, z_j.T) / self.temperature # (batch_size, batch_size)
921 |
922 | # Cosine similarity with negative samples from the queue
923 | sim_i_struct_queue = torch.matmul(z_i, biochem_queue_norm.T) / self.temperature # (batch_size, queue_size)
924 | sim_j_biochem_queue = torch.matmul(z_j, struct_queue_norm.T) / self.temperature # (batch_size, queue_size)
925 |
926 | # Combine positive and negative samples
927 | sim_matrix_i = torch.cat([sim_ij, sim_i_struct_queue], dim=1) # (batch_size, batch_size + queue_size)
928 | sim_matrix_j = torch.cat([sim_ij.T, sim_j_biochem_queue], dim=1) # (batch_size, batch_size + queue_size)
929 |
930 | # Create labels (positive samples on the diagonal)
931 | labels = torch.arange(batch_size).long().to(sim_matrix_i.device)
932 |
933 | # Contrastive loss for both modalities
934 | loss_i = F.cross_entropy(sim_matrix_i, labels)
935 | loss_j = F.cross_entropy(sim_matrix_j, labels)
936 |
937 | loss = (loss_i + loss_j) / 2.0
938 | return loss
939 |
940 | def _contrastive_loss_subarea(self, struct_subarea_cls_tokens, biochem_subarea_cls_tokens):
941 | """Compute contrastive loss for the subarea CLS tokens without using a queue, using only the current batch."""
942 | batch_size, num_subareas, hidden_dim = struct_subarea_cls_tokens.size()
943 |
944 | # Normalize CLS tokens
945 | z_i = F.normalize(struct_subarea_cls_tokens, dim=-1)
946 | z_j = F.normalize(biochem_subarea_cls_tokens, dim=-1)
947 |
948 | # Cosine similarity within the batch for subarea CLS tokens
949 | sim_ij = torch.matmul(z_i, z_j.transpose(1, 2)) / self.temperature # (batch_size, num_subareas, num_subareas)
950 |
951 | # Create labels (positive samples on the diagonal)
952 | labels = torch.arange(num_subareas).long().to(sim_ij.device).unsqueeze(0).expand(batch_size, -1)
953 |
954 | # Reshape sim_ij and labels for efficient cross-entropy calculation
955 | sim_ij = sim_ij.view(batch_size * num_subareas, num_subareas) # (batch_size * num_subareas, num_subareas)
956 | labels = labels.reshape(batch_size * num_subareas) # (batch_size * num_subareas,)
957 |
958 | # Compute contrastive loss in one step
959 | loss = F.cross_entropy(sim_ij, labels)
960 |
961 | return loss
962 |
963 | def _init_params(self):
964 | for name, p in self.named_parameters():
965 | if p.dim() > 1:
966 | nn.init.xavier_uniform_(p)
967 |
968 | def _get_features(self, batch):
969 | return batch
970 |
971 |
972 |
--------------------------------------------------------------------------------