├── .gitignore ├── LICENSE ├── README.md ├── developer ├── README.md ├── src │ ├── __init__.py │ ├── config.py │ ├── data.py │ ├── dataset.py │ ├── model.py │ ├── network │ │ ├── __init__.py │ │ ├── head.py │ │ ├── ligand_encoder.py │ │ └── pharmacophore_encoder.py │ └── trainer.py └── train_example.py ├── environment.yml ├── examples ├── 6OIM_D_MOV.pdb ├── 6OIM_protein.pdb └── library.tar ├── feature_extraction.py ├── images └── overview.png ├── modeling.py ├── pyproject.toml ├── screening.py ├── src ├── pmnet │ ├── __init__.py │ ├── api │ │ ├── __init__.py │ │ └── typing.py │ ├── data │ │ ├── __init__.py │ │ ├── constant.py │ │ ├── extract_pocket.py │ │ ├── objects │ │ │ ├── __init__.py │ │ │ ├── atom_classes.py │ │ │ ├── objects.py │ │ │ └── utils.py │ │ ├── parser.py │ │ ├── pointcloud.py │ │ └── token_inference.py │ ├── module.py │ ├── network │ │ ├── __init__.py │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ ├── swin.py │ │ │ ├── swinv2.py │ │ │ └── timm.py │ │ ├── builder.py │ │ ├── cavity_head.py │ │ ├── decoders │ │ │ ├── __init__.py │ │ │ └── fpn_decoder.py │ │ ├── detector.py │ │ ├── feature_embedding.py │ │ ├── mask_head.py │ │ ├── necks │ │ │ ├── __init__.py │ │ │ └── center_crop.py │ │ ├── nn │ │ │ ├── __init__.py │ │ │ └── layers.py │ │ └── token_head.py │ ├── pharmacophore_model.py │ ├── scoring │ │ ├── __init__.py │ │ ├── graph_match.py │ │ ├── ligand.py │ │ ├── ligand_utils.py │ │ ├── match_utils.py │ │ ├── match_utils_numba.py │ │ └── tree.py │ └── utils │ │ ├── __init__.py │ │ ├── density_map.py │ │ ├── download_weight.py │ │ └── smoothing.py └── pmnet_appl │ ├── README.md │ ├── __init__.py │ ├── base │ ├── __init__.py │ └── proxy.py │ ├── keys │ ├── test.txt │ └── train.txt │ ├── sbddreward │ ├── __init__.py │ ├── data.py │ ├── get_cache.py │ ├── network │ │ ├── __init__.py │ │ ├── block.py │ │ ├── head.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── one_hot.py │ │ │ ├── pair_transition.py │ │ │ ├── triangular_attention.py │ │ │ └── triangular_multiplicative_update.py │ │ ├── ligand_encoder.py │ │ └── pharmacophore_encoder.py │ └── proxy.py │ └── tacogfn_reward │ ├── __init__.py │ ├── data.py │ ├── db_keys │ ├── test.txt │ └── train.txt │ ├── get_cache.py │ └── proxy.py └── utils ├── parse_rcsb_pdb.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # USER 2 | .DS_Store 3 | weights/ 4 | run.sh 5 | result/ 6 | examples/library/ 7 | nogit/ 8 | maintain_test/ 9 | uv.lock 10 | 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | cover/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | .pybuilder/ 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | # For a library or package, you might want to ignore these files since the code is 98 | # intended to run in multiple environments; otherwise, check them in: 99 | # .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # poetry 109 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 110 | # This is especially recommended for binary packages to ensure reproducibility, and is more 111 | # commonly ignored for libraries. 112 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 113 | #poetry.lock 114 | 115 | # pdm 116 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 117 | #pdm.lock 118 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 119 | # in version control. 120 | # https://pdm.fming.dev/#use-with-ide 121 | .pdm.toml 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # mypy 153 | .mypy_cache/ 154 | .dmypy.json 155 | dmypy.json 156 | 157 | # Pyre type checker 158 | .pyre/ 159 | 160 | # pytype static type analyzer 161 | .pytype/ 162 | 163 | # Cython debug symbols 164 | cython_debug/ 165 | 166 | # PyCharm 167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 169 | # and can be added to the global gitignore or merged into this file. For a more nuclear 170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 171 | #.idea/ 172 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Seonghwan Seo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /developer/README.md: -------------------------------------------------------------------------------- 1 | ## Using PharmacoNet's protein representation 2 | 3 | Example scripts to use PharmacoNet's protein pharmacophore representation, which depend on `torch-geometric`. 4 | 5 | ```bash 6 | # construct conda environment; pymol-open-source is not required. 7 | conda create -n pmnet-dev python=3.10 openbabel=3.1.1 8 | conda activate pmnet-dev 9 | # install PharmacoNet & torch_geometric & wandb & tensorboard 10 | pip install -e '.[dev]' --find-links https://data.pyg.org/whl/torch-2.3.1+cu121.html 11 | ``` 12 | 13 | Example datasets (100 or 1,000 pockets) can be available at [Google Drive](https://drive.google.com/drive/folders/1o8tDCsjIqaPRoJhs5SKW4yi0geA9h_Nv?usp=sharing), which are constructed by CrossDocked2020 and QuickVina 2.1. 14 | -------------------------------------------------------------------------------- /developer/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeonghwanSeo/PharmacoNet/6595694cdf910b52c9fa0512c35f8139f5be2cf5/developer/src/__init__.py -------------------------------------------------------------------------------- /developer/src/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from omegaconf import MISSING 4 | 5 | 6 | @dataclass 7 | class ModelConfig: 8 | hidden_dim: int = 128 9 | ligand_num_convs: int = 4 10 | 11 | 12 | @dataclass 13 | class DataConfig: 14 | protein_info_path: str = MISSING 15 | train_protein_code_path: str = MISSING 16 | protein_dir: str = MISSING 17 | ligand_path: str = MISSING 18 | 19 | 20 | @dataclass 21 | class LrSchedulerConfig: 22 | scheduler: str = "lambdalr" 23 | lr_decay: int = 50_000 24 | 25 | 26 | @dataclass 27 | class OptimizerConfig: 28 | opt: str = "adam" 29 | lr: float = 1e-3 30 | eps: float = 1e-8 31 | betas: tuple[float, float] = (0.9, 0.999) 32 | weight_decay: float = 0.05 33 | clip_grad: float = 1.0 34 | 35 | 36 | @dataclass 37 | class TrainConfig: 38 | val_every: int = 2_000 39 | log_every: int = 10 40 | print_every: int = 100 41 | save_every: int = 1_000 42 | max_iterations: int = 300_000 43 | batch_size: int = 4 44 | num_workers: int = 4 45 | 46 | opt: OptimizerConfig = OptimizerConfig() 47 | lr_scheduler: LrSchedulerConfig = LrSchedulerConfig() 48 | 49 | # NOTE: HYPER PARAMETER 50 | split_ratio: float = 0.9 51 | center_noise: float = 3.0 52 | 53 | 54 | @dataclass 55 | class Config: 56 | log_dir: str = MISSING 57 | model: ModelConfig = ModelConfig() 58 | train: TrainConfig = TrainConfig() 59 | data: DataConfig = DataConfig() 60 | 61 | def to_dict(self): 62 | return config_to_dict(self) 63 | 64 | 65 | def config_to_dict(obj) -> dict: 66 | if not hasattr(obj, "__dataclass_fields__"): 67 | return obj 68 | result = {} 69 | for field in obj.__dataclass_fields__.values(): 70 | value = getattr(obj, field.name) 71 | result[field.name] = config_to_dict(value) 72 | return {"config": result} 73 | -------------------------------------------------------------------------------- /developer/src/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from openbabel import pybel 3 | from openbabel.pybel import ob 4 | from torch_geometric.data import Data as Data 5 | 6 | pybel.ob.OBMessageHandler().SetOutputLevel(0) # 0: None 7 | 8 | 9 | ATOM_DICT = { 10 | 6: 0, # C 11 | 7: 1, # N 12 | 8: 2, # O 13 | 9: 3, # F 14 | 15: 4, # P 15 | 16: 5, # S 16 | 17: 6, # Cl 17 | 35: 7, # Br 18 | 53: 8, # I 19 | -1: 9, # UNKNOWN 20 | } 21 | NUM_ATOM_FEATURES = 10 + 2 + 2 22 | 23 | BOND_DICT = { 24 | 1: 0, 25 | 2: 1, 26 | 3: 2, 27 | 1.5: 3, # AROMATIC 28 | -1: 4, # UNKNOWN 29 | } 30 | NUM_BOND_FEATURES = 5 31 | 32 | 33 | def smi2graph(smiles: str) -> Data: 34 | return Data(**smi2graphdata(smiles)) 35 | 36 | 37 | def smi2graphdata(smiles: str) -> dict[str, torch.Tensor]: 38 | pbmol = pybel.readstring("smi", smiles) 39 | atom_features = get_atom_features(pbmol) 40 | edge_attr, edge_index = get_bond_features(pbmol) 41 | return dict( 42 | x=torch.FloatTensor(atom_features), 43 | edge_index=torch.LongTensor(edge_index), 44 | edge_attr=torch.FloatTensor(edge_attr), 45 | ) 46 | 47 | 48 | def get_atom_features(pbmol: pybel.Molecule) -> list[list[float]]: 49 | facade = pybel.ob.OBStereoFacade(pbmol.OBMol) 50 | features = [] 51 | for atom in pbmol.atoms: 52 | feat = [0] * NUM_ATOM_FEATURES 53 | feat[ATOM_DICT.get(atom.atomicnum, 9)] = 1 54 | 55 | mid = atom.OBAtom.GetId() 56 | if facade.HasTetrahedralStereo(mid): 57 | stereo = facade.GetTetrahedralStereo(mid).GetConfig().winding 58 | if stereo == pybel.ob.OBStereo.Clockwise: 59 | feat[10] = 1 60 | else: 61 | feat[11] = 1 62 | charge = atom.formalcharge 63 | if charge > 0: 64 | feat[12] = 1 65 | elif charge < 0: 66 | feat[13] = 1 67 | features.append(feat) 68 | return features 69 | 70 | 71 | def get_bond_features( 72 | pbmol: pybel.Molecule, 73 | ) -> tuple[list[list[float]], tuple[list[int], list[int]]]: 74 | edge_index_row = [] 75 | edge_index_col = [] 76 | edge_attr = [] 77 | obmol: ob.OBMol = pbmol.OBMol 78 | for obbond in ob.OBMolBondIter(obmol): 79 | obbond: ob.OBBond 80 | edge_index_row.append(obbond.GetBeginAtomIdx() - 1) 81 | edge_index_col.append(obbond.GetEndAtomIdx() - 1) 82 | 83 | feat = [0] * NUM_BOND_FEATURES 84 | if obbond.IsAromatic(): 85 | feat[3] = 1 86 | else: 87 | feat[BOND_DICT.get(obbond.GetBondOrder(), 4)] = 1 88 | edge_attr.append(feat) 89 | edge_index = (edge_index_row, edge_index_col) 90 | return edge_attr, edge_index 91 | -------------------------------------------------------------------------------- /developer/src/dataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch.utils.data import Dataset 7 | from torch_geometric.data import Batch, Data 8 | 9 | from pmnet.api import ProteinParser 10 | 11 | from .data import smi2graphdata 12 | 13 | 14 | class BaseDataset(Dataset): 15 | def __init__( 16 | self, 17 | code_list: list[str], 18 | protein_info: dict[str, tuple[float, float, float]], 19 | protein_dir: Path | str, 20 | ligand_path: Path | str, 21 | center_noise: float = 0.0, 22 | ): 23 | self.parser: ProteinParser = ProteinParser(center_noise) 24 | 25 | self.code_list: list[str] = code_list 26 | self.protein_info = protein_info 27 | self.protein_dir = Path(protein_dir) 28 | self.center_noise = center_noise 29 | with open(ligand_path, "rb") as f: 30 | self.ligand_data: dict[str, list[tuple[str, str, float]]] = pickle.load(f) 31 | 32 | def __len__(self): 33 | return len(self.code_list) 34 | 35 | def __getitem__(self, index: int) -> tuple[tuple[Tensor, Tensor, Tensor, Tensor], Batch]: 36 | code = self.code_list[index] 37 | protein_path: str = str(self.protein_dir / f"{code}.pdb") 38 | center: tuple[float, float, float] = self.protein_info[code] 39 | pharmacophore_info = self.parser(protein_path, center=center) 40 | ligands = self.ligand_data[code] 41 | ligand_graphs: Batch = Batch.from_data_list(list(map(self.get_ligand_data, ligands))) 42 | return pharmacophore_info, ligand_graphs 43 | 44 | @staticmethod 45 | def get_ligand_data(args: tuple[str, str, float]) -> Data: 46 | ligand_id, smiles, affinity = args 47 | data = smi2graphdata(smiles) 48 | x, edge_index, edge_attr = data["x"], data["edge_index"], data["edge_attr"] 49 | affinity = min(float(affinity), 0.0) 50 | return Data( 51 | x, 52 | edge_index, 53 | edge_attr, 54 | affinity=torch.FloatTensor([affinity]), 55 | ) 56 | -------------------------------------------------------------------------------- /developer/src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from omegaconf import DictConfig 4 | from torch import Tensor 5 | 6 | from pmnet.api import PharmacoNet, get_pmnet_dev 7 | 8 | from .config import Config 9 | from .data import NUM_ATOM_FEATURES, NUM_BOND_FEATURES 10 | from .network import AffinityHead, GraphEncoder, PharmacophoreEncoder 11 | 12 | Cache = tuple[Tensor, Tensor, Tensor] 13 | 14 | 15 | class AffinityModel(nn.Module): 16 | def __init__(self, config: Config | DictConfig): 17 | super().__init__() 18 | self.pmnet: PharmacoNet = get_pmnet_dev() 19 | self.global_cfg = config 20 | self.cfg = config.model 21 | self.pharmacophore_encoder: PharmacophoreEncoder = PharmacophoreEncoder(self.cfg.hidden_dim) 22 | self.ligand_encoder: GraphEncoder = GraphEncoder( 23 | NUM_ATOM_FEATURES, 24 | NUM_BOND_FEATURES, 25 | self.cfg.hidden_dim, 26 | self.cfg.hidden_dim, 27 | self.cfg.ligand_num_convs, 28 | ) 29 | self.head: AffinityHead = AffinityHead(self.cfg.hidden_dim) 30 | self.l2_loss: nn.MSELoss = nn.MSELoss() 31 | self.initialize_weights() 32 | 33 | def initialize_weights(self): 34 | self.pharmacophore_encoder.initialize_weights() 35 | self.ligand_encoder.initialize_weights() 36 | self.head.initialize_weights() 37 | 38 | # NOTE: Model training 39 | def forward_train(self, batch) -> Tensor: 40 | if self.pmnet.device != self.device: 41 | self.pmnet.to(self.device) 42 | 43 | loss_list = [] 44 | for pharmacophore_info, ligand_graphs in batch: 45 | # NOTE: Run PharmacoNet Feature Extraction 46 | # (Model is freezed; method `run_extraction` is decorated by torch.no_grad()) 47 | pmnet_attr = self.pmnet.run_extraction(pharmacophore_info) 48 | del pharmacophore_info 49 | 50 | # NOTE: Binding Affinity Prediction 51 | x_protein, pos_protein, Z_protein = self.pharmacophore_encoder.forward(pmnet_attr) 52 | x_ligand = self.ligand_encoder.forward(ligand_graphs.to(self.device)) 53 | affinity = self.head.forward(x_protein, x_ligand, ligand_graphs.batch, ligand_graphs.num_graphs) 54 | 55 | loss_list.append(self.l2_loss.forward(affinity, ligand_graphs.affinity)) 56 | loss = torch.stack(loss_list).mean() 57 | return loss 58 | 59 | @property 60 | def device(self) -> torch.device: 61 | return next(self.parameters()).device 62 | -------------------------------------------------------------------------------- /developer/src/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .head import AffinityHead 2 | from .ligand_encoder import GraphEncoder 3 | from .pharmacophore_encoder import PharmacophoreEncoder 4 | -------------------------------------------------------------------------------- /developer/src/network/head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | from torch_geometric.utils import to_dense_batch 4 | 5 | 6 | class AffinityHead(nn.Module): 7 | def __init__(self, hidden_dim: int, p_dropout: float = 0.1): 8 | super().__init__() 9 | self.interaction_mlp: nn.Module = nn.Sequential( 10 | nn.Linear(hidden_dim, hidden_dim), 11 | nn.LeakyReLU(), 12 | ) 13 | self.mlp_affinity: nn.Module = nn.Sequential( 14 | nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(), nn.Linear(hidden_dim, 1) 15 | ) 16 | self.dropout = nn.Dropout(p_dropout) 17 | 18 | def initialize_weights(self): 19 | def _init_weight(m): 20 | if isinstance(m, (nn.Linear)): 21 | nn.init.xavier_normal_(m.weight) 22 | if m.bias is not None: 23 | nn.init.constant_(m.bias, 0) 24 | 25 | self.apply(_init_weight) 26 | 27 | def forward( 28 | self, 29 | x_protein: Tensor, 30 | x_ligand: Tensor, 31 | ligand_batch: Tensor, 32 | num_ligands: int, 33 | ) -> Tensor: 34 | """ 35 | affinity predict header for (single protein - multi ligands) 36 | output: (N_ligand,) 37 | """ 38 | Z_complex = torch.einsum("ik,jk->ijk", x_ligand, x_protein) # [Vlig, Vprot, Fh] 39 | Z_complex, mask_complex = to_dense_batch(Z_complex, ligand_batch, batch_size=num_ligands) 40 | mask_complex = mask_complex.unsqueeze(-1) # [N, Vlig, 1] 41 | Z_complex = self.interaction_mlp(self.dropout(Z_complex)) 42 | pair_affinity = self.mlp_affinity(Z_complex).squeeze(-1) * mask_complex 43 | return pair_affinity.sum((1, 2)) 44 | -------------------------------------------------------------------------------- /developer/src/network/ligand_encoder.py: -------------------------------------------------------------------------------- 1 | import torch_geometric.nn as gnn 2 | from torch import Tensor, nn 3 | from torch_geometric.data import Batch, Data 4 | 5 | 6 | class GraphEncoder(nn.Module): 7 | def __init__( 8 | self, 9 | input_node_dim: int, 10 | input_edge_dim: int, 11 | hidden_dim: int, 12 | out_dim: int, 13 | num_convs: int, 14 | ): 15 | super().__init__() 16 | self.graph_channels: int = out_dim 17 | self.atom_channels: int = out_dim 18 | 19 | # Ligand Encoding 20 | self.node_layer = nn.Linear(input_node_dim, hidden_dim) 21 | self.edge_layer = nn.Linear(input_edge_dim, hidden_dim) 22 | self.conv_list = nn.ModuleList( 23 | [ 24 | gnn.GINEConv( 25 | nn=nn.Sequential(gnn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU()), 26 | edge_dim=hidden_dim, 27 | ) 28 | for _ in range(num_convs) 29 | ] 30 | ) 31 | 32 | self.head = nn.Sequential(nn.Linear(hidden_dim, out_dim), nn.LayerNorm(out_dim)) 33 | 34 | def initialize_weights(self): 35 | def _init_weight(m): 36 | if isinstance(m, (nn.Linear)): 37 | nn.init.xavier_normal_(m.weight) 38 | if m.bias is not None: 39 | nn.init.constant_(m.bias, 0) 40 | elif isinstance(m, nn.Embedding): 41 | m.weight.data.uniform_(-1, 1) 42 | 43 | self.apply(_init_weight) 44 | 45 | def forward(self, data: Data | Batch) -> Tensor: 46 | """Affinity Prediction 47 | 48 | Args: 49 | x: Node Feature 50 | edge_attr: Edge Feature 51 | edge_index: Edge Index 52 | 53 | Returns: 54 | updated_data: Union[Data, Batch] 55 | """ 56 | x: Tensor = self.node_layer(data.x) 57 | edge_attr: Tensor = self.edge_layer(data.edge_attr) 58 | 59 | skip_x = x 60 | edge_index = data.edge_index 61 | for layer in self.conv_list: 62 | x = layer(x, edge_index, edge_attr) 63 | x = skip_x + x 64 | return self.head(x) 65 | -------------------------------------------------------------------------------- /developer/src/network/pharmacophore_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | from pmnet.api.typing import HotspotInfo, MultiScaleFeature 5 | 6 | 7 | class PharmacophoreEncoder(nn.Module): 8 | def __init__(self, hidden_dim: int): 9 | super().__init__() 10 | self.multi_scale_dims = [96, 96, 96, 96, 96] 11 | self.hotspot_dim = 192 12 | self.hidden_dim = hidden_dim 13 | self.hotspot_mlp: nn.Module = nn.Sequential(nn.SiLU(), nn.Linear(self.hotspot_dim, hidden_dim)) 14 | self.pocket_mlp_list: nn.ModuleList = nn.ModuleList( 15 | [nn.Sequential(nn.SiLU(), nn.Conv3d(channels, hidden_dim, 3)) for channels in self.multi_scale_dims] 16 | ) 17 | self.pocket_layer: nn.Module = nn.Sequential( 18 | nn.SiLU(), 19 | nn.Linear(5 * hidden_dim, hidden_dim), 20 | nn.SiLU(), 21 | nn.Linear(hidden_dim, hidden_dim), 22 | ) 23 | 24 | def initialize_weights(self): 25 | def _init_weight(m): 26 | if isinstance(m, nn.Linear | nn.Conv3d): 27 | nn.init.normal_(m.weight, std=0.01) 28 | if m.bias is not None: 29 | nn.init.constant_(m.bias, 0) 30 | 31 | self.apply(_init_weight) 32 | 33 | def forward(self, pmnet_attr: tuple[MultiScaleFeature, list[HotspotInfo]]) -> tuple[Tensor, Tensor, Tensor]: 34 | """ 35 | Out: 36 | - hotspot_features: FloatTensor (V, Fh) 37 | - hotspot_positions: FloatTensor (V, 3) (* Real value.) 38 | - pocket_features: FloatTensor (Fh,) 39 | """ 40 | 41 | multi_scale_features, hotspot_infos = pmnet_attr 42 | dev = multi_scale_features[0].device 43 | 44 | # NOTE: Node features 45 | if len(hotspot_infos) > 0: 46 | hotspot_positions = torch.tensor([info["hotspot_position"] for info in hotspot_infos], device=dev) 47 | hotspot_features = torch.stack([info["hotspot_feature"] for info in hotspot_infos]) 48 | hotspot_features = self.hotspot_mlp(hotspot_features) 49 | else: 50 | hotspot_positions = torch.zeros((0, 3), device=dev) 51 | hotspot_features = torch.zeros((0, self.hidden_dim), device=dev) 52 | 53 | # NOTE: Global features 54 | pocket_features: Tensor = torch.cat( 55 | [ 56 | mlp(feat.squeeze(0)).mean((-1, -2, -3)) 57 | for mlp, feat in zip(self.pocket_mlp_list, multi_scale_features, strict=True) 58 | ], 59 | dim=-1, 60 | ) 61 | pocket_features = self.pocket_layer(pocket_features) 62 | 63 | return hotspot_features, hotspot_positions, pocket_features 64 | -------------------------------------------------------------------------------- /developer/src/trainer.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | import random 4 | import sys 5 | import time 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | import torch.multiprocessing 11 | import torch.utils.tensorboard 12 | import wandb 13 | from omegaconf import OmegaConf 14 | from torch.utils.data import DataLoader 15 | 16 | from pmnet.api import PharmacoNet 17 | 18 | from .config import Config 19 | from .dataset import BaseDataset 20 | from .model import AffinityModel 21 | 22 | torch.multiprocessing.set_sharing_strategy("file_system") 23 | 24 | 25 | class Trainer: 26 | def __init__(self, config: Config, device: str = "cuda"): 27 | self.config = config 28 | self.device = device 29 | self.log_dir = Path(config.log_dir) 30 | self.log_dir.mkdir(parents=True) 31 | self.save_dir = self.log_dir / "save" 32 | self.save_dir.mkdir(parents=True) 33 | 34 | self.dictconfig = OmegaConf.create(config.to_dict()) 35 | OmegaConf.save(self.dictconfig, self.log_dir / "config.yaml") 36 | self.logger = create_logger(logfile=self.log_dir / "train.log") 37 | if wandb.run is None: 38 | self._summary_writer = torch.utils.tensorboard.SummaryWriter(self.log_dir) 39 | 40 | self.model = AffinityModel(config) 41 | self.model.to(device) 42 | self.pmnet: PharmacoNet = self.model.pmnet 43 | self.setup_data() 44 | self.setup_train() 45 | 46 | def fit(self): 47 | it = 1 48 | epoch = 0 49 | best_loss = float("inf") 50 | self.model.train() 51 | while it <= self.config.train.max_iterations: 52 | for batch in self.train_dataloader: 53 | if it > self.config.train.max_iterations: 54 | break 55 | if it % 1024 == 0: 56 | gc.collect() 57 | torch.cuda.empty_cache() 58 | 59 | tick = time.time() 60 | info = self.train_batch(batch) 61 | info["time"] = time.time() - tick 62 | 63 | if it % self.config.train.print_every == 0: 64 | self.logger.info( 65 | f"epoch {epoch} iteration {it} train : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items()) 66 | ) 67 | if it % self.config.train.log_every == 0: 68 | self.log(info, it, epoch, "train") 69 | if it % self.config.train.save_every == 0: 70 | self.save_checkpoint(f"epoch-{epoch}-it-{it}.pth") 71 | if it % self.config.train.val_every == 0: 72 | tick = time.time() 73 | info = self.evaluate() 74 | info["time"] = time.time() - tick 75 | self.logger.info( 76 | f"epoch {epoch} iteration {it} valid : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items()) 77 | ) 78 | self.log(info, it, epoch, "valid") 79 | if info["loss"] < best_loss: 80 | torch.save(self.model.state_dict(), self.save_dir / "best.pth") 81 | best_loss = info["loss"] 82 | it += 1 83 | epoch += 1 84 | torch.save(self.model.state_dict(), self.save_dir / "last.pth") 85 | 86 | def log(self, info, index, epoch, key): 87 | info.update({"step": index, "epoch": epoch}) 88 | if wandb.run is not None: 89 | wandb.log({f"{key}/{k}": v for k, v in info.items()}, step=index) 90 | else: 91 | for k, v in info.items(): 92 | self._summary_writer.add_scalar(f"{key}/{k}", v, index) 93 | 94 | def train_batch(self, batch) -> dict[str, float]: 95 | loss = self.model.forward_train(batch) 96 | loss.backward() 97 | torch.nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), self.config.train.opt.clip_grad) 98 | self.optimizer.step() 99 | self.optimizer.zero_grad() 100 | self.lr_scheduler.step() 101 | return {"loss": loss.item()} 102 | 103 | @torch.no_grad() 104 | def evaluate(self) -> dict[str, float]: 105 | self.model.eval() 106 | logs = {"loss": []} 107 | for batch in self.val_dataloader: 108 | loss = self.model.forward_train(batch) 109 | logs["loss"].append(loss.item()) 110 | self.model.train() 111 | return {k: float(np.mean(v)) for k, v in logs.items()} 112 | 113 | def setup_data(self): 114 | config = self.config 115 | protein_info = {} 116 | with open(config.data.protein_info_path) as f: 117 | lines = f.readlines() 118 | for line in lines: 119 | code, x, y, z = line.strip().split(",") 120 | protein_info[code] = (float(x), float(y), float(z)) 121 | 122 | with open(config.data.train_protein_code_path) as f: 123 | codes = [ln.strip() for ln in f.readlines()] 124 | random.seed(0) 125 | random.shuffle(codes) 126 | split_offset = int(len(codes) * config.train.split_ratio) 127 | train_codes = codes[:split_offset] 128 | val_codes = codes[split_offset:] 129 | 130 | self.train_dataset = BaseDataset( 131 | train_codes, 132 | protein_info, 133 | config.data.protein_dir, 134 | config.data.ligand_path, 135 | config.train.center_noise, 136 | ) 137 | 138 | self.val_dataset = BaseDataset( 139 | val_codes, 140 | protein_info, 141 | config.data.protein_dir, 142 | config.data.ligand_path, 143 | ) 144 | 145 | self.train_dataloader: DataLoader = DataLoader( 146 | self.train_dataset, 147 | batch_size=config.train.batch_size, 148 | shuffle=True, 149 | num_workers=config.train.num_workers, 150 | drop_last=True, 151 | collate_fn=collate_fn, 152 | ) 153 | 154 | self.val_dataloader: DataLoader = DataLoader( 155 | self.val_dataset, 156 | batch_size=config.train.batch_size, 157 | shuffle=False, 158 | num_workers=config.train.num_workers, 159 | collate_fn=collate_fn, 160 | ) 161 | 162 | self.logger.info(f"train set: {len(self.train_dataset)}") 163 | self.logger.info(f"valid set: {len(self.val_dataset)}") 164 | 165 | def setup_train(self): 166 | self.optimizer = torch.optim.Adam( 167 | self.model.parameters(), 168 | self.config.train.opt.lr, 169 | eps=self.config.train.opt.eps, 170 | ) 171 | 172 | self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 173 | self.optimizer, 174 | lr_lambda=lambda steps: 2 ** (-steps / self.config.train.lr_scheduler.lr_decay), 175 | ) 176 | 177 | def save_checkpoint(self, filename: str): 178 | ckpt = { 179 | "model_state_dict": self.model.state_dict(), 180 | "config": self.dictconfig, 181 | } 182 | torch.save(ckpt, self.save_dir / filename) 183 | 184 | 185 | def collate_fn(batch): 186 | return batch 187 | 188 | 189 | def create_logger(name="logger", loglevel=logging.INFO, logfile=None, streamHandle=True): 190 | logger = logging.getLogger(name) 191 | logger.setLevel(loglevel) 192 | formatter = logging.Formatter( 193 | fmt="%(asctime)s - %(message)s", 194 | datefmt="%d/%m/%Y %H:%M:%S", 195 | ) 196 | handlers = [] 197 | if logfile is not None: 198 | handlers.append(logging.FileHandler(logfile, mode="a")) 199 | if streamHandle: 200 | handlers.append(logging.StreamHandler(stream=sys.stdout)) 201 | 202 | for handler in logger.handlers[:]: 203 | logging.root.removeHandler(handler) 204 | 205 | for handler in handlers: 206 | handler.setFormatter(formatter) 207 | logger.addHandler(handler) 208 | 209 | return logger 210 | -------------------------------------------------------------------------------- /developer/train_example.py: -------------------------------------------------------------------------------- 1 | from src.config import Config 2 | from src.trainer import Trainer 3 | 4 | config = Config() 5 | config.data.protein_dir = "./dataset/protein/" 6 | config.data.train_protein_code_path = "./dataset/train_key.txt" 7 | config.data.ligand_path = "./dataset/ligand.pkl" 8 | config.train.max_iterations = 100 9 | config.train.batch_size = 16 10 | config.train.num_workers = 4 11 | config.train.log_every = 1 12 | config.train.print_every = 1 13 | config.train.val_every = 10 14 | config.log_dir = "./result/debug" 15 | trainer = Trainer(config, device="cuda") 16 | trainer.fit() 17 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pmnet 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - python=3.11 6 | - pip>=24.0 7 | - pymol-open-source 8 | -------------------------------------------------------------------------------- /examples/6OIM_D_MOV.pdb: -------------------------------------------------------------------------------- 1 | CRYST1 40.868 58.417 65.884 90.00 90.00 90.00 P 21 21 21 1 2 | HETATM 1 C1 MOV A 303 1.642 -7.717 -1.656 1.00 25.39 C 3 | HETATM 2 N1 MOV A 303 3.595 -9.086 -1.744 1.00 25.47 N 4 | HETATM 3 O1 MOV A 303 -0.606 -11.017 -1.525 1.00 29.03 O 5 | HETATM 4 C2 MOV A 303 2.258 -8.966 -1.680 1.00 25.81 C 6 | HETATM 5 N2 MOV A 303 -0.516 -6.457 -1.522 1.00 25.26 N 7 | HETATM 6 O2 MOV A 303 -2.563 -2.305 0.154 1.00 25.22 O 8 | HETATM 7 C3 MOV A 303 4.408 -8.024 -1.795 1.00 26.04 C 9 | HETATM 8 N3 MOV A 303 -0.485 -8.794 -1.541 1.00 26.50 N 10 | HETATM 9 O3 MOV A 303 5.556 -9.536 -3.856 1.00 29.54 O 11 | HETATM 10 C4 MOV A 303 3.854 -6.742 -1.782 1.00 25.72 C 12 | HETATM 11 N4 MOV A 303 1.468 -10.112 -1.635 1.00 26.41 N 13 | HETATM 12 C5 MOV A 303 2.467 -6.591 -1.708 1.00 25.83 C 14 | HETATM 13 N5 MOV A 303 2.882 -13.267 -0.492 1.00 28.14 N 15 | HETATM 14 C6 MOV A 303 5.789 -8.217 -1.865 1.00 26.76 C 16 | HETATM 15 N6 MOV A 303 -2.128 -4.145 -1.011 1.00 25.51 N 17 | HETATM 16 C7 MOV A 303 0.232 -7.655 -1.586 1.00 25.83 C 18 | HETATM 17 C8 MOV A 303 0.088 -10.003 -1.564 1.00 27.12 C 19 | HETATM 18 C9 MOV A 303 2.047 -11.378 -1.648 1.00 27.82 C 20 | HETATM 19 C10 MOV A 303 2.320 -12.047 -0.454 1.00 27.79 C 21 | HETATM 20 C11 MOV A 303 3.178 -13.859 -1.638 1.00 28.92 C 22 | HETATM 21 C12 MOV A 303 2.924 -13.240 -2.857 1.00 28.92 C 23 | HETATM 22 C13 MOV A 303 2.347 -11.976 -2.860 1.00 28.87 C 24 | HETATM 23 C14 MOV A 303 2.006 -11.420 0.913 1.00 27.52 C 25 | HETATM 24 C15 MOV A 303 3.303 -11.287 1.711 1.00 27.35 C 26 | HETATM 25 C16 MOV A 303 0.986 -12.270 1.683 1.00 27.17 C 27 | HETATM 26 C17 MOV A 303 0.121 -5.160 -1.149 1.00 25.83 C 28 | HETATM 27 C18 MOV A 303 -0.822 -4.289 -0.337 1.00 25.87 C 29 | HETATM 28 C19 MOV A 303 -2.541 -5.147 -2.020 1.00 25.62 C 30 | HETATM 29 C20 MOV A 303 -1.985 -6.526 -1.667 1.00 25.42 C 31 | HETATM 30 C21 MOV A 303 -2.610 -7.064 -0.371 1.00 25.93 C 32 | HETATM 31 C22 MOV A 303 2.037 -11.248 -4.165 1.00 29.50 C 33 | HETATM 32 C23 MOV A 303 -2.919 -3.127 -0.691 1.00 25.94 C 34 | HETATM 33 C24 MOV A 303 -4.277 -2.959 -1.369 1.00 25.39 C 35 | HETATM 34 C25 MOV A 303 -5.364 -2.168 -0.643 1.00 25.97 C 36 | HETATM 35 C26 MOV A 303 6.343 -8.975 -2.901 1.00 27.88 C 37 | HETATM 36 C27 MOV A 303 7.719 -9.162 -2.966 1.00 27.85 C 38 | HETATM 37 C28 MOV A 303 8.550 -8.602 -2.005 1.00 27.87 C 39 | HETATM 38 C29 MOV A 303 8.011 -7.850 -0.968 1.00 27.43 C 40 | HETATM 39 C30 MOV A 303 6.635 -7.660 -0.895 1.00 26.79 C 41 | HETATM 40 F1 MOV A 303 4.658 -5.663 -1.836 1.00 26.54 F 42 | HETATM 41 F2 MOV A 303 6.144 -6.931 0.129 1.00 26.28 F 43 | CONECT 1 12 4 16 44 | CONECT 2 7 4 45 | CONECT 3 17 46 | CONECT 4 1 2 11 47 | CONECT 5 16 26 29 48 | CONECT 6 32 49 | CONECT 7 2 14 10 50 | CONECT 8 17 16 51 | CONECT 9 35 52 | CONECT 10 7 40 12 53 | CONECT 11 4 18 17 54 | CONECT 12 1 10 55 | CONECT 13 19 20 56 | CONECT 14 7 35 39 57 | CONECT 15 27 32 28 58 | CONECT 16 5 1 8 59 | CONECT 17 8 3 11 60 | CONECT 18 11 19 22 61 | CONECT 19 13 18 23 62 | CONECT 20 13 21 63 | CONECT 21 20 22 64 | CONECT 22 18 21 31 65 | CONECT 23 19 24 25 66 | CONECT 24 23 67 | CONECT 25 23 68 | CONECT 26 5 27 69 | CONECT 27 15 26 70 | CONECT 28 15 29 71 | CONECT 29 5 28 30 72 | CONECT 30 29 73 | CONECT 31 22 74 | CONECT 32 15 6 33 75 | CONECT 33 32 34 76 | CONECT 34 33 77 | CONECT 35 14 9 36 78 | CONECT 36 35 37 79 | CONECT 37 36 38 80 | CONECT 38 37 39 81 | CONECT 39 14 38 41 82 | CONECT 40 10 83 | CONECT 41 39 84 | END 85 | -------------------------------------------------------------------------------- /feature_extraction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | from pmnet.api import get_pmnet_dev 6 | 7 | 8 | class ArgParser(argparse.ArgumentParser): 9 | def __init__(self): 10 | super().__init__("PharmacoNet Feature Extraction Script") 11 | self.formatter_class = argparse.ArgumentDefaultsHelpFormatter 12 | self.add_argument( 13 | "-p", 14 | "--protein", 15 | type=str, 16 | help="custom path of protein pdb file (.pdb)", 17 | required=True, 18 | ) 19 | self.add_argument( 20 | "-o", 21 | "--out", 22 | type=str, 23 | help="save path of features (torch object)", 24 | required=True, 25 | ) 26 | self.add_argument( 27 | "--ref_ligand", 28 | type=str, 29 | help="path of ligand to define the center of box (.sdf, .pdb, .mol2)", 30 | ) 31 | self.add_argument("--center", nargs="+", type=float, help="coordinate of the center") 32 | self.add_argument("--cuda", action="store_true", help="use gpu acceleration with CUDA") 33 | 34 | 35 | def main(args): 36 | """ 37 | return tuple[multi_scale_features, hotspot_info] 38 | multi_scale_features: list[torch.Tensor]: 39 | - [96, 4, 4, 4], [96, 8, 8, 8], [96, 16, 16, 16], [96, 32, 32, 32], [96, 64, 64, 64] 40 | hotspot_info 41 | - hotspot_feature: torch.Tensor (192,) 42 | - hotspot_position: tuple[float, float, float] - (x, y, z) 43 | - hotspot_score: float in [0, 1] 44 | 45 | - nci_type: str (10 types) 46 | 'Hydrophobic': Hydrophobic interaction 47 | 'PiStacking_P': PiStacking (Parallel) 48 | 'PiStacking_T': PiStacking (T-shaped) 49 | 'PiCation_lring': Interaction btw Protein Cation & Ligand Aromatic Ring 50 | 'PiCation_pring': Interaction btw Protein Aromatic Ring & Ligand Cation 51 | 'SaltBridge_pneg': SaltBridge btw Protein Anion & Ligand Cation 52 | 'SaltBridge_lneg': SaltBridge btw Protein Cation & Ligand Anion 53 | 'XBond': Halogen Bond 54 | 'HBond_pdon': Hydrogen Bond btw Protein Donor & Ligand Acceptor 55 | 'HBond_ldon': Hydrogen Bond btw Protein Acceptor & Ligand Donor 56 | 57 | - hotspot_type: str (7 types) 58 | {'Hydrophobic', 'Aromatic', 'Cation', 'Anion', 59 | 'Halogen', 'HBond_donor', 'HBond_acceptor'} 60 | *** `type` is obtained from `nci_type`. 61 | - point_type: str (7 types) 62 | {'Hydrophobic', 'Aromatic', 'Cation', 'Anion', 63 | 'Halogen', 'HBond_donor', 'HBond_acceptor'} 64 | *** `type` is obtained from `nci_type`. 65 | ] 66 | """ 67 | device = "cuda" if args.cuda else "cpu" 68 | module = get_pmnet_dev(device) 69 | multi_scale_features, hotspot_infos = module.feature_extraction(args.protein, args.ref_ligand, args.center) 70 | torch.save([multi_scale_features, hotspot_infos], args.out) 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = ArgParser() 75 | args = parser.parse_args() 76 | main(args) 77 | -------------------------------------------------------------------------------- /images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeonghwanSeo/PharmacoNet/6595694cdf910b52c9fa0512c35f8139f5be2cf5/images/overview.png -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from pathlib import Path 5 | 6 | import pmnet 7 | from pmnet import PharmacophoreModel 8 | from pmnet.module import PharmacoNet 9 | from utils import visualize 10 | from utils.parse_rcsb_pdb import download_pdb, parse_pdb 11 | 12 | SUCCESS = 0 13 | EXIT = 1 14 | FAIL = 2 15 | 16 | 17 | class Modeling_ArgParser(argparse.ArgumentParser): 18 | def __init__(self): 19 | super().__init__("pharmacophore modeling script") 20 | self.formatter_class = argparse.ArgumentDefaultsHelpFormatter 21 | 22 | # config 23 | cfg_args = self.add_argument_group("config") 24 | cfg_args.add_argument("--pdb", type=str, help="RCSB PDB code") 25 | cfg_args.add_argument("-l", "--ligand_id", type=str, help="RCSB ligand code") 26 | cfg_args.add_argument("-p", "--protein", type=str, help="custom path of protein pdb file (.pdb)") 27 | cfg_args.add_argument("-c", "--chain", type=str, help="Chain") 28 | cfg_args.add_argument("-a", "--all", action="store_true", help="use all binding sites") 29 | cfg_args.add_argument( 30 | "--out_dir", 31 | type=str, 32 | help="custom directorh path. default: `./result/{PDBID | prefix}`", 33 | ) 34 | cfg_args.add_argument("--prefix", type=str, help="task name. default: {PDBID}") 35 | cfg_args.add_argument( 36 | "--suffix", 37 | choices=("pm", "json"), 38 | type=str, 39 | help="extension of pharmacophore model (pm (default) | json)", 40 | default="pm", 41 | ) 42 | 43 | # system config 44 | env_args = self.add_argument_group("environment") 45 | env_args.add_argument("--weight_path", type=str, help="(Optional) custom pharmaconet weight path") 46 | env_args.add_argument("--cuda", action="store_true", help="use gpu acceleration with CUDA") 47 | env_args.add_argument("--force", action="store_true", help="force to save the pharmacophore model") 48 | env_args.add_argument("-v", "--verbose", action="store_true", help="verbose") 49 | 50 | # config 51 | adv_args = self.add_argument_group("Advanced Setting") 52 | adv_args.add_argument( 53 | "--ref_ligand", 54 | type=str, 55 | help="path of ligand to define the center of box (.sdf, .pdb, .mol2)", 56 | ) 57 | adv_args.add_argument("--center", nargs="+", type=float, help="coordinate of the center") 58 | 59 | 60 | def main(args): 61 | logging.info(pmnet.__description__) 62 | assert args.prefix is not None or args.pdb is not None, "MISSING PREFIX: `--prefix` or `--pdb`" 63 | PREFIX = args.prefix if args.prefix else args.pdb 64 | 65 | # NOTE: Setting 66 | if args.out_dir is None: 67 | SAVE_DIR = Path("./result") / PREFIX 68 | else: 69 | SAVE_DIR = Path(args.out_dir) 70 | SAVE_DIR.mkdir(exist_ok=True, parents=True) 71 | 72 | # NOTE: Load PharmacoNet 73 | module = PharmacoNet("cuda" if args.cuda else "cpu", weight_path=args.weight_path) 74 | logging.info("Load PharmacoNet finish") 75 | 76 | # NOTE: Set Protein 77 | protein_path: str 78 | if isinstance(args.pdb, str): 79 | protein_path = str(SAVE_DIR / f"{PREFIX}.pdb") 80 | if not os.path.exists(protein_path): 81 | logging.info(f"Download {args.pdb} to {protein_path}") 82 | download_pdb(args.pdb, protein_path) 83 | else: 84 | logging.info(f"Load {protein_path}") 85 | elif isinstance(args.protein, str): 86 | protein_path = args.protein 87 | assert os.path.exists(protein_path) 88 | logging.info(f"Load {protein_path}") 89 | else: 90 | raise Exception("Missing protein: `--pdb` or `--protein`") 91 | 92 | # NOTE: Functions 93 | def run_pmnet(filename, ligand_path=None, center=None) -> PharmacophoreModel: 94 | model_path = SAVE_DIR / f"{filename}.{args.suffix}" 95 | pymol_path = SAVE_DIR / f"{filename}_pymol.pse" 96 | if (not args.force) and os.path.exists(model_path): 97 | logging.warning(f"Modeling Pass - {model_path} exists") 98 | pharmacophore_model = PharmacophoreModel.load(str(model_path)) 99 | else: 100 | pharmacophore_model = module.run(protein_path, ref_ligand_path=ligand_path, center=center) 101 | pharmacophore_model.save(str(model_path)) 102 | logging.info(f"Save Pharmacophore Model to {model_path}") 103 | if (not args.force) and os.path.exists(pymol_path): 104 | logging.warning(f"Visualizing Pass - {pymol_path} exists\n") 105 | else: 106 | visualize.visualize_single(pharmacophore_model, protein_path, ligand_path, PREFIX, str(pymol_path)) 107 | logging.info(f"Save Pymol Visualization Session to {pymol_path}\n") 108 | return pharmacophore_model 109 | 110 | def run_pmnet_ref_ligand(ligand_path) -> PharmacophoreModel: 111 | logging.info(f"Using center of {ligand_path} as center of box") 112 | return run_pmnet(f"{PREFIX}_{Path(ligand_path).stem}_model", ligand_path) 113 | 114 | def run_pmnet_center(center) -> PharmacophoreModel: 115 | x, y, z = center 116 | logging.info(f"Using center {(x, y, z)}") 117 | return run_pmnet(f"{PREFIX}_{x}_{y}_{z}_model", center=(x, y, z)) 118 | 119 | def run_pmnet_inform(inform) -> PharmacophoreModel: 120 | logging.info(f"Running {inform.order}th Ligand...\n{str(inform)}") 121 | return run_pmnet( 122 | f"{PREFIX}_{inform.pdbchain}_{inform.id}_model", 123 | inform.file_path, 124 | inform.center, 125 | ) 126 | 127 | def run_pmnet_manual_center(): 128 | logging.info("Enter the center of binding site manually:") 129 | x = float(input("x: ")) 130 | y = float(input("y: ")) 131 | z = float(input("z: ")) 132 | return run_pmnet_center((x, y, z)) 133 | 134 | ############ 135 | # NOTE: Run!! 136 | 137 | # NOTE: Case 1 With Custom Autobox Ligand Center 138 | if args.ref_ligand is not None: 139 | assert os.path.exists( 140 | args.ref_ligand 141 | ), f"Wrong Path!. The arguments for reference ligand does not exist ({args.ref_ligand})" 142 | run_pmnet_ref_ligand(args.ref_ligand) 143 | return SUCCESS 144 | 145 | # NOTE: Case 2: With Custom Center 146 | if args.center is not None: 147 | assert ( 148 | len(args.center) == 3 149 | ), "Wrong Center!. The arguments for center coordinates should be 3. (ex. --center 1.00 2.00 -1.50)" 150 | run_pmnet_center(args.center) 151 | return SUCCESS 152 | 153 | # NOTE: Case 3: With Detected Ligand(s) Center 154 | # NOTE: Ligand Detection 155 | inform_list = parse_pdb(PREFIX, protein_path, SAVE_DIR) 156 | 157 | # NOTE: Case 3-1: No detected Ligand 158 | if len(inform_list) == 0: 159 | logging.warning("No ligand is detected!") 160 | run_pmnet_manual_center() 161 | return SUCCESS 162 | 163 | # NOTE: Case 3-2: with `all` option 164 | if args.all: 165 | logging.info("Use All Binding Site (-a | --all)") 166 | model_dict = {} 167 | for inform in inform_list: 168 | model_dict[f"{PREFIX}_{inform.pdbchain}_{inform.id}"] = ( 169 | run_pmnet_inform(inform), 170 | inform.file_path, 171 | ) 172 | pymol_path = SAVE_DIR / f"{PREFIX}.pse" 173 | logging.info("Visualize all pharmacophore models...") 174 | if (not args.force) and os.path.exists(pymol_path): 175 | logging.warning(f"Visualizing Pass - {pymol_path} exists\n") 176 | else: 177 | visualize.visualize_multiple(model_dict, protein_path, PREFIX, str(pymol_path)) 178 | logging.info(f"Save Pymol Visualization Session to {pymol_path}\n") 179 | return 180 | 181 | inform_list_text = "\n\n".join(str(inform) for inform in inform_list) 182 | logging.info(f"A total of {len(inform_list)} ligand(s) are detected!\n{inform_list_text}\n") 183 | 184 | # NOTE: Case 3-3: pattern matching 185 | if args.ligand_id is not None or args.chain is not None: 186 | logging.info(f"Filtering with matching pattern - ligand id: {args.ligand_id}, chain: {args.chain}") 187 | filtered_inform_list = [] 188 | for inform in inform_list: 189 | if args.ligand_id is not None and args.ligand_id.upper() != inform.id: 190 | continue 191 | if args.chain is not None and args.chain.upper() not in [ 192 | inform.pdbchain, 193 | inform.authchain, 194 | ]: 195 | continue 196 | filtered_inform_list.append(inform) 197 | inform_list = filtered_inform_list 198 | del filtered_inform_list 199 | 200 | if len(inform_list) == 0: 201 | logging.warning("No matching pattern!") 202 | return FAIL 203 | if len(inform_list) > 1: 204 | inform_list_text = "\n\n".join(str(inform) for inform in inform_list) 205 | logging.info(f"A total of {len(inform_list)} ligands are selected!\n{inform_list_text}\n") 206 | 207 | if len(inform_list) == 1: 208 | run_pmnet_inform(inform_list[0]) 209 | return SUCCESS 210 | 211 | logging.info( 212 | f"Select the ligand number(s) (ex. {inform_list[-1].order} ; {inform_list[0].order},{inform_list[-1].order} ; manual ; all ; exit)" 213 | ) 214 | inform_dic = {str(inform.order): inform for inform in inform_list} 215 | answer = ask_prompt(inform_dic) 216 | if answer == "exit": 217 | return EXIT 218 | if answer == "manual": 219 | run_pmnet_manual_center() 220 | return SUCCESS 221 | if answer == "all": 222 | filtered_inform_list = inform_list 223 | else: 224 | number_list = answer.split(",") 225 | filtered_inform_list = [] 226 | for number in number_list: 227 | filtered_inform_list.append(inform_dic[number.strip()]) 228 | for inform in filtered_inform_list: 229 | run_pmnet_inform(inform) 230 | return SUCCESS 231 | 232 | 233 | def ask_prompt(number_dic): 234 | flag = False 235 | while not flag: 236 | answer = input("ligand number: ") 237 | if answer in ["all", "exit", "manual"]: 238 | break 239 | number_list = answer.split(",") 240 | for number in number_list: 241 | if number.strip() not in number_dic: 242 | flag = False 243 | logging.warning(f"Invalid number: {number}") 244 | break 245 | else: 246 | flag = True 247 | return answer 248 | 249 | 250 | if __name__ == "__main__": 251 | parser = Modeling_ArgParser() 252 | args = parser.parse_args() 253 | if args.verbose: 254 | logging.basicConfig(level=logging.DEBUG) 255 | else: 256 | logging.basicConfig(level=logging.INFO) 257 | main(args) 258 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "pharmaconet" 7 | version = "2.1.2" 8 | description = "PharmacoNet: Open-Source Software for Protein-based Pharmacophore Modeling and Virtual Screening" 9 | authors = [{ name = "Seonghwan Seo", email = "shwan0106@kaist.ac.kr" }] 10 | license = { text = "MIT" } 11 | requires-python = ">=3.10,<3.13" 12 | classifiers = [ 13 | "Intended Audience :: Developers", 14 | "Intended Audience :: Science/Research", 15 | "Development Status :: 4 - Beta", 16 | "Operating System :: OS Independent", 17 | "License :: OSI Approved :: MIT License", 18 | "Topic :: Scientific/Engineering", 19 | "Topic :: Scientific/Engineering :: Bio-Informatics", 20 | "Topic :: Scientific/Engineering :: Chemistry", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Programming Language :: Python :: 3.12", 24 | ] 25 | dependencies = [ 26 | "torch>=1.13.0", 27 | "numpy", 28 | "numba>=0.59", 29 | "omegaconf>=2.3.0", 30 | "molvoxel==0.2.0", 31 | "gdown>=5.1.0", 32 | "tqdm", 33 | # chem&bioinfo 34 | "openbabel-wheel>=3.1.1.20", 35 | "biopython>=1.83", 36 | ] 37 | [project.optional-dependencies] 38 | appl = [ 39 | "torch>=2.3.1", 40 | "torch-geometric>=2.4.0", 41 | "torch-scatter>=2.1.2", 42 | "torch-sparse>=0.6.18", 43 | "torch-cluster>=1.6.3", 44 | ] 45 | dev = [ 46 | "torch>=2.3.1", 47 | "torch-geometric>=2.4.0", 48 | "torch-scatter>=2.1.2", 49 | "torch-sparse>=0.6.18", 50 | "torch-cluster>=1.6.3", 51 | "wandb", 52 | "tensorboard", 53 | ] 54 | [project.urls] 55 | Repository = "https://github.com/SeonghwanSeo/PharmacoNet" 56 | 57 | # CODING 58 | [tool.black] 59 | line-length = 120 60 | target-version = ["py310"] 61 | 62 | [tool.ruff] 63 | target-version = "py310" 64 | line-length = 120 65 | [tool.ruff.lint] 66 | select = ["E", "F", "B", "UP", "I", "T203"] 67 | ignore = ["E501", "E741"] 68 | [tool.ruff.lint.per-file-ignores] 69 | "__init__.py" = [ 70 | "F401", # imported but unused 71 | "E402", # Module level import not at top of file 72 | ] 73 | 74 | [tool.basedpyright] 75 | pythonVersion = "3.10" 76 | typeCheckingMode = "standard" 77 | include = ["src/"] 78 | -------------------------------------------------------------------------------- /screening.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | from functools import partial 4 | from pathlib import Path 5 | 6 | from pmnet import PharmacophoreModel 7 | 8 | 9 | class Screening_ArgParser(argparse.ArgumentParser): 10 | def __init__(self): 11 | super().__init__("scoring") 12 | self.formatter_class = argparse.ArgumentDefaultsHelpFormatter 13 | cfg_args = self.add_argument_group("config") 14 | cfg_args.add_argument( 15 | "-p", 16 | "--pharmacophore_model", 17 | type=str, 18 | help="path of pharmacophore model (.pm | .json)", 19 | required=True, 20 | ) 21 | cfg_args.add_argument( 22 | "-d", 23 | "--library_dir", 24 | type=str, 25 | help="molecular library directory path", 26 | required=True, 27 | ) 28 | cfg_args.add_argument("-o", "--out", type=str, help="result file path", required=True) 29 | cfg_args.add_argument("--cpus", type=int, help="number of cpus", default=1) 30 | 31 | param_args = self.add_argument_group("parameter") 32 | param_args.add_argument( 33 | "--hydrophobic", 34 | type=float, 35 | help="weight for hydrophobic carbon", 36 | default=1.0, 37 | ) 38 | param_args.add_argument("--aromatic", type=float, help="weight for aromatic ring", default=4.0) 39 | param_args.add_argument("--hba", type=float, help="weight for hbond acceptor", default=4.0) 40 | param_args.add_argument("--hbd", type=float, help="weight for hbond donor", default=4.0) 41 | param_args.add_argument("--halogen", type=float, help="weight for halogen atom", default=4.0) 42 | param_args.add_argument("--anion", type=float, help="weight for anion", default=8.0) 43 | param_args.add_argument("--cation", type=float, help="weight for cation", default=8.0) 44 | 45 | 46 | def func(file, model, weight): 47 | return file, model.scoring_file(file, weight) 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = Screening_ArgParser() 52 | args = parser.parse_args() 53 | model = PharmacophoreModel.load(args.pharmacophore_model) 54 | weight = dict( 55 | Cation=args.cation, 56 | Anion=args.anion, 57 | Aromatic=args.aromatic, 58 | HBond_donor=args.hbd, 59 | HBond_acceptor=args.hba, 60 | Halogen=args.halogen, 61 | Hydrophobic=args.hydrophobic, 62 | ) 63 | library_path = Path(args.library_dir) 64 | file_list = list(library_path.rglob("*.sdf")) + list(library_path.rglob("*.mol2")) 65 | print(f"find {len(file_list)} molecules") 66 | f = partial(func, model=model, weight=weight) 67 | with multiprocessing.Pool(args.cpus) as pool: 68 | result = pool.map(f, file_list) 69 | 70 | result.sort(key=lambda x: x[1], reverse=True) 71 | 72 | with open(args.out, "w") as w: 73 | w.write("path,score\n") 74 | for filename, score in result: 75 | w.write(f"{filename},{score}\n") 76 | -------------------------------------------------------------------------------- /src/pmnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .pharmacophore_model import PharmacophoreModel 2 | 3 | __version__ = "2.1.2" 4 | __citation_information__ = "Seo, S., & Kim, W. Y. (2024). PharmacoNet: deep learning-guided pharmacophore modeling for ultra-large-scale virtual screening. Chemical Science, 15(46), 19473-19487." 5 | __maintainer__ = "https://github.com/SeonghwanSeo/PharmacoNet" 6 | 7 | __description__ = ( 8 | f"PharmacoNet v{__version__} - Open-source Protein-based Pharmacophore Modeling Tool\n" 9 | f"PharmacoNet is a deep learning based tool for Protein-based Pharmacophore Modeling.\n" 10 | f"If you use PharmacoNet in your work, please cite: '{__citation_information__}'\n" 11 | f"Supported and maintained by: Seonghwan Seo ({__maintainer__})\n" 12 | ) 13 | -------------------------------------------------------------------------------- /src/pmnet/api/__init__.py: -------------------------------------------------------------------------------- 1 | # NOTE: For DL Model Training 2 | __all__ = ["PharmacoNet", "ProteinParser", "get_pmnet_dev"] 3 | 4 | import torch 5 | 6 | from pmnet.data.parser import ProteinParser 7 | from pmnet.module import PharmacoNet 8 | 9 | from . import typing 10 | 11 | 12 | def get_pmnet_dev( 13 | device: str | torch.device = "cpu", 14 | score_threshold: float = 0.5, 15 | molvoxel_library: str = "numpy", 16 | compile: bool = False, 17 | ) -> PharmacoNet: 18 | """ 19 | device: 'cpu' | 'cuda' 20 | score_threshold: float | dict[str, float] | None 21 | custom threshold to identify hotspots. 22 | For feature extraction, recommended value is '0.5' 23 | molvoxel_library: str 24 | If you want to use PharmacoNet in DL model training, recommend to use 'numpy' 25 | compile: bool 26 | torch.compile 27 | """ 28 | pm_net: PharmacoNet = PharmacoNet(device, score_threshold, False, molvoxel_library) 29 | if compile: 30 | assert torch.__version__ >= "2.0.0" 31 | pm_net.run_extraction = torch.compile(pm_net.run_extraction) 32 | return pm_net 33 | -------------------------------------------------------------------------------- /src/pmnet/api/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from torch import Tensor 4 | 5 | MultiScaleFeature = tuple[Tensor, Tensor, Tensor, Tensor, Tensor] 6 | HotspotInfo = dict[str, Any] 7 | -------------------------------------------------------------------------------- /src/pmnet/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .constant import INTERACTION_LIST 2 | from .pointcloud import PROTEIN_CHANNEL_LIST 3 | -------------------------------------------------------------------------------- /src/pmnet/data/constant.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | INTERACTION_LIST: Sequence[str] = ( 4 | "Hydrophobic", 5 | "PiStacking_P", 6 | "PiStacking_T", 7 | "PiCation_lring", 8 | "PiCation_pring", 9 | "HBond_ldon", 10 | "HBond_pdon", 11 | "SaltBridge_lneg", 12 | "SaltBridge_pneg", 13 | "XBond", 14 | ) 15 | 16 | NUM_INTERACTION_TYPES: int = 10 17 | 18 | HYDROPHOBIC = 0 19 | PISTACKING_P = 1 20 | PISTACKING_T = 2 21 | PICATION_LRING = 3 22 | PICATION_PRING = 4 23 | HBOND_LDON = 5 24 | HBOND_PDON = 6 25 | SALTBRIDGE_LNEG = 7 26 | SALTBRIDGE_PNEG = 8 27 | XBOND = 9 28 | 29 | # PLIP Distance + 0.5 A 30 | INTERACTION_DIST = { 31 | HYDROPHOBIC: 4.5, # 4.0 + 0.5 32 | PISTACKING_P: 6.0, # 5.5 + 0.5 33 | PISTACKING_T: 6.0, # 5.5 + 0.5 34 | PICATION_LRING: 6.5, # 6.0 + 0.5 35 | PICATION_PRING: 6.5, # 6.0 + 0.5 36 | HBOND_LDON: 4.5, # 4.1 + 0.5 - 0.1 (to be devided to 0.5) 37 | HBOND_PDON: 4.5, # 4.1 + 0.5 - 0.1 38 | SALTBRIDGE_LNEG: 6.0, # 5.5 + 0.5 39 | SALTBRIDGE_PNEG: 6.0, # 5.5 + 0.5 40 | XBOND: 4.5, # 4.0 + 0.5 41 | } 42 | 43 | LONG_INTERACTION: set[int] = { 44 | PISTACKING_P, 45 | PISTACKING_T, 46 | PICATION_PRING, 47 | PICATION_LRING, 48 | SALTBRIDGE_LNEG, 49 | SALTBRIDGE_PNEG, 50 | } 51 | 52 | SHORT_INTERACTION: set[int] = { 53 | HYDROPHOBIC, 54 | HBOND_LDON, 55 | HBOND_PDON, 56 | XBOND, 57 | } 58 | -------------------------------------------------------------------------------- /src/pmnet/data/extract_pocket.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import warnings 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | from Bio.PDB import PDBIO, PDBParser 8 | from Bio.PDB.PDBIO import Select 9 | from numpy.typing import ArrayLike 10 | 11 | warnings.filterwarnings("ignore") 12 | 13 | AMINO_ACID = [ 14 | "GLY", 15 | "ALA", 16 | "VAL", 17 | "LEU", 18 | "ILE", 19 | "PRO", 20 | "PHE", 21 | "TYR", 22 | "TRP", 23 | "SER", 24 | "THR", 25 | "CYS", 26 | "MET", 27 | "ASN", 28 | "GLN", 29 | "ASP", 30 | "GLU", 31 | "LYS", 32 | "ARG", 33 | "HIS", 34 | "HIP", 35 | "HIE", 36 | "TPO", 37 | "HID", 38 | "LEV", 39 | "MEU", 40 | "PTR", 41 | "GLV", 42 | "CYT", 43 | "SEP", 44 | "HIZ", 45 | "CYM", 46 | "GLM", 47 | "ASQ", 48 | "TYS", 49 | "CYX", 50 | "GLZ", 51 | "MSE", 52 | "CSO", 53 | "KCX", 54 | "CSD", 55 | "MLY", 56 | "PCA", 57 | "LLP", 58 | ] 59 | 60 | 61 | class DistSelect(Select): 62 | def __init__(self, center, cutoff: float = 40.0): 63 | self.center = np.array(center).reshape(1, 3) 64 | self.cutoff = cutoff 65 | 66 | def accept_residue(self, residue): 67 | if super().accept_residue(residue) == 0: 68 | return 0 69 | if residue.get_resname() not in AMINO_ACID: 70 | return 0 71 | residue_positions = np.array( 72 | [list(atom.get_vector()) for atom in residue.get_atoms() if "H" not in atom.get_id()] 73 | ) 74 | if residue_positions.shape[0] == 0: 75 | return 0 76 | min_dis = np.min(np.linalg.norm(residue_positions - self.center, axis=-1)) 77 | if min_dis < self.cutoff: 78 | return 1 79 | else: 80 | return 0 81 | 82 | 83 | DEFAULT_CUTOFF = 16 * math.sqrt(3) + 5.0 84 | 85 | 86 | def extract_pocket( 87 | protein_pdb_path: str | Path, 88 | out_pocket_pdb_path: str, 89 | center: ArrayLike, 90 | cutoff: float = DEFAULT_CUTOFF, 91 | ): 92 | parser = PDBParser() 93 | structure = parser.get_structure("protein", str(protein_pdb_path)) 94 | io = PDBIO() 95 | io.set_structure(structure) 96 | io.save(out_pocket_pdb_path, DistSelect(center, cutoff)) 97 | command = f"obabel {out_pocket_pdb_path} -O {out_pocket_pdb_path} -d 2>/dev/null" 98 | os.system(command) 99 | -------------------------------------------------------------------------------- /src/pmnet/data/objects/__init__.py: -------------------------------------------------------------------------------- 1 | from .objects import Protein 2 | -------------------------------------------------------------------------------- /src/pmnet/data/objects/atom_classes.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | from dataclasses import dataclass, field 3 | from functools import cached_property 4 | 5 | import numpy as np 6 | from numpy.typing import NDArray 7 | from openbabel.pybel import ob 8 | 9 | from . import utils 10 | 11 | Tuple3D = tuple[float, float, float] 12 | 13 | 14 | @dataclass 15 | class Point3D(Sequence): 16 | x: float 17 | y: float 18 | z: float 19 | 20 | @classmethod 21 | def from_obatom(cls, obatom: ob.OBAtom): 22 | return cls(*utils.ob_coords(obatom)) 23 | 24 | @classmethod 25 | def from_array(cls, array: NDArray): 26 | x, y, z = array 27 | return cls(x, y, z) 28 | 29 | def __array__(self): 30 | return np.array((self.x, self.y, self.z)) 31 | 32 | def __iter__(self): 33 | yield self.x 34 | yield self.y 35 | yield self.z 36 | 37 | def __getitem__(self, idx: int): 38 | if idx == 0: 39 | return self.x 40 | elif idx == 1: 41 | return self.y 42 | elif idx == 2: 43 | return self.z 44 | raise ValueError 45 | 46 | def __len__(self): 47 | return 3 48 | 49 | 50 | @dataclass 51 | class BaseInteractablePart: 52 | 53 | @property 54 | def small(self): 55 | if self.__small is None: 56 | self.__small = self.to_small() 57 | return self.__small 58 | 59 | def to_small(self): 60 | raise NotImplementedError 61 | 62 | 63 | @dataclass 64 | class BaseHydrophobicAtom(BaseInteractablePart): 65 | obatom: ob.OBAtom 66 | coords: Point3D = field(init=False) 67 | 68 | def __post_init__(self): 69 | self.coords = Point3D.from_obatom(self.obatom) 70 | 71 | @property 72 | def index(self) -> int: 73 | return self.obatom.GetIdx() - 1 74 | 75 | 76 | @dataclass 77 | class BaseHBondAcceptor(BaseInteractablePart): 78 | obatom: ob.OBAtom 79 | coords: Point3D = field(init=False) 80 | 81 | def __post_init__(self): 82 | self.coords = Point3D.from_obatom(self.obatom) 83 | 84 | @property 85 | def index(self) -> int: 86 | return self.obatom.GetIdx() - 1 87 | 88 | 89 | @dataclass 90 | class BaseHBondDonor(BaseInteractablePart): 91 | obatom: ob.OBAtom 92 | coords: Point3D = field(init=False) 93 | hydrogens: Sequence[ob.OBAtom] = field(init=False) 94 | hydrogen_coords_list: Sequence[Point3D] = field(init=False) 95 | 96 | def __post_init__(self): 97 | self.coords = Point3D.from_obatom(self.obatom) 98 | hydrogens = [neigh for neigh in ob.OBAtomAtomIter(self.obatom) if neigh.GetAtomicNum() == 1] 99 | self.hydrogens = hydrogens 100 | self.hydrogen_coords_list = [Point3D.from_obatom(h) for h in hydrogens] 101 | 102 | @property 103 | def index(self) -> int: 104 | return self.obatom.GetIdx() - 1 105 | 106 | 107 | @dataclass 108 | class BaseRing(BaseInteractablePart): 109 | obatoms: Sequence[ob.OBAtom] 110 | center: Point3D = field(init=False) 111 | normal: NDArray = field(init=False) 112 | 113 | def __post_init__(self): 114 | coords_list = np.array([utils.ob_coords(obatom) for obatom in self.obatoms]) 115 | self.center = Point3D.from_array(np.mean(coords_list, axis=0)) 116 | p1, p2, p3 = coords_list[0], coords_list[2], coords_list[4] 117 | v1, v2 = utils.vector(p1, p2), utils.vector(p1, p3) 118 | self.normal = utils.normalize(np.cross(v1, v2)) 119 | 120 | @property 121 | def indices(self) -> list[int]: 122 | return [obatom.GetIdx() - 1 for obatom in self.obatoms] 123 | 124 | 125 | @dataclass 126 | class BaseCharged(BaseInteractablePart): 127 | obatoms: Sequence[ob.OBAtom] 128 | center: Point3D = None 129 | 130 | def __post_init__(self): 131 | if self.center is None: 132 | if len(self.obatoms) == 1: 133 | self.center = Point3D.from_obatom(self.obatoms[0]) 134 | else: 135 | coords_list = np.array([utils.ob_coords(obatom) for obatom in self.obatoms]) 136 | self.center = Point3D.from_array(np.mean(coords_list, axis=0)) 137 | 138 | @property 139 | def indices(self) -> list[int]: 140 | return [obatom.GetIdx() - 1 for obatom in self.obatoms] 141 | 142 | 143 | @dataclass 144 | class BasePosCharged(BaseCharged): 145 | pass 146 | 147 | 148 | @dataclass 149 | class BaseNegCharged(BaseCharged): 150 | pass 151 | 152 | 153 | @dataclass 154 | class BaseXBondDonor(BaseInteractablePart): 155 | X: ob.OBAtom 156 | C: ob.OBAtom = field(init=False) 157 | X_coords: Point3D = field(init=False) 158 | C_coords: Point3D = field(init=False) 159 | 160 | def __post_init__(self): 161 | for neigh in ob.OBAtomAtomIter(self.X): 162 | if neigh.GetAtomicNum() == 6: 163 | self.C = neigh 164 | break 165 | assert self.C is not None 166 | self.X_coords = Point3D.from_obatom(self.X) 167 | self.C_coords = Point3D.from_obatom(self.C) 168 | 169 | @property 170 | def X_index(self) -> int: 171 | return self.X.GetIdx() - 1 172 | 173 | @property 174 | def C_index(self) -> int: 175 | return self.C.GetIdx() - 1 176 | 177 | @property 178 | def indices(self) -> list[int]: 179 | return [self.X_index, self.C_index] 180 | 181 | 182 | @dataclass 183 | class BaseXBondAcceptor(BaseInteractablePart): 184 | O: ob.OBAtom 185 | Y: ob.OBAtom 186 | O_coords: Point3D = field(init=False) 187 | Y_coords: Point3D = field(init=False) 188 | 189 | def __post_init__(self): 190 | self.O_coords = Point3D.from_obatom(self.O) 191 | self.Y_coords = Point3D.from_obatom(self.Y) 192 | 193 | @property 194 | def O_index(self) -> int: 195 | return self.O.GetIdx() - 1 196 | 197 | @property 198 | def Y_index(self) -> int: 199 | return self.Y.GetIdx() - 1 200 | 201 | @property 202 | def indices(self) -> list[int]: 203 | return [self.O_index, self.Y_index] 204 | 205 | 206 | """ PROTEIN """ 207 | 208 | 209 | class ProteinAtom: 210 | @cached_property 211 | def obresidue_index(self) -> int: 212 | return self.obresidue.GetIdx() - 1 213 | 214 | @cached_property 215 | def obresidue_name(self) -> str: 216 | return self.obresidue.GetName() 217 | 218 | @cached_property 219 | def obresidue(self) -> ob.OBResidue: 220 | if hasattr(self, "obatom"): 221 | obatom = self.obatom 222 | elif hasattr(self, "obatoms"): 223 | obatom = self.obatoms[0] 224 | else: 225 | obatom = self.Y 226 | return obatom.GetResidue() 227 | 228 | 229 | @dataclass 230 | class HydrophobicAtom_P(ProteinAtom, BaseHydrophobicAtom): 231 | pass 232 | 233 | 234 | @dataclass 235 | class HBondAcceptor_P(ProteinAtom, BaseHBondAcceptor): 236 | pass 237 | 238 | 239 | @dataclass 240 | class HBondDonor_P(ProteinAtom, BaseHBondDonor): 241 | pass 242 | 243 | 244 | @dataclass 245 | class Ring_P(ProteinAtom, BaseRing): 246 | pass 247 | 248 | 249 | @dataclass 250 | class PosCharged_P(ProteinAtom, BasePosCharged): 251 | pass 252 | 253 | 254 | @dataclass 255 | class NegCharged_P(ProteinAtom, BaseNegCharged): 256 | pass 257 | 258 | 259 | @dataclass 260 | class XBondAcceptor_P(ProteinAtom, BaseXBondAcceptor): 261 | pass 262 | -------------------------------------------------------------------------------- /src/pmnet/data/objects/objects.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from openbabel import pybel 4 | from openbabel.pybel import ob 5 | 6 | from .atom_classes import ( 7 | HBondAcceptor_P, 8 | HBondDonor_P, 9 | HydrophobicAtom_P, 10 | NegCharged_P, 11 | PosCharged_P, 12 | Ring_P, 13 | XBondAcceptor_P, 14 | ) 15 | 16 | pybel.ob.obErrorLog.StopLogging() 17 | 18 | 19 | class Protein: 20 | def __init__( 21 | self, 22 | pbmol: pybel.Molecule, 23 | addh: bool = True, 24 | ): 25 | """ 26 | pbmol: Pybel Mol 27 | addh: if True, call OBMol.AddPolarHydrogens() 28 | setup: if True, find interactable parts 29 | ligand: if ligand is not None, extract pocket 30 | """ 31 | 32 | self.addh: bool = addh 33 | 34 | self.pbmol = pbmol.clone 35 | self.pbmol.removeh() 36 | self.obmol = self.pbmol.OBMol 37 | self.obatoms: list[ob.OBAtom] = list(ob.OBMolAtomIter(self.obmol)) 38 | self.num_heavyatoms = len(self.obatoms) 39 | 40 | self.pbmol_hyd: pybel.Molecule 41 | if addh: 42 | self.pbmol_hyd = self.pbmol.clone 43 | self.pbmol_hyd.OBMol.AddPolarHydrogens() 44 | else: 45 | self.pbmol_hyd = pbmol 46 | self.obmol_hyd = self.pbmol_hyd.OBMol 47 | self.obatoms_hyd: list[ob.OBAtom] = list(ob.OBMolAtomIter(self.obmol_hyd))[: self.num_heavyatoms] 48 | self.obatoms_hyd_nonwater: list[ob.OBAtom] = [ 49 | obatom 50 | for obatom in self.obatoms_hyd 51 | if obatom.GetResidue().GetName() != "HOH" and obatom.GetAtomicNum() in [6, 7, 8, 16] 52 | ] 53 | self.obresidues_hyd: list[ob.OBResidue] = list(ob.OBResidueIter(self.obmol_hyd)) 54 | 55 | self.hydrophobic_atoms_all: list[HydrophobicAtom_P] 56 | self.rings_all: list[Ring_P] 57 | self.pos_charged_atoms_all: list[PosCharged_P] 58 | self.neg_charged_atoms_all: list[NegCharged_P] 59 | self.hbond_donors_all: list[HBondDonor_P] 60 | self.hbond_acceptors_all: list[HBondAcceptor_P] 61 | self.xbond_acceptors_all: list[XBondAcceptor_P] 62 | 63 | self.hydrophobic_atoms_all = self.__find_hydrophobic_atoms() 64 | self.rings_all = self.__find_rings() 65 | self.pos_charged_atoms_all, self.neg_charged_atoms_all = self.__find_charged_atoms() 66 | self.hbond_donors_all = self.__find_hbond_donors() 67 | self.hbond_acceptors_all = self.__find_hbond_acceptors() 68 | self.xbond_acceptors_all = self.__find_xbond_acceptors() 69 | 70 | @classmethod 71 | def from_pdbfile(cls, path, addh=True, **kwargs): 72 | pbmol = next(pybel.readfile("pdb", path)) 73 | return cls(pbmol, addh, **kwargs) 74 | 75 | # Search Interactable Part 76 | def __find_hydrophobic_atoms(self) -> list[HydrophobicAtom_P]: 77 | hydrophobics = [ 78 | HydrophobicAtom_P(obatom) 79 | for obatom in self.obatoms_hyd_nonwater 80 | if obatom.GetAtomicNum() == 6 and all(neigh.GetAtomicNum() in (1, 6) for neigh in ob.OBAtomAtomIter(obatom)) 81 | ] 82 | return hydrophobics 83 | 84 | def __find_hbond_acceptors(self) -> list[HBondAcceptor_P]: 85 | acceptors = [HBondAcceptor_P(obatom) for obatom in self.obatoms_hyd_nonwater if obatom.IsHbondAcceptor()] 86 | return acceptors 87 | 88 | def __find_hbond_donors(self) -> list[HBondDonor_P]: 89 | donors = [HBondDonor_P(obatom) for obatom in self.obatoms_hyd_nonwater if obatom.IsHbondDonor()] 90 | return donors 91 | 92 | def __find_rings(self) -> list[Ring_P]: 93 | rings = [] 94 | ring_candidates = self.pbmol_hyd.sssr 95 | for ring in ring_candidates: 96 | if not 4 < len(ring._path) <= 6: 97 | continue 98 | obatoms = [self.obatoms_hyd[idx - 1] for idx in sorted(ring._path)] 99 | residue = obatoms[0].GetResidue() 100 | if residue.GetName() not in ["TYR", "TRP", "HIS", "PHE"]: 101 | continue 102 | rings.append(Ring_P(obatoms)) 103 | return rings 104 | 105 | def __find_charged_atoms(self) -> tuple[list[PosCharged_P], list[NegCharged_P]]: 106 | pos_charged = [] 107 | neg_charged = [] 108 | 109 | for obresidue in self.obresidues_hyd: 110 | obresname = obresidue.GetName() 111 | if obresname in ("ARG", "HIS", "LYS"): 112 | obatoms = [ 113 | obatom 114 | for obatom in ob.OBResidueAtomIter(obresidue) 115 | if obatom.GetAtomicNum() == 7 and obresidue.GetAtomProperty(obatom, ob.SIDECHAIN) 116 | ] 117 | if len(obatoms) > 0: 118 | pos_charged.append(PosCharged_P(obatoms)) 119 | 120 | elif obresname in ("GLU", "ASP"): 121 | obatoms = [ 122 | obatom 123 | for obatom in ob.OBResidueAtomIter(obresidue) 124 | if obatom.GetAtomicNum() == 8 and obresidue.GetAtomProperty(obatom, ob.SIDECHAIN) 125 | ] 126 | if len(obatoms) > 0: 127 | neg_charged.append(NegCharged_P(obatoms)) 128 | 129 | return pos_charged, neg_charged 130 | 131 | def __find_xbond_acceptors(self) -> list[XBondAcceptor_P]: 132 | """Look for halogen bond acceptors (Y-{O|N|S}, with Y=N,C)""" 133 | acceptors = [] 134 | for obatom in self.obatoms_hyd_nonwater: 135 | if obatom.GetAtomicNum() not in [8, 7, 16]: 136 | continue 137 | neighbors = [neigh for neigh in ob.OBAtomAtomIter(obatom) if neigh.GetAtomicNum() in [6, 7, 16]] 138 | if len(neighbors) == 1: 139 | O, Y = obatom, neighbors[0] 140 | acceptors.append(XBondAcceptor_P(O, Y)) 141 | return acceptors 142 | -------------------------------------------------------------------------------- /src/pmnet/data/objects/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections.abc import Sequence 3 | 4 | import numpy as np 5 | from numpy.typing import NDArray 6 | from openbabel.pybel import ob 7 | 8 | 9 | def check_in_cutoff(coords, neighbor_coords_list, cutoff: float): 10 | """ 11 | coords: (3,) 12 | neighbor_coords: (N, 3) 13 | cutoff: scalar 14 | """ 15 | x1, y1, z1 = coords 16 | cutoff_square = cutoff**2 17 | for neighbor_coords in neighbor_coords_list: 18 | x2, y2, z2 = neighbor_coords 19 | distance_sq = (x1 - x2) ** 2 + (y1 - y2) ** 2 + (z1 - z2) ** 2 20 | if distance_sq < cutoff_square: 21 | return True 22 | return False 23 | 24 | 25 | def angle_btw_vectors(vec1: NDArray, vec2: NDArray, degree=True, normalized=False) -> float: 26 | if np.array_equal(vec1, vec2): 27 | return 0.0 28 | if normalized: 29 | cosval = np.dot(vec1, vec2) 30 | else: 31 | cosval = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) 32 | angle = math.acos(np.clip(cosval, -1, 1)) 33 | return math.degrees(angle) if degree else angle 34 | 35 | 36 | def vector(p1: Sequence[float] | NDArray, p2: Sequence[float] | NDArray) -> NDArray: 37 | return np.subtract(p2, p1) 38 | 39 | 40 | def euclidean3d(p1: Sequence[float] | NDArray, p2: Sequence[float] | NDArray) -> float: 41 | return math.sqrt(sum([(a - b) ** 2 for a, b in zip(p1, p2, strict=False)])) 42 | 43 | 44 | def normalize(vec: NDArray) -> NDArray: 45 | norm = np.linalg.norm(vec) 46 | assert norm > 0, "vector size is zero" 47 | return vec / norm 48 | 49 | 50 | def projection(point: Sequence[float] | NDArray, origin: Sequence[float] | NDArray, normal: NDArray) -> NDArray: 51 | """ 52 | point: point to be projected 53 | normal, orig: normal vector & origin of projection plane 54 | """ 55 | c = np.dot(normal, np.subtract(point, origin)) 56 | return np.subtract(point, c * normal) 57 | 58 | 59 | def ob_coords(obatom: ob.OBAtom) -> tuple[float, float, float]: 60 | return (obatom.x(), obatom.y(), obatom.z()) 61 | -------------------------------------------------------------------------------- /src/pmnet/data/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | from molvoxel import Voxelizer, create_voxelizer 8 | from numpy.typing import NDArray 9 | from openbabel import pybel 10 | from torch import Tensor 11 | 12 | from pmnet.data import pointcloud, token_inference 13 | from pmnet.data.extract_pocket import extract_pocket 14 | from pmnet.data.objects import Protein 15 | 16 | 17 | class ProteinParser: 18 | def __init__( 19 | self, 20 | center_noise: float = 0.0, 21 | pocket_extract: bool = True, 22 | molvoxel_library: str = "numpy", 23 | ): 24 | """ 25 | center_noise: for data augmentation 26 | pocket_extract: if True, we read pocket instead of entire protein. (faster) 27 | """ 28 | self.voxelizer = create_voxelizer(0.5, 64, sigma=1 / 3, library=molvoxel_library) 29 | self.noise: float = center_noise 30 | self.extract: bool = pocket_extract 31 | 32 | ob_log_handler = pybel.ob.OBMessageHandler() 33 | ob_log_handler.SetOutputLevel(0) # 0: None 34 | 35 | def __call__( 36 | self, 37 | protein_pdb_path: str | Path, 38 | ref_ligand_path: str | Path | None = None, 39 | center: NDArray[np.float32] | tuple[float, float, float] | None = None, 40 | ) -> tuple[Tensor, Tensor, Tensor, Tensor]: 41 | return self.parse(protein_pdb_path, ref_ligand_path, center) 42 | 43 | def parse( 44 | self, 45 | protein_pdb_path: str | Path, 46 | ref_ligand_path: str | Path | None = None, 47 | center: NDArray[np.float32] | tuple[float, float, float] | None = None, 48 | ) -> tuple[Tensor, Tensor, Tensor, Tensor]: 49 | assert (ref_ligand_path is not None) or (center is not None) 50 | _center = self.get_center(ref_ligand_path, center) 51 | return parse_protein(self.voxelizer, protein_pdb_path, _center, self.noise, self.extract) 52 | 53 | @staticmethod 54 | def get_center( 55 | ref_ligand_path: str | Path | None = None, 56 | center: tuple[float, float, float] | NDArray | None = None, 57 | ) -> tuple[float, float, float]: 58 | if center is not None: 59 | assert len(center) == 3 60 | x, y, z = center 61 | else: 62 | assert ref_ligand_path is not None 63 | extension = os.path.splitext(ref_ligand_path)[1] 64 | assert extension in [".sdf", ".pdb", ".mol2"] 65 | ref_ligand = next(pybel.readfile(extension[1:], str(ref_ligand_path))) 66 | x, y, z = np.mean([atom.coords for atom in ref_ligand.atoms], axis=0, dtype=np.float32).tolist() 67 | return float(x), float(y), float(z) 68 | 69 | 70 | def parse_protein( 71 | voxelizer: Voxelizer, 72 | protein_pdb_path: str | Path, 73 | center: NDArray[np.float32] | tuple[float, float, float], 74 | center_noise: float = 0.0, 75 | pocket_extract: bool = True, 76 | ) -> tuple[Tensor, Tensor, Tensor, Tensor]: 77 | if isinstance(center, tuple): 78 | center = np.array(center, dtype=np.float32) 79 | if center_noise > 0: 80 | center = center + (np.random.rand(3) * 2 - 1) * center_noise 81 | 82 | if pocket_extract: 83 | with tempfile.TemporaryDirectory() as dirname: 84 | pocket_path = os.path.join(dirname, "pocket.pdb") 85 | extract_pocket(protein_pdb_path, pocket_path, center) 86 | protein_obj: Protein = Protein.from_pdbfile(pocket_path) 87 | else: 88 | protein_obj: Protein = Protein.from_pdbfile(protein_pdb_path) 89 | 90 | token_positions, token_classes = token_inference.get_token_informations(protein_obj) 91 | tokens, filter = token_inference.get_token_and_filter(token_positions, token_classes, center) 92 | token_positions = token_positions[filter] 93 | 94 | protein_positions, protein_features = pointcloud.get_protein_pointcloud(protein_obj) 95 | protein_image = np.asarray( 96 | voxelizer.forward_features(protein_positions, center, protein_features, radii=1.5), 97 | np.float32, 98 | ) 99 | mask = np.logical_not(np.asarray(voxelizer.forward_single(protein_positions, center, radii=1.0), np.bool_)) 100 | del protein_obj 101 | return ( 102 | torch.from_numpy(protein_image).to(torch.float), 103 | torch.from_numpy(mask).to(torch.bool), 104 | torch.from_numpy(token_positions).to(torch.float), 105 | torch.from_numpy(tokens).to(torch.long), 106 | ) 107 | -------------------------------------------------------------------------------- /src/pmnet/data/pointcloud.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import numpy as np 4 | from numpy.typing import NDArray 5 | from openbabel.pybel import ob 6 | 7 | from .objects import Protein 8 | 9 | protein_atom_num_list = (6, 7, 8, 16, -1) 10 | protein_atom_symbol_list = ("C", "N", "O", "S", "UNK_ATOM") 11 | protein_aminoacid_list = ( 12 | "GLY", 13 | "ALA", 14 | "VAL", 15 | "LEU", 16 | "ILE", 17 | "PRO", 18 | "PHE", 19 | "TYR", 20 | "TRP", 21 | "SER", 22 | "THR", 23 | "CYS", 24 | "MET", 25 | "ASN", 26 | "GLN", 27 | "ASP", 28 | "GLU", 29 | "LYS", 30 | "ARG", 31 | "HIS", 32 | "UNK_AA", 33 | ) 34 | protein_interactable_list = ( 35 | "HydrophobicAtom", 36 | "Ring", 37 | "HBondDonor", 38 | "HBondAcceptor", 39 | "Cation", 40 | "Anion", 41 | "XBondAcceptor", 42 | ) 43 | 44 | NUM_PROTEIN_ATOMIC_NUM = len(protein_atom_num_list) 45 | NUM_PROTEIN_AMINOACID_NUM = len(protein_aminoacid_list) 46 | NUM_PROTEIN_INTERACTABLE_NUM = len(protein_interactable_list) 47 | 48 | PROTEIN_CHANNEL_LIST: Sequence[str] = protein_atom_symbol_list + protein_aminoacid_list + protein_interactable_list 49 | NUM_PROTEIN_CHANNEL = len(PROTEIN_CHANNEL_LIST) 50 | 51 | 52 | def get_position(obatom: ob.OBAtom) -> tuple[float, float, float]: 53 | return (obatom.x(), obatom.y(), obatom.z()) 54 | 55 | 56 | def protein_atom_function(atom: ob.OBAtom, out: NDArray, **kwargs) -> NDArray[np.float32]: 57 | atomicnum = atom.GetAtomicNum() 58 | if atomicnum in protein_atom_num_list: 59 | out[protein_atom_num_list.index(atomicnum)] = 1 60 | else: 61 | out[NUM_PROTEIN_ATOMIC_NUM - 1] = 1 62 | residue_type = atom.GetResidue().GetName() 63 | if residue_type in protein_aminoacid_list: 64 | out[NUM_PROTEIN_ATOMIC_NUM + protein_aminoacid_list.index(residue_type)] = 1 65 | else: 66 | out[NUM_PROTEIN_ATOMIC_NUM + NUM_PROTEIN_AMINOACID_NUM - 1] = 1 67 | return out 68 | 69 | 70 | def get_protein_pointcloud( 71 | pocket_obj: Protein, 72 | ) -> tuple[NDArray[np.float32], NDArray[np.float32]]: 73 | positions = np.array( 74 | [(obatom.x(), obatom.y(), obatom.z()) for obatom in pocket_obj.obatoms], 75 | dtype=np.float32, 76 | ) 77 | 78 | channels = np.zeros((pocket_obj.num_heavyatoms, NUM_PROTEIN_CHANNEL), dtype=np.float32) 79 | for i, atom in enumerate(pocket_obj.obatoms): 80 | protein_atom_function(atom, channels[i]) 81 | 82 | offset = NUM_PROTEIN_ATOMIC_NUM + NUM_PROTEIN_AMINOACID_NUM 83 | for hydrop in pocket_obj.hydrophobic_atoms_all: 84 | channels[hydrop.index, offset] = 1 85 | for ring in pocket_obj.rings_all: 86 | channels[ring.indices, offset + 1] = 1 87 | for donor in pocket_obj.hbond_donors_all: 88 | channels[donor.index, offset + 2] = 1 89 | for acceptor in pocket_obj.hbond_acceptors_all: 90 | channels[acceptor.index, offset + 3] = 1 91 | for cation in pocket_obj.pos_charged_atoms_all: 92 | channels[cation.indices, offset + 4] = 1 93 | for anion in pocket_obj.neg_charged_atoms_all: 94 | channels[anion.indices, offset + 5] = 1 95 | for acceptor in pocket_obj.xbond_acceptors_all: 96 | channels[acceptor.indices, offset + 6] = 1 97 | return positions, channels 98 | -------------------------------------------------------------------------------- /src/pmnet/data/token_inference.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | from numpy.typing import ArrayLike, NDArray 5 | 6 | from . import constant as C 7 | from .objects import Protein 8 | 9 | 10 | def get_token_informations( 11 | protein_obj: Protein, 12 | ) -> tuple[NDArray[np.float32], NDArray[np.int16]]: 13 | """get token information 14 | 15 | Args: 16 | protein_obj: Union[Protein] 17 | 18 | Returns: 19 | token_positions: [float, (N, 3)] token center positions 20 | token_classes: [int, (N,)] token interaction type 21 | """ 22 | num_tokens = ( 23 | len(protein_obj.hydrophobic_atoms_all) 24 | + len(protein_obj.rings_all) * 3 25 | + len(protein_obj.hbond_donors_all) 26 | + len(protein_obj.hbond_acceptors_all) 27 | + len(protein_obj.pos_charged_atoms_all) * 2 28 | + len(protein_obj.neg_charged_atoms_all) 29 | + len(protein_obj.xbond_acceptors_all) 30 | ) 31 | 32 | positions: list[tuple[float, float, float]] = [] 33 | classes: list[int] = [] 34 | 35 | # NOTE: Hydrophobic 36 | positions.extend(tuple(hydrop.coords) for hydrop in protein_obj.hydrophobic_atoms_all) 37 | classes.extend([C.HYDROPHOBIC] * len(protein_obj.hydrophobic_atoms_all)) 38 | 39 | # NOTE: PiStacking_P 40 | positions.extend(tuple(ring.center) for ring in protein_obj.rings_all) 41 | classes.extend([C.PISTACKING_P] * len(protein_obj.rings_all)) 42 | 43 | # NOTE: PiStacking_T 44 | positions.extend(tuple(ring.center) for ring in protein_obj.rings_all) 45 | classes.extend([C.PISTACKING_T] * len(protein_obj.rings_all)) 46 | 47 | # NOTE: PiCation_lring 48 | positions.extend(tuple(cation.center) for cation in protein_obj.pos_charged_atoms_all) 49 | classes.extend([C.PICATION_LRING] * len(protein_obj.pos_charged_atoms_all)) 50 | 51 | # NOTE: PiCation_pring 52 | positions.extend(tuple(ring.center) for ring in protein_obj.rings_all) 53 | classes.extend([C.PICATION_PRING] * len(protein_obj.rings_all)) 54 | 55 | # NOTE: HBond_ldon 56 | positions.extend(tuple(acceptor.coords) for acceptor in protein_obj.hbond_acceptors_all) 57 | classes.extend([C.HBOND_LDON] * len(protein_obj.hbond_acceptors_all)) 58 | 59 | # NOTE: HBond_pdon 60 | positions.extend(tuple(donor.coords) for donor in protein_obj.hbond_donors_all) 61 | classes.extend([C.HBOND_PDON] * len(protein_obj.hbond_donors_all)) 62 | 63 | # NOTE: Saltbridge_lneg 64 | positions.extend(tuple(cation.center) for cation in protein_obj.pos_charged_atoms_all) 65 | classes.extend([C.SALTBRIDGE_LNEG] * len(protein_obj.pos_charged_atoms_all)) 66 | 67 | # NOTE: Saltbridge_pneg 68 | positions.extend(tuple(anion.center) for anion in protein_obj.neg_charged_atoms_all) 69 | classes.extend([C.SALTBRIDGE_PNEG] * len(protein_obj.neg_charged_atoms_all)) 70 | 71 | # NOTE: XBond 72 | positions.extend(tuple(acceptor.O_coords) for acceptor in protein_obj.xbond_acceptors_all) 73 | classes.extend([C.XBOND] * len(protein_obj.xbond_acceptors_all)) 74 | 75 | assert len(positions) == len(classes) == num_tokens 76 | return ( 77 | np.array(positions, dtype=np.float32), 78 | np.array(classes, dtype=np.int16), 79 | ) 80 | 81 | 82 | def get_token_and_filter( 83 | positions: NDArray[np.float32], 84 | classes: NDArray[np.int16], 85 | center: NDArray[np.float32], 86 | ) -> tuple[NDArray[np.int16], NDArray[np.int16]]: 87 | """Create token and Filtering valid instances 88 | 89 | Args: 90 | positions: [float, (N, 3)] token center positions 91 | classes: [int, (N,)] token interaction type 92 | center: [float, (3,)] voxel image center 93 | resolution: voxel image resolution 94 | dimension: voxel imzge dimension (size) 95 | 96 | Returns: 97 | token: [int, (N_token, 4)] 98 | filter: [int, (N_token,)] 99 | """ 100 | resolution, dimension = 0.5, 64 101 | filter = [] 102 | tokens = [] 103 | x_center, y_center, z_center = center 104 | x_start = x_center - (dimension / 2) * resolution 105 | y_start = y_center - (dimension / 2) * resolution 106 | z_start = z_center - (dimension / 2) * resolution 107 | for i, ((x, y, z), c) in enumerate(zip(positions, classes, strict=False)): 108 | _x = int((x - x_start) // resolution) 109 | _y = int((y - y_start) // resolution) 110 | _z = int((z - z_start) // resolution) 111 | if (0 <= _x < dimension) and (0 <= _y < dimension) and (0 <= _z < dimension): 112 | filter.append(i) 113 | tokens.append((_x, _y, _z, c)) 114 | 115 | return np.array(tokens, dtype=np.int16), np.array(filter, dtype=np.int16) 116 | 117 | 118 | def get_box_area(tokens: ArrayLike) -> NDArray[np.bool_]: 119 | """Create Box Area 120 | 121 | Args: 122 | tokens: [Ntoken, 4] 123 | resolution: float, default = 0.5 124 | dimension: int, default = 64, 125 | 126 | Returns: 127 | box_areas: BoolArray [Ntoken, D, H, W] D=H=W=dimension 128 | """ 129 | resolution, dimension, pharmacophore_size = 0.5, 64, 1.0 130 | num_tokens = len(tokens) 131 | box_areas = np.zeros((num_tokens, dimension, dimension, dimension), dtype=np.bool_) 132 | grids = np.stack( 133 | np.meshgrid( 134 | np.arange(dimension), 135 | np.arange(dimension), 136 | np.arange(dimension), 137 | indexing="ij", 138 | ), 139 | 3, 140 | ) 141 | for i, (x, y, z, t) in enumerate(tokens): 142 | x, y, z, t = int(x), int(y), int(z), int(t) 143 | distances = np.linalg.norm(grids - np.array([[x, y, z]]), axis=-1) 144 | threshold = math.ceil((C.INTERACTION_DIST[int(t)] + pharmacophore_size) / resolution) 145 | box_areas[i] = distances < threshold 146 | return box_areas 147 | -------------------------------------------------------------------------------- /src/pmnet/network/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | backbones, 3 | decoders, 4 | detector, 5 | feature_embedding, 6 | mask_head, 7 | necks, 8 | token_head, 9 | ) 10 | from .builder import build_model 11 | -------------------------------------------------------------------------------- /src/pmnet/network/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .swin import SwinTransformer 2 | from .swinv2 import SwinTransformerV2 3 | -------------------------------------------------------------------------------- /src/pmnet/network/backbones/timm.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | def to_3tuple( 9 | v: float | int | tuple[float, float, float] | tuple[int, int, int], 10 | ) -> tuple[float, float, float] | tuple[int, int, int]: 11 | if isinstance(v, (int | float)): 12 | return (v, v, v) 13 | else: 14 | return v 15 | 16 | 17 | def _trunc_normal_(tensor, mean, std, a, b): 18 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 19 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 20 | def norm_cdf(x): 21 | # Computes standard normal cumulative distribution function 22 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 23 | 24 | if (mean < a - 2 * std) or (mean > b + 2 * std): 25 | warnings.warn( 26 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 27 | "The distribution of values may be incorrect.", 28 | stacklevel=2, 29 | ) 30 | 31 | # Values are generated by using a truncated uniform distribution and 32 | # then using the inverse CDF for the normal distribution. 33 | # Get upper and lower cdf values 34 | u = norm_cdf((a - mean) / std) 35 | v = norm_cdf((b - mean) / std) 36 | 37 | # Uniformly fill tensor with values from [u, v], then translate to 38 | # [2u-1, 2v-1]. 39 | tensor.uniform_(2 * u - 1, 2 * v - 1) 40 | 41 | # Use inverse cdf transform for normal distribution to get truncated 42 | # standard normal 43 | tensor.erfinv_() 44 | 45 | # Transform to proper mean, std 46 | tensor.mul_(std * math.sqrt(2.0)) 47 | tensor.add_(mean) 48 | 49 | # Clamp to ensure it's in the proper range 50 | tensor.clamp_(min=a, max=b) 51 | return tensor 52 | 53 | 54 | def trunc_normal_( 55 | tensor: torch.Tensor, 56 | mean: float = 0.0, 57 | std: float = 1.0, 58 | a: float = -2.0, 59 | b: float = 2.0, 60 | ) -> torch.Tensor: 61 | r"""Fills the input Tensor with values drawn from a truncated 62 | normal distribution. The values are effectively drawn from the 63 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 64 | with values outside :math:`[a, b]` redrawn until they are within 65 | the bounds. The method used for generating the random values works 66 | best when :math:`a \leq \text{mean} \leq b`. 67 | 68 | NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are 69 | applied while sampling the normal with mean/std applied, therefore a, b args 70 | should be adjusted to match the range of mean, std args. 71 | 72 | Args: 73 | tensor: an n-dimensional `torch.Tensor` 74 | mean: the mean of the normal distribution 75 | std: the standard deviation of the normal distribution 76 | a: the minimum cutoff value 77 | b: the maximum cutoff value 78 | Examples: 79 | >>> w = torch.empty(3, 5) 80 | >>> nn.init.trunc_normal_(w) 81 | """ 82 | with torch.no_grad(): 83 | return _trunc_normal_(tensor, mean, std, a, b) 84 | 85 | 86 | def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): 87 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 88 | 89 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 90 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 91 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 92 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 93 | 'survival rate' as the argument. 94 | 95 | """ 96 | if drop_prob == 0.0 or not training: 97 | return x 98 | keep_prob = 1 - drop_prob 99 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 100 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 101 | if keep_prob > 0.0 and scale_by_keep: 102 | random_tensor.div_(keep_prob) 103 | return x * random_tensor 104 | 105 | 106 | class DropPath(nn.Module): 107 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 108 | 109 | def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): 110 | super(__class__, self).__init__() 111 | self.drop_prob = drop_prob 112 | self.scale_by_keep = scale_by_keep 113 | 114 | def forward(self, x): 115 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 116 | 117 | def extra_repr(self): 118 | return f"drop_prob={round(self.drop_prob,3):0.3f}" 119 | -------------------------------------------------------------------------------- /src/pmnet/network/builder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from pmnet.network.backbones.swinv2 import SwinTransformerV2 4 | from pmnet.network.cavity_head import CavityHead 5 | from pmnet.network.decoders.fpn_decoder import FPNDecoder 6 | from pmnet.network.detector import PharmacoNetModel 7 | from pmnet.network.feature_embedding import FeaturePyramidNetwork 8 | from pmnet.network.mask_head import MaskHead 9 | from pmnet.network.token_head import TokenHead 10 | 11 | 12 | def build_model(config: dict) -> nn.Module: 13 | # embedding 14 | embedding = FeaturePyramidNetwork( 15 | backbone=SwinTransformerV2( 16 | in_channels=33, 17 | image_size=64, 18 | patch_size=2, 19 | embed_dim=96, 20 | depths=(2, 6, 2, 2), 21 | num_heads=(3, 6, 12, 24), 22 | window_size=4, 23 | out_indices=(0, 1, 2, 3), 24 | ), 25 | decoder=FPNDecoder( 26 | feature_channels=(33, 96, 192, 384, 768), 27 | num_convs=(1, 2, 2, 2, 2), 28 | channels=96, 29 | interpolate_mode="nearest", 30 | ), 31 | ) 32 | cavity_head = CavityHead( 33 | feature_dim=96, 34 | hidden_dim=96, 35 | ) 36 | 37 | token_head = TokenHead( 38 | feature_dim=96, 39 | num_interactions=10, 40 | token_feature_dim=192, 41 | num_feature_mlp_layers=3, 42 | num_score_mlp_layers=3, 43 | ) 44 | 45 | mask_head = MaskHead( 46 | token_feature_dim=192, 47 | decoder=FPNDecoder( 48 | feature_channels=(96, 96, 96, 96, 96), 49 | num_convs=(1, 2, 2, 2, 2), 50 | channels=96, 51 | ), 52 | ) 53 | 54 | return PharmacoNetModel(embedding, cavity_head, token_head, mask_head, num_interactions=10) 55 | -------------------------------------------------------------------------------- /src/pmnet/network/cavity_head.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from torch import Tensor, nn 4 | 5 | from .nn import BaseConv3d 6 | 7 | 8 | class CavityHead(nn.Module): 9 | def __init__( 10 | self, 11 | feature_dim: int = 96, 12 | hidden_dim: int = 96, 13 | norm_layer: type[nn.Module] | None = nn.BatchNorm3d, 14 | act_layer: type[nn.Module] | None = partial(nn.ReLU, inplace=True), # noqa 15 | ): 16 | super().__init__() 17 | 18 | self.short_head = nn.Sequential( 19 | BaseConv3d( 20 | feature_dim, 21 | hidden_dim, 22 | kernel_size=3, 23 | norm_layer=norm_layer, 24 | act_layer=act_layer, 25 | ), 26 | BaseConv3d(hidden_dim, 1, kernel_size=1, norm_layer=None, act_layer=None), 27 | ) 28 | self.long_head = nn.Sequential( 29 | BaseConv3d( 30 | feature_dim, 31 | hidden_dim, 32 | kernel_size=3, 33 | norm_layer=norm_layer, 34 | act_layer=act_layer, 35 | ), 36 | BaseConv3d(hidden_dim, 1, kernel_size=1, norm_layer=None, act_layer=None), 37 | ) 38 | 39 | def initialize_weights(self): 40 | for m in self.short_head.children(): 41 | m.initialize_weights() 42 | for m in self.long_head.children(): 43 | m.initialize_weights() 44 | 45 | def forward( 46 | self, 47 | features: Tensor, 48 | ) -> tuple[Tensor, Tensor]: 49 | """Pocket Extraction Function 50 | 51 | Args: 52 | features: FloatTensor [N, F, D, H, W] 53 | 54 | Returns: 55 | focus_area_short: FloatTensor [N, 1, D, H, W] 56 | focus_area_long: FloatTensor [N, 1, D, H, W] 57 | """ 58 | focus_area_short = self.short_head(features) 59 | focus_area_long = self.long_head(features) 60 | return focus_area_short, focus_area_long 61 | -------------------------------------------------------------------------------- /src/pmnet/network/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .fpn_decoder import FPNDecoder 2 | -------------------------------------------------------------------------------- /src/pmnet/network/decoders/fpn_decoder.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | 7 | from ..nn.layers import BaseConv3d 8 | 9 | 10 | class FPNDecoder(nn.Module): 11 | """ 12 | Modified FPN Structure [https://arxiv.org/abs/1807.10221] 13 | feature_channels: Bottom-Up Manner. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | feature_channels: Sequence[int] = (33, 96, 192, 384, 768), 19 | num_convs: Sequence[int] = (1, 2, 2, 2, 2), 20 | channels: int = 96, 21 | interpolate_mode: str = "nearest", 22 | align_corners: bool = False, 23 | norm_layer: type[nn.Module] | None = nn.BatchNorm3d, 24 | act_layer: type[nn.Module] | None = nn.ReLU, 25 | ): 26 | super().__init__() 27 | self.feature_channels = feature_channels 28 | self.interpolate_mode = interpolate_mode 29 | if interpolate_mode == "trilinear": 30 | self.align_corners = align_corners 31 | else: 32 | self.align_corners = None 33 | self.channels = channels 34 | 35 | lateral_conv_list = [] 36 | fpn_convs_list = [] 37 | for level, (channels, num_conv) in enumerate(zip(self.feature_channels, num_convs, strict=False)): 38 | if level == (len(self.feature_channels) - 1): # Lowest-Resolution Channels (Top) 39 | lateral_conv = nn.Identity() 40 | fpn_convs = nn.Sequential( 41 | *[ 42 | BaseConv3d( 43 | channels if i == 0 else self.channels, 44 | self.channels, 45 | kernel_size=3, 46 | norm_layer=norm_layer, 47 | act_layer=act_layer, 48 | ) 49 | for i in range(num_conv) 50 | ] 51 | ) 52 | else: 53 | lateral_conv = BaseConv3d( 54 | channels, 55 | self.channels, 56 | kernel_size=1, 57 | norm_layer=norm_layer, 58 | act_layer=act_layer, 59 | ) 60 | fpn_convs = nn.Sequential( 61 | *[ 62 | BaseConv3d( 63 | self.channels, 64 | self.channels, 65 | kernel_size=3, 66 | norm_layer=norm_layer, 67 | act_layer=act_layer, 68 | ) 69 | for _ in range(num_conv) 70 | ] 71 | ) 72 | lateral_conv_list.append(lateral_conv) 73 | fpn_convs_list.append(fpn_convs) 74 | 75 | self.lateral_conv_list = nn.ModuleList(lateral_conv_list) 76 | self.fpn_convs_list = nn.ModuleList(fpn_convs_list) 77 | 78 | def initialize_weights(self): 79 | for m in self.lateral_conv_list: 80 | if isinstance(m, BaseConv3d): 81 | m.initialize_weights() 82 | for seqm in self.fpn_convs_list: 83 | for m in seqm.children(): 84 | m.initialize_weights() 85 | 86 | def forward(self, features: Sequence[Tensor]) -> list[Tensor]: 87 | """Forward function. 88 | Args: 89 | features: Bottom-Up, [Highest-Resolution Feature Map, ..., Lowest-Resolution Feature Map] 90 | Returns: 91 | features: Top-Down, [Lowest-Resolution Feature Map, ..., Highest-Resolution Feature Map] 92 | """ 93 | num_levels = len(features) 94 | assert num_levels == len(self.feature_channels) 95 | fpn = None 96 | multi_scale_features = [] 97 | for level in range(num_levels - 1, -1, -1): 98 | feature = features[level] 99 | lateral_conv = self.lateral_conv_list[level] 100 | fpn_convs = self.fpn_convs_list[level] 101 | current_fpn = lateral_conv(feature) 102 | if level == (num_levels - 1): # Top 103 | assert fpn is None 104 | fpn = current_fpn 105 | else: 106 | assert fpn is not None 107 | fpn = current_fpn + F.interpolate( 108 | fpn, 109 | size=current_fpn.size()[-3:], 110 | mode=self.interpolate_mode, 111 | align_corners=self.align_corners, 112 | ) 113 | fpn = fpn_convs(fpn) 114 | multi_scale_features.append(fpn) 115 | return multi_scale_features 116 | -------------------------------------------------------------------------------- /src/pmnet/network/detector.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import torch.nn as nn 4 | from torch import IntTensor, Tensor 5 | 6 | from .cavity_head import CavityHead 7 | from .feature_embedding import FeaturePyramidNetwork 8 | from .mask_head import MaskHead 9 | from .token_head import TokenHead 10 | 11 | 12 | class PharmacoNetModel(nn.Module): 13 | def __init__( 14 | self, 15 | embedding: FeaturePyramidNetwork, 16 | cavity_head: CavityHead, 17 | token_head: TokenHead, 18 | mask_head: MaskHead, 19 | num_interactions: int, 20 | ): 21 | super().__init__() 22 | self.num_interactions = num_interactions 23 | self.embedding = embedding 24 | self.cavity_head = cavity_head 25 | self.token_head = token_head 26 | self.mask_head = mask_head 27 | 28 | def initialize_weights(self): 29 | self.embedding.initialize_weights() 30 | self.token_head.initialize_weights() 31 | self.mask_head.initialize_weights() 32 | 33 | def setup_train(self, criterion: nn.Module): 34 | self.criterion = criterion 35 | 36 | def forward_feature(self, in_image: Tensor) -> tuple[Tensor, ...]: 37 | """Feature Embedding 38 | Args: 39 | in_image: FloatTensor [N, C, Din, Hin, Win] 40 | Returns: 41 | multi-scale features: [FloatTensor [N, F, D, H, W]] 42 | """ 43 | return tuple(self.embedding.forward(in_image)) 44 | 45 | def forward_cavity_extraction(self, features: Tensor) -> tuple[Tensor, Tensor]: 46 | """Cavity Extraction 47 | Args: 48 | features: FloatTensor [N, F, Dout, Hout, Wout] 49 | Returns: 50 | cavity_narrow: FloatTensor [N, 1, Dout, Hout, Wout] 51 | cavity_wide: FloatTensor [N, 1, Dout, Hout, Wout] 52 | """ 53 | return self.cavity_head.forward(features) 54 | 55 | def forward_token_prediction( 56 | self, 57 | features: Tensor, 58 | tokens_list: Sequence[IntTensor], 59 | ) -> tuple[list[Tensor], list[Tensor]]: 60 | """token Selection Network 61 | 62 | Args: 63 | features: FloatTensor [N, F, Dout, Hout, Wout] 64 | tokens_list: List[IntTensor [Ntoken, 4] - (x, y, z, i)] 65 | 66 | Returns: 67 | token_scores_list: List[FloatTensor [Ntoken,] $\\in$ [0, 1]] 68 | token_features_list: List[FloatTensor [Ntoken, F]] 69 | """ 70 | token_scores_list, token_features_list = self.token_head.forward(features, tokens_list) 71 | return token_scores_list, token_features_list 72 | 73 | def forward_segmentation( 74 | self, 75 | multi_scale_features: tuple[Tensor, ...], 76 | box_tokens_list: Sequence[IntTensor], 77 | box_token_features_list: Sequence[Tensor], 78 | return_aux: bool = False, 79 | ) -> tuple[list[Tensor], list[list[Tensor]] | None]: 80 | """Mask Prediction 81 | 82 | Args: 83 | multi_scales_features: List[FloatTensor [N, F, D_scale, H_scale, W_scale]] 84 | box_tokens_list: List[IntTensor [Nbox, 4] - (x, y, z, i)] 85 | box_token_features_list: List[FloatTensor [Nbox, F]] 86 | 87 | Returns: 88 | box_masks_list: List[FloatTensor [Nbox, D, H, W]] 89 | aux_box_masks_list: List[List[FloatTensor [Nbox, D_scale, H_scale, W_scale]]] 90 | """ 91 | return self.mask_head.forward(multi_scale_features, box_tokens_list, box_token_features_list, return_aux) 92 | -------------------------------------------------------------------------------- /src/pmnet/network/feature_embedding.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | from torch import Tensor, nn 4 | 5 | from pmnet.network.backbones.swinv2 import SwinTransformerV2 6 | from pmnet.network.decoders.fpn_decoder import FPNDecoder 7 | 8 | 9 | class FeaturePyramidNetwork(nn.Module): 10 | def __init__( 11 | self, 12 | backbone: SwinTransformerV2, 13 | decoder: FPNDecoder, 14 | neck: nn.Module | None = None, 15 | feature_indices: tuple[int, ...] = (0, 1, 2, 3), 16 | set_input_to_bottom: bool = True, 17 | ): 18 | super().__init__() 19 | self.backbone: SwinTransformerV2 = backbone 20 | self.decoder: FPNDecoder = decoder 21 | self.feature_indices: tuple[int, ...] = feature_indices 22 | self.input_is_bottom = set_input_to_bottom 23 | 24 | if neck is not None: 25 | self.with_neck = True 26 | self.neck = neck 27 | else: 28 | self.with_neck = False 29 | 30 | def initialize_weights(self): 31 | self.backbone.initialize_weights() 32 | self.decoder.initialize_weights() 33 | if self.with_neck: 34 | self.neck.initialize_weights() 35 | 36 | def forward(self, in_image: Tensor) -> Sequence[Tensor]: 37 | """Feature Pyramid Network -> return multi-scale feature maps 38 | Args: 39 | in_image: (N, C, D, H, W) 40 | Returns: 41 | multi-scale features (top_down): [(N, F, D, H, W)] 42 | """ 43 | bottom_up_features: Sequence[Tensor] = self.backbone(in_image) 44 | if self.feature_indices is not None: 45 | bottom_up_features = [bottom_up_features[index] for index in self.feature_indices] 46 | if self.input_is_bottom: 47 | bottom_up_features = [in_image, *bottom_up_features] 48 | if self.with_neck: 49 | bottom_up_features = self.neck(bottom_up_features) 50 | top_down_features = self.decoder(bottom_up_features) 51 | return top_down_features 52 | -------------------------------------------------------------------------------- /src/pmnet/network/mask_head.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | 6 | from pmnet.network.decoders.fpn_decoder import FPNDecoder 7 | 8 | 9 | class MaskHead(nn.Module): 10 | def __init__( 11 | self, 12 | decoder: FPNDecoder, 13 | token_feature_dim: int = 192, 14 | ): 15 | super().__init__() 16 | feature_channels_list: list[int] = list(decoder.feature_channels) 17 | self.point_mlp_list = nn.ModuleList( 18 | [nn.Linear(token_feature_dim, channels) for channels in feature_channels_list] 19 | ) 20 | self.background_mlp_list = nn.ModuleList( 21 | [nn.Linear(token_feature_dim, channels) for channels in feature_channels_list] 22 | ) 23 | self.decoder = decoder 24 | self.conv_logits = nn.Conv3d(decoder.channels, 1, kernel_size=1) 25 | 26 | def initialize_weights(self): 27 | def _init_weight(m): 28 | if isinstance(m, nn.Linear | nn.Conv3d): 29 | nn.init.normal_(m.weight, std=0.01) 30 | if m.bias is not None: 31 | nn.init.constant_(m.bias, 0) 32 | 33 | for m in [self.conv_logits, self.point_mlp_list, self.background_mlp_list]: 34 | m.apply(_init_weight) 35 | 36 | self.decoder.initialize_weights() 37 | 38 | def forward( 39 | self, 40 | multi_scale_features: Tensor, 41 | tokens_list: Sequence[Tensor], 42 | token_features_list: Sequence[Tensor], 43 | return_aux: bool = False, 44 | ) -> tuple[list[Tensor], list[list[Tensor]] | None]: 45 | """Box Predicting Function 46 | 47 | Args: 48 | multi_scale_features: Top-Down, List[FloatTensor [N, F_scale, D_scale, H_scale, W_scale]] 49 | tokens_list: List[IntTensor [Nbox, 4] - (x, y, z, i)] 50 | token_features_list: List[FloatTensor[Nbox, Ftoken] 51 | 52 | Returns: 53 | masks_list: List[FloatTensor [Nbox, D, H, W]] 54 | aux_masks_list(optional): List[List[FloatTensor [Nbox, D, H, W]]] 55 | """ 56 | num_images = len(tokens_list) 57 | assert len(multi_scale_features[0]) == num_images 58 | multi_scale_features = multi_scale_features[::-1] # Top-Down -> Bottom-Up 59 | if return_aux: 60 | out_masks_list = [] 61 | aux_masks_list = [] 62 | for image_idx in range(num_images): 63 | out = self.do_predict_w_aux( 64 | [multi_scale_features[level][image_idx] for level in range(len(multi_scale_features))], 65 | tokens_list[image_idx], 66 | token_features_list[image_idx], 67 | ) 68 | out_masks_list.append(out[-1]) 69 | aux_masks_list.append(out[:-1]) 70 | else: 71 | aux_masks_list = None 72 | out_masks_list = [ 73 | self.do_predict_single( 74 | [multi_scale_features[level][image_idx] for level in range(len(multi_scale_features))], 75 | tokens_list[image_idx], 76 | token_features_list[image_idx], 77 | ) 78 | for image_idx in range(num_images) 79 | ] 80 | return out_masks_list, aux_masks_list 81 | 82 | def do_predict_w_aux( 83 | self, 84 | multi_scale_features: Sequence[Tensor], 85 | tokens: Tensor, 86 | token_features: Tensor, 87 | ) -> list[Tensor]: 88 | """Box Predicting Function 89 | 90 | Args: 91 | multi_scale_features: Bottom-Up, List[FloatTensor [F_scale, D_scale, H_scale, W_scale]] 92 | token_features: FloatTensor [Nbox, Ftoken] 93 | tokens: IntTensor [Nbox, 4] - (x, y, z, i) 94 | 95 | Returns: 96 | multi_scale_box_masks: List[FloatTensor [Nbox, D_scale, H_scale, W_scale]] 97 | """ 98 | Nbox = tokens.size(0) 99 | multi_scale_size = [features.size()[1:] for features in multi_scale_features] 100 | if Nbox > 0: 101 | Dout, Hout, Wout = multi_scale_size[0] 102 | token_indices = torch.split(tokens, 1, dim=1) # (x_list, y_list, z_list, i_list) 103 | xs, ys, zs, _ = token_indices 104 | 105 | bottom_up_box_features = [] 106 | for level in range(len(multi_scale_features)): 107 | features = multi_scale_features[level] 108 | _, D, H, W = features.shape 109 | _xs = torch.div(xs, Dout // D, rounding_mode="trunc") 110 | _ys = torch.div(ys, Hout // H, rounding_mode="trunc") 111 | _zs = torch.div(zs, Wout // W, rounding_mode="trunc") 112 | box_features = self.get_box_features(features, (_xs, _ys, _zs), token_features, level) 113 | bottom_up_box_features.append(box_features) 114 | 115 | top_down_features = self.decoder(bottom_up_box_features) 116 | top_down_box_masks = [self.conv_logits(features).squeeze(1) for features in top_down_features] 117 | return top_down_box_masks 118 | else: 119 | return [ 120 | torch.empty( 121 | (0, *size), 122 | dtype=multi_scale_features[0].dtype, 123 | device=tokens.device, 124 | ) 125 | for size in multi_scale_size[::-1] 126 | ] 127 | 128 | def do_predict_single( 129 | self, 130 | multi_scale_features: Sequence[Tensor], 131 | tokens: Tensor, 132 | token_features: Tensor, 133 | ) -> Tensor: 134 | """Box Predicting Function 135 | 136 | Args: 137 | multi_scale_features: Bottom-Up, List[FloatTensor [F_scale, D_scale, H_scale, W_scale]] 138 | token_features: FloatTensor [Nbox, Ftoken] 139 | tokens: IntTensor [Nbox, 4] - (x, y, z, i) 140 | 141 | Returns: 142 | box_masks: FloatTensor [Nbox, D_out, H_out, W_out] 143 | """ 144 | Nbox = tokens.size(0) 145 | multi_scale_size = [features.size()[1:] for features in multi_scale_features] 146 | Dout, Hout, Wout = multi_scale_size[0] 147 | if Nbox > 0: 148 | token_indices = torch.split(tokens, 1, dim=1) # (x_list, y_list, z_list, i_list) 149 | xs, ys, zs, _ = token_indices 150 | 151 | bottom_up_box_features = [] 152 | for level in range(len(multi_scale_features)): 153 | features = multi_scale_features[level] 154 | _, D, H, W = features.shape 155 | _xs = torch.div(xs, Dout // D, rounding_mode="trunc") 156 | _ys = torch.div(ys, Hout // H, rounding_mode="trunc") 157 | _zs = torch.div(zs, Wout // W, rounding_mode="trunc") 158 | box_features = self.get_box_features(features, (_xs, _ys, _zs), token_features, level) 159 | bottom_up_box_features.append(box_features) 160 | 161 | top_down_features = self.decoder(bottom_up_box_features) 162 | return self.conv_logits(top_down_features[-1]).squeeze(1) 163 | else: 164 | return torch.empty( 165 | (0, Dout, Hout, Wout), 166 | dtype=multi_scale_features[0].dtype, 167 | device=tokens.device, 168 | ) 169 | 170 | def get_box_features( 171 | self, 172 | features: Tensor, 173 | token_indices: tuple[Tensor, Tensor, Tensor], 174 | token_features: Tensor, 175 | level: int, 176 | ) -> Tensor: 177 | """Extract token features 178 | 179 | Args: 180 | features: FloatTensor [F_scale, D_scale, H_scale, W_scale] 181 | token_indices: Tuple[IntTensor [Nbox,], IntTensor [Nbox,], IntTensor[Nbox,]] - (xs, ys, zs) 182 | token_features: FloatTensor [Nbox, Ftoken] 183 | 184 | Returns: 185 | box_features: FloatTensor [Nbox, F_scale, D_scale, H_scale, W_scale] 186 | """ 187 | F, D, H, W = features.shape 188 | xs, ys, zs = token_indices 189 | Nbox = token_features.size(0) 190 | Nboxs = torch.arange(Nbox, dtype=xs.dtype, device=xs.device) 191 | background_features = self.background_mlp_list[level](token_features) # [Nbox, F] 192 | point_features = self.point_mlp_list[level](token_features) # [Nbox, F] 193 | box_features = background_features.view(Nbox, F, 1, 1, 1).repeat(1, 1, D, H, W) # [Nbox, F, D, H, W] 194 | box_features[Nboxs, :, xs, ys, zs] += point_features 195 | features = features.unsqueeze(0) + box_features 196 | return features 197 | -------------------------------------------------------------------------------- /src/pmnet/network/necks/__init__.py: -------------------------------------------------------------------------------- 1 | from .center_crop import CenterCrop, MultipleCenterCrop 2 | -------------------------------------------------------------------------------- /src/pmnet/network/necks/center_crop.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, nn 2 | 3 | 4 | class MultipleCenterCrop(nn.Module): 5 | def __init__(self, crop_sizes: list[int]): 6 | super().__init__() 7 | self.crop_sizes: list[tuple[int, int, int]] = [(size, size, size) for size in crop_sizes] 8 | 9 | def forward(self, inputs: list[Tensor]) -> list[Tensor]: 10 | outputs: list[Tensor] = [] 11 | for size, tensor in zip(self.crop_sizes, inputs, strict=True): 12 | Dc, Hc, Wc = size 13 | _, _, D, H, W = tensor.size() 14 | d, h, w = (D - Dc) // 2, (H - Hc) // 2, (W - Wc) // 2 15 | assert d >= 0 and h >= 0 and w >= 0 16 | if d == 0 and h == 0 and w == 0: 17 | outputs.append(tensor) 18 | else: 19 | outputs.append(tensor[:, :, d : D - d, h : H - h, w : W - w].contiguous()) 20 | return outputs 21 | 22 | def initialize_weights(self): 23 | pass 24 | 25 | 26 | class CenterCrop(nn.Module): 27 | def __init__(self, crop_size: int): 28 | super().__init__() 29 | self.crop_size: tuple[int, int, int] = (crop_size, crop_size, crop_size) 30 | 31 | def forward(self, input: Tensor) -> Tensor: 32 | Dc, Hc, Wc = self.crop_size 33 | _, _, D, H, W = input.size() 34 | d, h, w = (D - Dc) // 2, (H - Hc) // 2, (W - Wc) // 2 35 | assert d >= 0 and h >= 0 and w >= 0 36 | if d != 0 or h != 0 or w != 0: 37 | return input[:, :, d : D - d, h : H - h, w : W - w].contiguous() 38 | else: 39 | return input 40 | 41 | def initialize_weights(self): 42 | pass 43 | -------------------------------------------------------------------------------- /src/pmnet/network/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import BaseConv3d 2 | -------------------------------------------------------------------------------- /src/pmnet/network/nn/layers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class BaseConv3d(nn.Module): 5 | def __init__( 6 | self, 7 | in_channels: int, 8 | out_channels: int, 9 | kernel_size: int, 10 | stride: int = 1, 11 | padding: int | None = None, 12 | dilation: int = 1, 13 | groups: int = 1, 14 | norm_layer: type[nn.Module] | None = nn.BatchNorm3d, 15 | act_layer: type[nn.Module] | None = nn.ReLU, 16 | ): 17 | super().__init__() 18 | if padding is None: 19 | padding = (kernel_size - 1) // 2 20 | bias = norm_layer is None 21 | self._conv = nn.Conv3d( 22 | in_channels, 23 | out_channels, 24 | kernel_size, 25 | stride, 26 | padding, 27 | dilation, 28 | groups, 29 | bias, 30 | ) 31 | self._norm = norm_layer(out_channels) if norm_layer is not None else nn.Identity() 32 | self._act = act_layer() if act_layer is not None else nn.Identity() 33 | 34 | def initialize_weights(self): 35 | if isinstance(self._act, nn.LeakyReLU): 36 | a = self._act.negative_slope 37 | nn.init.kaiming_normal_(self._conv.weight, a, mode="fan_out", nonlinearity="leaky_relu") 38 | else: 39 | nn.init.kaiming_normal_(self._conv.weight, mode="fan_out", nonlinearity="relu") 40 | if self._conv.bias is not None: 41 | nn.init.constant_(self._conv.bias, 0.0) 42 | if not isinstance(self._norm, nn.Identity): 43 | nn.init.constant_(self._norm.weight, 1.0) 44 | 45 | def forward(self, x): 46 | return self._act(self._norm(self._conv(x))) 47 | -------------------------------------------------------------------------------- /src/pmnet/network/token_head.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | 6 | 7 | class TokenHead(nn.Module): 8 | def __init__( 9 | self, 10 | feature_dim: int = 96, 11 | num_interactions: int = 10, 12 | token_feature_dim: int = 192, 13 | num_feature_mlp_layers: int = 3, 14 | num_score_mlp_layers: int = 3, 15 | ): 16 | super().__init__() 17 | self.interaction_embedding = nn.Embedding(num_interactions, feature_dim) 18 | self.token_feature_dim = token_feature_dim 19 | 20 | feature_mlp = [] 21 | dim = 2 * feature_dim 22 | for _ in range(num_feature_mlp_layers): 23 | feature_mlp.append(nn.Linear(dim, token_feature_dim)) 24 | feature_mlp.append(nn.SiLU(inplace=True)) 25 | dim = token_feature_dim 26 | self.feature_mlp = nn.Sequential(*feature_mlp) 27 | if 2 * feature_dim != token_feature_dim: 28 | self.skip = nn.Linear(2 * feature_dim, token_feature_dim) 29 | else: 30 | self.skip = nn.Identity() 31 | 32 | score_mlp = [] 33 | for _ in range(num_score_mlp_layers - 1): 34 | score_mlp.append(nn.Linear(token_feature_dim, token_feature_dim)) 35 | score_mlp.append(nn.ReLU(inplace=True)) 36 | score_mlp.append(nn.Linear(token_feature_dim, 1)) 37 | self.score_mlp = nn.Sequential(*score_mlp) 38 | 39 | def initialize_weights(self): 40 | def _init_weight(m): 41 | if isinstance(m, nn.Linear): 42 | nn.init.normal_(m.weight, std=0.01) 43 | if m.bias is not None: 44 | nn.init.constant_(m.bias, 0) 45 | 46 | nn.init.uniform_(self.interaction_embedding.weight, -1.0, 1.0) 47 | for m in [self.feature_mlp, self.score_mlp, self.skip]: 48 | m.apply(_init_weight) 49 | 50 | def forward(self, features: Tensor, tokens_list: Sequence[Tensor]) -> tuple[list[Tensor], list[Tensor]]: 51 | """Token Scoring Function 52 | 53 | Args: 54 | features: FloatTensor [N, F, D, H, W] 55 | tokens_list: List[IntTensor [Ntoken, 4] - (x, y, z, i)] 56 | 57 | Returns: 58 | token_scores_list: List[FloatTensor [Ntoken,]] 59 | token_features_list: List[FloatTensor [Ntoken, F]] 60 | """ 61 | num_images = len(tokens_list) 62 | token_features_list = [ 63 | self.extract_token_features(features[idx], tokens_list[idx]) for idx in range(num_images) 64 | ] 65 | token_scores_list = [self.score_mlp(token_features).squeeze(-1) for token_features in token_features_list] 66 | return token_scores_list, token_features_list 67 | 68 | def extract_token_features(self, features: Tensor, tokens: Tensor) -> Tensor: 69 | """Extract token features 70 | 71 | Args: 72 | features: FloatTensor [D, H, W, F] 73 | tokens: IntTensor [Ntoken, 4] - (x, y, z, i) 74 | 75 | Returns: 76 | token_features: FloatTensor [Ntoken, Fh] 77 | """ 78 | if tokens.size(0) == 0: 79 | return torch.empty([0, self.token_feature_dim], dtype=torch.float, device=features.device) 80 | else: 81 | features = features.permute(1, 2, 3, 0).contiguous() # [D, H, W, F] 82 | x_list, y_list, z_list, i_list = torch.split(tokens, 1, dim=1) # (x_list, y_list, z_list, i_list) 83 | token_features = features[x_list, y_list, z_list].squeeze(1) # [Ntoken, F] 84 | embeddings = self.interaction_embedding(i_list).squeeze(1) # [Ntoken, F] 85 | token_features = torch.cat([token_features, embeddings], dim=1) # [Ntoken, 2F] 86 | return self.skip(token_features) + self.feature_mlp(token_features) # [Ntoken, Fh] 87 | -------------------------------------------------------------------------------- /src/pmnet/scoring/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeonghwanSeo/PharmacoNet/6595694cdf910b52c9fa0512c35f8139f5be2cf5/src/pmnet/scoring/__init__.py -------------------------------------------------------------------------------- /src/pmnet/scoring/ligand_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.typing import NDArray 3 | from openbabel import pybel 4 | from openbabel.pybel import ob 5 | 6 | 7 | class PharmacophoreNode: 8 | def __init__( 9 | self, 10 | atom_indices: int | tuple[int, ...], 11 | center_indices: None | int | tuple[int, ...] = None, 12 | ): 13 | if center_indices is None: 14 | center_indices = atom_indices 15 | self.atom_indices: int | tuple[int, ...] = atom_indices 16 | self.center_indices: int | tuple[int, ...] = center_indices 17 | 18 | def get_center(self, atom_positions: NDArray) -> NDArray: 19 | if isinstance(self.center_indices, int): 20 | return atom_positions[self.center_indices] 21 | else: 22 | return np.mean(atom_positions[self.center_indices, :], axis=0) 23 | 24 | 25 | def get_pharmacophore_nodes( 26 | pbmol: pybel.Molecule, 27 | ) -> dict[str, list[PharmacophoreNode]]: 28 | obmol = pbmol.OBMol 29 | obatoms: list[ob.OBAtom] = list(ob.OBMolAtomIter(obmol)) 30 | pbmol_hyd = pbmol.clone 31 | pbmol_hyd.OBMol.AddPolarHydrogens() 32 | obmol_hyd = pbmol_hyd.OBMol 33 | num_heavy_atoms = len(obatoms) 34 | obatoms_hyd: list[ob.OBAtom] = list(ob.OBMolAtomIter(obmol_hyd))[:num_heavy_atoms] 35 | 36 | hydrophobics = [ 37 | PharmacophoreNode(idx) 38 | for idx, obatom in enumerate(obatoms) 39 | if obatom.GetAtomicNum() == 6 and all(neigh.GetAtomicNum() in (1, 6) for neigh in ob.OBAtomAtomIter(obatom)) 40 | ] 41 | hbond_acceptors = [ 42 | PharmacophoreNode(idx) 43 | for idx, obatom in enumerate(obatoms) 44 | if obatom.GetAtomicNum() not in [9, 17, 35, 53] and obatom.IsHbondAcceptor() 45 | ] 46 | hbond_donors = [PharmacophoreNode(idx) for idx, obatom in enumerate(obatoms_hyd) if obatom.IsHbondDonor()] 47 | rings = [ 48 | PharmacophoreNode(tuple(sorted(idx - 1 for idx in ring._path))) # start from 1 -> minus 49 | for ring in pbmol.sssr 50 | if ring.IsAromatic() 51 | ] 52 | rings.sort(key=lambda ring: ring.atom_indices) 53 | 54 | pos_charged = [ 55 | PharmacophoreNode(idx) 56 | for idx, obatom in enumerate(obatoms) 57 | if is_quartamine_N(obatom) or is_tertamine_N(obatom) or is_sulfonium_S(obatom) 58 | ] 59 | neg_charged = [] 60 | 61 | for idx, obatom in enumerate(obatoms): 62 | if is_guanidine_C(obatom): 63 | nitrogens = tuple(neigh.GetIdx() - 1 for neigh in ob.OBAtomAtomIter(obatom) if neigh.GetAtomicNum() == 7) 64 | pos_charged.append(PharmacophoreNode((idx,) + nitrogens, idx)) 65 | 66 | elif is_phosphate_P(obatom) or is_sulfate_S(obatom): 67 | neighbors = tuple(neigh.GetIdx() - 1 for neigh in ob.OBAtomAtomIter(obatom)) 68 | neg_charged.append(PharmacophoreNode((idx,) + neighbors, idx)) 69 | 70 | elif is_sulfonicacid_S(obatom): 71 | oxygens = tuple(neigh.GetIdx() - 1 for neigh in ob.OBAtomAtomIter(obatom) if neigh.GetAtomicNum() == 8) 72 | neg_charged.append(PharmacophoreNode((idx,) + oxygens, idx)) 73 | 74 | elif is_carboxylate_C(obatom): 75 | oxygens = tuple(neigh.GetIdx() - 1 for neigh in ob.OBAtomAtomIter(obatom) if neigh.GetAtomicNum() == 8) 76 | neg_charged.append(PharmacophoreNode((idx,) + oxygens, oxygens)) 77 | 78 | xbond_donors = [PharmacophoreNode(idx) for idx, obatom in enumerate(obatoms) if is_halocarbon_X(obatom)] 79 | 80 | return { 81 | "Hydrophobic": hydrophobics, 82 | "Aromatic": rings, 83 | "Cation": pos_charged, 84 | "Anion": neg_charged, 85 | "HBond_donor": hbond_donors, 86 | "HBond_acceptor": hbond_acceptors, 87 | "Halogen": xbond_donors, 88 | } 89 | 90 | 91 | """ FUNCTIONAL GROUP """ 92 | 93 | 94 | def is_quartamine_N(obatom: ob.OBAtom): 95 | # It's a nitrogen, so could be a protonated amine or quaternary ammonium 96 | if obatom.GetAtomicNum() != 7: # Nitrogen 97 | return False 98 | if obatom.GetExplicitDegree() != 4: 99 | return False 100 | for neigh in ob.OBAtomAtomIter(obatom): 101 | if neigh.GetAtomicNum() == 1: # It's a quat. ammonium (N with 4 residues != H) 102 | return False 103 | return True 104 | 105 | 106 | def is_tertamine_N(obatom: ob.OBAtom): # Nitrogen 107 | return obatom.GetAtomicNum() == 7 and obatom.GetHyb() == 3 and obatom.GetHvyDegree() == 3 108 | 109 | 110 | def is_sulfonium_S(obatom: ob.OBAtom): 111 | if obatom.GetAtomicNum() != 16: # Sulfur 112 | return False 113 | if obatom.GetExplicitDegree() != 3: 114 | return False 115 | for neigh in ob.OBAtomAtomIter(obatom): 116 | if neigh.GetAtomicNum() == 1: # It's a sulfonium (S with 3 residues != H) 117 | return False 118 | return True 119 | 120 | 121 | def is_guanidine_C(obatom: ob.OBAtom): 122 | if obatom.GetAtomicNum() != 6: # It's a carbon atom 123 | return False 124 | numNs = 0 125 | numN_with_only_C = 0 126 | for neigh in ob.OBAtomAtomIter(obatom): 127 | if neigh.GetAtomicNum() == 7: 128 | numNs += 1 129 | if neigh.GetHvyDegree() == 1: 130 | numN_with_only_C += 1 131 | else: 132 | return False 133 | return numNs == 3 and numN_with_only_C > 0 134 | 135 | 136 | def is_sulfonicacid_S(obatom: ob.OBAtom): 137 | if obatom.GetAtomicNum() != 16: # Sulfur 138 | return False 139 | numOs = 0 140 | for neigh in ob.OBAtomAtomIter(obatom): 141 | if neigh.GetAtomicNum() == 8: 142 | numOs += 1 143 | return numOs == 3 144 | 145 | 146 | def is_sulfate_S(obatom: ob.OBAtom): 147 | if obatom.GetAtomicNum() != 16: # Sulfur 148 | return False 149 | numOs = 0 150 | for neigh in ob.OBAtomAtomIter(obatom): 151 | if neigh.GetAtomicNum() == 8: 152 | numOs += 1 153 | return numOs == 4 154 | 155 | 156 | def is_phosphate_P(obatom: ob.OBAtom): 157 | if obatom.GetAtomicNum() != 15: # Phosphor 158 | return False 159 | for neigh in ob.OBAtomAtomIter(obatom): 160 | if neigh.GetAtomicNum() != 8: # It's a phosphate, only O 161 | return False 162 | return True 163 | 164 | 165 | def is_carboxylate_C(obatom: ob.OBAtom): 166 | if obatom.GetAtomicNum() != 6: # It's a carbon atom 167 | return False 168 | numOs = numCs = 0 169 | for neigh in ob.OBAtomAtomIter(obatom): 170 | neigh_z = neigh.GetAtomicNum() 171 | if neigh_z == 8: 172 | numOs += 1 173 | elif neigh_z == 6: 174 | numCs += 1 175 | return numOs == 2 and numCs == 1 176 | 177 | 178 | def is_halocarbon_X(obatom: ob.OBAtom) -> bool: 179 | if obatom.GetAtomicNum() not in [9, 17, 35, 53]: # Halogen atoms 180 | return False 181 | for neigh in ob.OBAtomAtomIter(obatom): 182 | if neigh.GetAtomicNum() == 6: 183 | return True 184 | return False 185 | -------------------------------------------------------------------------------- /src/pmnet/scoring/match_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | 5 | DISTANCE_SIGMA_THRESHOLD = 2.0 6 | PASS_THRESHOLD = 0.5 7 | 8 | 9 | def scoring_matching_pair( 10 | cluster_node_match_list1, 11 | cluster_node_match_list2, 12 | num_conformers: int, 13 | ) -> tuple[float, ...]: 14 | """ 15 | cluster_node_match_list1: list[tuple[LigandNode, list[ModelNode], NDArray[np.float32]]], 16 | cluster_node_match_list2: list[tuple[LigandNode, list[ModelNode], NDArray[np.float32]]], 17 | num_conformers: int, 18 | """ 19 | match_scores = np.zeros((num_conformers,), dtype=np.float32) 20 | num_fails = np.zeros((num_conformers,), dtype=np.int16) 21 | 22 | match_threshold = len(cluster_node_match_list1) * len(cluster_node_match_list2) * (1 - PASS_THRESHOLD) 23 | 24 | num_pass = np.empty((num_conformers,), dtype=np.int16) 25 | likelihood = np.empty((num_conformers,), dtype=np.float32) 26 | for cluster_node_match1, cluster_node_match2 in itertools.product( 27 | cluster_node_match_list1, cluster_node_match_list2 28 | ): 29 | ligand_node1, model_node_list1, weights1 = cluster_node_match1 30 | ligand_node2, model_node_list2, weights2 = cluster_node_match2 31 | ligand_edge = ligand_node1.neighbor_edge_dict[ligand_node2] 32 | distances = ligand_edge.distances 33 | 34 | num_match = len(model_node_list1) * len(model_node_list2) 35 | means = np.array( 36 | [ 37 | [model_node1.neighbor_edge_dict[model_node2].distance_mean] 38 | for model_node1, model_node2 in itertools.product(model_node_list1, model_node_list2) 39 | ], 40 | dtype=np.float32, 41 | ) # [M*N, 1] 42 | stds = np.array( 43 | [ 44 | [model_node1.neighbor_edge_dict[model_node2].distance_std] 45 | for model_node1, model_node2 in itertools.product(model_node_list1, model_node_list2) 46 | ], 47 | dtype=np.float32, 48 | ) # [M*N, 1] 49 | weights = (weights1.reshape(-1, 1) * weights2.reshape(1, -1)).reshape(-1) # [M * N] 50 | 51 | weights_sum = sum(weights) 52 | normalize_coeff = 1 / weights_sum # / (math.sqrt(2 * math.pi) (skip) 53 | score_coeff = weights_sum / num_match 54 | 55 | distance_sigma_array = (distances.reshape(1, num_conformers) - means) / stds 56 | np.sum( 57 | np.abs(distance_sigma_array) < DISTANCE_SIGMA_THRESHOLD, 58 | axis=0, 59 | out=num_pass, 60 | ) 61 | num_fails += num_pass < (num_match * PASS_THRESHOLD) 62 | if min(num_fails) > match_threshold: 63 | return (-1,) * num_conformers 64 | np.dot( 65 | weights / stds.reshape(-1), 66 | np.exp(-0.5 * distance_sigma_array**2), 67 | out=likelihood, 68 | ) 69 | match_scores += likelihood * normalize_coeff * score_coeff 70 | 71 | return tuple( 72 | float(score) if num_fail <= match_threshold else -1 73 | for score, num_fail in zip(match_scores, num_fails, strict=False) 74 | ) 75 | 76 | 77 | def scoring_matching_self( 78 | cluster_node_match_list, 79 | num_conformers: int, 80 | ) -> tuple[float, ...]: 81 | """ 82 | cluster_node_match_list: list[tuple[LigandNode, list[ModelNode], NDArray[np.float32]]] 83 | num_conformers: str 84 | """ 85 | match_scores = np.zeros((num_conformers,), dtype=np.float32) 86 | likelihood = np.empty((num_conformers,), dtype=np.float32) 87 | for cluster_node_match1, cluster_node_match2 in itertools.combinations(cluster_node_match_list, 2): 88 | ligand_node1, model_node_list1, weights1 = cluster_node_match1 89 | ligand_node2, model_node_list2, weights2 = cluster_node_match2 90 | ligand_edge = ligand_node1.neighbor_edge_dict[ligand_node2] 91 | distances = ligand_edge.distances 92 | 93 | num_match = len(model_node_list1) * len(model_node_list2) 94 | means = np.array( 95 | [ 96 | [model_node1.neighbor_edge_dict[model_node2].distance_mean] 97 | for model_node1, model_node2 in itertools.product(model_node_list1, model_node_list2) 98 | ], 99 | dtype=np.float32, 100 | ) # [M*N, 1] 101 | stds = np.array( 102 | [ 103 | [model_node1.neighbor_edge_dict[model_node2].distance_std] 104 | for model_node1, model_node2 in itertools.product(model_node_list1, model_node_list2) 105 | ], 106 | dtype=np.float32, 107 | ) # [M*N, 1] 108 | weights = (weights1.reshape(-1, 1) * weights2.reshape(1, -1)).reshape(-1) # [M*N] 109 | weights_sum = sum(weights) 110 | normalize_coeff = 1 / weights_sum # / (math.sqrt(2 * math.pi) (skip) 111 | score_coeff = weights_sum / num_match 112 | 113 | distance_sigma_array = (distances.reshape(1, num_conformers) - means) / stds 114 | np.dot( 115 | weights / stds.reshape(-1), 116 | np.exp(-0.5 * distance_sigma_array**2), 117 | out=likelihood, 118 | ) 119 | 120 | match_scores += likelihood * normalize_coeff * score_coeff 121 | 122 | return tuple(match_scores.tolist()) 123 | -------------------------------------------------------------------------------- /src/pmnet/scoring/match_utils_numba.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import math 3 | 4 | import numba as nb 5 | import numpy as np 6 | from numpy.typing import NDArray 7 | 8 | DISTANCE_SIGMA_THRESHOLD = 2.0 9 | PASS_THRESHOLD = 0.5 10 | 11 | 12 | @nb.njit( 13 | "void(float32[::1],float32[:, :, ::1],float32[::1],float32[::1],float32[::1],int16[::1])", 14 | fastmath=True, 15 | cache=True, 16 | ) 17 | def __numba_run( 18 | distances: NDArray[np.float32], 19 | mean_stds: NDArray[np.float32], 20 | weights1: NDArray[np.float32], 21 | weights2: NDArray[np.float32], 22 | score_array: NDArray[np.float32], 23 | fail_array: NDArray[np.int16], 24 | ): 25 | """Scoring Function 26 | 27 | Args: 28 | distances: [C,] 29 | mean_stds: [M, N, 2] 30 | weights1: [M,] 31 | weights2: [N,] 32 | score_array: [C,] 33 | fail_array: [C,] 34 | """ 35 | assert mean_stds.shape[0] == weights1.shape[0] 36 | assert mean_stds.shape[1] == weights2.shape[0] 37 | assert distances.shape[0] == score_array.shape[0] == fail_array.shape[0] 38 | 39 | num_match: int 40 | pass_threshold: float 41 | 42 | W1: float 43 | W2: float 44 | normalize_coeff: float 45 | score_coeff: float 46 | 47 | num_pass: int 48 | mu: float 49 | std: float 50 | sigma_sq: float 51 | likelihood: float 52 | _likelihood: float 53 | 54 | M = weights1.shape[0] 55 | N = weights2.shape[0] 56 | C = distances.shape[0] 57 | 58 | num_match = M * N 59 | pass_threshold = (num_match + 1) // 2 # PASS_THRESHOLD 60 | 61 | # NOTE: Coefficient Calculation 62 | W1 = sum(weights1) 63 | W2 = sum(weights2) 64 | normalize_coeff = 1 / (W1 * W2) 65 | score_coeff = (W1 * W2) / num_match 66 | 67 | for c in range(C): 68 | d = distances[c] 69 | num_pass = 0 70 | likelihood = 0.0 71 | for m in range(M): 72 | w1 = weights1[m] 73 | _likelihood = 0.0 74 | for n in range(N): 75 | w2 = weights2[n] 76 | mu = mean_stds[m, n, 0] 77 | std = mean_stds[m, n, 1] 78 | sigma_sq = ((d - mu) / std) ** 2 79 | _likelihood += w2 / std * math.exp(-0.5 * sigma_sq) 80 | if sigma_sq < 4.0: 81 | num_pass += 1 82 | likelihood += w1 * _likelihood 83 | 84 | score_array[c] += likelihood * normalize_coeff * score_coeff 85 | if num_pass < pass_threshold: 86 | fail_array[c] += 1 87 | 88 | 89 | @nb.njit( 90 | "void(float32[::1],float32[:, :, ::1],float32[::1],float32[::1],float32[::1])", 91 | fastmath=True, 92 | cache=True, 93 | ) 94 | def __numba_run_self( 95 | distances: NDArray[np.float32], 96 | mean_stds: NDArray[np.float32], 97 | weights1: NDArray[np.float32], 98 | weights2: NDArray[np.float32], 99 | score_array: NDArray[np.float32], 100 | ): 101 | """Scoring Function 102 | 103 | Args: 104 | distances: [C,] 105 | mean_stds: [M, N, 2] 106 | weights1: [M,] 107 | weights2: [N,] 108 | score_array: [C,] 109 | """ 110 | assert mean_stds.shape[0] == weights1.shape[0] 111 | assert mean_stds.shape[1] == weights2.shape[0] 112 | 113 | num_match: int 114 | 115 | W1: float 116 | W1: float 117 | normalize_coeff: float 118 | score_coeff: float 119 | 120 | mu: float 121 | std: float 122 | sigma_sq: float 123 | likelihood: float 124 | _likelihood: float 125 | 126 | M = weights1.shape[0] 127 | N = weights2.shape[0] 128 | C = distances.shape[0] 129 | 130 | num_match = M * N 131 | 132 | # NOTE: Coefficient Calculation 133 | W1 = sum(weights1) 134 | W2 = sum(weights2) 135 | normalize_coeff = 1 / (W1 * W2) 136 | score_coeff = (W1 * W2) / num_match 137 | 138 | for c in range(C): 139 | d = distances[c] 140 | likelihood = 0.0 141 | for m in range(M): 142 | w1 = weights1[m] 143 | _likelihood = 0.0 144 | for n in range(N): 145 | w2 = weights2[n] 146 | mu = mean_stds[m, n, 0] 147 | std = mean_stds[m, n, 1] 148 | sigma_sq = ((d - mu) / std) ** 2 149 | _likelihood += w2 / std * math.exp(-0.5 * sigma_sq) 150 | likelihood += w1 * _likelihood 151 | score_array[c] += likelihood * normalize_coeff * score_coeff 152 | 153 | 154 | def __get_distance_mean_std(model_node1, model_node2) -> tuple[float, float]: 155 | """ 156 | model_node1: ModelNode 157 | model_node2: ModelNode 158 | """ 159 | edge = model_node1.neighbor_edge_dict[model_node2] 160 | return edge.distance_mean, edge.distance_std 161 | 162 | 163 | def scoring_matching_pair( 164 | cluster_node_match_list1, 165 | cluster_node_match_list2, 166 | num_conformers: int, 167 | ) -> tuple[float, ...]: 168 | """ 169 | cluster_node_match_list1: List[tuple[LigandNode, List[ModelNode], NDArray[np.float32]]], 170 | cluster_node_match_list2: List[tuple[LigandNode, List[ModelNode], NDArray[np.float32]]], 171 | num_conformers: int, 172 | """ 173 | 174 | match_threshold = len(cluster_node_match_list1) * len(cluster_node_match_list2) * (1 - PASS_THRESHOLD) 175 | 176 | match_scores = np.zeros((num_conformers,), dtype=np.float32) 177 | num_fails = np.zeros((num_conformers,), dtype=np.int16) 178 | for ligand_node1, model_node_list1, weights1 in cluster_node_match_list1: 179 | for ligand_node2, model_node_list2, weights2 in cluster_node_match_list2: 180 | ligand_edge = ligand_node1.neighbor_edge_dict[ligand_node2] 181 | distances = ligand_edge.distances 182 | 183 | mean_stds = np.array( 184 | [ 185 | [__get_distance_mean_std(model_node1, model_node2) for model_node2 in model_node_list2] 186 | for model_node1 in model_node_list1 187 | ], 188 | dtype=np.float32, 189 | ) # [M, N, 2] 190 | __numba_run(distances, mean_stds, weights1, weights2, match_scores, num_fails) 191 | if min(num_fails) > match_threshold: 192 | return (-1,) * num_conformers 193 | 194 | return tuple( 195 | float(score) if num_fail <= match_threshold else -1 196 | for score, num_fail in zip(match_scores, num_fails, strict=True) 197 | ) 198 | 199 | 200 | def scoring_matching_self( 201 | cluster_node_match_list, 202 | num_conformers: int, 203 | ) -> tuple[float, ...]: 204 | """ 205 | cluster_node_match_list: List[tuple[LigandNode, List[ModelNode], NDArray[np.float32]]], 206 | num_conformers: int, 207 | """ 208 | match_scores = np.zeros((num_conformers,), dtype=np.float32) 209 | for match1, match2 in itertools.combinations(cluster_node_match_list, 2): 210 | ligand_node1, model_node_list1, weights1 = match1 211 | ligand_node2, model_node_list2, weights2 = match2 212 | 213 | ligand_edge = ligand_node1.neighbor_edge_dict[ligand_node2] 214 | distances = ligand_edge.distances 215 | 216 | mean_stds = np.array( 217 | [ 218 | [__get_distance_mean_std(model_node1, model_node2) for model_node2 in model_node_list2] 219 | for model_node1 in model_node_list1 220 | ], 221 | dtype=np.float32, 222 | ) # [M, N, 2] 223 | __numba_run_self( 224 | distances, 225 | mean_stds, 226 | weights1, 227 | weights2, 228 | match_scores, 229 | ) 230 | 231 | return tuple(match_scores.tolist()) 232 | -------------------------------------------------------------------------------- /src/pmnet/scoring/tree.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections.abc import Iterator 4 | from typing import TYPE_CHECKING 5 | 6 | if TYPE_CHECKING: 7 | from pmnet.pharmacophore_model import ModelNodeCluster 8 | 9 | from .ligand import LigandNodeCluster 10 | 11 | LigandClusterPair = tuple[LigandNodeCluster, LigandNodeCluster] 12 | ModelClusterPair = tuple[ModelNodeCluster, ModelNodeCluster] 13 | 14 | 15 | class ClusterMatchTree: 16 | def __init__( 17 | self, 18 | model_cluster: ModelNodeCluster | None, 19 | pair_scores: dict[int, float] | None, 20 | parent: ClusterMatchTree, 21 | ): 22 | self.level: int = parent.level + 1 23 | self.num_matches: int = parent.num_matches + (model_cluster is not None) 24 | self.parent: ClusterMatchTree = parent 25 | self.root: ClusterMatchTreeRoot = parent.root 26 | self.children: list[ClusterMatchTree] = [] 27 | 28 | ligand_cluster: LigandNodeCluster = self.root.ligand_cluster_list[self.level] 29 | self.ligand_cluster: LigandNodeCluster = ligand_cluster 30 | self.model_cluster: ModelNodeCluster | None = model_cluster 31 | 32 | self.pair_scores: dict[int, float] 33 | if model_cluster is not None: 34 | assert pair_scores is not None 35 | self_pair_scores = self.root.matching_pair_scores_dict[ligand_cluster, ligand_cluster][ 36 | model_cluster, model_cluster 37 | ] 38 | self.pair_scores = { 39 | conformer_id: parent.pair_scores[conformer_id] + self_pair_scores[conformer_id] + score 40 | for conformer_id, score in pair_scores.items() 41 | } 42 | else: 43 | self.pair_scores = parent.pair_scores 44 | 45 | @property 46 | def max_score(self) -> float: 47 | if self.num_matches == 0: 48 | return 0.0 49 | return max(self.pair_scores.values()) 50 | 51 | @property 52 | def conformer_ids(self): 53 | return self.pair_scores.keys() 54 | 55 | def dfs_run( 56 | self, 57 | match_dict: dict[LigandNodeCluster, dict[ModelNodeCluster, dict[int, float]]], 58 | ) -> int: 59 | """recursive function 60 | 61 | Args: 62 | level: level of new node 63 | ligand_cluster: ligand cluster according to the level 64 | model_cluster_dict: candidate model cluster 65 | ModelCluster: {conformer_id: accumulate_score} 66 | """ 67 | upd_match_dict: dict[LigandNodeCluster, dict[ModelNodeCluster, dict[int, float]]] = {} 68 | if self.model_cluster is not None: 69 | for ligand_cluster, model_cluster_dict in match_dict.items(): 70 | upd_model_cluster_dict = {} 71 | matching_pair_scores_dict = self.root.matching_pair_scores_dict[self.ligand_cluster, ligand_cluster] 72 | for ( 73 | model_cluster, 74 | conformer_pair_score_dict, 75 | ) in model_cluster_dict.items(): 76 | pair_score_list: tuple[float, ...] = matching_pair_scores_dict[self.model_cluster, model_cluster] 77 | # NOTE: Update Model Cluster list accoring to Validity of Pair (Use only Valid Conformer) 78 | upd_conformer_pair_score_dict: dict[int, float] = { 79 | conformer_id: total_score + pair_score_list[conformer_id] 80 | for conformer_id, total_score in conformer_pair_score_dict.items() 81 | if conformer_id in self.conformer_ids and pair_score_list[conformer_id] > 0 82 | } 83 | if len(upd_conformer_pair_score_dict) > 0: 84 | upd_model_cluster_dict[model_cluster] = upd_conformer_pair_score_dict 85 | upd_match_dict[ligand_cluster] = upd_model_cluster_dict 86 | else: 87 | upd_match_dict = match_dict.copy() 88 | 89 | # NOTE: Add Child 90 | if self.level < len(self.root.ligand_cluster_list) - 1: 91 | child_ligand_cluster = self.root.ligand_cluster_list[self.level + 1] 92 | model_cluster_dict = upd_match_dict.pop(child_ligand_cluster) 93 | max_num_matches = 0 94 | for model_cluster, conformer_pair_score_dict in model_cluster_dict.items(): 95 | child = self.add_child(model_cluster, conformer_pair_score_dict) 96 | num_matches = child.dfs_run(upd_match_dict) 97 | max_num_matches = max(num_matches, max_num_matches) 98 | if len(model_cluster_dict) == 0 or (self.num_matches + max_num_matches) < 5: 99 | child = self.add_child(None, None) 100 | num_matches = child.dfs_run(upd_match_dict) 101 | max_num_matches = max(num_matches, max_num_matches) 102 | return max_num_matches + int(self.model_cluster is not None) 103 | else: 104 | return int(self.model_cluster is not None) 105 | 106 | def add_child( 107 | self, 108 | model_cluster: ModelNodeCluster | None, 109 | pair_score_dict: dict[int, float] | None, 110 | ): 111 | child = ClusterMatchTree(model_cluster, pair_score_dict, self) 112 | self.children.append(child) 113 | return child 114 | 115 | def delete(self): 116 | assert self.level >= 0 117 | self.parent.children.remove(self) 118 | del self 119 | 120 | @property 121 | def size(self) -> int: 122 | if len(self.children) == 0: 123 | return 1 124 | size = 0 125 | for node in self.children: 126 | size += node.size 127 | return size 128 | 129 | @property 130 | def key(self) -> list[ModelNodeCluster | None]: 131 | key = [] 132 | node: ClusterMatchTree = self 133 | while node is not self.root: 134 | key.append(node.model_cluster) 135 | node = node.parent 136 | key.reverse() 137 | return key 138 | 139 | @property 140 | def item(self) -> dict[LigandNodeCluster, ModelNodeCluster | None]: 141 | node: ClusterMatchTree = self 142 | graph_match: dict[LigandNodeCluster, ModelNodeCluster | None] = {} 143 | while node is not self.root: 144 | graph_match[node.ligand_cluster] = node.model_cluster 145 | node = node.parent 146 | return graph_match 147 | 148 | def iteration(self, level: int | None = None) -> Iterator[ClusterMatchTree]: 149 | if level is not None: 150 | yield from self.iteration_level(level) 151 | else: 152 | yield from self.iteration_leaf() 153 | 154 | def iteration_level(self, level: int) -> Iterator[ClusterMatchTree]: 155 | assert self.level <= level 156 | if self.level < level: 157 | for node in self.children: 158 | yield from node.iteration_level(level) 159 | elif self.level == level: 160 | yield self 161 | 162 | def iteration_leaf(self) -> Iterator[ClusterMatchTree]: 163 | if len(self.children) > 0: 164 | for node in self.children: 165 | yield from node.iteration_leaf() 166 | else: 167 | yield self 168 | 169 | def __tree_repr__(self): 170 | repr = " " * (self.level + 1) + f"- {self.model_cluster}" 171 | if len(self.children) > 0: 172 | repr += " " * (self.level + 2) 173 | repr += f"(level {self.level + 1}) {self.children[0].ligand_cluster}\n" 174 | for child in self.children: 175 | repr += child.__tree_repr__() 176 | repr += "\n" 177 | return repr 178 | 179 | def __repr__(self): 180 | repr = "" 181 | tree = self 182 | while tree is not self.root: 183 | repr = f"({tree.ligand_cluster}, {tree.model_cluster})\n" + repr 184 | tree = tree.parent 185 | return repr 186 | 187 | 188 | class ClusterMatchTreeRoot(ClusterMatchTree): 189 | def __init__( 190 | self, 191 | ligand_cluster_list: list[LigandNodeCluster], 192 | cluster_match_dict: dict[LigandNodeCluster, list[ModelNodeCluster]], 193 | matching_pair_score_dict: dict[LigandClusterPair, dict[ModelClusterPair, tuple[float, ...]]], 194 | num_conformers: int, 195 | ): 196 | self.root = self 197 | self.level: int = -1 198 | self.num_matches: int = 0 199 | self.num_conformers: int = num_conformers 200 | self.children: list[ClusterMatchTree] = [] 201 | self.ligand_cluster_list: list[LigandNodeCluster] = ligand_cluster_list 202 | self.cluster_match_dict: dict[LigandNodeCluster, list[ModelNodeCluster]] = cluster_match_dict 203 | self.matching_pair_scores_dict: dict[ 204 | tuple[LigandNodeCluster, LigandNodeCluster], 205 | dict[tuple[ModelNodeCluster, ModelNodeCluster], tuple[float, ...]], 206 | ] = matching_pair_score_dict 207 | 208 | self.model_cluster = None 209 | self.pair_scores: dict[int, float] = {conformer_id: 0.0 for conformer_id in range(num_conformers)} 210 | 211 | def __repr__(self): 212 | repr = "Root\n" 213 | if len(self.children) > 0: 214 | repr += f"(level {self.level + 1}) {self.children[0].ligand_cluster}\n" 215 | for child in self.children: 216 | repr += child.__tree_repr__() 217 | return repr 218 | 219 | def run(self): 220 | match_dict = { 221 | ligand_cluster: { 222 | model_cluster: {conformer_id: 0.0 for conformer_id in range(self.num_conformers)} 223 | for model_cluster in self.cluster_match_dict[ligand_cluster] 224 | } 225 | for ligand_cluster in self.ligand_cluster_list 226 | } 227 | self.dfs_run(match_dict) 228 | -------------------------------------------------------------------------------- /src/pmnet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeonghwanSeo/PharmacoNet/6595694cdf910b52c9fa0512c35f8139f5be2cf5/src/pmnet/utils/__init__.py -------------------------------------------------------------------------------- /src/pmnet/utils/download_weight.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | 5 | def download_pretrained_model(weight_path, verbose): 6 | if not os.path.exists(weight_path): 7 | weight_path = Path(weight_path) 8 | weight_path.parent.mkdir(exist_ok=True) 9 | try: 10 | import gdown 11 | except ImportError: 12 | import subprocess 13 | import sys 14 | 15 | subprocess.check_call([sys.executable, "-m", "pip", "install", "gdown"]) 16 | import gdown 17 | if verbose: 18 | print(f"Download pre-trained model... (path: {weight_path})") 19 | gdown.download( 20 | "https://drive.google.com/uc?id=1gzjdM7bD3jPm23LBcDXtkSk18nETL04p", 21 | str(weight_path), 22 | quiet=False, 23 | ) 24 | if verbose: 25 | print("Download pre-trained model finish") 26 | -------------------------------------------------------------------------------- /src/pmnet/utils/smoothing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def clip(value, lower, upper): 7 | return lower if value < lower else upper if value > upper else value 8 | 9 | 10 | def to_3tuple(value): 11 | if not isinstance(value, tuple): 12 | return (value, value, value) 13 | else: 14 | return value 15 | 16 | 17 | class GaussianSmoothing(nn.Module): 18 | def __init__( 19 | self, 20 | kernel_size: int | tuple[int, int, int], 21 | sigma: float | tuple[float, float, float], 22 | ): 23 | super().__init__() 24 | kernel_size = to_3tuple(kernel_size) 25 | sigma = to_3tuple(sigma) 26 | 27 | # The gaussian kernel is the product of the 28 | # gaussian function of each dimension. 29 | meshgrids = torch.meshgrid( 30 | [torch.arange(size, dtype=torch.float) for size in kernel_size], 31 | indexing="ij", 32 | ) 33 | kernel: torch.Tensor = None 34 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids, strict=False): 35 | mean = (size - 1) / 2 36 | # _kernel = 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-((mgrid - mean) / (2 * std)) ** 2) 37 | _kernel = torch.exp(-(((mgrid - mean) / (std)) ** 2) / 2) # omit constant part 38 | if kernel is None: 39 | kernel = _kernel 40 | else: 41 | kernel *= _kernel 42 | 43 | # Make sure sum of values in gaussian kernel equals 1. 44 | kernel /= torch.sum(kernel) # (Kd, Kh, Kw), Kd, Kh, Kw: kernel_size 45 | 46 | # Reshape to depthwise convolutional weight 47 | kernel = kernel.view(1, 1, *kernel.size()) # (1, 1, Kd, Kh, Kw) 48 | kernel = kernel.repeat(1, 1, 1, 1, 1) # (1, 1, Kd, Kh, Kw) 49 | 50 | self.register_buffer("weight", kernel) 51 | self.pad: tuple[int, int, int, int, int, int] = ( 52 | kernel_size[0] // 2, 53 | kernel_size[0] // 2, 54 | kernel_size[1] // 2, 55 | kernel_size[1] // 2, 56 | kernel_size[2] // 2, 57 | kernel_size[2] // 2, 58 | ) 59 | 60 | @torch.no_grad() 61 | def forward(self, x): 62 | """ 63 | Apply gaussian filter to input. 64 | Arguments: 65 | input (torch.Tensor): Input to apply gaussian filter on. 66 | Returns: 67 | filtered (torch.Tensor): Filtered output. 68 | """ 69 | x = F.pad(x, self.pad, mode="constant", value=0.0) 70 | weight = self.weight.repeat(x.shape[-4], 1, 1, 1, 1) 71 | return F.conv3d(x, weight=weight, groups=x.shape[-4]) 72 | -------------------------------------------------------------------------------- /src/pmnet_appl/README.md: -------------------------------------------------------------------------------- 1 | # Pre-trained Docking Proxy Models 2 | 3 | Easy-to-use docking score prediction models. 4 | 5 | Implementation List: 6 | 7 | - TacoGFN: Target-conditioned GFlowNet for Structure-based Drug Design [[paper](https://arxiv.org/abs/2310.03223)] 8 | 9 | If you use this implementation, please cite PharmacoNet with related papers: 10 | 11 | ## Install 12 | 13 | To use the pre-trained proxy model, you need to install torch geometric and associated libraries. 14 | You can simply install them with following scheme at the root directory: 15 | 16 | ```bash 17 | # Case 1. Install both PharmacoNet and torch-geometric 18 | pip install -e '.[appl]' --find-links https://data.pyg.org/whl/torch-2.3.1+cu121.html 19 | # Case 2. Install only PharmacoNet (already torch-geometric is installed) 20 | pip install -e . 21 | # Case 3. In your project (already torch-geometric is installed) 22 | pip install pharmaconet @ git+https://github.com/SeonghwanSeo/PharmacoNet.git 23 | ``` 24 | 25 | ## Load Pretrained Model 26 | 27 | ```python 28 | from pmnet_appl import get_docking_proxy 29 | from pmnet_appl.tacogfn_reward import TacoGFN_Proxy 30 | 31 | device: str | torch.device = "cuda" | "cpu" 32 | 33 | # Cache for CrossDocked2020 Targets: 15,201 training pockets + 100 test pockets 34 | cache_db = "train" | "test" | "all" | None 35 | 36 | # TacoGFN Reward Function 37 | train_dataset = "ZINCDock15M" | "CrossDocked2020" 38 | proxy: TacoGFN_Proxy = get_docking_proxy("TacoGFN_Reward", "QVina", train_dataset, cache_db, device) 39 | proxy = TacoGFN_Proxy.load("QVina", train_dataset, cache_db, device) 40 | 41 | # if cache_db is 'test' | 'all' 42 | print(proxy.scoring("14gs_A", "c1ccccc1")) 43 | print(proxy.scoring_list("14gs_A", ["c1ccccc1", "C1CCCCC1"])) 44 | ``` 45 | 46 | ## Use custom target cache 47 | 48 | ```python 49 | proxy = get_docking_proxy("TacoGFN_Reward", "QVina", "ZINCDock15M", None, device) 50 | save_cache_path = "" 51 | protein_info_dict = { 52 | "": ("", ""), # use center of reference ligand 53 | "": ("", (1.0, 2.0, 3.0)), # use center coordinates 54 | } 55 | proxy.get_cache_database(protein_info_dict, save_cache_path, verbose=False) 56 | 57 | # Load Custom Target Cache 58 | proxy = get_docking_proxy("TacoGFN_Reward", "QVina", "ZINCDock15M", save_cache_path, device) 59 | proxy.scoring("key1", "c1ccccc1") 60 | proxy.scoring_list("key2", ["c1ccccc1", "C1CCCCC1"]) 61 | ``` 62 | -------------------------------------------------------------------------------- /src/pmnet_appl/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright: if you use this script, please cite: 3 | ``` 4 | @article{seo2023pharmaconet, 5 | title = {PharmacoNet: Accelerating Large-Scale Virtual Screening by Deep Pharmacophore Modeling}, 6 | author = {Seo, Seonghwan and Kim, Woo Youn}, 7 | journal = {arXiv preprint arXiv:2310.00681}, 8 | year = {2023}, 9 | url = {https://arxiv.org/abs/2310.00681}, 10 | } 11 | ``` 12 | """ 13 | 14 | from __future__ import annotations 15 | 16 | from pathlib import Path 17 | 18 | import torch 19 | 20 | from pmnet_appl.base import BaseProxy 21 | 22 | ALLOWED_MODEL_LIST = ["TacoGFN_Reward", "SBDDReward"] 23 | ALLOWED_DOCKING_LIST = ["QVina", "UniDock_Vina"] 24 | 25 | 26 | def get_docking_proxy( 27 | model: str, 28 | docking: str, 29 | train_dataset: str, 30 | db: str | Path | None, 31 | device: str | torch.device, 32 | ) -> BaseProxy: 33 | """Get Docking Proxy Model 34 | 35 | Parameters 36 | ---------- 37 | model : str 38 | Model name (Currently: ['TacoGFN_Reward', 'SBDDReward']) 39 | docking : str 40 | Docking program name 41 | train_dataset : str 42 | Dataset for model training 43 | db : Path | str | None 44 | cache database path ('train' | 'test' | 'all' | custom cache database path) 45 | - 'train': CrossDocked2020 training pockets (15,201) 46 | - 'test': CrossDocked2020 test pockets (100) 47 | - 'all': train + test 48 | device : str | torch.device 49 | cuda | spu 50 | 51 | Returns 52 | ------- 53 | Proxy Model: BaseProxy 54 | """ 55 | 56 | assert model in ("TacoGFN_Reward", "SBDDReward"), f"model({model}) is not allowed" 57 | if model == "TacoGFN_Reward": 58 | from pmnet_appl.tacogfn_reward import TacoGFN_Proxy 59 | 60 | assert docking in ["QVina"] 61 | assert train_dataset in ["ZINCDock15M", "CrossDocked2020"] 62 | return TacoGFN_Proxy.load(docking, train_dataset, db, device) 63 | elif model == "SBDDReward": 64 | from pmnet_appl.sbddreward import SBDDReward_Proxy 65 | 66 | assert docking in ["UniDock_Vina"] 67 | assert train_dataset in ["ZINC"] 68 | return SBDDReward_Proxy.load(docking, train_dataset, db, device) 69 | else: 70 | raise ValueError(docking) 71 | -------------------------------------------------------------------------------- /src/pmnet_appl/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .proxy import BaseProxy 2 | -------------------------------------------------------------------------------- /src/pmnet_appl/base/proxy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright: if you use this script, please cite: 3 | ``` 4 | @article{seo2023pharmaconet, 5 | title = {PharmacoNet: Accelerating Large-Scale Virtual Screening by Deep Pharmacophore Modeling}, 6 | author = {Seo, Seonghwan and Kim, Woo Youn}, 7 | journal = {arXiv preprint arXiv:2310.00681}, 8 | year = {2023}, 9 | url = {https://arxiv.org/abs/2310.00681}, 10 | } 11 | ``` 12 | """ 13 | 14 | from __future__ import annotations 15 | 16 | from pathlib import Path 17 | from typing import Any 18 | 19 | import gdown 20 | import torch 21 | import torch.nn as nn 22 | import tqdm 23 | from numpy.typing import NDArray 24 | from torch import Tensor 25 | 26 | from pmnet.api import PharmacoNet, get_pmnet_dev 27 | from pmnet.api.typing import HotspotInfo, MultiScaleFeature 28 | 29 | Cache = Any 30 | 31 | 32 | class BaseProxy(nn.Module): 33 | root_dir: Path = Path(__file__).parent 34 | cache_gdrive_link: dict[tuple[str, str], str] = {} 35 | model_gdrive_link: dict[str, str] = {} 36 | 37 | def __init__( 38 | self, 39 | ckpt_path: str | Path | None = None, 40 | device: str | torch.device = "cuda", 41 | compile_pmnet: bool = False, 42 | ): 43 | super().__init__() 44 | self.pmnet = None # NOTE: Lazy 45 | self.ckpt_path: str | Path | None = ckpt_path 46 | self._cache = {} 47 | self._setup_model() 48 | self.eval() 49 | self.to(device) 50 | if self.ckpt_path is not None: 51 | self._load_checkpoint(self.ckpt_path) 52 | self.compile_pmnet: bool = compile_pmnet 53 | 54 | # NOTE: Implement Here! 55 | def _setup_model(self): 56 | pass 57 | 58 | def _load_checkpoint(self, ckpt_path: str | Path): 59 | self.load_state_dict(torch.load(ckpt_path, self.device)) 60 | 61 | @torch.no_grad() 62 | def _scoring_list(self, cache: Cache, smiles_list: list[str]) -> Tensor: 63 | raise NotImplementedError 64 | 65 | def _get_cache(self, pmnet_attr: tuple[MultiScaleFeature, list[HotspotInfo]]) -> Cache: 66 | raise NotImplementedError 67 | 68 | @classmethod 69 | def _download_model(cls, suffix: str): 70 | weight_dir = cls.root_dir / "weights" 71 | weight_dir.mkdir(parents=True, exist_ok=True) 72 | model_path = weight_dir / f"model-{suffix}.pth" 73 | if not model_path.exists(): 74 | id = cls.model_gdrive_link[suffix] 75 | gdown.download(f"https://drive.google.com/uc?id={id}", str(model_path)) 76 | 77 | @classmethod 78 | def _download_cache(cls, suffix: str, label: str): 79 | weight_dir = cls.root_dir / "weights" 80 | cache_path = weight_dir / f"cache-{label}-{suffix}.pt" 81 | if not cache_path.exists(): 82 | id = cls.cache_gdrive_link[(suffix, label)] 83 | gdown.download(f"https://drive.google.com/uc?id={id}", str(cache_path)) 84 | 85 | # NOTE: Python Method 86 | @classmethod 87 | def load( 88 | cls, 89 | docking: str, 90 | train_dataset: str, 91 | db: Path | str | None, 92 | device: str | torch.device = "cpu", 93 | ): 94 | """Load Pretrained Proxy Model 95 | 96 | Parameters 97 | ---------- 98 | docking : str 99 | docking program name 100 | train_dataset : str 101 | training dataset name 102 | db : Path | str | None 103 | cache database path ('train' | 'test' | 'all' | custom cache database path) 104 | - 'train': CrossDocked2020 training pockets (15,201) 105 | - 'test': CrossDocked2020 test pockets (100) 106 | - 'all': train + test 107 | device : str | torch.device 108 | cuda | spu 109 | """ 110 | weight_dir = cls.root_dir / "weights" 111 | suffix = f"{docking}-{train_dataset}" 112 | ckpt_path = weight_dir / f"model-{suffix}.pth" 113 | cls._download_model(suffix) 114 | 115 | train_cache_path = weight_dir / f"cache-train-{suffix}.pt" 116 | test_cache_path = weight_dir / f"cache-test-{suffix}.pt" 117 | if db is None: 118 | cache_dict = {} 119 | elif db == "all": 120 | cls._download_cache(suffix, "train") 121 | cls._download_cache(suffix, "test") 122 | cache_dict = torch.load(train_cache_path, "cpu") | torch.load(test_cache_path, "cpu") 123 | elif db == "train": 124 | cls._download_cache(suffix, "train") 125 | cache_dict = torch.load(train_cache_path, "cpu") 126 | elif db == "test": 127 | cls._download_cache(suffix, "test") 128 | cache_dict = torch.load(test_cache_path, "cpu") 129 | else: 130 | cache_dict = torch.load(db, "cpu") 131 | 132 | model = cls(ckpt_path, device) 133 | model.update_cache(cache_dict) 134 | return model 135 | 136 | def scoring(self, target: str, smiles: str) -> Tensor: 137 | """Scoring single molecule with its SMILES 138 | 139 | Parameters 140 | ---------- 141 | target : str 142 | target key 143 | smiles : str 144 | molecule smiles 145 | 146 | Returns 147 | ------- 148 | Tensor [1,] 149 | Esimated Docking Score (or Simga) 150 | 151 | """ 152 | return self._scoring_list(self._cache[target], [smiles]) 153 | 154 | def scoring_list(self, target: str, smiles_list: list[str]) -> Tensor: 155 | """Scoring multiple molecules with their SMILES 156 | 157 | Parameters 158 | ---------- 159 | target : str 160 | target key 161 | smiles_list : list[str] 162 | molecule smiles list 163 | 164 | Returns 165 | ------- 166 | Tensor [N,] 167 | Esimated Docking Scores (or Simga) 168 | 169 | """ 170 | return self._scoring_list(self._cache[target], smiles_list) 171 | 172 | def put_cache(self, key: str, cache: Cache): 173 | """Add Cache 174 | 175 | Parameters 176 | ---------- 177 | key : str 178 | Pocket Key 179 | cache : Cache 180 | Pocket Feature Cache 181 | """ 182 | self._cache[key] = cache 183 | 184 | def update_cache(self, cache_dict: dict[str, Cache]): 185 | """Add Multiple Cache 186 | 187 | Parameters 188 | ---------- 189 | cache_dict : dict[str, Cache] 190 | Pocket Key - Cache Dictionary 191 | """ 192 | self._cache.update(cache_dict) 193 | 194 | def get_cache_database( 195 | self, 196 | pocket_info: dict[str, tuple[str | Path, str | Path | tuple[float, float, float] | NDArray]], 197 | save_path: str | Path | None = None, 198 | verbose: bool = True, 199 | ) -> dict[str, Cache]: 200 | """Get Cache Database 201 | 202 | Parameters 203 | ---------- 204 | pocket_info : dict[str, tuple[str | Path, str | Path | tuple[float, float, float] | NDArray]] 205 | Key: Pocket Identification Key 206 | Item: (Protein Path, Pocket Center Information(ref_ligand_path or center coordinates)) 207 | - Protein Path: str | Path 208 | - Pocket Center Information: str | Path | tuple[float, float, float] | NDArray] 209 | - if str | Path: ref_ligand_path 210 | - if tuple[float, float, float] | NDArray: center coordinates 211 | 212 | save_path: str | Path | None (default = None) 213 | if save_path is not None, save database at the input path. 214 | 215 | verbose: bool (default=True) 216 | if True, use tqdm. 217 | 218 | Returns 219 | ------- 220 | cache_dict: dict[str, Cache] 221 | Cache Database 222 | """ 223 | cache_dict: dict[str, Cache] = {} 224 | for key, (protein_pdb_path, pocket_center) in tqdm.tqdm(pocket_info.items(), disable=not (verbose)): 225 | try: 226 | if isinstance(pocket_center, str | Path): 227 | cache = self.get_cache(protein_pdb_path, ref_ligand_path=pocket_center) 228 | else: 229 | cache = self.get_cache(protein_pdb_path, center=pocket_center) 230 | except Exception as e: 231 | print(key, e) 232 | else: 233 | cache_dict[key] = cache 234 | if save_path is not None: 235 | torch.save(cache_dict, save_path) 236 | return cache_dict 237 | 238 | @torch.no_grad() 239 | def get_cache( 240 | self, 241 | protein_pdb_path: str | Path, 242 | ref_ligand_path: str | Path | None = None, 243 | center: tuple[float, float, float] | NDArray | None = None, 244 | ) -> Cache: 245 | """Calculate Cache 246 | 247 | Parameters 248 | ---------- 249 | protein_pdb_path : str | Path 250 | Protein PDB Path 251 | ref_ligand_path : str | Path | None 252 | Reference Ligand Path (None if center is not None) 253 | center : tuple[float, float, float] | NDArray | None 254 | Pocket Center Coordinates (None if ref_ligand_path is not None) 255 | 256 | Returns 257 | ------- 258 | cache: Cache 259 | Pocket Information Cache (device: 'cpu') 260 | 261 | """ 262 | self.setup_pmnet() 263 | assert self.pmnet is not None 264 | pmnet_attr = self.pmnet.feature_extraction(protein_pdb_path, ref_ligand_path, center) 265 | cache = self._get_cache(pmnet_attr) 266 | cache = [v.cpu() if isinstance(v, Tensor) else v for v in cache] 267 | return cache 268 | 269 | def setup_pmnet(self): 270 | # NOTE: Lazy Load 271 | if self.pmnet is None: 272 | self.pmnet: PharmacoNet | None = get_pmnet_dev(self.device, compile=self.compile_pmnet) 273 | if self.pmnet.device != self.device: 274 | self.pmnet.to(self.device) 275 | 276 | @property 277 | def device(self) -> torch.device: 278 | return next(self.parameters()).device 279 | -------------------------------------------------------------------------------- /src/pmnet_appl/keys/test.txt: -------------------------------------------------------------------------------- 1 | 2z3h_A 2 | 4aaw_A 3 | 4yhj_A 4 | 14gs_A 5 | 2v3r_A 6 | 4rn0_B 7 | 1fmc_B 8 | 3daf_A 9 | 1a2g_A 10 | 5w2g_A 11 | 3dzh_A 12 | 3g51_A 13 | 1coy_A 14 | 2jjg_A 15 | 2rhy_A 16 | 2pqw_A 17 | 4g3d_B 18 | 5bur_A 19 | 3gs6_A 20 | 1r1h_A 21 | 1dxo_C 22 | 1gg5_A 23 | 5q0k_A 24 | 5b08_A 25 | 2azy_A 26 | 5i0b_A 27 | 1phk_A 28 | 4keu_A 29 | 4q8b_B 30 | 1djy_A 31 | 5l1v_A 32 | 4zfa_A 33 | 2rma_A 34 | 3b6h_A 35 | 2zen_A 36 | 4p6p_A 37 | 3u5y_B 38 | 4f1m_A 39 | 4tqr_A 40 | 4lfu_A 41 | 3jyh_A 42 | 4iwq_A 43 | 1l3l_A 44 | 5ngz_A 45 | 1e8h_A 46 | 2e24_A 47 | 2hcj_B 48 | 3kc1_A 49 | 1d7j_A 50 | 4ja8_B 51 | 4u5s_A 52 | 4iiy_A 53 | 3v4t_A 54 | 3tym_A 55 | 4d7o_A 56 | 3ej8_A 57 | 1rs9_A 58 | 4kcq_A 59 | 3pdh_A 60 | 1umd_B 61 | 4pxz_A 62 | 2gns_A 63 | 1ai4_A 64 | 5mma_A 65 | 2cy0_A 66 | 3w83_B 67 | 2e6d_A 68 | 4rv4_A 69 | 5d7n_D 70 | 5mgl_A 71 | 1h36_A 72 | 4gvd_A 73 | 4tos_A 74 | 5aeh_A 75 | 4h3c_A 76 | 4rlu_A 77 | 4xli_B 78 | 3l3n_A 79 | 5tjn_A 80 | 5liu_X 81 | 3o96_A 82 | 4qlk_A 83 | 3hy9_B 84 | 4bel_A 85 | 3nfb_A 86 | 4m7t_A 87 | 3u9f_C 88 | 4aua_A 89 | 2f2c_B 90 | 3chc_B 91 | 1k9t_A 92 | 1h0i_A 93 | 4z2g_A 94 | 3af2_A 95 | 1jn2_P 96 | 3li4_A 97 | 3pnm_A 98 | 1afs_A 99 | 4azf_A 100 | 2pc8_A 101 | -------------------------------------------------------------------------------- /src/pmnet_appl/sbddreward/__init__.py: -------------------------------------------------------------------------------- 1 | from .proxy import SBDDReward_Proxy 2 | -------------------------------------------------------------------------------- /src/pmnet_appl/sbddreward/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from openbabel import pybel 3 | from openbabel.pybel import ob 4 | from torch_geometric.data import Data as Data 5 | 6 | ob_log_handler = pybel.ob.OBMessageHandler() 7 | ob_log_handler.SetOutputLevel(0) # 0: None 8 | 9 | 10 | ATOM_DICT = { 11 | 6: 0, # C 12 | 7: 1, # N 13 | 8: 2, # O 14 | 9: 3, # F 15 | 15: 4, # P 16 | 16: 5, # S 17 | 17: 6, # Cl 18 | 35: 7, # Br 19 | 53: 8, # I 20 | -1: 9, # UNKNOWN 21 | } 22 | _NUM_ATOM_TYPES = 10 23 | _NUM_ATOM_CHIRAL = 2 24 | _NUM_ATOM_CHARGE = 2 25 | NUM_ATOM_FEATURES = _NUM_ATOM_TYPES + _NUM_ATOM_CHIRAL + _NUM_ATOM_CHARGE 26 | 27 | BOND_DICT = { 28 | 1: 0, 29 | 2: 1, 30 | 3: 2, 31 | 1.5: 3, # AROMATIC 32 | -1: 4, # UNKNOWN 33 | } 34 | _NUM_BOND_TYPES = 5 35 | NUM_BOND_FEATURES = _NUM_BOND_TYPES 36 | 37 | 38 | def get_atom_features(pbmol: pybel.Molecule) -> list[list[float]]: 39 | facade = pybel.ob.OBStereoFacade(pbmol.OBMol) 40 | features = [] 41 | for atom in pbmol.atoms: 42 | feat = [0] * NUM_ATOM_FEATURES 43 | feat[ATOM_DICT.get(atom.atomicnum, 9)] = 1 44 | 45 | offset = _NUM_ATOM_TYPES 46 | mid = atom.OBAtom.GetId() 47 | if facade.HasTetrahedralStereo(mid): 48 | stereo = facade.GetTetrahedralStereo(mid).GetConfig().winding 49 | if stereo == pybel.ob.OBStereo.Clockwise: 50 | feat[offset + 0] = 1 51 | else: 52 | feat[offset + 1] = 1 53 | 54 | offset += _NUM_ATOM_CHIRAL 55 | charge = atom.formalcharge 56 | if charge > 0: 57 | feat[offset + 0] = 1 58 | elif charge < 0: 59 | feat[offset + 1] = 1 60 | features.append(feat) 61 | return features 62 | 63 | 64 | def get_bond_features( 65 | pbmol: pybel.Molecule, 66 | ) -> tuple[list[list[float]], tuple[list[int], list[int]]]: 67 | edge_index_row = [] 68 | edge_index_col = [] 69 | edge_attr = [] 70 | obmol: ob.OBMol = pbmol.OBMol 71 | for obbond in ob.OBMolBondIter(obmol): 72 | obbond: ob.OBBond 73 | edge_index_row.append(obbond.GetBeginAtomIdx() - 1) 74 | edge_index_col.append(obbond.GetEndAtomIdx() - 1) 75 | 76 | feat = [0] * NUM_BOND_FEATURES 77 | if obbond.IsAromatic(): 78 | feat[3] = 1 79 | else: 80 | feat[BOND_DICT.get(obbond.GetBondOrder(), 4)] = 1 81 | edge_attr.append(feat) 82 | edge_index = (edge_index_row, edge_index_col) 83 | return edge_attr, edge_index 84 | 85 | 86 | def smi2graph(smiles: str) -> Data: 87 | pbmol = pybel.readstring("smi", smiles) 88 | atom_features = get_atom_features(pbmol) 89 | edge_attr, edge_index = get_bond_features(pbmol) 90 | return Data( 91 | x=torch.FloatTensor(atom_features), 92 | edge_index=torch.LongTensor(edge_index), 93 | edge_attr=torch.FloatTensor(edge_attr), 94 | ) 95 | -------------------------------------------------------------------------------- /src/pmnet_appl/sbddreward/get_cache.py: -------------------------------------------------------------------------------- 1 | from pmnet_appl.sbddreward import SBDDReward_Proxy 2 | 3 | docking = "UniDock_Vina" 4 | train_dataset = "ZINC" 5 | proxy = SBDDReward_Proxy.load(docking, train_dataset, None, "cuda") 6 | 7 | save_database_path = "./tmp_db.pt" 8 | protein_info_dict = { 9 | "key1": ("./tmp1.pdb", "./ref_ligand1.sdf"), # reference ligand path 10 | "key2": ("./tmp2.pdb", (1.0, 2.0, 3.0)), # pocket center 11 | } 12 | 13 | cache_dict = proxy.get_cache_database(protein_info_dict, save_database_path, verbose=False) 14 | proxy.update_cache(cache_dict) 15 | proxy.scoring(list(cache_dict.keys())[0], "c1ccccc1") 16 | proxy.scoring_list(list(cache_dict.keys())[0], ["c1ccccc1", "C1CCCCC1"]) 17 | -------------------------------------------------------------------------------- /src/pmnet_appl/sbddreward/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .head import AffinityHead 2 | from .ligand_encoder import GraphEncoder 3 | from .pharmacophore_encoder import PharmacophoreEncoder 4 | -------------------------------------------------------------------------------- /src/pmnet_appl/sbddreward/network/block.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .layers.pair_transition import PairTransition 7 | from .layers.triangular_attention import TriangleAttention 8 | from .layers.triangular_multiplicative_update import DirectTriangleMultiplicativeUpdate 9 | 10 | 11 | class ComplexFormerBlock(nn.Module): 12 | def __init__( 13 | self, 14 | c_hidden: int, 15 | c_head: int, 16 | n_heads: int, 17 | n_transition: int, 18 | dropout: float, 19 | ): 20 | super().__init__() 21 | self.tri_mul_update = DirectTriangleMultiplicativeUpdate(c_hidden, c_hidden) 22 | self.tri_attention = TriangleAttention(c_hidden, c_head, n_heads) 23 | self.transition = PairTransition(c_hidden, n_transition) 24 | self.dropout = nn.Dropout2d(p=dropout) 25 | 26 | def forward( 27 | self, 28 | z_complex: torch.Tensor, 29 | zpair_protein: torch.Tensor, 30 | mask_complex: torch.Tensor, 31 | ) -> torch.Tensor: 32 | z_complex = z_complex + self.dropout(self.tri_mul_update.forward(z_complex, zpair_protein, mask_complex)) 33 | z_complex = z_complex + self.dropout(self.tri_attention.forward(z_complex, mask_complex)) 34 | z_complex = self.transition.forward(z_complex, mask_complex) 35 | return z_complex 36 | -------------------------------------------------------------------------------- /src/pmnet_appl/sbddreward/network/head.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | from torch_geometric.utils import to_dense_batch 6 | 7 | from .block import ComplexFormerBlock 8 | from .layers.one_hot import OneHotEncoding 9 | 10 | 11 | class AffinityHead(nn.Module): 12 | def __init__(self, hidden_dim: int, n_blocks: int, p_dropout: float = 0.1): 13 | super().__init__() 14 | # Complex Embedding 15 | self.interaction_mlp: nn.Module = nn.Sequential( 16 | nn.Linear(hidden_dim, hidden_dim), 17 | nn.LeakyReLU(), 18 | ) 19 | self.one_hot = OneHotEncoding(0, 30, 16) 20 | self.protein_pair_embedding = nn.Linear(16, hidden_dim) 21 | self.blocks = nn.ModuleList( 22 | [ComplexFormerBlock(hidden_dim, hidden_dim // 4, 4, 4, 0.1) for _ in range(n_blocks)] 23 | ) 24 | 25 | self.mlp_mu: nn.Module = nn.Sequential( 26 | nn.Linear(hidden_dim, hidden_dim), 27 | nn.LeakyReLU(), 28 | nn.Linear(hidden_dim, 1), 29 | nn.Sigmoid(), 30 | ) 31 | self.mlp_std: nn.Module = nn.Sequential( 32 | nn.Linear(hidden_dim, hidden_dim), 33 | nn.LeakyReLU(), 34 | nn.Linear(hidden_dim, 1), 35 | nn.Sigmoid(), 36 | ) 37 | 38 | self.mlp_sigma_bias: nn.Module = nn.Sequential( 39 | nn.Linear(hidden_dim * 2, hidden_dim), 40 | nn.LeakyReLU(), 41 | nn.Linear(hidden_dim, 1), 42 | ) 43 | self.mlp_sigma: nn.Module = nn.Linear(hidden_dim, 1) 44 | self.gate_sigma: nn.Module = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) 45 | 46 | self.linear_distance = nn.Linear(hidden_dim, 1) 47 | self.dropout = nn.Dropout(p_dropout) 48 | 49 | def scoring( 50 | self, 51 | X_protein: Tensor, 52 | pos_protein: Tensor, 53 | Z_protein: Tensor, 54 | X_ligand: Tensor, 55 | Z_ligand: Tensor, 56 | ligand_batch: Tensor, 57 | return_sigma: bool = False, 58 | ) -> Tensor: 59 | sigma = self.cal_sigma(X_protein, pos_protein, Z_protein, X_ligand, Z_ligand, ligand_batch) 60 | if return_sigma: 61 | return sigma 62 | mu, std = self.cal_mu(Z_protein), self.cal_std(Z_protein) 63 | return sigma * std + mu 64 | 65 | def cal_mu(self, Z_protein) -> torch.Tensor: 66 | return self.mlp_mu(self.dropout(Z_protein)).view(1) * -15 67 | 68 | def cal_std(self, Z_protein) -> torch.Tensor: 69 | return self.mlp_std(self.dropout(Z_protein)).view(1) * 5 70 | 71 | def cal_sigma(self, X_protein, pos_protein, Z_protein, X_ligand, Z_ligand, ligand_batch) -> torch.Tensor: 72 | Z_complex, mask_complex = self._embedding(X_protein, pos_protein, X_ligand, ligand_batch, Z_ligand.shape[0]) 73 | Z_protein, Z_ligand, Z_complex = ( 74 | self.dropout(Z_protein), 75 | self.dropout(Z_ligand), 76 | self.dropout(Z_complex), 77 | ) 78 | z_sigma = self.mlp_sigma(Z_complex) * self.gate_sigma(Z_complex) 79 | sigma = (z_sigma.squeeze(-1) * mask_complex).sum((1, 2)) 80 | bias = self.mlp_sigma_bias(torch.cat([Z_protein.view(1, -1).repeat(Z_ligand.size(0), 1), Z_ligand], dim=-1)) 81 | return sigma.view(-1) + bias.view(-1) 82 | 83 | def _embedding(self, X_protein, pos_protein, X_ligand, ligand_batch, num_ligands) -> tuple[Tensor, Tensor]: 84 | Z_complex = torch.einsum("ik,jk->ijk", X_ligand, X_protein) # [Vlig, Vprot, Fh] 85 | Z_complex = self.interaction_mlp(self.dropout(Z_complex)) 86 | Z_complex, mask_complex = to_dense_batch(Z_complex, ligand_batch, batch_size=num_ligands) 87 | 88 | mask_complex = mask_complex.unsqueeze(-1) # [N, Vlig, 1] 89 | if X_protein.shape[0] > 0: 90 | pdist_protein = torch.cdist(pos_protein, pos_protein, compute_mode="donot_use_mm_for_euclid_dist") 91 | pdist_protein = self.one_hot(pdist_protein).unsqueeze(0) 92 | Zpair_protein = self.protein_pair_embedding(pdist_protein.to(device=X_ligand.device, dtype=torch.float)) 93 | Z_complex_init = Z_complex 94 | for block in self.blocks: 95 | Z_complex = block(Z_complex, Zpair_protein, mask_complex) 96 | Z_complex = Z_complex_init + Z_complex 97 | return Z_complex, mask_complex 98 | -------------------------------------------------------------------------------- /src/pmnet_appl/sbddreward/network/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeonghwanSeo/PharmacoNet/6595694cdf910b52c9fa0512c35f8139f5be2cf5/src/pmnet_appl/sbddreward/network/layers/__init__.py -------------------------------------------------------------------------------- /src/pmnet_appl/sbddreward/network/layers/one_hot.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class OneHotEncoding(nn.Module): 8 | def __init__( 9 | self, 10 | bin_min: float = 0.0, 11 | bin_max: float = 15.0, 12 | num_classes: int = 16, 13 | rounding_mode: str = "floor", 14 | ): 15 | super().__init__() 16 | assert num_classes > 1 17 | self.bin_min: float = bin_min 18 | self.bin_size: int = int((bin_max - bin_min) / (num_classes - 1)) 19 | self.bin_max: float = bin_max + (self.bin_size / 2) # to prevent float error. 20 | self.num_classes: int = num_classes 21 | self.rounding_mode: str = rounding_mode 22 | 23 | def forward(self, x) -> torch.Tensor: 24 | x = x.clip(self.bin_min, self.bin_max) 25 | idx = torch.div(x - self.bin_min, self.bin_size, rounding_mode=self.rounding_mode).long() 26 | out = torch.nn.functional.one_hot(idx, num_classes=self.num_classes).float() 27 | return out 28 | -------------------------------------------------------------------------------- /src/pmnet_appl/sbddreward/network/layers/pair_transition.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from torch import nn 4 | 5 | 6 | class PairTransition(nn.Module): 7 | def __init__(self, c_hidden, expand: int = 4): 8 | super().__init__() 9 | self.layer_norm = nn.LayerNorm(c_hidden) 10 | self.linear_1 = nn.Linear(c_hidden, expand * c_hidden) 11 | self.relu = nn.ReLU() 12 | self.linear_2 = nn.Linear(expand * c_hidden, c_hidden) 13 | 14 | def forward(self, z, mask): 15 | z = self.layer_norm(z) 16 | z = self.linear_1(z) 17 | z = self.relu(z) 18 | z = self.linear_2(z) 19 | z = z * mask.unsqueeze(-1) 20 | return z 21 | -------------------------------------------------------------------------------- /src/pmnet_appl/sbddreward/network/layers/triangular_attention.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from collections.abc import Sequence 5 | 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class TriangleAttention(nn.Module): 11 | def __init__(self, c_in, c_hidden, num_heads, inf=1e9): 12 | """ 13 | Args: 14 | c_in: 15 | Input channel dimension 16 | c_hidden: 17 | Overall hidden channel dimension (not per-head) 18 | num_heads: 19 | Number of attention heads 20 | """ 21 | super().__init__() 22 | 23 | self.c_in = c_in 24 | self.c_hidden = c_hidden 25 | self.num_heads = num_heads 26 | self.inf = inf 27 | 28 | self.layer_norm = nn.LayerNorm(self.c_in) 29 | self.mha = Attention(self.c_in, self.c_in, self.c_in, self.c_hidden, self.num_heads) 30 | 31 | def forward( 32 | self, 33 | x: torch.Tensor, 34 | mask: torch.Tensor | None = None, 35 | ) -> torch.Tensor: 36 | if mask is None: 37 | mask = x.new_ones(x.shape[:-1]) 38 | else: 39 | mask = mask.float() 40 | 41 | # [*, I, J, C_in] 42 | x = self.layer_norm(x) 43 | 44 | biases = [] 45 | 46 | # [*, I, 1, 1, J] 47 | mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] 48 | biases.append(mask_bias) 49 | 50 | # [*, H, I, J] 51 | # triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) 52 | 53 | # [*, 1, H, I, J] 54 | # triangle_bias = triangle_bias.unsqueeze(-4) 55 | # biases.append(triangle_bias) 56 | 57 | x = self.mha(q_x=x, kv_x=x, biases=biases) 58 | 59 | return x 60 | 61 | 62 | class Attention(nn.Module): 63 | """ 64 | Standard multi-head attention using AlphaFold's default layer 65 | initialization. Allows multiple bias vectors. 66 | """ 67 | 68 | def __init__( 69 | self, 70 | c_q: int, 71 | c_k: int, 72 | c_v: int, 73 | c_hidden: int, 74 | num_heads: int, 75 | gating: bool = True, 76 | ): 77 | """ 78 | Args: 79 | c_q: Input dimension of query data 80 | c_k: Input dimension of key data 81 | c_v: Input dimension of value data 82 | c_hidden: Per-head hidden dimension 83 | num_heads: Number of attention heads 84 | gating: Whether the output should be gated using query data 85 | """ 86 | super().__init__() 87 | 88 | self.c_q = c_q 89 | self.c_k = c_k 90 | self.c_v = c_v 91 | self.c_hidden = c_hidden 92 | self.num_heads = num_heads 93 | 94 | self.linear_q = nn.Linear(self.c_q, self.c_hidden * self.num_heads, bias=False) 95 | self.linear_k = nn.Linear(self.c_k, self.c_hidden * self.num_heads, bias=False) 96 | self.linear_v = nn.Linear(self.c_v, self.c_hidden * self.num_heads, bias=False) 97 | 98 | self.linear_o = nn.Linear(self.c_hidden * self.num_heads, self.c_q) 99 | if gating: 100 | self.linear_g = nn.Linear(self.c_q, self.c_hidden * self.num_heads) 101 | else: 102 | self.linear_g = None 103 | 104 | self.sigmoid = nn.Sigmoid() 105 | 106 | def init_weight(self): 107 | for module in [self.linear_q, self.linear_k, self.linear_v]: 108 | nn.init.xavier_uniform_(module.weight, gain=1) 109 | with torch.no_grad(): 110 | for module in [self.linear_o, self.linear_g]: 111 | if module is not None: 112 | module.weight.fill_(0.0) 113 | 114 | def _prep_qkv( 115 | self, q_x: torch.Tensor, kv_x: torch.Tensor, apply_scale: bool = True 116 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 117 | # [*, Q/K/V, H * C_hidden] 118 | q = self.linear_q(q_x) 119 | k = self.linear_k(kv_x) 120 | v = self.linear_v(kv_x) 121 | 122 | # [*, Q/K, H, C_hidden] 123 | q = q.view(q.shape[:-1] + (self.num_heads, -1)) 124 | k = k.view(k.shape[:-1] + (self.num_heads, -1)) 125 | v = v.view(v.shape[:-1] + (self.num_heads, -1)) 126 | 127 | # [*, H, Q/K, C_hidden] 128 | q = q.transpose(-2, -3) 129 | k = k.transpose(-2, -3) 130 | v = v.transpose(-2, -3) 131 | 132 | if apply_scale: 133 | q /= math.sqrt(self.c_hidden) 134 | 135 | return q, k, v 136 | 137 | def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor: 138 | if self.linear_g is not None: 139 | g = self.sigmoid(self.linear_g(q_x)) 140 | 141 | # [*, Q, H, C_hidden] 142 | g = g.view(g.shape[:-1] + (self.num_heads, -1)) 143 | o = o * g 144 | 145 | # [*, Q, H * C_hidden] 146 | o = o.reshape(o.shape[:-2] + (-1,)) 147 | 148 | # [*, Q, C_q] 149 | o = self.linear_o(o) 150 | 151 | return o 152 | 153 | def forward( 154 | self, 155 | q_x: torch.Tensor, 156 | kv_x: torch.Tensor, 157 | biases: list[torch.Tensor] | None = None, 158 | ) -> torch.Tensor: 159 | q, k, v = self._prep_qkv(q_x, kv_x) 160 | if biases is None: 161 | biases = [] 162 | o = attention(q, k, v, biases) 163 | o = o.transpose(-2, -3) 164 | 165 | o = self._wrap_up(o, q_x) 166 | 167 | return o 168 | 169 | 170 | def attention( 171 | query: torch.Tensor, 172 | key: torch.Tensor, 173 | value: torch.Tensor, 174 | biases: list[torch.Tensor], 175 | ) -> torch.Tensor: 176 | # [*, H, C_hidden, K] 177 | key = permute_final_dims(key, [1, 0]) 178 | 179 | # [*, H, Q, K] 180 | a = torch.matmul(query, key) 181 | for b in biases: 182 | a = a + b 183 | a = a.softmax(-1) 184 | 185 | # [*, H, Q, C_hidden] 186 | a = torch.matmul(a, value) 187 | 188 | return a 189 | 190 | 191 | def permute_final_dims(tensor: torch.Tensor, inds: Sequence[int]) -> torch.Tensor: 192 | zero_index = -1 * len(inds) 193 | first_inds = list(range(len(tensor.shape[:zero_index]))) 194 | return tensor.permute(first_inds + [zero_index + i for i in inds]) 195 | -------------------------------------------------------------------------------- /src/pmnet_appl/sbddreward/network/layers/triangular_multiplicative_update.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class DirectTriangleMultiplicativeUpdate(nn.Module): 8 | def __init__(self, c_in: int, c_hidden: int): 9 | super().__init__() 10 | self.linear_b = nn.Linear(c_in, c_hidden) 11 | self.linear_b_g = nn.Sequential(nn.Linear(c_in, c_hidden), nn.Sigmoid()) 12 | 13 | self.layernorm_z = nn.LayerNorm(c_in) 14 | self.linear_z = nn.Linear(c_in, c_hidden) 15 | self.linear_z_g = nn.Sequential(nn.Linear(c_in, c_hidden), nn.Sigmoid()) 16 | 17 | self.linear_o = nn.Linear(c_hidden, c_in) 18 | self.linear_o_g = nn.Sequential(nn.Linear(c_hidden, c_in), nn.Sigmoid()) 19 | 20 | def forward(self, z, b, z_mask): 21 | """ 22 | z: [N, A, B, C] 23 | b: [N, B, B, C] 24 | z_mask: [N, A, B] 25 | a -> b 26 | """ 27 | b = self.linear_b(b) * self.linear_b_g(b) 28 | z = self.layernorm_z(z) 29 | _z = self.linear_z(z) * self.linear_z_g(z) * z_mask.unsqueeze(-1) 30 | 31 | message = torch.einsum("bikc,bjkc->bijc", _z, b) 32 | 33 | z = self.linear_o_g(z) * self.linear_o(message) * z_mask.unsqueeze(-1) 34 | return z 35 | -------------------------------------------------------------------------------- /src/pmnet_appl/sbddreward/network/ligand_encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | import torch_geometric.nn as pygnn 5 | from torch import Tensor, nn 6 | from torch_geometric.data import Batch, Data 7 | from torch_scatter import scatter_mean, scatter_sum 8 | 9 | 10 | class GraphEncoder(nn.Module): 11 | def __init__( 12 | self, 13 | input_node_dim: int, 14 | input_edge_dim: int, 15 | hidden_dim: int, 16 | out_dim: int, 17 | num_convs: int, 18 | ): 19 | super().__init__() 20 | self.graph_channels: int = out_dim 21 | self.atom_channels: int = out_dim 22 | 23 | # Ligand Encoding 24 | self.node_layer = nn.Linear(input_node_dim, hidden_dim) 25 | self.edge_layer = nn.Linear(input_edge_dim, hidden_dim) 26 | self.conv_list = nn.ModuleList( 27 | [ 28 | pygnn.GINEConv( 29 | nn=nn.Sequential(pygnn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU()), 30 | edge_dim=hidden_dim, 31 | ) 32 | for _ in range(num_convs) 33 | ] 34 | ) 35 | self.readout_layer = nn.Linear(hidden_dim * 2, out_dim) 36 | self.readout_gate = nn.Linear(hidden_dim * 2, out_dim) 37 | 38 | self.head = nn.Sequential(nn.Linear(hidden_dim, out_dim), nn.LayerNorm(out_dim)) 39 | 40 | def init_weight(self): 41 | def _init_weight(m): 42 | if isinstance(m, nn.Linear): 43 | nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu") 44 | if m.bias is not None: 45 | nn.init.constant_(m.bias, 0) 46 | elif isinstance(m, nn.LayerNorm): 47 | nn.init.constant_(m.weight, 1.0) 48 | 49 | self.apply(_init_weight) 50 | 51 | def forward( 52 | self, 53 | data: Data | Batch, 54 | ) -> tuple[Tensor, Tensor]: 55 | """Affinity Prediction 56 | 57 | Args: 58 | x: Node Feature 59 | edge_attr: Edge Feature 60 | edge_index: Edge Index 61 | 62 | Returns: 63 | updated_data: Union[Data, Batch] 64 | """ 65 | x: Tensor = self.node_layer(data.x) 66 | edge_attr: Tensor = self.edge_layer(data.edge_attr) 67 | 68 | skip_x = x 69 | edge_index = data.edge_index 70 | for layer in self.conv_list: 71 | x = layer(x, edge_index, edge_attr) 72 | 73 | x = skip_x + x 74 | X = self.head(x) 75 | 76 | if isinstance(data, Batch): 77 | Z1 = scatter_sum(x, data.batch, dim=0, dim_size=data.num_graphs) # V, Fh -> N, Fh 78 | Z2 = scatter_mean(x, data.batch, dim=0, dim_size=data.num_graphs) # V, Fh -> N, Fh 79 | else: 80 | Z1 = x.sum(0, keepdim=True) # V, Fh -> 1, Fh 81 | Z2 = x.mean(0, keepdim=True) # V, Fh -> 1, Fh 82 | Z = torch.cat([Z1, Z2], dim=-1) 83 | Z = self.readout_gate(Z) * self.readout_layer(Z) # [N, Fh] 84 | return X, Z 85 | -------------------------------------------------------------------------------- /src/pmnet_appl/sbddreward/network/pharmacophore_encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | from torch import Tensor, nn 5 | 6 | from pmnet.api.typing import HotspotInfo, MultiScaleFeature 7 | 8 | 9 | class PharmacophoreEncoder(nn.Module): 10 | def __init__(self, hidden_dim: int): 11 | super().__init__() 12 | self.multi_scale_dims = (96, 96, 96, 96, 96) 13 | self.hotspot_dim = 192 14 | self.hidden_dim = hidden_dim 15 | self.hotspot_mlp: nn.Module = nn.Sequential(nn.SiLU(), nn.Linear(self.hotspot_dim, hidden_dim)) 16 | self.pocket_mlp_list: nn.ModuleList = nn.ModuleList( 17 | [nn.Sequential(nn.SiLU(), nn.Conv3d(channels, hidden_dim, 3)) for channels in self.multi_scale_dims] 18 | ) 19 | self.pocket_layer: nn.Module = nn.Sequential( 20 | nn.SiLU(), 21 | nn.Linear(5 * hidden_dim, hidden_dim), 22 | nn.SiLU(), 23 | nn.Linear(hidden_dim, hidden_dim), 24 | ) 25 | 26 | def init_weight(self): 27 | def _init_weight(m): 28 | if isinstance(m, nn.Linear | nn.Conv3d): 29 | nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu") 30 | if m.bias is not None: 31 | nn.init.constant_(m.bias, 0) 32 | 33 | self.apply(_init_weight) 34 | 35 | def forward(self, pmnet_attr: tuple[MultiScaleFeature, list[HotspotInfo]]) -> tuple[Tensor, Tensor, Tensor]: 36 | multi_scale_features, hotspot_infos = pmnet_attr 37 | dev = multi_scale_features[0].device 38 | if len(hotspot_infos) > 0: 39 | hotspot_positions = torch.tensor([info["hotspot_position"] for info in hotspot_infos], device=dev) 40 | hotspot_features = torch.stack([info["hotspot_feature"] for info in hotspot_infos]) 41 | hotspot_features = self.hotspot_mlp(hotspot_features) 42 | else: 43 | hotspot_positions = torch.zeros((0, 3), device=dev) 44 | hotspot_features = torch.zeros((0, self.hidden_dim), device=dev) 45 | pocket_features: Tensor = torch.cat( 46 | [ 47 | mlp(feat.squeeze(0)).mean((-1, -2, -3)) 48 | for mlp, feat in zip(self.pocket_mlp_list, multi_scale_features, strict=False) 49 | ], 50 | dim=-1, 51 | ) 52 | pocket_features = self.pocket_layer(pocket_features) 53 | return hotspot_features, hotspot_positions, pocket_features 54 | -------------------------------------------------------------------------------- /src/pmnet_appl/sbddreward/proxy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright: if you use this script, please cite: 3 | ``` 4 | @article{seo2023pharmaconet, 5 | title = {PharmacoNet: Accelerating Large-Scale Virtual Screening by Deep Pharmacophore Modeling}, 6 | author = {Seo, Seonghwan and Kim, Woo Youn}, 7 | journal = {arXiv preprint arXiv:2310.00681}, 8 | year = {2023}, 9 | url = {https://arxiv.org/abs/2310.00681}, 10 | } 11 | ``` 12 | """ 13 | 14 | from __future__ import annotations 15 | 16 | from pathlib import Path 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch_geometric.data as gd 21 | from torch import Tensor 22 | 23 | from pmnet.api.typing import HotspotInfo, MultiScaleFeature 24 | from pmnet_appl.base.proxy import BaseProxy 25 | from pmnet_appl.sbddreward.data import NUM_ATOM_FEATURES, NUM_BOND_FEATURES, smi2graph 26 | from pmnet_appl.sbddreward.network import ( 27 | AffinityHead, 28 | GraphEncoder, 29 | PharmacophoreEncoder, 30 | ) 31 | 32 | Cache = tuple[Tensor, Tensor, Tensor, float, float] 33 | 34 | 35 | class SBDDReward_Proxy(BaseProxy): 36 | root_dir = Path(__file__).parent 37 | 38 | def _setup_model(self): 39 | self.model = _RewardNetwork() 40 | 41 | def _get_cache(self, pmnet_attr: tuple[MultiScaleFeature, list[HotspotInfo]]) -> Cache: 42 | return self.model.get_cache(pmnet_attr) 43 | 44 | @torch.no_grad() 45 | def _scoring_list(self, cache: Cache, smiles_list: list[str], return_sigma: bool = False) -> Tensor: 46 | cache = ( 47 | cache[0].to(self.device), 48 | cache[1].to(self.device), 49 | cache[2].to(self.device), 50 | cache[3], 51 | cache[4], 52 | ) 53 | 54 | ligand_graphs = [] 55 | flag = [] 56 | for smi in smiles_list: 57 | try: 58 | graph = smi2graph(smi) 59 | except Exception: 60 | flag.append(False) 61 | else: 62 | flag.append(True) 63 | ligand_graphs.append(graph) 64 | if not any(flag): 65 | return torch.zeros(len(smiles_list), dtype=torch.float32, device=self.device) 66 | ligand_batch: gd.Batch = gd.Batch.from_data_list(ligand_graphs).to(self.device) 67 | if all(flag): 68 | return self.model.scoring(cache, ligand_batch, return_sigma) 69 | else: 70 | result = torch.zeros(len(smiles_list), dtype=torch.float32, device=self.device) 71 | result[flag] = self.model.scoring(cache, ligand_batch, return_sigma) 72 | return result 73 | 74 | @classmethod 75 | def load( 76 | cls, 77 | docking: str, 78 | train_dataset: str, 79 | db: Path | str | None, 80 | device: torch.device | str = "cpu", 81 | ): 82 | """Load Pretrained Proxy Model 83 | 84 | Parameters 85 | ---------- 86 | docking : str 87 | docking program (currently: ['UniDock_Vina']) 88 | train_dataset : str 89 | training dataset name (currently: ['ZINC']) 90 | db : Path | str | None 91 | cache database path ('train' | 'test' | 'all' | custom cache database path) 92 | - 'train': CrossDocked2020 training pockets (15,201) 93 | - 'test': CrossDocked2020 test pockets (100) 94 | - 'all': train + test 95 | device : str 96 | cuda | spu 97 | """ 98 | assert docking in ["UniDock_Vina"] 99 | assert train_dataset in ["ZINC"] 100 | return super().load(docking, train_dataset, db, device) 101 | 102 | def scoring(self, target: str, smiles: str, return_sigma: bool = False) -> Tensor: 103 | """Scoring single molecule with its SMILES 104 | 105 | Parameters 106 | ---------- 107 | target : str 108 | target key 109 | smiles : str 110 | molecule smiles 111 | return_sigma : bool (default = False) 112 | if True, return sigma instead of absolute affinity 113 | 114 | Returns 115 | ------- 116 | Tensor [1,] 117 | Esimated Docking Score (or Simga) 118 | 119 | """ 120 | return self._scoring_list(self._cache[target], [smiles], return_sigma) 121 | 122 | def scoring_list(self, target: str, smiles_list: list[str], return_sigma: bool = False) -> Tensor: 123 | """Scoring multiple molecules with their SMILES 124 | 125 | Parameters 126 | ---------- 127 | target : str 128 | target key 129 | smiles_list : list[str] 130 | molecule smiles list 131 | return_sigma : bool (default = False) 132 | if True, return sigma instead of absolute affinity 133 | 134 | Returns 135 | ------- 136 | Tensor [N,] 137 | Esimated Docking Scores (or Simga) 138 | 139 | """ 140 | return self._scoring_list(self._cache[target], smiles_list, return_sigma) 141 | 142 | def get_statistic(self, target: str) -> tuple[float, float]: 143 | cache: Cache = self._cache[target] 144 | return cache[-2], cache[-1] 145 | 146 | 147 | class _RewardNetwork(nn.Module): 148 | def __init__(self): 149 | super().__init__() 150 | self.pharmacophore_encoder: PharmacophoreEncoder = PharmacophoreEncoder(128) 151 | self.ligand_encoder: GraphEncoder = GraphEncoder(NUM_ATOM_FEATURES, NUM_BOND_FEATURES, 128, 128, 4) 152 | self.head: AffinityHead = AffinityHead(128, 3) 153 | 154 | def get_cache(self, pmnet_attr) -> Cache: 155 | X_protein, pos_protein, Z_protein = self.pharmacophore_encoder.forward(pmnet_attr) 156 | mu, std = self.head.cal_mu(Z_protein), self.head.cal_std(Z_protein) 157 | return ( 158 | X_protein.cpu(), 159 | pos_protein.cpu(), 160 | Z_protein.cpu(), 161 | mu.item(), 162 | std.item(), 163 | ) 164 | 165 | def scoring(self, cache: Cache, ligand_batch: gd.Batch, return_sigma: bool = False): 166 | X_protein, pos_protein, Z_protein, mu, std = cache 167 | X_ligand, Z_ligand = self.ligand_encoder.forward(ligand_batch) 168 | sigma = self.head.cal_sigma(X_protein, pos_protein, Z_protein, X_ligand, Z_ligand, ligand_batch.batch) 169 | if return_sigma: 170 | return sigma 171 | else: 172 | return sigma * std + mu 173 | 174 | def get_info(self, cache: Cache, ligand_batch: gd.Batch) -> tuple[float, float, Tensor]: 175 | X_protein, pos_protein, Z_protein, mu, std = cache 176 | X_ligand, Z_ligand = self.ligand_encoder.forward(ligand_batch) 177 | sigma = self.head.cal_sigma(X_protein, pos_protein, Z_protein, X_ligand, Z_ligand, ligand_batch.batch) 178 | return mu, std, sigma 179 | 180 | 181 | if __name__ == "__main__": 182 | print("start!") 183 | proxy = SBDDReward_Proxy.load("UniDock_Vina", "ZINC", "test", "cpu") 184 | print("proxy is loaded") 185 | print(proxy.scoring("14gs_A", "c1ccccc1")) 186 | print(proxy.scoring("14gs_A", "c11")) 187 | print(proxy.scoring_list("14gs_A", ["c1ccccc1", "C1CCCCC1", "c11"])) 188 | -------------------------------------------------------------------------------- /src/pmnet_appl/tacogfn_reward/__init__.py: -------------------------------------------------------------------------------- 1 | from .proxy import TacoGFN_Proxy 2 | -------------------------------------------------------------------------------- /src/pmnet_appl/tacogfn_reward/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric.data as gd 3 | from openbabel import pybel 4 | from openbabel.pybel import ob 5 | 6 | ATOM_DICT = { 7 | 6: 0, 8 | 7: 1, 9 | 8: 2, 10 | 9: 3, 11 | 15: 4, 12 | 16: 5, 13 | 17: 6, 14 | 35: 7, 15 | 53: 8, 16 | -1: 9, # UNKNOWN 17 | } 18 | 19 | BOND_DICT = { 20 | 1: 0, 21 | 2: 1, 22 | 3: 2, 23 | 1.5: 3, # AROMATIC 24 | -1: 4, # UNKNOWN 25 | } 26 | 27 | 28 | def smi2graph(smiles: str) -> gd.Data: 29 | pbmol = pybel.readstring("smi", smiles) 30 | obmol: ob.OBMol = pbmol.OBMol 31 | atom_features = [] 32 | pos = [] 33 | for pbatom in pbmol.atoms: 34 | atom_features.append(ATOM_DICT.get(pbatom.atomicnum, 9)) 35 | pos.append(pbatom.coords) 36 | 37 | edge_index = [] 38 | edge_type = [] 39 | for obbond in ob.OBMolBondIter(obmol): 40 | obbond: ob.OBBond 41 | edge_index.append((obbond.GetBeginAtomIdx() - 1, obbond.GetEndAtomIdx() - 1)) 42 | if obbond.IsAromatic(): 43 | edge_type.append(3) 44 | else: 45 | edge_type.append(BOND_DICT.get(obbond.GetBondOrder(), 4)) 46 | 47 | return gd.Data( 48 | x=torch.LongTensor(atom_features), 49 | edge_index=torch.LongTensor(edge_index).T, 50 | edge_attr=torch.LongTensor(edge_type), 51 | ) 52 | -------------------------------------------------------------------------------- /src/pmnet_appl/tacogfn_reward/db_keys/test.txt: -------------------------------------------------------------------------------- 1 | 2z3h_A 2 | 4aaw_A 3 | 4yhj_A 4 | 14gs_A 5 | 2v3r_A 6 | 4rn0_B 7 | 1fmc_B 8 | 3daf_A 9 | 1a2g_A 10 | 5w2g_A 11 | 3dzh_A 12 | 3g51_A 13 | 1coy_A 14 | 2jjg_A 15 | 2rhy_A 16 | 2pqw_A 17 | 4g3d_B 18 | 5bur_A 19 | 3gs6_A 20 | 1r1h_A 21 | 1dxo_C 22 | 1gg5_A 23 | 5q0k_A 24 | 5b08_A 25 | 2azy_A 26 | 5i0b_A 27 | 1phk_A 28 | 4keu_A 29 | 4q8b_B 30 | 1djy_A 31 | 5l1v_A 32 | 4zfa_A 33 | 2rma_A 34 | 3b6h_A 35 | 2zen_A 36 | 4p6p_A 37 | 3u5y_B 38 | 4f1m_A 39 | 4tqr_A 40 | 4lfu_A 41 | 3jyh_A 42 | 4iwq_A 43 | 1l3l_A 44 | 5ngz_A 45 | 1e8h_A 46 | 2e24_A 47 | 2hcj_B 48 | 3kc1_A 49 | 1d7j_A 50 | 4ja8_B 51 | 4u5s_A 52 | 4iiy_A 53 | 3v4t_A 54 | 3tym_A 55 | 4d7o_A 56 | 3ej8_A 57 | 1rs9_A 58 | 4kcq_A 59 | 3pdh_A 60 | 1umd_B 61 | 4pxz_A 62 | 2gns_A 63 | 1ai4_A 64 | 5mma_A 65 | 2cy0_A 66 | 3w83_B 67 | 2e6d_A 68 | 4rv4_A 69 | 5d7n_D 70 | 5mgl_A 71 | 1h36_A 72 | 4gvd_A 73 | 4tos_A 74 | 5aeh_A 75 | 4h3c_A 76 | 4rlu_A 77 | 4xli_B 78 | 3l3n_A 79 | 5tjn_A 80 | 5liu_X 81 | 3o96_A 82 | 4qlk_A 83 | 3hy9_B 84 | 4bel_A 85 | 3nfb_A 86 | 4m7t_A 87 | 3u9f_C 88 | 4aua_A 89 | 2f2c_B 90 | 3chc_B 91 | 1k9t_A 92 | 1h0i_A 93 | 4z2g_A 94 | 3af2_A 95 | 1jn2_P 96 | 3li4_A 97 | 3pnm_A 98 | 1afs_A 99 | 4azf_A 100 | 2pc8_A 101 | -------------------------------------------------------------------------------- /src/pmnet_appl/tacogfn_reward/get_cache.py: -------------------------------------------------------------------------------- 1 | from pmnet_appl.tacogfn_reward import TacoGFN_Proxy 2 | 3 | docking = "QVina" 4 | train_dataset = "ZincDock15M" # or "CrossDocked2020" 5 | proxy = TacoGFN_Proxy.load(docking, train_dataset, None, "cuda") 6 | 7 | save_database_path = "./tmp_db.pt" 8 | protein_info_dict = { 9 | "key1": ("./tmp1.pdb", "./ref_ligand1.sdf"), # reference ligand path 10 | "key2": ("./tmp2.pdb", (1.0, 2.0, 3.0)), # pocket center 11 | } 12 | 13 | cache_dict = proxy.get_cache_database(protein_info_dict, save_database_path, verbose=False) 14 | proxy.update_cache(cache_dict) 15 | proxy.scoring(list(cache_dict.keys())[0], "c1ccccc1") 16 | proxy.scoring_list(list(cache_dict.keys())[0], ["c1ccccc1", "C1CCCCC1"]) 17 | -------------------------------------------------------------------------------- /utils/parse_rcsb_pdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from urllib.request import urlopen 5 | 6 | import numpy as np 7 | import pymol 8 | from openbabel import pybel 9 | 10 | PathLike = str | Path 11 | 12 | 13 | @dataclass 14 | class LigandInform: 15 | order: int 16 | id: str 17 | pdbchain: str 18 | authchain: str 19 | residx: int 20 | center: tuple[float, float, float] 21 | file_path: PathLike 22 | name: str | None 23 | synonyms: str | None 24 | 25 | def __str__(self) -> str: 26 | x, y, z = self.center 27 | string = ( 28 | f"Ligand {self.order}\n" 29 | f"- ID : {self.id} (Chain: {self.pdbchain} [auth {self.authchain}])\n" 30 | f"- Center : {x:.3f}, {y:.3f}, {z:.3f}" 31 | ) 32 | if self.name is not None: 33 | string += f"\n- Name : {self.name}" 34 | if self.synonyms is not None: 35 | string += f"\n- Synonyms: {self.synonyms}" 36 | return string 37 | 38 | 39 | def download_pdb(pdb_code: str, output_file: PathLike): 40 | url = f"https://files.rcsb.org/download/{pdb_code.lower()}.pdb" 41 | try: 42 | with urlopen(url) as response: 43 | content = response.read().decode("utf-8") 44 | with open(output_file, "w") as file: 45 | file.write(content) 46 | except Exception as e: 47 | print(f"Error downloading PDB file: {e}") 48 | 49 | 50 | def parse_pdb(pdb_code: str, protein_path: PathLike, save_dir: PathLike) -> list[LigandInform]: 51 | protein: pybel.Molecule = next(pybel.readfile("pdb", str(protein_path))) 52 | 53 | if "HET" not in protein.data.keys(): 54 | return [] 55 | het_lines: list[str] = protein.data["HET"].split("\n") 56 | hetnam_lines: list[str] = protein.data["HETNAM"].split("\n") 57 | if "HETSYN" in protein.data.keys(): 58 | hetsyn_lines = protein.data["HETSYN"].split("\n") 59 | else: 60 | hetsyn_lines = [] 61 | 62 | het_id_list = tuple(line.strip().split()[0] for line in het_lines) 63 | 64 | ligand_name_dict = {} 65 | for line in hetnam_lines: 66 | line = line.strip() 67 | if line.startswith(het_id_list): 68 | key, *strings = line.split() 69 | assert key not in ligand_name_dict 70 | ligand_name_dict[key] = " ".join(strings) 71 | else: 72 | _, key, *strings = line.split() 73 | assert key in ligand_name_dict 74 | if ligand_name_dict[key][-1] == "-": 75 | ligand_name_dict[key] += " ".join(strings) 76 | else: 77 | ligand_name_dict[key] += " " + " ".join(strings) 78 | 79 | ligand_syn_dict = {} 80 | for line in hetsyn_lines: 81 | line: str = line.strip() 82 | if line.startswith(het_id_list): 83 | key, *strings = line.split() 84 | assert key not in ligand_syn_dict 85 | ligand_syn_dict[key] = " ".join(strings) 86 | else: 87 | _, key, *strings = line.split() 88 | assert key in ligand_syn_dict 89 | if ligand_syn_dict[key][-1] == "-": 90 | ligand_syn_dict[key] += " ".join(strings) 91 | else: 92 | ligand_syn_dict[key] += " " + " ".join(strings) 93 | 94 | pymol.finish_launching(["pymol", "-cq"]) 95 | pymol.cmd.load(str(protein_path)) 96 | 97 | ligand_inform_list = [] 98 | last_chain = protein.data["SEQRES"].split("\n")[-1].split()[1] 99 | for idx, line in enumerate(het_lines): 100 | vs = line.strip().split() 101 | if len(vs) == 4: 102 | ligid, authchain, residue_idx, _ = vs 103 | else: 104 | ( 105 | ligid, 106 | authchain, 107 | residue_idx, 108 | ) = ( 109 | vs[0], 110 | vs[1][0], 111 | vs[1][1:], 112 | ) 113 | pdbchain = chr(ord(last_chain) + idx + 1) 114 | identify_key = f"{pdb_code}_{pdbchain}_{ligid}" 115 | ligand_path = os.path.join(save_dir, f"{identify_key}.pdb") 116 | 117 | if not os.path.exists(ligand_path): 118 | pymol.cmd.select( 119 | identify_key, 120 | f"resn {ligid} and resi {residue_idx} and chain {authchain}", 121 | ) 122 | pymol.cmd.save(ligand_path, identify_key) 123 | 124 | ligand = next(pybel.readfile("pdb", ligand_path)) 125 | x, y, z = np.mean([atom.coords for atom in ligand.atoms], axis=0).tolist() 126 | 127 | inform = LigandInform( 128 | idx + 1, 129 | ligid, 130 | pdbchain, 131 | authchain, 132 | int(residue_idx), 133 | (x, y, z), 134 | ligand_path, 135 | ligand_name_dict.get(ligid), 136 | ligand_syn_dict.get(ligid), 137 | ) 138 | ligand_inform_list.append(inform) 139 | # pymol.cmd.quit() 140 | return ligand_inform_list 141 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import tempfile 4 | 5 | import pymol 6 | from pymol import cmd 7 | 8 | from pmnet import PharmacophoreModel 9 | 10 | 11 | class Visualize_ArgParser(argparse.ArgumentParser): 12 | def __init__(self): 13 | super().__init__("scoring") 14 | self.formatter_class = argparse.ArgumentDefaultsHelpFormatter 15 | self.add_argument("model", type=str, help="path to save pharmacophore model (.json | .pkl)") 16 | self.add_argument("-p", "--protein", type=str, help="path of protein file") 17 | self.add_argument("-l", "--ligand", type=str, help="path of reference ligand file") 18 | self.add_argument( 19 | "-o", 20 | "--out", 21 | type=str, 22 | help="path of pymol session file (.pse)", 23 | required=True, 24 | ) 25 | self.add_argument("--prefix", type=str, help="prefix") 26 | 27 | 28 | PHARMACOPHORE_COLOR_DICT = { 29 | "Hydrophobic": "orange", 30 | "Aromatic": "deeppurple", 31 | "Cation": "blue", 32 | "Anion": "red", 33 | "HBond_acceptor": "magenta", 34 | "HBond_donor": "cyan", 35 | "Halogen": "yellow", 36 | } 37 | 38 | INTERACTION_COLOR_DICT = { 39 | "Hydrophobic": "orange", 40 | "PiStacking_P": "deeppurple", 41 | "PiStacking_T": "deeppurple", 42 | "PiCation_lring": "blue", 43 | "PiCation_pring": "deeppurple", 44 | "HBond_ldon": "magenta", 45 | "HBond_pdon": "cyan", 46 | "SaltBridge_lneg": "blue", 47 | "SaltBridge_pneg": "red", 48 | "XBond": "yellow", 49 | } 50 | 51 | 52 | def visualize_single( 53 | model: PharmacophoreModel, 54 | protein_path: str | None, 55 | ligand_path: str | None, 56 | prefix: str, 57 | save_path: str, 58 | ): 59 | pymol.pymol_argv = ["pymol", "-pcq"] 60 | pymol.finish_launching(args=["pymol", "-pcq", "-K"]) 61 | pymol.cmd.reinitialize() 62 | pymol.cmd.feedback("disable", "all", "everything") 63 | 64 | prefix = f"{prefix}_" if prefix else "" 65 | 66 | # NOTE: Draw Molecule 67 | if protein_path: 68 | cmd.load(protein_path) 69 | else: 70 | with tempfile.TemporaryDirectory() as direc: 71 | protein_path = f"{direc}/pocket.pdb" 72 | with open(protein_path, "w") as w: 73 | w.write(model.pocket_pdbblock) 74 | cmd.load(protein_path) 75 | cmd.set_name(os.path.splitext(os.path.basename(protein_path))[0], f"{prefix}Protein") 76 | cmd.remove("hetatm") 77 | 78 | if ligand_path: 79 | cmd.load(ligand_path) 80 | cmd.set_name(os.path.splitext(os.path.basename(ligand_path))[0], f"{prefix}Ligand") 81 | 82 | # NOTE: Pharmacophore Model 83 | nci_dict = {} 84 | for node in model.nodes: 85 | hotspot_color = INTERACTION_COLOR_DICT[node.interaction_type] 86 | pharmacophore_color = PHARMACOPHORE_COLOR_DICT[node.type] 87 | hotspot_id = f"{prefix}hotspot{node.index}" 88 | cmd.pseudoatom(hotspot_id, pos=node.hotspot_position, color=hotspot_color) 89 | cmd.set("sphere_color", hotspot_color, hotspot_id) 90 | 91 | pharmacophore_id = f"{prefix}point{node.index}" 92 | cmd.pseudoatom(pharmacophore_id, pos=node.center, color=hotspot_color) 93 | cmd.set("sphere_color", pharmacophore_color, pharmacophore_id) 94 | cmd.set("sphere_scale", node.radius, pharmacophore_id) 95 | 96 | interaction_id = f"{prefix}interaction{node.index}" 97 | cmd.distance(interaction_id, hotspot_id, pharmacophore_id) 98 | cmd.set("dash_color", pharmacophore_color, interaction_id) 99 | 100 | result_id = f"{prefix}NCI{node.index}" 101 | cmd.group(result_id, f"{hotspot_id} {pharmacophore_id} {interaction_id}") 102 | nci_dict.setdefault(node.interaction_type, []).append(result_id) 103 | 104 | for interaction_type, nci_list in nci_dict.items(): 105 | cmd.group(f"{prefix}{interaction_type}", " ".join(nci_list)) 106 | cmd.group(f"{prefix}Model", f"{prefix}{interaction_type}") 107 | 108 | cmd.set("stick_transparency", 0.6, f"{prefix}Protein") 109 | cmd.set("cartoon_transparency", 0.6, f"{prefix}Protein") 110 | cmd.color("gray90", f"{prefix}Protein and (name C*)") 111 | 112 | cmd.set("sphere_scale", 0.3, "*hotspot*") 113 | cmd.set("sphere_transparency", 0.2, "*point*") 114 | cmd.set("dash_gap", 0.2, "*interaction*") 115 | cmd.set("dash_length", 0.4, "*interaction*") 116 | cmd.hide("label", "*interaction*") 117 | 118 | cmd.bg_color("white") 119 | cmd.show("sticks", f"{prefix}Protein") 120 | cmd.show("sphere", f"{prefix}Model") 121 | cmd.show("dash", f"{prefix}Model") 122 | cmd.disable(f"{prefix}Protein") 123 | cmd.enable(f"{prefix}Protein") 124 | 125 | cmd.save(save_path) 126 | 127 | 128 | def visualize_multiple( 129 | model_dict: dict[str, tuple[PharmacophoreModel, str]], 130 | protein_path: str, 131 | pdb: str, 132 | save_path: str, 133 | ): 134 | pymol.pymol_argv = ["pymol", "-pcq"] 135 | pymol.finish_launching(args=["pymol", "-pcq", "-K"]) 136 | cmd.reinitialize() 137 | cmd.feedback("disable", "all", "everything") 138 | 139 | # NOTE: Draw Molecule 140 | cmd.load(protein_path) 141 | cmd.set_name(os.path.splitext(os.path.basename(protein_path))[0], pdb) 142 | cmd.remove("hetatm") 143 | 144 | for prefix, (model, ligand_path) in model_dict.items(): 145 | if ligand_path: 146 | cmd.load(ligand_path) 147 | cmd.set_name(os.path.splitext(os.path.basename(ligand_path))[0], f"{prefix}_Ligand") 148 | 149 | # NOTE: Pharmacophore Model 150 | nci_dict = {} 151 | for node in model.nodes: 152 | hotspot_color = INTERACTION_COLOR_DICT[node.interaction_type] 153 | pharmacophore_color = PHARMACOPHORE_COLOR_DICT[node.type] 154 | hotspot_id = f"{prefix}_hotspot{node.index}" 155 | cmd.pseudoatom(hotspot_id, pos=node.hotspot_position, color=hotspot_color) 156 | cmd.set("sphere_color", hotspot_color, hotspot_id) 157 | 158 | pharmacophore_id = f"{prefix}_point{node.index}" 159 | cmd.pseudoatom(pharmacophore_id, pos=node.center, color=hotspot_color) 160 | cmd.set("sphere_color", pharmacophore_color, pharmacophore_id) 161 | cmd.set("sphere_scale", node.radius, pharmacophore_id) 162 | 163 | interaction_id = f"{prefix}_interaction{node.index}" 164 | cmd.distance(interaction_id, hotspot_id, pharmacophore_id) 165 | cmd.set("dash_color", pharmacophore_color, interaction_id) 166 | 167 | result_id = f"{prefix}_NCI{node.index}" 168 | cmd.group(result_id, f"{hotspot_id} {pharmacophore_id} {interaction_id}") 169 | nci_dict.setdefault(node.interaction_type, []).append(result_id) 170 | 171 | for interaction_type, nci_list in nci_dict.items(): 172 | cmd.group(f"{prefix}_{interaction_type}", " ".join(nci_list)) 173 | cmd.group(f"{prefix}_Model", f"{prefix}_{interaction_type}") 174 | cmd.group(prefix, f"{prefix}_Model {prefix}_Ligand") 175 | 176 | cmd.set("stick_transparency", 0.6, pdb) 177 | cmd.set("cartoon_transparency", 0.6, pdb) 178 | cmd.color("gray90", f"{pdb} and (name C*)") 179 | 180 | cmd.set("sphere_scale", 0.3, "*hotspot*") 181 | cmd.set("sphere_transparency", 0.2, "*point*") 182 | cmd.set("dash_gap", 0.2, "*interaction*") 183 | cmd.set("dash_length", 0.4, "*interaction*") 184 | cmd.hide("label", "*interaction*") 185 | 186 | cmd.bg_color("white") 187 | cmd.show("sphere", "*Model") 188 | cmd.show("dash", "*Model") 189 | cmd.show("sticks", pdb) 190 | cmd.disable(pdb) 191 | cmd.enable(pdb) 192 | cmd.save(save_path) 193 | 194 | 195 | if __name__ == "__main__": 196 | parser = Visualize_ArgParser() 197 | args = parser.parse_args() 198 | visualize_single( 199 | PharmacophoreModel.load(args.model), 200 | args.protein, 201 | args.ligand, 202 | args.prefix, 203 | args.out, 204 | ) 205 | --------------------------------------------------------------------------------