├── .gitignore ├── LICENSE ├── README.md ├── assets └── overview.png ├── configs ├── sampling │ ├── protac │ │ ├── protac_all_gui.yml │ │ ├── protac_anchor.yml │ │ ├── protac_dist.yml │ │ └── protac_no_gui.yml │ └── zinc.yml └── training │ ├── nn_variants │ ├── pos_eps_rot_eps.yml │ ├── pos_eps_rot_euler.yml │ └── pos_newton_rot_eps.yml │ ├── warhead_protac.yml │ └── zinc.yml ├── data ├── .gitignore └── protac │ ├── 3d_index.pkl │ ├── e3_ligand.csv │ ├── index.pkl │ ├── linker.csv │ ├── protac.csv │ ├── smi_protac.txt │ └── warhead.csv ├── datasets ├── __init__.py ├── linker_data.py └── linker_dataset.py ├── models ├── common.py ├── diff_protac_bond.py ├── encoders │ ├── __init__.py │ └── node_edge_net.py ├── eps_net.py └── transition.py ├── playground └── check_data.ipynb ├── scripts ├── baselines │ ├── eval_3dlinker.py │ ├── eval_delinker.py │ └── eval_difflinker.py ├── eval_protac.py ├── prepare_data.py ├── sample_protac.py └── train_protac.py └── utils ├── calc_SC_RDKit.py ├── const.py ├── data.py ├── eval_bond.py ├── evaluation.py ├── fpscores.pkl.gz ├── frag_utils.py ├── geometry.py ├── guidance_funcs.py ├── misc.py ├── prior_num_atoms.py ├── reconstruct_linker.py ├── sascorer.py ├── so3.py ├── train.py ├── train_linker_smiles.pkl ├── transforms.py ├── visualize.py └── wehi_pains.csv /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # IDE 132 | .idea/ 133 | 134 | # OS FILE 135 | .DS_Store 136 | test_* 137 | ckpts 138 | output* 139 | logs/ 140 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Jiaqi Guan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LinkerNet: Fragment Poses and Linker Co-Design with 3D Equivariant Diffusion 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/guanjq/targetdiff/blob/main/LICIENCE) 4 | 5 | 6 | This repository is the official implementation of LinkerNet: Fragment Poses and Linker Co-Design with 3D Equivariant Diffusion (NeurIPS 2023). [[PDF]](https://openreview.net/forum?id=6EaLIw3W7c) 7 | 8 |

9 | 10 |

11 | 12 | ## Installation 13 | 14 | ### Dependency 15 | 16 | The code has been tested in the following environment: 17 | 18 | 19 | | Package | Version | 20 | |-------------------|-----------| 21 | | Python | 3.8 | 22 | | PyTorch | 1.13.1 | 23 | | CUDA | 11.6 | 24 | | PyTorch Geometric | 2.2.0 | 25 | | RDKit | 2022.03.2 | 26 | 27 | ### Install via Conda and Pip 28 | ```bash 29 | conda create -n targetdiff python=3.8 30 | conda activate targetdiff 31 | conda install pytorch pytorch-cuda=11.6 -c pytorch -c nvidia 32 | conda install pyg -c pyg 33 | conda install rdkit openbabel tensorboard pyyaml easydict python-lmdb -c conda-forge 34 | ``` 35 | 36 | --- 37 | ## Data Preprocess 38 | 39 | ### PROTAC-DB 40 | 41 | We have provided all data files related to PROTAC-DB dataset in this repo. 42 | * The raw data (.csv files in the data/protac folder) are downloaded from [PROTAC-DB](http://cadd.zju.edu.cn/protacdb/). 43 | * The index.pkl file is obtained in playground/check_data.ipynb 44 | * The 3d_index.pkl file containing the conformation generated by RDKit, which is obtained by running the following command: 45 | 46 | ```bash 47 | python scripts/prepare_data.py --raw_path data/protac/index.pkl --dest data/protac/3d_index.pkl 48 | ``` 49 | 50 | Note that RDKit version may influence the PROTAC-DB dataset processing and splitting. We provided the processed data and split file [here](https://drive.google.com/drive/folders/1Nt37DO1PYwPNM0_uF2Zzz4QxWD3v8pF6?usp=drive_link) 51 | 52 | ### ZINC 53 | 54 | The raw ZINC data are same as [DiffLinker](https://zenodo.org/records/7121271). 55 | We preprocess ZINC data to output an index file by running: 56 | ```bash 57 | python scripts/prepare_data.py \ 58 | --raw_path data/zinc_difflinker \ 59 | --dest data/zinc_difflinker/index_full.pkl \ 60 | --dataset zinc_difflinker --mode full 61 | ``` 62 | We also provided the preprocessed index file [here](https://drive.google.com/drive/folders/1C1srELCCNJLk8v1smjvmbE-xYvnog5jU?usp=sharing). 63 | 64 | --- 65 | ## Training 66 | python scripts/train_protac.py configs/training/zinc.yml 67 | 68 | We have provided the [pretrained checkpoints](https://drive.google.com/drive/folders/1C1srELCCNJLk8v1smjvmbE-xYvnog5jU?usp=sharing) on ZINC / PROTAC. 69 | 70 | ## Sampling 71 | python scripts/sample_protac.py configs/sampling/zinc.yml --subset test --start_id 0 --end_id -1 --num_samples 250 --outdir outputs/zinc 72 | 73 | We have also provided the sampling results in the same link. 74 | 75 | 76 | ## Evaluation 77 | python scripts/eval_protac.py {SAMPLE_RESULT_DIR} 78 | 79 | 80 | ## Citation 81 | ``` 82 | @inproceedings{guan2023linkernet, 83 | title={LinkerNet: Fragment Poses and Linker Co-Design with 3D Equivariant Diffusion}, 84 | author={Guan, Jiaqi and Peng, Xingang and Jiang, PeiQi and Luo, Yunan and Peng, Jian and Ma, Jianzhu}, 85 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 86 | year={2023} 87 | } 88 | ``` -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanjq/LinkerNet/280759c16ccecece0d81ab9cebe3f44041b80e51/assets/overview.png -------------------------------------------------------------------------------- /configs/sampling/protac/protac_all_gui.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: protac 3 | path: ./data/protac 4 | version: v1 5 | split_mode: warhead # [warhead, ligase, random] 6 | index_name: 3d_index.pkl 7 | max_num_atoms: 70 8 | 9 | model: 10 | checkpoint: ckpts/protac_model.pt 11 | 12 | sample: 13 | seed: 2022 14 | num_samples: 100 15 | num_atoms: prior # [ref, prior] 16 | cand_bond_mask: True 17 | guidance_opt: 18 | - type: anchor_prox 19 | update: frag_rot 20 | min_d: 1.2 21 | max_d: 1.9 22 | decay: False 23 | - type: frag_distance 24 | mode: frag_center_distance 25 | constraint_mode: dynamic # [dynamic, const] 26 | sigma: 0.2 27 | min_d: 28 | max_d: 29 | -------------------------------------------------------------------------------- /configs/sampling/protac/protac_anchor.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: protac 3 | path: ./data/protac 4 | version: v1 5 | split_mode: warhead # [warhead, ligase, random] 6 | index_name: 3d_index.pkl 7 | max_num_atoms: 70 8 | 9 | model: 10 | checkpoint: ckpts/protac_model.pt 11 | 12 | sample: 13 | seed: 2022 14 | num_samples: 100 15 | num_atoms: prior # [ref, prior] 16 | cand_bond_mask: True 17 | guidance_opt: 18 | - type: anchor_prox 19 | update: frag_rot 20 | min_d: 1.2 21 | max_d: 1.9 22 | decay: False 23 | # - type: frag_distance 24 | # mode: frag_center_distance 25 | # constraint_mode: dynamic # [dynamic, const] 26 | # sigma: 0.2 27 | # min_d: 28 | # max_d: 29 | -------------------------------------------------------------------------------- /configs/sampling/protac/protac_dist.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: protac 3 | path: ./data/protac 4 | version: v1 5 | split_mode: warhead # [warhead, ligase, random] 6 | index_name: 3d_index.pkl 7 | max_num_atoms: 70 8 | 9 | model: 10 | checkpoint: ckpts/protac_model.pt 11 | 12 | sample: 13 | seed: 2022 14 | num_samples: 100 15 | num_atoms: prior # [ref, prior] 16 | cand_bond_mask: True 17 | guidance_opt: 18 | # - type: anchor_prox 19 | # update: frag_rot 20 | # min_d: 1.2 21 | # max_d: 1.9 22 | # decay: False 23 | - type: frag_distance 24 | mode: frag_center_distance 25 | constraint_mode: dynamic # [dynamic, const] 26 | sigma: 0.2 27 | min_d: 28 | max_d: 29 | -------------------------------------------------------------------------------- /configs/sampling/protac/protac_no_gui.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: protac 3 | path: ./data/protac 4 | version: v1 5 | split_mode: warhead # [warhead, ligase, random] 6 | index_name: 3d_index.pkl 7 | max_num_atoms: 70 8 | 9 | model: 10 | checkpoint: ckpts/protac_model.pt 11 | 12 | sample: 13 | seed: 2022 14 | num_samples: 100 15 | num_atoms: prior # [ref, prior] 16 | cand_bond_mask: True 17 | guidance_opt: 18 | # - type: anchor_prox 19 | # update: frag_rot 20 | # min_d: 1.2 21 | # max_d: 1.9 22 | # decay: False 23 | # - type: frag_distance 24 | # mode: frag_center_distance 25 | # constraint_mode: dynamic # [dynamic, const] 26 | # sigma: 0.2 27 | # min_d: 28 | # max_d: 29 | -------------------------------------------------------------------------------- /configs/sampling/zinc.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: zinc 3 | path: ./data/zinc_difflinker 4 | version: full 5 | index_name: index_full.pkl 6 | max_num_atoms: 30 7 | 8 | model: 9 | checkpoint: ckpts/zinc_model.pt 10 | 11 | sample: 12 | seed: 2022 13 | num_samples: 100 14 | num_atoms: ref 15 | guidance_opt: 16 | # num_steps: 100 17 | # energy_drift: 18 | # - type: frag_link_prox 19 | # min_d: 1.2 20 | # max_d: 1.9 21 | -------------------------------------------------------------------------------- /configs/training/nn_variants/pos_eps_rot_eps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: zinc 3 | path: ./data/zinc_difflinker 4 | version: full # [tiny, full] 5 | index_name: index_full.pkl 6 | 7 | model: 8 | num_steps: 500 9 | node_emb_dim: 256 10 | edge_emb_dim: 64 11 | time_emb_type: plain 12 | time_emb_dim: 1 13 | time_emb_scale: 1000 14 | 15 | train_frag_rot: True 16 | train_frag_pos: True 17 | train_link: True 18 | train_bond: True 19 | 20 | frag_pos_prior: 21 | known_anchor: False 22 | known_linker_bond: False 23 | rel_geometry: two_pos_and_rot 24 | 25 | diffusion: 26 | trans_rot_opt: 27 | sche_type: cosine 28 | s: 0.01 29 | trans_pos_opt: 30 | sche_type: cosine 31 | s: 0.01 32 | trans_link_pos_opt: 33 | sche_type: cosine 34 | s: 0.01 35 | trans_link_cls_opt: 36 | sche_type: cosine 37 | s: 0.01 38 | trans_link_bond_opt: 39 | sche_type: cosine 40 | s: 0.01 41 | 42 | eps_net: 43 | net_type: node_edge_net 44 | encoder: 45 | num_blocks: 6 46 | cutoff: 15. 47 | use_gate: True 48 | num_gaussians: 20 49 | expansion_mode: exp 50 | tr_output_type: invariant_eps 51 | rot_output_type: invariant_eps 52 | output_n_heads: 8 53 | separate_att: True 54 | sym_force: False 55 | 56 | train: 57 | seed: 2023 58 | loss_weights: 59 | frag_rot: 1.0 60 | frag_pos: 1.0 61 | link_pos: 1.0 62 | link_cls: 100.0 63 | link_bond: 100.0 64 | 65 | batch_size: 64 66 | num_workers: 8 67 | n_acc_batch: 1 68 | max_iters: 500000 69 | val_freq: 2000 70 | pos_noise_std: 0.05 71 | max_grad_norm: 50.0 72 | optimizer: 73 | type: adamw 74 | lr: 5.e-4 75 | weight_decay: 1.e-8 76 | beta1: 0.99 77 | beta2: 0.999 78 | scheduler: 79 | type: plateau 80 | factor: 0.6 81 | patience: 10 82 | min_lr: 1.e-6 -------------------------------------------------------------------------------- /configs/training/nn_variants/pos_eps_rot_euler.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: zinc 3 | path: ./data/zinc_difflinker 4 | version: full # [tiny, full] 5 | index_name: index_full.pkl 6 | 7 | model: 8 | num_steps: 500 9 | node_emb_dim: 256 10 | edge_emb_dim: 64 11 | time_emb_type: plain 12 | time_emb_dim: 1 13 | time_emb_scale: 1000 14 | 15 | train_frag_rot: True 16 | train_frag_pos: True 17 | train_link: True 18 | train_bond: True 19 | 20 | frag_pos_prior: 21 | known_anchor: False 22 | known_linker_bond: False 23 | rel_geometry: two_pos_and_rot 24 | 25 | diffusion: 26 | trans_rot_opt: 27 | sche_type: cosine 28 | s: 0.01 29 | trans_pos_opt: 30 | sche_type: cosine 31 | s: 0.01 32 | trans_link_pos_opt: 33 | sche_type: cosine 34 | s: 0.01 35 | trans_link_cls_opt: 36 | sche_type: cosine 37 | s: 0.01 38 | trans_link_bond_opt: 39 | sche_type: cosine 40 | s: 0.01 41 | 42 | eps_net: 43 | net_type: node_edge_net 44 | encoder: 45 | num_blocks: 6 46 | cutoff: 15. 47 | use_gate: True 48 | num_gaussians: 20 49 | expansion_mode: exp 50 | tr_output_type: invariant_eps 51 | rot_output_type: euler_equation 52 | output_n_heads: 8 53 | separate_att: True 54 | sym_force: False 55 | 56 | train: 57 | seed: 2023 58 | loss_weights: 59 | frag_rot: 1.0 60 | frag_pos: 1.0 61 | link_pos: 1.0 62 | link_cls: 100.0 63 | link_bond: 100.0 64 | 65 | batch_size: 64 66 | num_workers: 8 67 | n_acc_batch: 1 68 | max_iters: 500000 69 | val_freq: 2000 70 | pos_noise_std: 0.05 71 | max_grad_norm: 50.0 72 | optimizer: 73 | type: adamw 74 | lr: 5.e-4 75 | weight_decay: 1.e-8 76 | beta1: 0.99 77 | beta2: 0.999 78 | scheduler: 79 | type: plateau 80 | factor: 0.6 81 | patience: 10 82 | min_lr: 1.e-6 -------------------------------------------------------------------------------- /configs/training/nn_variants/pos_newton_rot_eps.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: zinc 3 | path: ./data/zinc_difflinker 4 | version: full # [tiny, full] 5 | index_name: index_full.pkl 6 | 7 | model: 8 | num_steps: 500 9 | node_emb_dim: 256 10 | edge_emb_dim: 64 11 | time_emb_type: plain 12 | time_emb_dim: 1 13 | time_emb_scale: 1000 14 | 15 | train_frag_rot: True 16 | train_frag_pos: True 17 | train_link: True 18 | train_bond: True 19 | 20 | frag_pos_prior: 21 | known_anchor: False 22 | known_linker_bond: False 23 | rel_geometry: two_pos_and_rot 24 | 25 | diffusion: 26 | trans_rot_opt: 27 | sche_type: cosine 28 | s: 0.01 29 | trans_pos_opt: 30 | sche_type: cosine 31 | s: 0.01 32 | trans_link_pos_opt: 33 | sche_type: cosine 34 | s: 0.01 35 | trans_link_cls_opt: 36 | sche_type: cosine 37 | s: 0.01 38 | trans_link_bond_opt: 39 | sche_type: cosine 40 | s: 0.01 41 | 42 | eps_net: 43 | net_type: node_edge_net 44 | encoder: 45 | num_blocks: 6 46 | cutoff: 15. 47 | use_gate: True 48 | num_gaussians: 20 49 | expansion_mode: exp 50 | tr_output_type: newton_equation 51 | rot_output_type: invariant_eps 52 | output_n_heads: 8 53 | separate_att: True 54 | sym_force: False 55 | 56 | train: 57 | seed: 2023 58 | loss_weights: 59 | frag_rot: 1.0 60 | frag_pos: 1.0 61 | link_pos: 1.0 62 | link_cls: 100.0 63 | link_bond: 100.0 64 | 65 | batch_size: 64 66 | num_workers: 8 67 | n_acc_batch: 1 68 | max_iters: 500000 69 | val_freq: 2000 70 | pos_noise_std: 0.05 71 | max_grad_norm: 50.0 72 | optimizer: 73 | type: adamw 74 | lr: 5.e-4 75 | weight_decay: 1.e-8 76 | beta1: 0.99 77 | beta2: 0.999 78 | scheduler: 79 | type: plateau 80 | factor: 0.6 81 | patience: 10 82 | min_lr: 1.e-6 -------------------------------------------------------------------------------- /configs/training/warhead_protac.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: protac 3 | path: ./data/protac 4 | version: v1 5 | split_mode: warhead # [warhead, ligase, random] 6 | index_name: 3d_index.pkl 7 | max_num_atoms: 70 8 | 9 | model: 10 | num_steps: 500 11 | node_emb_dim: 256 12 | edge_emb_dim: 64 13 | time_emb_type: plain 14 | time_emb_dim: 1 15 | time_emb_scale: 1000 16 | 17 | train_frag_rot: True 18 | train_frag_pos: True 19 | train_link: True 20 | train_bond: True 21 | 22 | frag_pos_prior: 23 | known_anchor: False 24 | known_linker_bond: False 25 | rel_geometry: two_pos_and_rot 26 | 27 | diffusion: 28 | trans_rot_opt: 29 | sche_type: cosine 30 | s: 0.01 31 | trans_pos_opt: 32 | sche_type: cosine 33 | s: 0.01 34 | trans_link_pos_opt: 35 | sche_type: cosine 36 | s: 0.01 37 | trans_link_cls_opt: 38 | sche_type: cosine 39 | s: 0.01 40 | trans_link_bond_opt: 41 | sche_type: cosine 42 | s: 0.01 43 | 44 | eps_net: 45 | net_type: node_edge_net 46 | encoder: 47 | num_blocks: 6 48 | cutoff: 15. 49 | use_gate: True 50 | num_gaussians: 20 51 | expansion_mode: exp 52 | tr_output_type: newton_equation 53 | rot_output_type: euler_equation 54 | output_n_heads: 8 55 | 56 | train: 57 | ckpt_path: ckpts/zinc_model.pt 58 | seed: 2023 59 | loss_weights: 60 | frag_rot: 1.0 61 | frag_pos: 1.0 62 | link_pos: 1.0 63 | link_cls: 100.0 64 | link_bond: 100.0 65 | 66 | batch_size: 2 67 | num_workers: 8 68 | n_acc_batch: 1 69 | max_iters: 500000 70 | val_freq: 2000 71 | pos_noise_std: 0.05 72 | max_grad_norm: 50.0 73 | optimizer: 74 | type: adamw 75 | lr: 1.e-4 76 | weight_decay: 1.e-8 77 | beta1: 0.99 78 | beta2: 0.999 79 | scheduler: 80 | type: plateau 81 | factor: 0.6 82 | patience: 10 83 | min_lr: 1.e-6 -------------------------------------------------------------------------------- /configs/training/zinc.yml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: zinc 3 | path: ./data/zinc_difflinker 4 | version: full # [tiny, full] 5 | index_name: index_full.pkl 6 | 7 | model: 8 | num_steps: 500 9 | node_emb_dim: 256 10 | edge_emb_dim: 64 11 | time_emb_type: plain 12 | time_emb_dim: 1 13 | time_emb_scale: 1000 14 | 15 | train_frag_rot: True 16 | train_frag_pos: True 17 | train_link: True 18 | train_bond: True 19 | 20 | frag_pos_prior: 21 | known_anchor: False 22 | known_linker_bond: False 23 | rel_geometry: two_pos_and_rot 24 | 25 | diffusion: 26 | trans_rot_opt: 27 | sche_type: cosine 28 | s: 0.01 29 | trans_pos_opt: 30 | sche_type: cosine 31 | s: 0.01 32 | trans_link_pos_opt: 33 | sche_type: cosine 34 | s: 0.01 35 | trans_link_cls_opt: 36 | sche_type: cosine 37 | s: 0.01 38 | trans_link_bond_opt: 39 | sche_type: cosine 40 | s: 0.01 41 | 42 | eps_net: 43 | net_type: node_edge_net 44 | encoder: 45 | num_blocks: 6 46 | cutoff: 15. 47 | use_gate: True 48 | num_gaussians: 20 49 | expansion_mode: exp 50 | tr_output_type: newton_equation_outer 51 | rot_output_type: euler_equation_outer 52 | output_n_heads: 8 53 | separate_att: True 54 | sym_force: False 55 | 56 | train: 57 | seed: 2023 58 | loss_weights: 59 | frag_rot: 1.0 60 | frag_pos: 1.0 61 | link_pos: 1.0 62 | link_cls: 100.0 63 | link_bond: 100.0 64 | 65 | batch_size: 64 66 | num_workers: 8 67 | n_acc_batch: 1 68 | max_iters: 500000 69 | val_freq: 2000 70 | pos_noise_std: 0.05 71 | max_grad_norm: 50.0 72 | optimizer: 73 | type: adamw 74 | lr: 5.e-4 75 | weight_decay: 1.e-8 76 | beta1: 0.99 77 | beta2: 0.999 78 | scheduler: 79 | type: plateau 80 | factor: 0.6 81 | patience: 10 82 | min_lr: 1.e-6 -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | #zinc_3dlinker/*.json 2 | #zinc_3dlinker/smi_train.txt 3 | #protac/protac.sdf 4 | zinc_difflinker/index_full.pkl 5 | !.gitignore -------------------------------------------------------------------------------- /data/protac/3d_index.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanjq/LinkerNet/280759c16ccecece0d81ab9cebe3f44041b80e51/data/protac/3d_index.pkl -------------------------------------------------------------------------------- /data/protac/index.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanjq/LinkerNet/280759c16ccecece0d81ab9cebe3f44041b80e51/data/protac/index.pkl -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanjq/LinkerNet/280759c16ccecece0d81ab9cebe3f44041b80e51/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/linker_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_scatter 3 | import numpy as np 4 | from torch_geometric.data import Data, Batch 5 | from torch_geometric.loader import DataLoader 6 | 7 | FOLLOW_BATCH = () 8 | 9 | 10 | class FragLinkerData(Data): 11 | 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | 15 | # @staticmethod 16 | # def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None, **kwargs): 17 | # instance = FragLinkerData(**kwargs) 18 | # 19 | # if protein_dict is not None: 20 | # for key, item in protein_dict.items(): 21 | # instance['protein_' + key] = item 22 | # 23 | # if ligand_dict is not None: 24 | # for key, item in ligand_dict.items(): 25 | # instance['ligand_' + key] = item 26 | # 27 | # instance['ligand_nbh_list'] = {i.item(): [j.item() for k, j in enumerate(instance.ligand_bond_index[1]) 28 | # if instance.ligand_bond_index[0, k].item() == i] 29 | # for i in instance.ligand_bond_index[0]} 30 | # return instance 31 | 32 | # def __inc__(self, key, value, *args, **kwargs): 33 | # if key == 'ligand_bond_index': 34 | # return self['ligand_element'].size(0) 35 | # # elif key == 'ligand_context_bond_index': 36 | # # return self['ligand_context_element'].size(0) 37 | # else: 38 | # return super().__inc__(key, value) 39 | 40 | 41 | def batch_from_data_list(data_list): 42 | return Batch.from_data_list(data_list, follow_batch=FOLLOW_BATCH) 43 | 44 | 45 | def torchify_dict(data): 46 | output = {} 47 | for k, v in data.items(): 48 | if isinstance(v, np.ndarray): 49 | output[k] = torch.from_numpy(v) 50 | else: 51 | output[k] = v 52 | return output 53 | -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import time 6 | from torch_geometric.nn import radius_graph, knn_graph 7 | import math 8 | 9 | 10 | class GaussianSmearing(nn.Module): 11 | # used to embed the edge distances 12 | def __init__(self, start=0.0, stop=5.0, num_gaussians=50, expansion_mode='linear'): 13 | super().__init__() 14 | self.start, self.stop, self.num_gaussians = start, stop, num_gaussians 15 | self.expansion_mode = expansion_mode 16 | if expansion_mode == 'exp': 17 | offset = torch.exp(torch.linspace(start=np.log(start+1), end=np.log(stop+1), steps=num_gaussians)) - 1 18 | diff = torch.diff(offset) 19 | diff = torch.cat([diff[:1], diff]) 20 | coeff = -0.5 / (diff ** 2) 21 | self.register_buffer('coeff', coeff) 22 | elif expansion_mode == 'linear': 23 | offset = torch.linspace(start=start, end=stop, steps=num_gaussians) 24 | self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2 25 | else: 26 | raise NotImplementedError('type_ must be either exp or linear') 27 | 28 | self.register_buffer('offset', offset) 29 | 30 | def __repr__(self): 31 | return f'GaussianSmearing(start={self.start}, stop={self.stop}, ' \ 32 | f'num_gaussians={self.num_gaussians}, expansion_mode={self.expansion_mode})' 33 | 34 | def forward(self, dist): 35 | dist = dist.view(-1, 1) - self.offset.view(1, -1) 36 | return torch.exp(self.coeff * torch.pow(dist, 2)) 37 | 38 | 39 | class AngleExpansion(nn.Module): 40 | def __init__(self, start=1.0, stop=5.0, half_expansion=10): 41 | super(AngleExpansion, self).__init__() 42 | l_mul = 1. / torch.linspace(stop, start, half_expansion) 43 | r_mul = torch.linspace(start, stop, half_expansion) 44 | coeff = torch.cat([l_mul, r_mul], dim=-1) 45 | self.register_buffer('coeff', coeff) 46 | 47 | def forward(self, angle): 48 | return torch.cos(angle.view(-1, 1) * self.coeff.view(1, -1)) 49 | 50 | 51 | class Swish(nn.Module): 52 | def __init__(self): 53 | super(Swish, self).__init__() 54 | self.beta = nn.Parameter(torch.tensor(1.0)) 55 | 56 | def forward(self, x): 57 | return x * torch.sigmoid(self.beta * x) 58 | 59 | 60 | class ShiftedSoftplus(nn.Module): 61 | def __init__(self): 62 | super().__init__() 63 | self.shift = torch.log(torch.tensor(2.0)).item() 64 | 65 | def forward(self, x): 66 | return F.softplus(x) - self.shift 67 | 68 | 69 | NONLINEARITIES = { 70 | "tanh": nn.Tanh(), 71 | "relu": nn.ReLU(), 72 | "softplus": nn.Softplus(), 73 | "elu": nn.ELU(), 74 | "swish": Swish(), 75 | 'silu': nn.SiLU() 76 | } 77 | 78 | 79 | class MLP(nn.Module): 80 | """MLP with the same hidden dim across all layers.""" 81 | 82 | def __init__(self, in_dim, out_dim, hidden_dim, num_layer=2, norm=True, act_fn='relu', act_last=False): 83 | super().__init__() 84 | layers = [] 85 | for layer_idx in range(num_layer): 86 | if layer_idx == 0: 87 | layers.append(nn.Linear(in_dim, hidden_dim)) 88 | elif layer_idx == num_layer - 1: 89 | layers.append(nn.Linear(hidden_dim, out_dim)) 90 | else: 91 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 92 | if layer_idx < num_layer - 1 or act_last: 93 | if norm: 94 | layers.append(nn.LayerNorm(hidden_dim)) 95 | layers.append(NONLINEARITIES[act_fn]) 96 | self.net = nn.Sequential(*layers) 97 | 98 | def forward(self, x): 99 | return self.net(x) 100 | 101 | 102 | def outer_product(*vectors): 103 | for index, vector in enumerate(vectors): 104 | if index == 0: 105 | out = vector.unsqueeze(-1) 106 | else: 107 | out = out * vector.unsqueeze(1) 108 | out = out.view(out.shape[0], -1).unsqueeze(-1) 109 | return out.squeeze() 110 | 111 | 112 | def get_h_dist(dist_metric, hi, hj): 113 | if dist_metric == 'euclidean': 114 | h_dist = torch.sum((hi - hj) ** 2, -1, keepdim=True) 115 | return h_dist 116 | elif dist_metric == 'cos_sim': 117 | hi_norm = torch.norm(hi, p=2, dim=-1, keepdim=True) 118 | hj_norm = torch.norm(hj, p=2, dim=-1, keepdim=True) 119 | h_dist = torch.sum(hi * hj, -1, keepdim=True) / (hi_norm * hj_norm) 120 | return h_dist, hj_norm 121 | 122 | 123 | def compose_context(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand): 124 | # previous version has problems when ligand atom types are fixed 125 | # (due to sorting randomly in case of same element) 126 | 127 | batch_ctx = torch.cat([batch_protein, batch_ligand], dim=0) 128 | # sort_idx = batch_ctx.argsort() 129 | sort_idx = torch.sort(batch_ctx, stable=True).indices 130 | 131 | mask_ligand = torch.cat([ 132 | torch.zeros([batch_protein.size(0)], device=batch_protein.device).bool(), 133 | torch.ones([batch_ligand.size(0)], device=batch_ligand.device).bool(), 134 | ], dim=0)[sort_idx] 135 | 136 | batch_ctx = batch_ctx[sort_idx] 137 | if isinstance(h_protein, list): 138 | h_ctx_sca = torch.cat([h_protein[0], h_ligand[0]], dim=0)[sort_idx] # (N_protein+N_ligand, H) 139 | h_ctx_vec = torch.cat([h_protein[1], h_ligand[1]], dim=0)[sort_idx] # (N_protein+N_ligand, H, 3) 140 | h_ctx = [h_ctx_sca, h_ctx_vec] 141 | else: 142 | h_ctx = torch.cat([h_protein, h_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, H) 143 | pos_ctx = torch.cat([pos_protein, pos_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, 3) 144 | 145 | return h_ctx, pos_ctx, batch_ctx, mask_ligand 146 | 147 | 148 | def hybrid_edge_connection(ligand_pos, protein_pos, k, ligand_index, protein_index): 149 | # fully-connected for ligand atoms 150 | dst = torch.repeat_interleave(ligand_index, len(ligand_index)) 151 | src = ligand_index.repeat(len(ligand_index)) 152 | mask = dst != src 153 | dst, src = dst[mask], src[mask] 154 | ll_edge_index = torch.stack([src, dst]) 155 | 156 | # knn for ligand-protein edges 157 | ligand_protein_pos_dist = torch.unsqueeze(ligand_pos, 1) - torch.unsqueeze(protein_pos, 0) 158 | ligand_protein_pos_dist = torch.norm(ligand_protein_pos_dist, p=2, dim=-1) 159 | knn_p_idx = torch.topk(ligand_protein_pos_dist, k=k, largest=False, dim=1).indices 160 | knn_p_idx = protein_index[knn_p_idx] 161 | knn_l_idx = torch.unsqueeze(ligand_index, 1) 162 | knn_l_idx = knn_l_idx.repeat(1, k) 163 | pl_edge_index = torch.stack([knn_p_idx, knn_l_idx], dim=0) 164 | pl_edge_index = pl_edge_index.view(2, -1) 165 | return ll_edge_index, pl_edge_index 166 | 167 | 168 | def batch_hybrid_edge_connection(x, k, mask_ligand, batch, add_p_index=False): 169 | batch_size = batch.max().item() + 1 170 | batch_ll_edge_index, batch_pl_edge_index, batch_p_edge_index = [], [], [] 171 | with torch.no_grad(): 172 | for i in range(batch_size): 173 | ligand_index = ((batch == i) & (mask_ligand == 1)).nonzero()[:, 0] 174 | protein_index = ((batch == i) & (mask_ligand == 0)).nonzero()[:, 0] 175 | # print(f'batch: {i}, ligand_index: {ligand_index} {len(ligand_index)}') 176 | ligand_pos, protein_pos = x[ligand_index], x[protein_index] 177 | ll_edge_index, pl_edge_index = hybrid_edge_connection( 178 | ligand_pos, protein_pos, k, ligand_index, protein_index) 179 | batch_ll_edge_index.append(ll_edge_index) 180 | batch_pl_edge_index.append(pl_edge_index) 181 | if add_p_index: 182 | all_pos = torch.cat([protein_pos, ligand_pos], 0) 183 | p_edge_index = knn_graph(all_pos, k=k, flow='source_to_target') 184 | p_edge_index = p_edge_index[:, p_edge_index[1] < len(protein_pos)] 185 | p_src, p_dst = p_edge_index 186 | all_index = torch.cat([protein_index, ligand_index], 0) 187 | # print('len protein index: ', len(protein_index)) 188 | # print('max index: ', max(p_src), max(p_dst)) 189 | p_edge_index = torch.stack([all_index[p_src], all_index[p_dst]], 0) 190 | batch_p_edge_index.append(p_edge_index) 191 | # edge_index.append(torch.cat([ll_edge_index, pl_edge_index], -1)) 192 | 193 | if add_p_index: 194 | edge_index = [torch.cat([ll, pl, p], -1) for ll, pl, p in zip( 195 | batch_ll_edge_index, batch_pl_edge_index, batch_p_edge_index)] 196 | else: 197 | edge_index = [torch.cat([ll, pl], -1) for ll, pl in zip(batch_ll_edge_index, batch_pl_edge_index)] 198 | edge_index = torch.cat(edge_index, -1) 199 | return edge_index 200 | 201 | 202 | # Time Embedding 203 | def basic_time_embedding(beta): 204 | emb = torch.stack([beta, torch.sin(beta), torch.cos(beta)], dim=-1) # (N, 3) 205 | return emb 206 | 207 | 208 | def sinusoidal_embedding(timesteps, embedding_dim, max_positions=10000): 209 | """ from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py """ 210 | assert len(timesteps.shape) == 1 211 | half_dim = embedding_dim // 2 212 | emb = math.log(max_positions) / (half_dim - 1) 213 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) 214 | emb = timesteps.float()[:, None] * emb[None, :] 215 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 216 | if embedding_dim % 2 == 1: # zero pad 217 | emb = F.pad(emb, (0, 1), mode='constant') 218 | assert emb.shape == (timesteps.shape[0], embedding_dim) 219 | return emb 220 | 221 | 222 | class GaussianFourierProjection(nn.Module): 223 | """Gaussian Fourier embeddings for noise levels. 224 | from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/models/layerspp.py#L32 225 | """ 226 | 227 | def __init__(self, embedding_size=256, scale=1.0): 228 | super().__init__() 229 | self.W = nn.Parameter(torch.randn(embedding_size // 2) * scale, requires_grad=False) 230 | 231 | def forward(self, x): 232 | x_proj = x[:, None] * self.W[None, :] * 2 * math.pi 233 | emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 234 | return emb 235 | 236 | 237 | def get_timestep_embedding(embedding_type, embedding_dim, embedding_scale=10000): 238 | if embedding_type == 'basic': 239 | emb_func = (lambda x: basic_time_embedding(x)) 240 | elif embedding_type == 'sin': 241 | emb_func = (lambda x: sinusoidal_embedding(embedding_scale * x, embedding_dim)) 242 | elif embedding_type == 'fourier': 243 | emb_func = GaussianFourierProjection(embedding_size=embedding_dim, scale=embedding_scale) 244 | else: 245 | raise NotImplemented 246 | return emb_func 247 | -------------------------------------------------------------------------------- /models/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .node_edge_net import NodeEdgeNet 2 | 3 | 4 | def get_refine_net(config, net_type, node_hidden_dim, edge_hidden_dim, train_link): 5 | if net_type == 'node_edge_net': 6 | refine_net = NodeEdgeNet( 7 | node_dim=node_hidden_dim, 8 | edge_dim=edge_hidden_dim, 9 | num_blocks=config.num_blocks, 10 | cutoff=config.cutoff, 11 | use_gate=config.use_gate, 12 | update_pos=train_link, 13 | expansion_mode=config.get('expansion_mode', 'linear') 14 | ) 15 | else: 16 | raise ValueError(net_type) 17 | return refine_net 18 | -------------------------------------------------------------------------------- /models/encoders/node_edge_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Module, Linear, ModuleList 4 | from torch_scatter import scatter_sum 5 | from models.common import GaussianSmearing, MLP 6 | 7 | 8 | class NodeBlock(Module): 9 | 10 | def __init__(self, node_dim, edge_dim, hidden_dim, use_gate): 11 | super().__init__() 12 | self.use_gate = use_gate 13 | self.node_dim = node_dim 14 | 15 | self.node_net = MLP(node_dim, hidden_dim, hidden_dim) 16 | self.edge_net = MLP(edge_dim, hidden_dim, hidden_dim) 17 | self.msg_net = Linear(hidden_dim, hidden_dim) 18 | 19 | if self.use_gate: 20 | self.gate = MLP(edge_dim+node_dim+1, hidden_dim, hidden_dim) # add 1 for time 21 | 22 | self.centroid_lin = Linear(node_dim, hidden_dim) 23 | self.layer_norm = nn.LayerNorm(hidden_dim) 24 | self.act = nn.ReLU() 25 | self.out_transform = Linear(hidden_dim, node_dim) 26 | 27 | def forward(self, x, edge_index, edge_attr, node_time): 28 | """ 29 | Args: 30 | x: Node features, (N, H). 31 | edge_index: (2, E). 32 | edge_attr: (E, H) 33 | """ 34 | N = x.size(0) 35 | row, col = edge_index # (E,) , (E,) 36 | 37 | h_node = self.node_net(x) # (N, H) 38 | 39 | # Compose messages 40 | h_edge = self.edge_net(edge_attr) # (E, H_per_head) 41 | msg_j = self.msg_net(h_edge * h_node[col]) 42 | 43 | if self.use_gate: 44 | gate = self.gate(torch.cat([edge_attr, x[col], node_time[col]], dim=-1)) 45 | msg_j = msg_j * torch.sigmoid(gate) 46 | 47 | # Aggregate messages 48 | aggr_msg = scatter_sum(msg_j, row, dim=0, dim_size=N) 49 | out = self.centroid_lin(x) + aggr_msg 50 | 51 | out = self.layer_norm(out) 52 | out = self.out_transform(self.act(out)) 53 | return out 54 | 55 | 56 | class BondFFN(Module): 57 | def __init__(self, bond_dim, node_dim, inter_dim, use_gate, out_dim=None): 58 | super().__init__() 59 | out_dim = bond_dim if out_dim is None else out_dim 60 | self.use_gate = use_gate 61 | self.bond_linear = Linear(bond_dim, inter_dim, bias=False) 62 | self.node_linear = Linear(node_dim, inter_dim, bias=False) 63 | self.inter_module = MLP(inter_dim, out_dim, inter_dim) 64 | if self.use_gate: 65 | self.gate = MLP(bond_dim+node_dim+1, out_dim, 32) # +1 for time 66 | 67 | def forward(self, bond_feat_input, node_feat_input, time): 68 | bond_feat = self.bond_linear(bond_feat_input) 69 | node_feat = self.node_linear(node_feat_input) 70 | inter_feat = bond_feat * node_feat 71 | inter_feat = self.inter_module(inter_feat) 72 | if self.use_gate: 73 | gate = self.gate(torch.cat([bond_feat_input, node_feat_input, time], dim=-1)) 74 | inter_feat = inter_feat * torch.sigmoid(gate) 75 | return inter_feat 76 | 77 | 78 | class QKVLin(Module): 79 | def __init__(self, h_dim, key_dim, num_heads): 80 | super().__init__() 81 | self.num_heads = num_heads 82 | self.q_lin = Linear(h_dim, key_dim) 83 | self.k_lin = Linear(h_dim, key_dim) 84 | self.v_lin = Linear(h_dim, h_dim) 85 | 86 | def forward(self, inputs): 87 | n = inputs.size(0) 88 | return [ 89 | self.q_lin(inputs).view(n, self.num_heads, -1), 90 | self.k_lin(inputs).view(n, self.num_heads, -1), 91 | self.v_lin(inputs).view(n, self.num_heads, -1), 92 | ] 93 | 94 | 95 | class EdgeBlock(Module): 96 | def __init__(self, edge_dim, node_dim, hidden_dim=None, use_gate=True): 97 | super().__init__() 98 | self.use_gate = use_gate 99 | inter_dim = edge_dim * 2 if hidden_dim is None else hidden_dim 100 | 101 | self.bond_ffn_left = BondFFN(edge_dim, node_dim, inter_dim=inter_dim, use_gate=use_gate) 102 | self.bond_ffn_right = BondFFN(edge_dim, node_dim, inter_dim=inter_dim, use_gate=use_gate) 103 | 104 | self.node_ffn_left = Linear(node_dim, edge_dim) 105 | self.node_ffn_right = Linear(node_dim, edge_dim) 106 | 107 | self.self_ffn = Linear(edge_dim, edge_dim) 108 | self.layer_norm = nn.LayerNorm(edge_dim) 109 | self.out_transform = Linear(edge_dim, edge_dim) 110 | self.act = nn.ReLU() 111 | 112 | def forward(self, h_bond, bond_index, h_node, bond_time): 113 | """ 114 | h_bond: (b, bond_dim) 115 | bond_index: (2, b) 116 | h_node: (n, node_dim) 117 | """ 118 | N = h_node.size(0) 119 | left_node, right_node = bond_index 120 | 121 | # message from neighbor bonds 122 | msg_bond_left = self.bond_ffn_left(h_bond, h_node[left_node], bond_time) 123 | msg_bond_left = scatter_sum(msg_bond_left, right_node, dim=0, dim_size=N) 124 | msg_bond_left = msg_bond_left[left_node] 125 | 126 | msg_bond_right = self.bond_ffn_right(h_bond, h_node[right_node], bond_time) 127 | msg_bond_right = scatter_sum(msg_bond_right, left_node, dim=0, dim_size=N) 128 | msg_bond_right = msg_bond_right[right_node] 129 | 130 | h_bond = ( 131 | msg_bond_left + msg_bond_right 132 | + self.node_ffn_left(h_node[left_node]) 133 | + self.node_ffn_right(h_node[right_node]) 134 | + self.self_ffn(h_bond) 135 | ) 136 | h_bond = self.layer_norm(h_bond) 137 | 138 | h_bond = self.out_transform(self.act(h_bond)) 139 | return h_bond 140 | 141 | 142 | class NodeEdgeNet(Module): 143 | def __init__(self, node_dim, edge_dim, num_blocks, cutoff, use_gate, **kwargs): 144 | super().__init__() 145 | self.node_dim = node_dim 146 | self.edge_dim = edge_dim 147 | self.num_blocks = num_blocks 148 | self.cutoff = cutoff 149 | self.use_gate = use_gate 150 | self.kwargs = kwargs 151 | 152 | if 'num_gaussians' not in kwargs: 153 | num_gaussians = 16 154 | else: 155 | num_gaussians = kwargs['num_gaussians'] 156 | if 'start' not in kwargs: 157 | start = 0 158 | else: 159 | start = kwargs['start'] 160 | self.distance_expansion = GaussianSmearing( 161 | start=start, stop=cutoff, num_gaussians=num_gaussians, expansion_mode=kwargs['expansion_mode']) 162 | print('distance expansion: ', self.distance_expansion) 163 | if ('update_edge' in kwargs) and (not kwargs['update_edge']): 164 | self.update_edge = False 165 | input_edge_dim = num_gaussians 166 | else: 167 | self.update_edge = True # default update edge 168 | input_edge_dim = edge_dim + num_gaussians 169 | 170 | if ('update_pos' in kwargs) and (not kwargs['update_pos']): 171 | self.update_pos = False 172 | else: 173 | self.update_pos = True # default update pos 174 | print('update pos: ', self.update_pos) 175 | # node network 176 | self.node_blocks_with_edge = ModuleList() 177 | self.edge_embs = ModuleList() 178 | self.edge_blocks = ModuleList() 179 | self.pos_blocks = ModuleList() 180 | for _ in range(num_blocks): 181 | self.node_blocks_with_edge.append(NodeBlock( 182 | node_dim=node_dim, edge_dim=edge_dim, hidden_dim=node_dim, use_gate=use_gate, 183 | )) 184 | self.edge_embs.append(Linear(input_edge_dim, edge_dim)) 185 | if self.update_edge: 186 | self.edge_blocks.append(EdgeBlock( 187 | edge_dim=edge_dim, node_dim=node_dim, use_gate=use_gate, 188 | )) 189 | if self.update_pos: 190 | self.pos_blocks.append(PosUpdate( 191 | node_dim=node_dim, edge_dim=edge_dim, hidden_dim=edge_dim, use_gate=use_gate, 192 | )) 193 | 194 | @property 195 | def node_hidden_dim(self): 196 | return self.node_dim 197 | 198 | @property 199 | def edge_hidden_dim(self): 200 | return self.edge_dim 201 | 202 | def forward(self, pos_node, h_node, h_edge, edge_index, linker_mask, node_time, edge_time): 203 | for i in range(self.num_blocks): 204 | # edge fetures before each block 205 | if self.update_pos or (i==0): 206 | h_edge_dist, relative_vec, distance = self._build_edges_dist(pos_node, edge_index) 207 | if self.update_edge: 208 | h_edge = torch.cat([h_edge, h_edge_dist], dim=-1) 209 | else: 210 | h_edge = h_edge_dist 211 | h_edge = self.edge_embs[i](h_edge) 212 | 213 | # node and edge feature updates 214 | h_node_with_edge = self.node_blocks_with_edge[i](h_node, edge_index, h_edge, node_time) 215 | if self.update_edge: 216 | h_edge = h_edge + self.edge_blocks[i](h_edge, edge_index, h_node, edge_time) 217 | h_node = h_node + h_node_with_edge 218 | # pos updates 219 | if self.update_pos: 220 | delta_x = self.pos_blocks[i](h_node, h_edge, edge_index, relative_vec, distance, edge_time) 221 | pos_node = pos_node + delta_x * linker_mask[:, None] # only linker positions will be updated 222 | return pos_node, h_node, h_edge 223 | 224 | def _build_edges_dist(self, pos, edge_index): 225 | # distance 226 | relative_vec = pos[edge_index[0]] - pos[edge_index[1]] 227 | distance = torch.norm(relative_vec, dim=-1, p=2) 228 | edge_dist = self.distance_expansion(distance) 229 | return edge_dist, relative_vec, distance 230 | 231 | 232 | class PosUpdate(Module): 233 | def __init__(self, node_dim, edge_dim, hidden_dim, use_gate): 234 | super().__init__() 235 | self.left_lin_edge = MLP(node_dim, edge_dim, hidden_dim) 236 | self.right_lin_edge = MLP(node_dim, edge_dim, hidden_dim) 237 | self.edge_lin = BondFFN(edge_dim, edge_dim, node_dim, use_gate, out_dim=1) 238 | 239 | def forward(self, h_node, h_edge, edge_index, relative_vec, distance, edge_time): 240 | edge_index_left, edge_index_right = edge_index 241 | 242 | left_feat = self.left_lin_edge(h_node[edge_index_left]) 243 | right_feat = self.right_lin_edge(h_node[edge_index_right]) 244 | weight_edge = self.edge_lin(h_edge, left_feat * right_feat, edge_time) 245 | 246 | # relative_vec = pos_node[edge_index_left] - pos_node[edge_index_right] 247 | # distance = torch.norm(relative_vec, dim=-1, keepdim=True) 248 | force_edge = weight_edge * relative_vec / distance.unsqueeze(-1) / (distance.unsqueeze(-1) + 1.) 249 | delta_pos = scatter_sum(force_edge, edge_index_left, dim=0, dim_size=h_node.shape[0]) 250 | 251 | return delta_pos 252 | 253 | 254 | class PosPredictor(Module): 255 | def __init__(self, node_dim, edge_dim, bond_dim, use_gate): 256 | super().__init__() 257 | self.left_lin_edge = MLP(node_dim, edge_dim, hidden_dim=edge_dim) 258 | self.right_lin_edge = MLP(node_dim, edge_dim, hidden_dim=edge_dim) 259 | self.edge_lin = BondFFN(edge_dim, edge_dim, node_dim, use_gate, out_dim=1) 260 | 261 | self.bond_dim = bond_dim 262 | if bond_dim > 0: 263 | self.left_lin_bond = MLP(node_dim, bond_dim, hidden_dim=bond_dim) 264 | self.right_lin_bond = MLP(node_dim, bond_dim, hidden_dim=bond_dim) 265 | self.bond_lin = BondFFN(bond_dim, bond_dim, node_dim, use_gate, out_dim=1) 266 | 267 | def forward(self, h_node, pos_node, h_bond, bond_index, h_edge, edge_index, is_frag): 268 | # 1 pos update through edges 269 | is_left_frag = is_frag[edge_index[0]] 270 | edge_index_left, edge_index_right = edge_index[:, is_left_frag] 271 | 272 | left_feat = self.left_lin_edge(h_node[edge_index_left]) 273 | right_feat = self.right_lin_edge(h_node[edge_index_right]) 274 | weight_edge = self.edge_lin(h_edge[is_left_frag], left_feat * right_feat) 275 | force_edge = weight_edge * (pos_node[edge_index_left] - pos_node[edge_index_right]) 276 | delta_pos = scatter_sum(force_edge, edge_index_left, dim=0, dim_size=h_node.shape[0]) 277 | 278 | # 2 pos update through bonds 279 | if self.bond_dim > 0: 280 | is_left_frag = is_frag[bond_index[0]] 281 | bond_index_left, bond_index_right = bond_index[:, is_left_frag] 282 | 283 | left_feat = self.left_lin_bond(h_node[bond_index_left]) 284 | right_feat = self.right_lin_bond(h_node[bond_index_right]) 285 | weight_bond = self.bond_lin(h_bond[is_left_frag], left_feat * right_feat) 286 | force_bond = weight_bond * (pos_node[bond_index_left] - pos_node[bond_index_right]) 287 | delta_pos = delta_pos + scatter_sum(force_bond, bond_index_left, dim=0, dim_size=h_node.shape[0]) 288 | 289 | pos_update = pos_node + delta_pos / 10. 290 | return pos_update 291 | -------------------------------------------------------------------------------- /models/eps_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.encoders import get_refine_net 4 | from utils.geometry import apply_rotation_to_vector, quaternion_1ijk_to_rotation_matrix 5 | from utils.so3 import so3vec_to_rotation, rotation_to_so3vec 6 | from torch_scatter import scatter_mean, scatter_softmax, scatter_sum 7 | from models.common import MLP 8 | import numpy as np 9 | 10 | 11 | class ForceLayer(nn.Module): 12 | def __init__(self, input_dim, hidden_dim, output_dim, n_heads, edge_feat_dim, time_dim, 13 | separate_att=False, act_fn='relu', norm=True): 14 | super().__init__() 15 | self.input_dim = input_dim 16 | self.hidden_dim = hidden_dim 17 | self.output_dim = output_dim 18 | self.n_heads = n_heads 19 | self.edge_feat_dim = edge_feat_dim 20 | self.time_dim = time_dim 21 | self.act_fn = act_fn 22 | self.separate_att = separate_att 23 | 24 | kv_input_dim = input_dim * 2 + edge_feat_dim 25 | self.xk_func = MLP(kv_input_dim, output_dim, hidden_dim, norm=norm, act_fn=act_fn) 26 | self.xv_func = MLP(kv_input_dim + time_dim, self.n_heads, hidden_dim, norm=norm, act_fn=act_fn) 27 | self.xq_func = MLP(input_dim, output_dim, hidden_dim, norm=norm, act_fn=act_fn) 28 | 29 | def forward(self, h, rel_x, edge_feat, edge_index, inner_edge_mask, t, e_w=None): 30 | N = h.size(0) 31 | src, dst = edge_index 32 | hi, hj = h[dst], h[src] 33 | 34 | # multi-head attention 35 | kv_input = torch.cat([edge_feat, hi, hj], -1) 36 | 37 | k = self.xk_func(kv_input).view(-1, self.n_heads, self.output_dim // self.n_heads) 38 | v = self.xv_func(torch.cat([kv_input, t[dst]], -1)) 39 | e_w = e_w.view(-1, 1) if e_w is not None else 1. 40 | v = v * e_w 41 | 42 | v = v.unsqueeze(-1) * rel_x.unsqueeze(1) # (xi - xj) [n_edges, n_heads, 3] 43 | q = self.xq_func(h).view(-1, self.n_heads, self.output_dim // self.n_heads) 44 | 45 | if self.separate_att: 46 | k_in, v_in = k[inner_edge_mask], v[inner_edge_mask] 47 | alpha_in = scatter_softmax((q[dst[inner_edge_mask]] * k_in / np.sqrt(k_in.shape[-1])).sum(-1), 48 | dst[inner_edge_mask], dim=0) # (E, heads) 49 | m_in = alpha_in.unsqueeze(-1) * v_in # (E, heads, 3) 50 | inner_forces = scatter_sum(m_in, dst[inner_edge_mask], dim=0, dim_size=N).mean(1) 51 | 52 | k_out, v_out = k[~inner_edge_mask], v[~inner_edge_mask] 53 | alpha_out = scatter_softmax((q[dst[~inner_edge_mask]] * k_out / np.sqrt(k_out.shape[-1])).sum(-1), 54 | dst[~inner_edge_mask], dim=0) # (E, heads) 55 | m_out = alpha_out.unsqueeze(-1) * v_out # (E, heads, 3) 56 | outer_forces = scatter_sum(m_out, dst[~inner_edge_mask], dim=0, dim_size=N).mean(1) 57 | 58 | else: 59 | # Compute attention weights 60 | alpha = scatter_softmax((q[dst] * k / np.sqrt(k.shape[-1])).sum(-1), dst, dim=0) # (E, heads) 61 | 62 | # Perform attention-weighted message-passing 63 | m = alpha.unsqueeze(-1) * v # (E, heads, 3) 64 | inner_forces = scatter_sum(m[inner_edge_mask], dst[inner_edge_mask], dim=0, dim_size=N).mean(1) 65 | outer_forces = scatter_sum(m[~inner_edge_mask], dst[~inner_edge_mask], dim=0, dim_size=N).mean(1) 66 | # output = scatter_sum(m, dst, dim=0, dim_size=N) # (N, heads, 3) 67 | return inner_forces, outer_forces # [num_nodes, 3] 68 | 69 | 70 | class SymForceLayer(nn.Module): 71 | def __init__(self, input_dim, hidden_dim, edge_feat_dim, time_dim, 72 | act_fn='relu', norm=True): 73 | super().__init__() 74 | self.input_dim = input_dim 75 | self.hidden_dim = hidden_dim 76 | self.output_dim = 1 77 | self.edge_feat_dim = edge_feat_dim 78 | self.time_dim = time_dim 79 | self.act_fn = act_fn 80 | self.pred_layer = MLP(input_dim + edge_feat_dim + time_dim, 1, hidden_dim, 81 | num_layer=3, norm=norm, act_fn=act_fn) 82 | 83 | def forward(self, h, rel_x, edge_feat, edge_index, inner_edge_mask, t): 84 | N = h.size(0) 85 | src, dst = edge_index[:, ~inner_edge_mask] 86 | rel_x = rel_x[~inner_edge_mask] 87 | distance = torch.norm(rel_x, p=2, dim=-1) 88 | hi, hj = h[dst], h[src] 89 | 90 | feat = torch.cat([edge_feat[~inner_edge_mask], (hi + hj) / 2, t[dst]], -1) 91 | forces = self.pred_layer(feat) * rel_x / distance.unsqueeze(-1) / (distance.unsqueeze(-1) + 1.) 92 | outer_forces = scatter_sum(forces, dst, dim=0, dim_size=N) 93 | return None, outer_forces # [num_nodes, 3] 94 | 95 | 96 | class EpsilonNet(nn.Module): 97 | 98 | def __init__(self, cfg, node_emb_dim, edge_emb_dim, time_emb_dim, num_classes, num_bond_classes, 99 | train_frag_rot, train_frag_pos, train_link, train_bond, pred_frag_dist=False, 100 | use_rel_geometry=False, softmax_last=True): 101 | super().__init__() 102 | self.encoder_type = cfg.net_type 103 | self.encoder = get_refine_net(cfg.encoder, cfg.net_type, node_emb_dim, edge_emb_dim, train_link) 104 | self.num_frags = 2 105 | self.pred_frag_dist = pred_frag_dist 106 | self.use_rel_geometry = use_rel_geometry 107 | self.train_frag_rot = train_frag_rot 108 | self.train_frag_pos = train_frag_pos 109 | self.train_link = train_link 110 | self.train_bond = train_bond 111 | self.tr_output_type = cfg.get('tr_output_type', 'invariant_eps') 112 | self.rot_output_type = cfg.rot_output_type 113 | print('EpsNet Softmax Last: ', softmax_last) 114 | 115 | if 'newton_equation' in self.tr_output_type or 'euler_equation' in self.rot_output_type: 116 | if cfg.get('sym_force', False): 117 | self.force_layer = SymForceLayer( 118 | input_dim=self.encoder.node_hidden_dim, 119 | hidden_dim=self.encoder.node_hidden_dim, 120 | edge_feat_dim=self.encoder.edge_hidden_dim, 121 | time_dim=time_emb_dim 122 | ) 123 | else: 124 | self.force_layer = ForceLayer( 125 | input_dim=self.encoder.node_hidden_dim, 126 | hidden_dim=self.encoder.node_hidden_dim, 127 | output_dim=self.encoder.node_hidden_dim, 128 | n_heads=cfg.output_n_heads, 129 | edge_feat_dim=self.encoder.edge_hidden_dim, 130 | time_dim=time_emb_dim, 131 | separate_att=cfg.get('separate_att', False), 132 | ) 133 | print(self.force_layer) 134 | if self.tr_output_type == 'invariant_eps' or self.rot_output_type == 'invariant_eps': 135 | self.frag_aggr = nn.Sequential( 136 | nn.Linear(node_emb_dim, node_emb_dim * 2), nn.ReLU(), 137 | nn.Linear(node_emb_dim * 2, node_emb_dim), nn.ReLU(), 138 | nn.Linear(node_emb_dim, node_emb_dim) 139 | ) 140 | if self.tr_output_type == 'invariant_eps': 141 | self.eps_crd_net = nn.Sequential( 142 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(), 143 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(), 144 | nn.Linear(node_emb_dim, 1 if pred_frag_dist else 3) 145 | ) 146 | if self.rot_output_type == 'invariant_eps': 147 | self.eps_rot_net = nn.Sequential( 148 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(), 149 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(), 150 | nn.Linear(node_emb_dim, 3) 151 | ) 152 | 153 | if softmax_last: 154 | self.eps_cls_net = nn.Sequential( 155 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(), 156 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(), 157 | nn.Linear(node_emb_dim, num_classes), nn.Softmax(dim=-1) 158 | ) 159 | else: 160 | self.eps_cls_net = nn.Sequential( 161 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(), 162 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(), 163 | nn.Linear(node_emb_dim, num_classes) 164 | ) 165 | if self.train_bond: 166 | if softmax_last: 167 | self.bond_pred_net = nn.Sequential( 168 | nn.Linear(edge_emb_dim, edge_emb_dim), nn.ReLU(), 169 | nn.Linear(edge_emb_dim, num_bond_classes), nn.Softmax(dim=-1) 170 | ) 171 | else: 172 | self.bond_pred_net = nn.Sequential( 173 | nn.Linear(edge_emb_dim, edge_emb_dim), nn.ReLU(), 174 | nn.Linear(edge_emb_dim, num_bond_classes) 175 | ) 176 | 177 | def update_frag_geometry(self, final_x, final_h, final_h_bond, 178 | edge_index, inner_edge_mask, batch_node, mask, node_t_emb, R): 179 | 180 | eps_d, eps_pos, recon_x = None, None, None 181 | v_next, R_next = None, None 182 | if self.tr_output_type == 'invariant_eps' or self.rot_output_type == 'invariant_eps': 183 | # Aggregate features for fragment update 184 | frag_h = final_h[mask] 185 | frag_batch = batch_node[mask] 186 | frag_h_mean = scatter_mean(frag_h, frag_batch, dim=0) 187 | frag_h_mean = self.frag_aggr(frag_h_mean) # (G, H) 188 | 189 | # Position / Distance changes 190 | if self.tr_output_type == 'invariant_eps': 191 | if self.pred_frag_dist: 192 | eps_d = self.eps_crd_net(frag_h_mean) # (G, 1) 193 | else: 194 | eps_crd = self.eps_crd_net(frag_h_mean) # local coordinates (G, 3) 195 | eps_pos = apply_rotation_to_vector(R, eps_crd) # (G, 3) 196 | 197 | if self.rot_output_type == 'invariant_eps': 198 | # New orientation 199 | eps_rot = self.eps_rot_net(frag_h_mean) # (G, 3) 200 | U = quaternion_1ijk_to_rotation_matrix(eps_rot) # (G, 3, 3) 201 | R_next = R @ U 202 | v_next = rotation_to_so3vec(R_next) # (G, 3) 203 | 204 | if 'newton_equation' in self.tr_output_type or 'euler_equation' in self.rot_output_type: 205 | src, dst = edge_index 206 | rel_x = final_x[dst] - final_x[src] 207 | inner_forces, outer_forces = self.force_layer( 208 | final_h, rel_x, final_h_bond, edge_index, inner_edge_mask, t=node_t_emb) # equivariant 209 | if 'outer' in self.tr_output_type: 210 | assert 'outer' in self.rot_output_type 211 | forces = outer_forces 212 | else: 213 | forces = inner_forces + outer_forces 214 | 215 | if 'newton_equation' in self.tr_output_type: 216 | frag_batch = batch_node[mask] 217 | frag_center = scatter_mean(final_x[mask], frag_batch, dim=0) # (G, 3) 218 | recon_x = frag_center + scatter_mean(forces[mask], frag_batch, dim=0) 219 | 220 | if 'euler_equation' in self.rot_output_type: 221 | x_f1, force_1, batch_node_f1 = final_x[mask], forces[mask], batch_node[mask] 222 | mu_1 = scatter_mean(x_f1, batch_node_f1, dim=0)[batch_node_f1] 223 | tau = scatter_sum(torch.cross(x_f1 - mu_1, force_1), batch_node_f1, dim=0) # (num_graphs, 3) 224 | inertia_mat = scatter_sum( 225 | torch.sum((x_f1 - mu_1) ** 2, dim=-1)[:, None, None] * torch.eye(3)[None].to(x_f1) - 226 | (x_f1 - mu_1).unsqueeze(-1) @ (x_f1 - mu_1).unsqueeze(-2), batch_node_f1, 227 | dim=0) # (num_graphs, 3, 3) 228 | omega = torch.linalg.solve(inertia_mat, tau.unsqueeze(-1)).squeeze(-1) # (num_graphs, 3) 229 | R_next = so3vec_to_rotation(-omega) @ R 230 | v_next = rotation_to_so3vec(R_next) 231 | 232 | assert (eps_pos is not None) or (eps_d is not None) or (recon_x is not None) 233 | return eps_pos, eps_d, recon_x, v_next, R_next 234 | 235 | def forward(self, 236 | x_noisy, node_attr, R1_noisy, R2_noisy, edge_index, edge_attr, 237 | fragment_mask, edge_mask, inner_edge_mask, batch_node, 238 | node_t_emb, edge_t_emb, graph_t_emb, rel_fragment_mask, rel_R_noisy 239 | ): 240 | """ 241 | Args: 242 | x_noisy: (N, 3) 243 | node_attr: (N, H) 244 | R1_noisy: (num_graphs, 3, 3) 245 | f_mask: (N, ) 246 | 247 | v_t: (F, 3). 248 | p_t: (F, 3). 249 | node_feat: (N, Hn). 250 | edge_feat: (E, He). 251 | edge_index: (2, E) 252 | mask_ligand: (N, ). 253 | mask_ll_edge: (E, ). 254 | beta: (F, ). 255 | Returns: 256 | v_next: UPDATED (not epsilon) SO3-vector of orientations, (F, 3). 257 | R_next: (F, 3, 3). 258 | eps_pos: (F, 3). 259 | c_denoised: (F, C). 260 | """ 261 | linker_mask = (fragment_mask == 0) 262 | 263 | final_x, final_h, final_h_bond = self.encoder( 264 | pos_node=x_noisy, h_node=node_attr, h_edge=edge_attr, edge_index=edge_index, 265 | linker_mask=linker_mask, node_time=node_t_emb, edge_time=edge_t_emb) 266 | 267 | # Update linker 268 | outputs = {} 269 | if self.train_link: 270 | x_denoised = final_x[linker_mask] # (L, 3) 271 | c_denoised = self.eps_cls_net(final_h[linker_mask]) # may have softmax-ed, (L, K) 272 | outputs.update({ 273 | 'linker_x': x_denoised, 274 | 'linker_c': c_denoised 275 | }) 276 | 277 | if self.train_bond: 278 | in_bond_feat = final_h_bond[edge_mask] 279 | in_bond_feat = (in_bond_feat[::2] + in_bond_feat[1::2]).repeat_interleave(2, dim=0) 280 | bond_denoised = self.bond_pred_net(in_bond_feat) 281 | outputs['linker_bond'] = bond_denoised 282 | 283 | # Update fragment geometry 284 | if self.train_frag_rot or self.train_frag_pos: 285 | if self.use_rel_geometry: 286 | eps_pos, _, recon_x, v_next, R_next = self.update_frag_geometry( 287 | final_x, final_h, final_h_bond, 288 | edge_index, inner_edge_mask, batch_node, rel_fragment_mask, node_t_emb, rel_R_noisy) 289 | outputs.update({ 290 | 'frag_eps_pos': eps_pos, 291 | 'frag_v_next': v_next, 292 | 'frag_R_next': R_next 293 | }) 294 | else: 295 | eps_pos1, eps_d1, recon_x1, v_next1, R_next1 = self.update_frag_geometry( 296 | final_x, final_h, final_h_bond, 297 | edge_index, inner_edge_mask, batch_node, (fragment_mask == 1), node_t_emb, R1_noisy) 298 | eps_pos2, eps_d2, recon_x2, v_next2, R_next2 = self.update_frag_geometry( 299 | final_x, final_h, final_h_bond, 300 | edge_index, inner_edge_mask, batch_node, (fragment_mask == 2), node_t_emb, R2_noisy) 301 | 302 | # zero center 303 | if not self.pred_frag_dist: 304 | if self.tr_output_type == 'invariant_eps': 305 | center = (eps_pos1 + eps_pos2) / 2 306 | eps_pos1, eps_pos2 = eps_pos1 - center, eps_pos2 - center 307 | elif 'newton_equation' in self.tr_output_type: 308 | center = (recon_x1 + recon_x2) / 2 309 | recon_x1, recon_x2 = recon_x1 - center, recon_x2 - center 310 | else: 311 | raise ValueError(self.tr_output_type) 312 | 313 | outputs.update({ 314 | 'frag_eps_pos': (eps_pos1, eps_pos2), 315 | 'frag_eps_d': eps_d1 + eps_d2 if self.pred_frag_dist else None, 316 | 'frag_recon_x': (recon_x1, recon_x2), 317 | 'frag_v_next': (v_next1, v_next2), 318 | 'frag_R_next': (R_next1, R_next2) 319 | }) 320 | 321 | return outputs 322 | -------------------------------------------------------------------------------- /scripts/baselines/eval_3dlinker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from rdkit import RDLogger 6 | from torch_geometric.transforms import Compose 7 | from tqdm.auto import tqdm 8 | 9 | import utils.misc as misc 10 | import utils.transforms as trans 11 | from datasets.linker_dataset import get_linker_dataset 12 | from utils.evaluation import LinkerEvaluator 13 | from utils.visualize import * 14 | 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('sample_path', type=str) 19 | parser.add_argument('--num_eval', type=int, default=400) 20 | parser.add_argument('--save_path', type=str) 21 | parser.add_argument('--config_path', type=str, default='configs/sampling/zinc.yml') 22 | 23 | args = parser.parse_args() 24 | RDLogger.DisableLog('rdApp.*') 25 | 26 | mols = [] 27 | num_valid_mols = 0 28 | for fidx in range(args.num_eval): 29 | sdf_path = os.path.join(args.sample_path, f'data_{fidx}.sdf') 30 | if os.path.exists(sdf_path): 31 | m = Chem.SDMolSupplier(sdf_path, sanitize=False) 32 | mols.append(m) 33 | num_valid_mols += 1 34 | else: 35 | mols.append([]) 36 | print('Num datapoints: ', args.num_eval) 37 | print('Num valid mols: ', num_valid_mols) 38 | 39 | config = misc.load_config(args.config_path) 40 | # Transforms 41 | atom_featurizer = trans.FeaturizeAtom( 42 | config.dataset.name, known_anchor=False, add_atom_type=True, add_atom_feat=True) 43 | graph_builder = trans.BuildCompleteGraph(known_linker_bond=False) 44 | test_transform = Compose([ 45 | atom_featurizer, 46 | trans.SelectCandAnchors(mode='k-hop', k=1), 47 | graph_builder, 48 | trans.StackFragLocalPos(max_num_atoms=config.dataset.get('max_num_atoms', 30)), 49 | # trans.RelativeGeometry(mode=cfg_model.get('rel_geometry', 'distance_and_two_rot')) 50 | ]) 51 | dataset, subsets = get_linker_dataset( 52 | cfg=config.dataset, 53 | transform_map={'train': None, 'val': None, 'test': test_transform} 54 | ) 55 | test_set = subsets['test'] 56 | 57 | all_results = [] 58 | for i in range(args.num_eval): 59 | ref_data = test_set[i] 60 | gen_mols = mols[i] 61 | if len(gen_mols) == 0: 62 | continue 63 | 64 | num_frag_atoms = sum(ref_data.fragment_mask > 0) 65 | gen_data_list = [] 66 | for mol in gen_mols: 67 | if mol is None: 68 | gen_data_list.append(None) 69 | continue 70 | gen_data = ref_data.clone() 71 | gen_data['pos'] = torch.from_numpy(mol.GetConformer().GetPositions().astype(np.float32)) 72 | num_linker_atoms = len(gen_data['pos']) - num_frag_atoms 73 | if num_linker_atoms < 0: 74 | gen_data_list.append(None) 75 | continue 76 | gen_data['fragment_mask'] = torch.cat( 77 | [ref_data.fragment_mask[:num_frag_atoms], torch.zeros(num_linker_atoms).long()]) 78 | gen_data['linker_mask'] = (gen_data['fragment_mask'] == 0) 79 | gen_data_list.append(gen_data) 80 | 81 | results = { 82 | 'gen_mols': gen_mols, 83 | 'ref_data': ref_data, 84 | 'data_list': gen_data_list 85 | } 86 | all_results.append(results) 87 | 88 | evaluator = LinkerEvaluator(all_results, reconstruct=False) 89 | evaluator.evaluate() 90 | if args.save_path is not None: 91 | evaluator.save_metrics(args.save_path) 92 | -------------------------------------------------------------------------------- /scripts/baselines/eval_delinker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from rdkit import RDLogger 6 | from torch_geometric.transforms import Compose 7 | from tqdm.auto import tqdm 8 | 9 | import utils.misc as misc 10 | import utils.transforms as trans 11 | from datasets.linker_dataset import get_linker_dataset 12 | from utils.evaluation import LinkerEvaluator, standardise_linker, remove_dummys_mol 13 | from utils.visualize import * 14 | from utils import frag_utils 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('sample_path', type=str) 20 | parser.add_argument('--num_eval', type=int, default=400) 21 | parser.add_argument('--num_samples', type=int, default=250) 22 | parser.add_argument('--save_path', type=str) 23 | parser.add_argument('--config_path', type=str, default='configs/sampling/zinc.yml') 24 | 25 | args = parser.parse_args() 26 | RDLogger.DisableLog('rdApp.*') 27 | 28 | mols = [] 29 | with open(args.sample_path, 'r') as f: 30 | for line in tqdm(f.readlines()): 31 | parts = line.strip().split(' ') 32 | data = { 33 | 'fragments': parts[0], 34 | 'true_molecule': parts[1], 35 | 'pred_molecule': parts[2], 36 | 'pred_linker': parts[3] if len(parts) > 3 else '', 37 | } 38 | mols.append(Chem.MolFromSmiles(data['pred_molecule'])) 39 | print('Num data: ', len(mols)) 40 | 41 | config = misc.load_config(args.config_path) 42 | # Transforms 43 | atom_featurizer = trans.FeaturizeAtom( 44 | config.dataset.name, known_anchor=False, add_atom_type=True, add_atom_feat=True) 45 | graph_builder = trans.BuildCompleteGraph(known_linker_bond=False) 46 | test_transform = Compose([ 47 | atom_featurizer, 48 | trans.SelectCandAnchors(mode='k-hop', k=1), 49 | graph_builder, 50 | trans.StackFragLocalPos(max_num_atoms=config.dataset.get('max_num_atoms', 30)), 51 | # trans.RelativeGeometry(mode=cfg_model.get('rel_geometry', 'distance_and_two_rot')) 52 | ]) 53 | dataset, subsets = get_linker_dataset( 54 | cfg=config.dataset, 55 | transform_map={'train': None, 'val': None, 'test': test_transform} 56 | ) 57 | test_set = subsets['test'] 58 | 59 | all_results = [] 60 | for i in range(args.num_eval): 61 | ref_data = test_set[i] 62 | gen_mols = [mols[idx] for idx in range(args.num_samples * i, args.num_samples * (i + 1))] 63 | 64 | num_frag_atoms = sum(ref_data.fragment_mask > 0) 65 | gen_data_list = [] 66 | for mol in gen_mols: 67 | if mol is None: 68 | gen_data_list.append(None) 69 | continue 70 | gen_data = ref_data.clone() 71 | gen_num_atoms = mol.GetNumAtoms() 72 | # gen_data['pos'] = torch.from_numpy(mol.GetConformer().GetPositions().astype(np.float32)) 73 | num_linker_atoms = gen_num_atoms - num_frag_atoms 74 | if num_linker_atoms < 0: 75 | gen_data_list.append(None) 76 | continue 77 | # gen_data['fragment_mask'] = torch.cat( 78 | # [ref_data.fragment_mask[:num_frag_atoms], torch.zeros(num_linker_atoms).long()]) 79 | # gen_data['linker_mask'] = (gen_data['fragment_mask'] == 0) 80 | gen_data_list.append(gen_data) 81 | 82 | results = { 83 | 'gen_mols': gen_mols, 84 | 'ref_data': ref_data, 85 | 'data_list': gen_data_list 86 | } 87 | all_results.append(results) 88 | 89 | # fix the mismatch between gen mols and ref mol 90 | data_list = [] 91 | with open(args.sample_path, 'r') as f: 92 | for line in tqdm(f.readlines()): 93 | parts = line.strip().split(' ') 94 | data = { 95 | 'fragments': parts[0], 96 | 'true_molecule': parts[1], 97 | 'pred_molecule': parts[2], 98 | 'pred_linker': parts[3] if len(parts) > 3 else '', 99 | } 100 | data_list.append(data) 101 | 102 | valid_all_results = [] 103 | for i in range(args.num_eval): 104 | valid_results = [] 105 | gen_data_list = [data_list[idx] for idx in range(args.num_samples * i, args.num_samples * (i + 1))] 106 | for gen_data in gen_data_list: 107 | gen_smi = gen_data['pred_molecule'] 108 | ref_smi = gen_data['true_molecule'] 109 | raw_frag_smi = gen_data['fragments'] 110 | frag_smi = Chem.MolToSmiles(remove_dummys_mol(Chem.MolFromSmiles(raw_frag_smi))) 111 | try: 112 | # gen_mols is chemically valid 113 | Chem.SanitizeMol(Chem.MolFromSmiles(gen_smi), sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES) 114 | except: 115 | continue 116 | 117 | # gen_mols should contain both fragments 118 | if len(Chem.MolFromSmiles(gen_smi).GetSubstructMatch(Chem.MolFromSmiles(frag_smi))) != Chem.MolFromSmiles( 119 | frag_smi).GetNumAtoms(): 120 | continue 121 | 122 | # Determine linkers of generated molecules 123 | try: 124 | linker = frag_utils.get_linker(Chem.MolFromSmiles(gen_smi), Chem.MolFromSmiles(frag_smi), frag_smi) 125 | linker_smi = standardise_linker(linker) 126 | except: 127 | continue 128 | 129 | valid_results.append({ 130 | 'ref_smi': ref_smi, 131 | 'frag_smi': frag_smi, 132 | 'gen_smi': gen_smi, 133 | 'linker_smi': linker_smi, 134 | 'metrics': {} 135 | }) 136 | valid_all_results.append(valid_results) 137 | 138 | evaluator = LinkerEvaluator(all_results, reconstruct=False) 139 | evaluator.gen_all_results = valid_all_results 140 | evaluator.validity(args.num_samples) 141 | 142 | evaluator.evaluate(eval_3d=False) 143 | if args.save_path is not None: 144 | evaluator.save_metrics(args.save_path) 145 | 146 | 147 | -------------------------------------------------------------------------------- /scripts/baselines/eval_difflinker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from rdkit import RDLogger 6 | from torch_geometric.transforms import Compose 7 | from tqdm.auto import tqdm 8 | 9 | import utils.misc as misc 10 | import utils.transforms as trans 11 | from datasets.linker_dataset import get_linker_dataset 12 | from utils.evaluation import LinkerEvaluator 13 | from utils.visualize import * 14 | 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('sample_path', type=str) 19 | parser.add_argument('--num_eval', type=int, default=None) 20 | parser.add_argument('--num_samples', type=int, default=100) 21 | parser.add_argument('--save_path', type=str) 22 | parser.add_argument('--config_path', type=str, default='configs/sampling/zinc.yml') 23 | 24 | args = parser.parse_args() 25 | RDLogger.DisableLog('rdApp.*') 26 | 27 | if args.sample_path.endswith('.sdf'): 28 | mols = Chem.SDMolSupplier(args.sample_path) 29 | if args.num_eval is None: 30 | args.num_eval = len(mols) // args.num_samples 31 | print('Load sdf done! Total mols: ', len(mols)) 32 | else: 33 | mols = [] 34 | folders = sorted(os.listdir(args.sample_path), key=lambda x: int(x)) 35 | if args.num_eval: 36 | folders = folders[:args.num_eval] 37 | else: 38 | args.num_eval = len(folders) 39 | num_valid_mols = 0 40 | for folder in tqdm(folders, desc='Load molecules'): 41 | sdf_files = [fn for fn in os.listdir(os.path.join(args.sample_path, folder)) if fn.endswith('.sdf')] 42 | # for fname in sorted(sdf_files, key=lambda x: int(x[:-5])): 43 | for fname in range(args.num_samples): 44 | sdf_path = os.path.join(args.sample_path, folder, str(fname) + '_.sdf') 45 | if os.path.exists(sdf_path): 46 | supp = Chem.SDMolSupplier(sdf_path, sanitize=False) 47 | mol = list(supp)[0] 48 | mols.append(mol) 49 | num_valid_mols += 1 50 | else: 51 | mols.append(None) 52 | print('Num datapoints: ', args.num_eval) 53 | print('Num valid mols: ', num_valid_mols) 54 | 55 | config = misc.load_config(args.config_path) 56 | # Transforms 57 | atom_featurizer = trans.FeaturizeAtom( 58 | config.dataset.name, known_anchor=False, add_atom_type=True, add_atom_feat=True) 59 | graph_builder = trans.BuildCompleteGraph(known_linker_bond=False) 60 | test_transform = Compose([ 61 | atom_featurizer, 62 | trans.SelectCandAnchors(mode='k-hop', k=1), 63 | graph_builder, 64 | trans.StackFragLocalPos(max_num_atoms=config.dataset.get('max_num_atoms', 30)), 65 | # trans.RelativeGeometry(mode=cfg_model.get('rel_geometry', 'distance_and_two_rot')) 66 | ]) 67 | dataset, subsets = get_linker_dataset( 68 | cfg=config.dataset, 69 | transform_map={'train': None, 'val': None, 'test': test_transform} 70 | ) 71 | test_set = subsets['test'] 72 | 73 | all_results = [] 74 | for i in range(args.num_eval): 75 | ref_data = test_set[i] 76 | gen_mols = [mols[idx] for idx in range(args.num_samples * i, args.num_samples * (i + 1))] 77 | num_frag_atoms = sum(ref_data.fragment_mask > 0) 78 | gen_data_list = [] 79 | for mol in gen_mols: 80 | if mol is None: 81 | gen_data_list.append(None) 82 | continue 83 | gen_data = ref_data.clone() 84 | gen_data['pos'] = torch.from_numpy(mol.GetConformer().GetPositions().astype(np.float32)) 85 | num_linker_atoms = len(gen_data['pos']) - num_frag_atoms 86 | if num_linker_atoms < 0: 87 | gen_data_list.append(None) 88 | continue 89 | gen_data['fragment_mask'] = torch.cat( 90 | [ref_data.fragment_mask[:num_frag_atoms], torch.zeros(num_linker_atoms).long()]) 91 | gen_data['linker_mask'] = (gen_data['fragment_mask'] == 0) 92 | gen_data_list.append(gen_data) 93 | 94 | results = { 95 | 'gen_mols': gen_mols, 96 | 'ref_data': ref_data, 97 | 'data_list': gen_data_list 98 | } 99 | all_results.append(results) 100 | 101 | evaluator = LinkerEvaluator(all_results, reconstruct=False) 102 | evaluator.evaluate() 103 | if args.save_path is not None: 104 | evaluator.save_metrics(args.save_path) 105 | -------------------------------------------------------------------------------- /scripts/eval_protac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from rdkit import RDLogger 3 | import torch 4 | from utils.evaluation import LinkerEvaluator 5 | from glob import glob 6 | import os 7 | from tqdm.auto import tqdm 8 | from torch_geometric.transforms import Compose 9 | import utils.misc as misc 10 | import utils.transforms as trans 11 | from datasets.linker_dataset import get_linker_dataset 12 | from utils.evaluation import LinkerEvaluator 13 | from utils.visualize import * 14 | 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('sample_path', type=str) 19 | parser.add_argument('--num_eval', type=int, default=None) 20 | parser.add_argument('--recon', type=eval, default=False) 21 | parser.add_argument('--save', type=eval, default=True) 22 | args = parser.parse_args() 23 | RDLogger.DisableLog('rdApp.*') 24 | 25 | if args.sample_path.endswith('.pt'): 26 | print(f'There is only one sampling file to evaluate') 27 | results = [torch.load(args.sample_path)] 28 | else: 29 | results_file_list = sorted(glob(os.path.join(args.sample_path, 'sampling_*.pt'))) 30 | if args.num_eval: 31 | results_file_list = results_file_list[:args.num_eval] 32 | print(f'There are {len(results_file_list)} files to evaluate') 33 | results = [] 34 | for f in tqdm(results_file_list, desc='Load sampling files'): 35 | results.append(torch.load(f)) 36 | 37 | if 'data_list' not in results[0].keys(): 38 | print('Can not find data_list in result keys -- add data_list based on the test set') 39 | config = misc.load_config('configs/sampling/zinc.yml') 40 | # Transforms 41 | atom_featurizer = trans.FeaturizeAtom( 42 | config.dataset.name, known_anchor=False, add_atom_type=True, add_atom_feat=True) 43 | graph_builder = trans.BuildCompleteGraph(known_linker_bond=False) 44 | test_transform = Compose([ 45 | atom_featurizer, 46 | trans.SelectCandAnchors(mode='k-hop', k=2), 47 | graph_builder, 48 | trans.StackFragLocalPos(max_num_atoms=config.dataset.get('max_num_atoms', 30)), 49 | # trans.RelativeGeometry(mode=cfg_model.get('rel_geometry', 'distance_and_two_rot')) 50 | ]) 51 | dataset, subsets = get_linker_dataset( 52 | cfg=config.dataset, 53 | transform_map={'train': None, 'val': None, 'test': test_transform} 54 | ) 55 | test_set = subsets['test'] 56 | 57 | all_results = [] 58 | for i in range(len(results)): 59 | ref_data = test_set[i] 60 | gen_mols = results[i]['gen_mols'] 61 | num_frag_atoms = sum(ref_data.fragment_mask > 0) 62 | gen_data_list = [ref_data.clone() for _ in range(100)] 63 | 64 | new_results = { 65 | 'gen_mols': gen_mols, 66 | 'ref_data': ref_data, 67 | 'data_list': gen_data_list, 68 | 'final_x': results[i]['final_x'], 69 | 'final_c': results[i]['final_c'], 70 | 'final_bond': results[i]['final_bond'] 71 | } 72 | all_results.append(new_results) 73 | results = all_results 74 | 75 | evaluator = LinkerEvaluator(results, reconstruct=args.recon) 76 | evaluator.evaluate() 77 | save_path = os.path.join(args.sample_path, 'summary.csv') 78 | if args.save: 79 | evaluator.save_metrics(save_path) 80 | -------------------------------------------------------------------------------- /scripts/prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | from tqdm.auto import tqdm 4 | from rdkit import Chem 5 | import torch 6 | import os 7 | import numpy as np 8 | from utils.data import compute_3d_coors_multiple, set_rdmol_positions 9 | 10 | 11 | def read_protac_file(file_path): 12 | with open(file_path, 'r') as f: 13 | lines = f.readlines() 14 | data = [] 15 | for i, line in enumerate(lines): 16 | toks = line.strip().split(' ') 17 | if len(toks) == 4: 18 | smi_protac, smi_linker, smi_warhead, smi_ligase = toks 19 | else: 20 | raise ValueError("Incorrect input format.") 21 | data.append({'smi_protac': smi_protac, 22 | 'smi_linker': smi_linker, 23 | 'smi_warhead': smi_warhead, 24 | 'smi_ligase': smi_ligase}) 25 | return data 26 | 27 | 28 | def validity_check(num_atoms, m1, m2, m3): 29 | if len(set(m1).intersection(set(m2))) != 0 or len(set(m2).intersection(set(m3))) != 0 or \ 30 | len(set(m1).intersection(set(m3))) != 0 or len(m1) + len(m2) + len(m3) != num_atoms: 31 | return False 32 | else: 33 | return True 34 | 35 | 36 | def preprocess_protac(raw_file_path, save_path): 37 | if raw_file_path.endswith('.txt'): 38 | raw_data = read_protac_file(raw_file_path) 39 | else: 40 | with open(raw_file_path, 'rb') as f: 41 | raw_data = pickle.load(f) 42 | processed_data = [] 43 | total = len(raw_data) 44 | for i, d in enumerate(tqdm(raw_data, desc='Generate 3D conformer')): 45 | smi_protac, smi_linker, smi_warhead, smi_ligase = d['smi_protac'], d['smi_linker'], \ 46 | d['smi_warhead'], d['smi_ligase'] 47 | 48 | # frag_0 (3d pos), frag_1 (3d pos), linker (3d pos) 49 | mol = Chem.MolFromSmiles(smi_protac) 50 | # generate 3d coordinates of mols 51 | pos, _ = compute_3d_coors_multiple(mol, maxIters=1000) 52 | if pos is None: 53 | print('Generate conformer fail!') 54 | continue 55 | mol = set_rdmol_positions(mol, pos) 56 | 57 | if raw_file_path.endswith('.txt'): 58 | warhead_m = mol.GetSubstructMatch(Chem.MolFromSmiles(smi_warhead)) 59 | ligase_m = mol.GetSubstructMatch(Chem.MolFromSmiles(smi_ligase)) 60 | linker_m = mol.GetSubstructMatch(Chem.MolFromSmiles(smi_linker)) 61 | d.update({ 62 | 'atom_indices_warhead': warhead_m, 63 | 'atom_indices_ligase': ligase_m, 64 | 'atom_indices_linker': linker_m 65 | }) 66 | else: 67 | warhead_m, ligase_m, linker_m = d['atom_indices_warhead'], d['atom_indices_ligase'], d['atom_indices_linker'] 68 | 69 | valid = validity_check(mol.GetNumAtoms(), warhead_m, ligase_m, linker_m) 70 | if not valid: 71 | print('Validity check fail!') 72 | continue 73 | 74 | processed_data.append({ 75 | 'mol': mol, 76 | **d 77 | }) 78 | 79 | print('Saving data') 80 | with open(save_path, 'wb') as f: 81 | pickle.dump(processed_data, f) 82 | print('Length raw data: \t%d' % total) 83 | print('Length processed data: \t%d' % len(processed_data)) 84 | 85 | 86 | def preprocess_zinc_from_difflinker(raw_file_dir, save_path, mode): 87 | datasets = {} 88 | if mode == 'full': 89 | all_subsets = ['train', 'val', 'test'] 90 | else: 91 | all_subsets = ['val', 'test'] 92 | for subset in all_subsets: 93 | data_list = [] 94 | n_fail = 0 95 | all_data = torch.load(os.path.join(raw_file_dir, f'zinc_final_{subset}.pt'), map_location='cpu') 96 | full_rdmols = Chem.SDMolSupplier(os.path.join(raw_file_dir, f'zinc_final_{subset}_mol.sdf'), sanitize=False) 97 | frag_rdmols = Chem.SDMolSupplier(os.path.join(raw_file_dir, f'zinc_final_{subset}_frag.sdf'), sanitize=False) 98 | link_rdmols = Chem.SDMolSupplier(os.path.join(raw_file_dir, f'zinc_final_{subset}_link.sdf'), sanitize=False) 99 | assert len(all_data) == len(frag_rdmols) == len(link_rdmols) 100 | for i in tqdm(range(len(all_data)), desc=subset): 101 | data = all_data[i] 102 | mol, frag_mol, link_mol = full_rdmols[i], frag_rdmols[i], link_rdmols[i] 103 | # if mol is None or frag_mol is None or link_mol is None: 104 | # print('Fail i: ', i) 105 | # n_fail += 1 106 | # continue 107 | pos = data['positions'] 108 | fragment_mask, linker_mask = data['fragment_mask'].bool(), data['linker_mask'].bool() 109 | # align full mol with positions, etc. 110 | mol_pos = mol.GetConformer().GetPositions() 111 | mapping = np.linalg.norm(mol_pos[None] - data['positions'].numpy()[:, None], axis=-1).argmin(axis=1) 112 | assert len(np.unique(mapping)) == len(mapping) 113 | new_mol = Chem.RenumberAtoms(mol, mapping.tolist()) 114 | Chem.SanitizeMol(new_mol) 115 | # check frag mol and link mol are aligned 116 | assert np.allclose(frag_mol.GetConformer().GetPositions(), pos[fragment_mask].numpy()) 117 | assert np.allclose(link_mol.GetConformer().GetPositions(), pos[linker_mask].numpy()) 118 | # print(mapping) 119 | 120 | # print(data['anchors'].nonzero(as_tuple=True)[0].tolist()) 121 | # Note: anchor atom index may be wrong! 122 | frag_mols = Chem.GetMolFrags(frag_mol, asMols=True, sanitizeFrags=False) 123 | assert len(frag_mols) == 2 124 | 125 | all_frag_atom_idx = set((fragment_mask == 1).nonzero()[:, 0].tolist()) 126 | frag_atom_idx_list = [] 127 | for m1 in new_mol.GetSubstructMatches(frag_mols[0]): 128 | for m2 in new_mol.GetSubstructMatches(frag_mols[1]): 129 | if len(set(m1).intersection(set(m2))) == 0 and set(m1).union(set(m2)) == all_frag_atom_idx: 130 | frag_atom_idx_list = [m1, m2] 131 | break 132 | 133 | try: 134 | assert len(frag_atom_idx_list) == 2 and all([x is not None and len(x) > 0 for x in frag_atom_idx_list]) 135 | except: 136 | print('Fail i: ', i) 137 | n_fail += 1 138 | continue 139 | new_fragment_mask = torch.zeros_like(fragment_mask).long() 140 | new_fragment_mask[list(frag_atom_idx_list[0])] = 1 141 | new_fragment_mask[list(frag_atom_idx_list[1])] = 2 142 | 143 | # extract frag mol directly from new_mol, in case the Kekulize error 144 | bond_ids = [] 145 | for bond_idx, bond in enumerate(new_mol.GetBonds()): 146 | start = bond.GetBeginAtomIdx() 147 | end = bond.GetEndAtomIdx() 148 | if (new_fragment_mask[start] > 0) == (new_fragment_mask[end] == 0): 149 | bond_ids.append(bond_idx) 150 | assert len(bond_ids) == 2 151 | break_mol = Chem.FragmentOnBonds(new_mol, bond_ids, addDummies=False) 152 | frags = [f for f in Chem.GetMolFrags(break_mol, asMols=True) 153 | if f.GetNumAtoms() != link_mol.GetNumAtoms() 154 | or not np.allclose(f.GetConformer().GetPositions(), link_mol.GetConformer().GetPositions())] 155 | assert len(frags) == 2 156 | new_frag_mol = Chem.CombineMols(*frags) 157 | assert np.allclose(new_frag_mol.GetConformer().GetPositions(), frag_mol.GetConformer().GetPositions()) 158 | 159 | data_list.append({ 160 | 'id': data['uuid'], 161 | 'smiles': data['name'], 162 | 'mol': new_mol, 163 | 'frag_mol': new_frag_mol, 164 | 'link_mol': link_mol, 165 | 'fragment_mask': new_fragment_mask, 166 | 'atom_indices_f1': list(frag_atom_idx_list[0]), 167 | 'atom_indices_f2': list(frag_atom_idx_list[1]), 168 | 'linker_mask': linker_mask, 169 | # 'anchors': data['anchors'].bool() 170 | }) 171 | print('n fail: ', n_fail) 172 | datasets[subset] = data_list 173 | 174 | print('Saving data') 175 | with open(save_path, 'wb') as f: 176 | pickle.dump(datasets, f) 177 | print('Length processed data: ', [f'{x}: {len(datasets[x])}' for x in datasets]) 178 | 179 | 180 | if __name__ == "__main__": 181 | parser = argparse.ArgumentParser() 182 | parser.add_argument('--raw_path', type=str, required=True) 183 | parser.add_argument('--dataset', type=str, default='protac', choices=['protac', 'zinc_difflinker']) 184 | parser.add_argument('--dest', type=str, required=True) 185 | parser.add_argument('--mode', type=str, default='full', choices=['tiny', 'full']) 186 | args = parser.parse_args() 187 | 188 | if args.dataset == 'protac': 189 | preprocess_protac(args.raw_path, args.dest) 190 | elif args.dataset == 'zinc_difflinker': 191 | preprocess_zinc_from_difflinker(args.raw_path, args.dest, args.mode) 192 | else: 193 | raise NotImplementedError 194 | -------------------------------------------------------------------------------- /scripts/sample_protac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from rdkit import RDLogger 7 | from torch_geometric.transforms import Compose 8 | from tqdm.auto import tqdm 9 | 10 | import utils.misc as misc 11 | import utils.transforms as trans 12 | from datasets.linker_dataset import get_linker_dataset 13 | from models.diff_protac_bond import DiffPROTACModel 14 | from utils.reconstruct_linker import parse_sampling_result, parse_sampling_result_with_bond 15 | from torch_geometric.data import Batch 16 | from utils.evaluation import eval_success_rate 17 | from utils.prior_num_atoms import setup_configs, sample_atom_num 18 | 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('config', type=str) 23 | parser.add_argument('--ckpt_path', type=str) 24 | parser.add_argument('--subset', choices=['val', 'test'], default='val') 25 | parser.add_argument('--device', type=str, default='cuda:0') 26 | parser.add_argument('--start_id', type=int, default=0) 27 | parser.add_argument('--end_id', type=int, default=-1) 28 | parser.add_argument('--num_samples', type=int, default=100) 29 | parser.add_argument('--outdir', type=str, default='./outputs_test') 30 | parser.add_argument('--tag', type=str, default='') 31 | parser.add_argument('--save_traj', type=eval, default=False) 32 | parser.add_argument('--cand_anchors_mode', type=str, default='k-hop', choices=['k-hop', 'exact']) 33 | parser.add_argument('--cand_anchors_k', type=int, default=2) 34 | args = parser.parse_args() 35 | RDLogger.DisableLog('rdApp.*') 36 | 37 | # Load configs 38 | config = misc.load_config(args.config) 39 | config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')] 40 | misc.seed_all(config.sample.seed) 41 | logger, writer, log_dir, ckpt_dir, vis_dir = misc.setup_logdir( 42 | args.config, args.outdir, mode='eval', tag=args.tag, create_dir=True) 43 | logger.info(args) 44 | 45 | # Load checkpoint 46 | ckpt_path = config.model.checkpoint if args.ckpt_path is None else args.ckpt_path 47 | ckpt = torch.load(ckpt_path, map_location=args.device) 48 | logger.info(f'Successfully load the model! {ckpt_path}') 49 | cfg_model = ckpt['configs'].model 50 | 51 | # Transforms 52 | test_transform = Compose([ 53 | trans.FeaturizeAtom(config.dataset.name, add_atom_feat=False) 54 | ]) 55 | atom_featurizer = trans.FeaturizeAtom( 56 | config.dataset.name, known_anchor=cfg_model.known_anchor, add_atom_type=False, add_atom_feat=True) 57 | graph_builder = trans.BuildCompleteGraph(known_linker_bond=cfg_model.known_linker_bond, 58 | known_cand_anchors=config.sample.get('cand_bond_mask', False)) 59 | init_transform = Compose([ 60 | atom_featurizer, 61 | trans.SelectCandAnchors(mode=args.cand_anchors_mode, k=args.cand_anchors_k), 62 | graph_builder, 63 | trans.StackFragLocalPos(max_num_atoms=config.dataset.get('max_num_atoms', 30)), 64 | trans.RelativeGeometry(mode=cfg_model.get('rel_geometry', 'distance_and_two_rot')) 65 | ]) 66 | logger.info(f'Init transform: {init_transform}') 67 | 68 | # Datasets and loaders 69 | logger.info('Loading dataset...') 70 | if config.dataset.split_mode == 'full': 71 | dataset = get_linker_dataset( 72 | cfg=config.dataset, 73 | transform_map=test_transform 74 | ) 75 | test_set = dataset 76 | else: 77 | dataset, subsets = get_linker_dataset( 78 | cfg=config.dataset, 79 | transform_map={'train': None, 'val': test_transform, 'test': test_transform} 80 | ) 81 | test_set = subsets[args.subset] 82 | logger.info(f'Test: {len(test_set)}') 83 | 84 | FOLLOW_BATCH = ['edge_type'] 85 | COLLATE_EXCLUDE_KEYS = ['nbh_list'] 86 | 87 | # Model 88 | logger.info('Building model...') 89 | model = DiffPROTACModel( 90 | cfg_model, 91 | num_classes=atom_featurizer.num_classes, 92 | num_bond_classes=graph_builder.num_bond_classes, 93 | atom_feature_dim=atom_featurizer.feature_dim, 94 | edge_feature_dim=graph_builder.bond_feature_dim 95 | ).to(args.device) 96 | logger.info('Num of parameters is %.2f M' % (np.sum([p.numel() for p in model.parameters()]) / 1e6)) 97 | model.load_state_dict(ckpt['model']) 98 | logger.info(f'Load model weights done!') 99 | 100 | # Sampling 101 | logger.info(f'Begin sampling [{args.start_id}, {args.end_id})...') 102 | assert config.sample.num_atoms in ['ref', 'prior'] 103 | if config.sample.num_atoms == 'prior': 104 | num_atoms_config = setup_configs(mode='frag_center_distance') 105 | else: 106 | num_atoms_config = None 107 | 108 | if args.end_id == -1: 109 | args.end_id = len(test_set) 110 | for idx in tqdm(range(args.start_id, args.end_id)): 111 | raw_data = test_set[idx] 112 | data_list = [raw_data.clone() for _ in range(args.num_samples)] 113 | # modify data list 114 | if num_atoms_config is None: 115 | new_data_list = [init_transform(data) for data in data_list] 116 | else: 117 | new_data_list = [] 118 | for data in data_list: 119 | # sample num atoms 120 | dist = torch.floor(data.frags_d) # (B, ) 121 | num_linker_atoms = sample_atom_num(dist, num_atoms_config).astype(int) 122 | num_f1_atoms = len(data.atom_indices_f1) 123 | num_f2_atoms = len(data.atom_indices_f2) 124 | num_f_atoms = num_f1_atoms + num_f2_atoms 125 | frag_pos = data.pos[data.fragment_mask > 0] 126 | frag_atom_type = data.atom_type[data.fragment_mask > 0] 127 | frag_bond_idx = (data.bond_index[0] < num_f_atoms) & (data.bond_index[1] < num_f_atoms) 128 | 129 | data.fragment_mask = torch.LongTensor([1] * num_f1_atoms + [2] * num_f2_atoms + [0] * num_linker_atoms) 130 | data.linker_mask = (data.fragment_mask == 0) 131 | data.anchor_mask = torch.cat([data.anchor_mask[:num_f_atoms], torch.zeros([num_linker_atoms]).long()]) 132 | data.bond_index = data.bond_index[:, frag_bond_idx] 133 | data.bond_type = data.bond_type[frag_bond_idx] 134 | data.pos = torch.cat([frag_pos, torch.zeros([num_linker_atoms, 3])], dim=0) 135 | data.atom_type = torch.cat([frag_atom_type, torch.zeros([num_linker_atoms]).long()]) 136 | new_data = init_transform(data) 137 | new_data_list.append(new_data) 138 | 139 | batch = Batch.from_data_list( 140 | new_data_list, follow_batch=FOLLOW_BATCH, exclude_keys=COLLATE_EXCLUDE_KEYS).to(args.device) 141 | traj_batch, final_x, final_c, final_bond = model.sample( 142 | batch, 143 | p_init_mode=cfg_model.frag_pos_prior, 144 | guidance_opt=config.sample.guidance_opt 145 | ) 146 | if model.train_bond: 147 | gen_mols = parse_sampling_result_with_bond( 148 | new_data_list, final_x, final_c, final_bond, atom_featurizer, 149 | known_linker_bonds=cfg_model.known_linker_bond, check_validity=True) 150 | else: 151 | gen_mols = parse_sampling_result(new_data_list, final_x, final_c, atom_featurizer) 152 | save_path = os.path.join(log_dir, f'sampling_{idx:06d}.pt') 153 | save_dict = { 154 | 'ref_data': init_transform(raw_data), 155 | 'data_list': new_data_list, # don't save it to reduce the size of outputs 156 | 'final_x': final_x, 'final_c': final_c, 'final_bond': final_bond, 157 | 'gen_mols': gen_mols 158 | } 159 | if args.save_traj: 160 | save_dict['traj'] = traj_batch 161 | 162 | torch.save(save_dict, save_path) 163 | logger.info('Sample done!') 164 | 165 | # Quick Eval 166 | recon_rate, complete_rate = [], [] 167 | anchor_dists = [] 168 | for idx in range(args.start_id, args.end_id): 169 | load_path = os.path.join(log_dir, f'sampling_{idx:06d}.pt') 170 | results = torch.load(load_path) 171 | rr, cr = eval_success_rate(results['gen_mols']) 172 | logger.info(f'idx: {idx} recon rate: {rr} complete rate: {cr}') 173 | recon_rate.append(rr) 174 | complete_rate.append(cr) 175 | logger.info(f'recon rate: mean: {np.mean(recon_rate):.4f} median: {np.median(recon_rate):.4f}') 176 | logger.info(f'complete rate: mean: {np.mean(complete_rate):.4f} median: {np.median(complete_rate):.4f}') 177 | -------------------------------------------------------------------------------- /scripts/train_protac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import torch.utils.tensorboard 7 | from rdkit import RDLogger 8 | from torch.nn.utils import clip_grad_norm_ 9 | from torch_geometric.loader import DataLoader 10 | from torch_geometric.transforms import Compose 11 | from tqdm.auto import tqdm 12 | 13 | import utils.misc as misc 14 | import utils.transforms as trans 15 | from datasets.linker_dataset import get_linker_dataset 16 | from models.diff_protac_bond import DiffPROTACModel 17 | from utils.evaluation import eval_success_rate 18 | from utils.reconstruct_linker import parse_sampling_result_with_bond, parse_sampling_result 19 | from utils.train import * 20 | 21 | torch.multiprocessing.set_sharing_strategy('file_system') 22 | 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('config', type=str, default='./configs/training/zinc.yml') 27 | parser.add_argument('--overfit_one', type=eval, default=False) 28 | parser.add_argument('--device', type=str, default='cuda') 29 | parser.add_argument('--logdir', type=str, default='./logs') 30 | parser.add_argument('--train_report_iter', type=int, default=200) 31 | parser.add_argument('--sampling_iter', type=int, default=10000) 32 | args = parser.parse_args() 33 | RDLogger.DisableLog('rdApp.*') 34 | 35 | # Load configs 36 | config = misc.load_config(args.config) 37 | misc.seed_all(config.train.seed) 38 | logger, writer, log_dir, ckpt_dir, vis_dir = misc.setup_logdir(args.config, args.logdir) 39 | logger.info(args) 40 | 41 | # Transforms 42 | cfg_model = config.model 43 | atom_featurizer = trans.FeaturizeAtom(config.dataset.name, known_anchor=cfg_model.known_anchor) 44 | graph_builder = trans.BuildCompleteGraph(known_linker_bond=cfg_model.known_linker_bond) 45 | max_num_atoms = config.dataset.get('max_num_atoms', 30) 46 | train_transform = Compose([ 47 | atom_featurizer, 48 | graph_builder, 49 | trans.StackFragLocalPos(max_num_atoms=max_num_atoms), 50 | trans.RelativeGeometry(mode=cfg_model.rel_geometry) 51 | ]) 52 | test_transform = Compose([ 53 | atom_featurizer, 54 | graph_builder, 55 | trans.StackFragLocalPos(max_num_atoms=max_num_atoms), 56 | trans.RelativeGeometry(mode=cfg_model.rel_geometry) 57 | ]) 58 | logger.info(f'Relative Geometry: {cfg_model.rel_geometry}') 59 | logger.info(f'Train transform: {train_transform}') 60 | logger.info(f'Test transform: {test_transform}') 61 | if cfg_model.rel_geometry in ['two_pos_and_rot', 'relative_pos_and_rot']: 62 | assert cfg_model.frag_pos_prior is None 63 | 64 | # Datasets and loaders 65 | logger.info('Loading dataset...') 66 | dataset, subsets = get_linker_dataset( 67 | cfg=config.dataset, 68 | transform_map={'train': train_transform, 'val': test_transform, 'test': test_transform} 69 | ) 70 | 71 | if config.dataset.version == 'tiny': 72 | if args.overfit_one: 73 | train_set, val_set, test_set = [subsets['val'][0].clone() for _ in range(config.train.batch_size)], \ 74 | [subsets['val'][0]], [subsets['val'][0]] 75 | else: 76 | train_set, val_set, test_set = subsets['val'], subsets['val'], subsets['val'] 77 | else: 78 | train_set, val_set, test_set = subsets['train'], subsets['val'], subsets['test'] 79 | # train_set, val_set, test_set = [subsets['train'][697].clone() for _ in range(config.train.batch_size)], \ 80 | # [subsets['val'][0]], [subsets['val'][0]] 81 | logger.info(f'Training: {len(train_set)} Validation: {len(val_set)} Test: {len(test_set)}') 82 | 83 | FOLLOW_BATCH = ['edge_type'] 84 | COLLATE_EXCLUDE_KEYS = ['nbh_list'] 85 | train_iterator = inf_iterator(DataLoader( 86 | train_set, 87 | batch_size=config.train.batch_size, 88 | shuffle=True, 89 | num_workers=config.train.num_workers, 90 | prefetch_factor=8, 91 | persistent_workers=True, 92 | follow_batch=FOLLOW_BATCH, 93 | exclude_keys=COLLATE_EXCLUDE_KEYS 94 | )) 95 | val_loader = DataLoader(val_set, config.train.batch_size, shuffle=False, 96 | follow_batch=FOLLOW_BATCH, exclude_keys=COLLATE_EXCLUDE_KEYS) 97 | test_loader = DataLoader(test_set, config.train.batch_size, shuffle=False, 98 | follow_batch=FOLLOW_BATCH, exclude_keys=COLLATE_EXCLUDE_KEYS) 99 | 100 | # Model 101 | logger.info('Building model...') 102 | model = DiffPROTACModel( 103 | config.model, 104 | num_classes=atom_featurizer.num_classes, 105 | num_bond_classes=graph_builder.num_bond_classes, 106 | atom_feature_dim=atom_featurizer.feature_dim, 107 | edge_feature_dim=graph_builder.bond_feature_dim 108 | ).to(args.device) 109 | logger.info('Num of parameters is %.2f M' % (np.sum([p.numel() for p in model.parameters()]) / 1e6)) 110 | if config.train.get('ckpt_path', None): 111 | ckpt = torch.load(config.train.ckpt_path, map_location=args.device) 112 | model.load_state_dict(ckpt['model']) 113 | logger.info(f'load checkpoint from {config.train.ckpt_path}!') 114 | 115 | # Optimizer and scheduler 116 | optimizer = get_optimizer(config.train.optimizer, model) 117 | scheduler = get_scheduler(config.train.scheduler, optimizer) 118 | # if config.train.get('ckpt_path', None): 119 | # logger.info('Resuming optimizer states...') 120 | # optimizer.load_state_dict(ckpt['optimizer']) 121 | # logger.info('Resuming scheduler states...') 122 | # scheduler.load_state_dict(ckpt['scheduler']) 123 | 124 | def train(it, batch): 125 | optimizer.zero_grad() 126 | loss_dict = model.get_loss(batch, pos_noise_std=config.train.get('pos_noise_std', 0.)) 127 | loss = sum_weighted_losses(loss_dict, config.train.loss_weights) 128 | loss_dict['overall'] = loss 129 | loss.backward() 130 | orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm, 131 | error_if_nonfinite=True) # 5% running time 132 | optimizer.step() 133 | 134 | # Logging 135 | log_losses(loss_dict, it, 'train', args.train_report_iter, logger, writer, others={ 136 | 'grad': orig_grad_norm, 137 | 'lr': optimizer.param_groups[0]['lr'] 138 | }) 139 | 140 | def validate(it): 141 | loss_tape = ValidationLossTape() 142 | with torch.no_grad(): 143 | model.eval() 144 | for i, batch in enumerate(tqdm(val_loader, desc='Validate', dynamic_ncols=True)): 145 | for t in np.linspace(0, model.num_steps - 1, 10).astype(int): 146 | time_step = torch.tensor([t] * batch.num_graphs).to(args.device) 147 | batch = batch.to(args.device) 148 | loss_dict = model.get_loss(batch, t=time_step) 149 | loss = sum_weighted_losses(loss_dict, config.train.loss_weights) 150 | loss_dict['overall'] = loss 151 | loss_tape.update(loss_dict, 1) 152 | 153 | avg_loss = loss_tape.log(it, logger, writer, 'val') 154 | # Trigger scheduler 155 | if config.train.scheduler.type == 'plateau': 156 | scheduler.step(avg_loss) 157 | elif config.train.scheduler.type == 'warmup_plateau': 158 | scheduler.step_ReduceLROnPlateau(avg_loss) 159 | else: 160 | scheduler.step() 161 | return avg_loss 162 | 163 | def test(it): 164 | loss_tape = ValidationLossTape() 165 | with torch.no_grad(): 166 | model.eval() 167 | for i, batch in enumerate(tqdm(test_loader, desc='Test', dynamic_ncols=True)): 168 | for t in np.linspace(0, model.num_steps - 1, 10).astype(int): 169 | time_step = torch.tensor([t] * batch.num_graphs).to(args.device) 170 | batch = batch.to(args.device) 171 | loss_dict = model.get_loss(batch, t=time_step) 172 | loss = sum_weighted_losses(loss_dict, config.train.loss_weights) 173 | loss_dict['overall'] = loss 174 | loss_tape.update(loss_dict, 1) 175 | avg_loss = loss_tape.log(it, logger, writer, 'test') 176 | return avg_loss 177 | 178 | 179 | try: 180 | model.train() 181 | best_loss, best_iter = None, None 182 | for it in range(1, config.train.max_iters + 1): 183 | # try: 184 | t1 = time.time() 185 | batch = next(train_iterator).to(args.device) 186 | t2 = time.time() 187 | # print('data processing time: ', t2 - t1) 188 | train(it, batch) 189 | if it % args.sampling_iter == 0: 190 | data = test_set[0] 191 | data_list = [data.clone() for _ in range(100)] 192 | batch = Batch.from_data_list( 193 | data_list, follow_batch=FOLLOW_BATCH, exclude_keys=COLLATE_EXCLUDE_KEYS).to(args.device) 194 | traj_batch, final_x, final_c, final_bond = model.sample(batch, p_init_mode=cfg_model.frag_pos_prior) 195 | if model.train_bond: 196 | gen_mols = parse_sampling_result_with_bond( 197 | data_list, final_x, final_c, final_bond, atom_featurizer, 198 | known_linker_bonds=cfg_model.known_linker_bond, check_validity=True) 199 | else: 200 | gen_mols = parse_sampling_result(data_list, final_x, final_c, atom_featurizer) 201 | save_path = os.path.join(vis_dir, f'sampling_results_{it}.pt') 202 | torch.save({ 203 | 'data': data, 204 | # 'traj': traj_batch, 205 | 'final_x': final_x, 'final_c': final_c, 'final_bond': final_bond, 206 | 'gen_mols': gen_mols 207 | }, save_path) 208 | logger.info(f'dump sampling vis to {save_path}!') 209 | recon_rate, complete_rate = eval_success_rate(gen_mols) 210 | logger.info(f'recon rate: {recon_rate:.4f}') 211 | logger.info(f'complete rate: {complete_rate:.4f}') 212 | writer.add_scalar('sampling/recon_rate', recon_rate, it) 213 | writer.add_scalar('sampling/complete_rate', complete_rate, it) 214 | writer.flush() 215 | 216 | if it % config.train.val_freq == 0 or it == config.train.max_iters: 217 | val_loss = validate(it) 218 | ckpt_path = os.path.join(ckpt_dir, '%d.pt' % it) 219 | torch.save({ 220 | 'configs': config, 221 | 'model': model.state_dict(), 222 | 'optimizer': optimizer.state_dict(), 223 | 'scheduler': scheduler.state_dict(), 224 | 'iteration': it, 225 | }, ckpt_path) 226 | if best_loss is None or val_loss < best_loss: 227 | logger.info(f'[Validate] Best val loss achieved: {val_loss:.6f}') 228 | best_loss, best_iter = val_loss, it 229 | test(it) 230 | model.train() 231 | 232 | except KeyboardInterrupt: 233 | logger.info('Terminating...') 234 | -------------------------------------------------------------------------------- /utils/calc_SC_RDKit.py: -------------------------------------------------------------------------------- 1 | import os 2 | from rdkit import Chem 3 | from rdkit.Chem import AllChem, rdShapeHelpers 4 | from rdkit.Chem.FeatMaps import FeatMaps 5 | from rdkit import RDConfig 6 | 7 | # Set up features to use in FeatureMap 8 | fdefName = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef') 9 | fdef = AllChem.BuildFeatureFactory(fdefName) 10 | 11 | fmParams = {} 12 | for k in fdef.GetFeatureFamilies(): 13 | fparams = FeatMaps.FeatMapParams() 14 | fmParams[k] = fparams 15 | 16 | keep = ('Donor', 'Acceptor', 'NegIonizable', 'PosIonizable', 17 | 'ZnBinder', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe') 18 | 19 | 20 | def get_FeatureMapScore(query_mol, ref_mol): 21 | featLists = [] 22 | for m in [query_mol, ref_mol]: 23 | rawFeats = fdef.GetFeaturesForMol(m) 24 | # filter that list down to only include the ones we're intereted in 25 | featLists.append([f for f in rawFeats if f.GetFamily() in keep]) 26 | fms = [FeatMaps.FeatMap(feats=x, weights=[1] * len(x), params=fmParams) for x in featLists] 27 | fms[0].scoreMode=FeatMaps.FeatMapScoreMode.Best 28 | fm_score = fms[0].ScoreFeats(featLists[1]) / min(fms[0].GetNumFeatures(), len(featLists[1])) 29 | 30 | return fm_score 31 | 32 | 33 | def calc_SC_RDKit_score(query_mol, ref_mol): 34 | fm_score = get_FeatureMapScore(query_mol, ref_mol) 35 | 36 | protrude_dist = rdShapeHelpers.ShapeProtrudeDist(query_mol, ref_mol, 37 | allowReordering=False) 38 | SC_RDKit_score = 0.5*fm_score + 0.5*(1 - protrude_dist) 39 | 40 | return SC_RDKit_score 41 | -------------------------------------------------------------------------------- /utils/const.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from rdkit import Chem 4 | 5 | 6 | TORCH_FLOAT = torch.float32 7 | TORCH_INT = torch.int8 8 | 9 | # #################################################################################### # 10 | # ####################################### ZINC ####################################### # 11 | # #################################################################################### # 12 | 13 | # Atom idx for one-hot encoding 14 | ATOM2IDX = {'C': 0, 'O': 1, 'N': 2, 'F': 3, 'S': 4, 'Cl': 5, 'Br': 6, 'I': 7} 15 | IDX2ATOM = {0: 'C', 1: 'O', 2: 'N', 3: 'F', 4: 'S', 5: 'Cl', 6: 'Br', 7: 'I'} 16 | 17 | # Atomic numbers (Z) 18 | CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53} 19 | 20 | # One-hot atom types 21 | NUMBER_OF_ATOM_TYPES = len(ATOM2IDX) 22 | 23 | 24 | # #################################################################################### # 25 | # ####################################### GEOM ####################################### # 26 | # #################################################################################### # 27 | 28 | # Atom idx for one-hot encoding 29 | GEOM_ATOM2IDX = {'C': 0, 'O': 1, 'N': 2, 'F': 3, 'S': 4, 'Cl': 5, 'Br': 6, 'I': 7, 'P': 8} 30 | GEOM_IDX2ATOM = {0: 'C', 1: 'O', 2: 'N', 3: 'F', 4: 'S', 5: 'Cl', 6: 'Br', 7: 'I', 8: 'P'} 31 | 32 | # Atomic numbers (Z) 33 | GEOM_CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53, 'P': 15} 34 | 35 | # One-hot atom types 36 | GEOM_NUMBER_OF_ATOM_TYPES = len(GEOM_ATOM2IDX) 37 | 38 | # Dataset keys 39 | DATA_LIST_ATTRS = { 40 | 'uuid', 'name', 'fragments_smi', 'linker_smi', 'num_atoms' 41 | } 42 | DATA_ATTRS_TO_PAD = { 43 | 'positions', 'one_hot', 'charges', 'anchors', 'fragment_mask', 'linker_mask', 'pocket_mask', 'fragment_only_mask' 44 | } 45 | DATA_ATTRS_TO_ADD_LAST_DIM = { 46 | 'charges', 'anchors', 'fragment_mask', 'linker_mask', 'pocket_mask', 'fragment_only_mask' 47 | } 48 | 49 | # Distribution of linker size in train data 50 | LINKER_SIZE_DIST = { 51 | 4: 85540, 52 | 3: 113928, 53 | 6: 70946, 54 | 7: 30408, 55 | 5: 77671, 56 | 9: 5177, 57 | 10: 1214, 58 | 8: 12712, 59 | 11: 158, 60 | 12: 7, 61 | } 62 | 63 | 64 | # Bond lengths from: 65 | # http://www.wiredchemist.com/chemistry/data/bond_energies_lengths.html 66 | # And: 67 | # http://chemistry-reference.com/tables/Bond%20Lengths%20and%20Enthalpies.pdf 68 | BONDS_1 = { 69 | 'H': { 70 | 'H': 74, 'C': 109, 'N': 101, 'O': 96, 'F': 92, 71 | 'B': 119, 'Si': 148, 'P': 144, 'As': 152, 'S': 134, 72 | 'Cl': 127, 'Br': 141, 'I': 161 73 | }, 74 | 'C': { 75 | 'H': 109, 'C': 154, 'N': 147, 'O': 143, 'F': 135, 76 | 'Si': 185, 'P': 184, 'S': 182, 'Cl': 177, 'Br': 194, 77 | 'I': 214 78 | }, 79 | 'N': { 80 | 'H': 101, 'C': 147, 'N': 145, 'O': 140, 'F': 136, 81 | 'Cl': 175, 'Br': 214, 'S': 168, 'I': 222, 'P': 177 82 | }, 83 | 'O': { 84 | 'H': 96, 'C': 143, 'N': 140, 'O': 148, 'F': 142, 85 | 'Br': 172, 'S': 151, 'P': 163, 'Si': 163, 'Cl': 164, 86 | 'I': 194 87 | }, 88 | 'F': { 89 | 'H': 92, 'C': 135, 'N': 136, 'O': 142, 'F': 142, 90 | 'S': 158, 'Si': 160, 'Cl': 166, 'Br': 178, 'P': 156, 91 | 'I': 187 92 | }, 93 | 'B': { 94 | 'H': 119, 'Cl': 175 95 | }, 96 | 'Si': { 97 | 'Si': 233, 'H': 148, 'C': 185, 'O': 163, 'S': 200, 98 | 'F': 160, 'Cl': 202, 'Br': 215, 'I': 243, 99 | }, 100 | 'Cl': { 101 | 'Cl': 199, 'H': 127, 'C': 177, 'N': 175, 'O': 164, 102 | 'P': 203, 'S': 207, 'B': 175, 'Si': 202, 'F': 166, 103 | 'Br': 214 104 | }, 105 | 'S': { 106 | 'H': 134, 'C': 182, 'N': 168, 'O': 151, 'S': 204, 107 | 'F': 158, 'Cl': 207, 'Br': 225, 'Si': 200, 'P': 210, 108 | 'I': 234 109 | }, 110 | 'Br': { 111 | 'Br': 228, 'H': 141, 'C': 194, 'O': 172, 'N': 214, 112 | 'Si': 215, 'S': 225, 'F': 178, 'Cl': 214, 'P': 222 113 | }, 114 | 'P': { 115 | 'P': 221, 'H': 144, 'C': 184, 'O': 163, 'Cl': 203, 116 | 'S': 210, 'F': 156, 'N': 177, 'Br': 222 117 | }, 118 | 'I': { 119 | 'H': 161, 'C': 214, 'Si': 243, 'N': 222, 'O': 194, 120 | 'S': 234, 'F': 187, 'I': 266 121 | }, 122 | 'As': { 123 | 'H': 152 124 | } 125 | } 126 | 127 | BONDS_2 = { 128 | 'C': {'C': 134, 'N': 129, 'O': 120, 'S': 160}, 129 | 'N': {'C': 129, 'N': 125, 'O': 121}, 130 | 'O': {'C': 120, 'N': 121, 'O': 121, 'P': 150}, 131 | 'P': {'O': 150, 'S': 186}, 132 | 'S': {'P': 186} 133 | } 134 | 135 | BONDS_3 = { 136 | 'C': {'C': 120, 'N': 116, 'O': 113}, 137 | 'N': {'C': 116, 'N': 110}, 138 | 'O': {'C': 113} 139 | } 140 | 141 | BOND_DICT = [ 142 | None, 143 | Chem.rdchem.BondType.SINGLE, 144 | Chem.rdchem.BondType.DOUBLE, 145 | Chem.rdchem.BondType.TRIPLE, 146 | Chem.rdchem.BondType.AROMATIC, 147 | ] 148 | 149 | BOND2IDX = { 150 | Chem.rdchem.BondType.SINGLE: 1, 151 | Chem.rdchem.BondType.DOUBLE: 2, 152 | Chem.rdchem.BondType.TRIPLE: 3, 153 | Chem.rdchem.BondType.AROMATIC: 4, 154 | } 155 | 156 | ALLOWED_BONDS = { 157 | 'H': 1, 158 | 'C': 4, 159 | 'N': 3, 160 | 'O': 2, 161 | 'F': 1, 162 | 'B': 3, 163 | 'Al': 3, 164 | 'Si': 4, 165 | 'P': [3, 5], 166 | 'S': 4, 167 | 'Cl': 1, 168 | 'As': 3, 169 | 'Br': 1, 170 | 'I': 1, 171 | 'Hg': [1, 2], 172 | 'Bi': [3, 5] 173 | } 174 | 175 | MARGINS_EDM = [10, 5, 2] 176 | 177 | COLORS = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8'] 178 | # RADII = [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3] 179 | RADII = [0.77, 0.77, 0.77, 0.77, 0.77, 0.77, 0.77, 0.77, 0.77] 180 | 181 | ZINC_TRAIN_LINKER_ID2SIZE = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12] 182 | ZINC_TRAIN_LINKER_SIZE2ID = { 183 | size: idx 184 | for idx, size in enumerate(ZINC_TRAIN_LINKER_ID2SIZE) 185 | } 186 | ZINC_TRAIN_LINKER_SIZE_WEIGHTS = [ 187 | 3.47347831e-01, 188 | 4.63079100e-01, 189 | 5.12370917e-01, 190 | 5.62392614e-01, 191 | 1.30294388e+00, 192 | 3.24247801e+00, 193 | 8.12391184e+00, 194 | 3.45634358e+01, 195 | 2.72428571e+02, 196 | 6.26585714e+03 197 | ] 198 | 199 | 200 | GEOM_TRAIN_LINKER_ID2SIZE = [ 201 | 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 202 | 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 36, 38, 41 203 | ] 204 | GEOM_TRAIN_LINKER_SIZE2ID = { 205 | size: idx 206 | for idx, size in enumerate(GEOM_TRAIN_LINKER_ID2SIZE) 207 | } 208 | GEOM_TRAIN_LINKER_SIZE_WEIGHTS = [ 209 | 1.07790681e+00, 4.54693604e-01, 3.62575713e-01, 3.75199484e-01, 210 | 3.67812588e-01, 3.92388528e-01, 3.83421054e-01, 4.26924670e-01, 211 | 4.92768040e-01, 4.99761944e-01, 4.92342726e-01, 5.71456905e-01, 212 | 7.30631393e-01, 8.45412928e-01, 9.97252243e-01, 1.25423985e+00, 213 | 1.57316129e+00, 2.19902962e+00, 3.22640431e+00, 4.25481066e+00, 214 | 6.34749573e+00, 9.00676236e+00, 1.43084017e+01, 2.25763173e+01, 215 | 3.36867096e+01, 9.50713805e+01, 2.08693274e+02, 2.51659537e+02, 216 | 7.77856749e+02, 8.55642424e+03, 8.55642424e+03, 4.27821212e+03, 217 | 4.27821212e+03 218 | ] 219 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import numpy as np 4 | from rdkit import Chem 5 | from rdkit.Chem import AllChem 6 | from rdkit.Chem.rdchem import BondType 7 | from rdkit.Chem import ChemicalFeatures 8 | from rdkit import RDConfig 9 | from openbabel import openbabel as ob 10 | import torch 11 | 12 | 13 | ATOM_FAMILIES = ['Acceptor', 'Donor', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe', 'NegIonizable', 'PosIonizable', 14 | 'ZnBinder'] 15 | ATOM_FAMILIES_ID = {s: i for i, s in enumerate(ATOM_FAMILIES)} 16 | BOND_TYPES = { 17 | BondType.UNSPECIFIED: 0, 18 | BondType.SINGLE: 1, 19 | BondType.DOUBLE: 2, 20 | BondType.TRIPLE: 3, 21 | BondType.AROMATIC: 4, 22 | } 23 | BOND_NAMES = {v: str(k) for k, v in BOND_TYPES.items()} 24 | HYBRIDIZATION_TYPE = ['S', 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2'] 25 | HYBRIDIZATION_TYPE_ID = {s: i for i, s in enumerate(HYBRIDIZATION_TYPE)} 26 | 27 | 28 | def process_from_mol(rdmol): 29 | fdefName = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef') 30 | factory = ChemicalFeatures.BuildFeatureFactory(fdefName) 31 | 32 | rd_num_atoms = rdmol.GetNumAtoms() 33 | feat_mat = np.zeros([rd_num_atoms, len(ATOM_FAMILIES)], dtype=np.compat.long) 34 | for feat in factory.GetFeaturesForMol(rdmol): 35 | feat_mat[feat.GetAtomIds(), ATOM_FAMILIES_ID[feat.GetFamily()]] = 1 36 | 37 | # Get hybridization in the order of atom idx. 38 | hybridization = [] 39 | for atom in rdmol.GetAtoms(): 40 | hybr = str(atom.GetHybridization()) 41 | idx = atom.GetIdx() 42 | hybridization.append((idx, hybr)) 43 | hybridization = sorted(hybridization) 44 | hybridization = [v[1] for v in hybridization] 45 | 46 | ptable = Chem.GetPeriodicTable() 47 | 48 | pos = np.array(rdmol.GetConformers()[0].GetPositions(), dtype=np.float32) 49 | element, valence, charge = [], [], [] 50 | accum_pos = 0 51 | accum_mass = 0 52 | for atom_idx in range(rd_num_atoms): 53 | atom = rdmol.GetAtomWithIdx(atom_idx) 54 | atom_num = atom.GetAtomicNum() 55 | element.append(atom_num) 56 | valence.append(atom.GetTotalValence()) 57 | charge.append(atom.GetFormalCharge()) 58 | atom_weight = ptable.GetAtomicWeight(atom_num) 59 | accum_pos += pos[atom_idx] * atom_weight 60 | accum_mass += atom_weight 61 | center_of_mass = accum_pos / accum_mass 62 | element = np.array(element, dtype=np.int) 63 | valence = np.array(valence, dtype=np.int) 64 | charge = np.array(charge, dtype=np.int) 65 | 66 | # in edge_type, we have 1 for single bond, 2 for double bond, 3 for triple bond, and 4 for aromatic bond. 67 | row, col, edge_type = [], [], [] 68 | for bond in rdmol.GetBonds(): 69 | start = bond.GetBeginAtomIdx() 70 | end = bond.GetEndAtomIdx() 71 | row += [start, end] 72 | col += [end, start] 73 | edge_type += 2 * [BOND_TYPES[bond.GetBondType()]] 74 | 75 | edge_index = np.array([row, col], dtype=np.long) 76 | edge_type = np.array(edge_type, dtype=np.long) 77 | 78 | perm = (edge_index[0] * rd_num_atoms + edge_index[1]).argsort() 79 | edge_index = edge_index[:, perm] 80 | edge_type = edge_type[perm] 81 | 82 | data = { 83 | 'rdmol': rdmol, 84 | 'element': element, 85 | 'pos': pos, 86 | 'bond_index': edge_index, 87 | 'bond_type': edge_type, 88 | 'center_of_mass': center_of_mass, 89 | 'atom_feature': feat_mat, 90 | 'hybridization': hybridization, 91 | 'valence': valence, 92 | 'charge': charge 93 | } 94 | return data 95 | 96 | 97 | # rdmol conformer 98 | def compute_3d_coors(mol, random_seed=0): 99 | mol = Chem.AddHs(mol) 100 | success = AllChem.EmbedMolecule(mol, randomSeed=random_seed) 101 | if success == -1: 102 | return 0, 0 103 | mol = Chem.RemoveHs(mol) 104 | c = mol.GetConformer(0) 105 | pos = c.GetPositions() 106 | return pos, 1 107 | 108 | 109 | def compute_3d_coors_multiple(mol, numConfs=20, maxIters=400, randomSeed=1): 110 | # mol = Chem.MolFromSmiles(smi) 111 | mol = Chem.AddHs(mol) 112 | AllChem.EmbedMultipleConfs(mol, numConfs=numConfs, numThreads=0, randomSeed=randomSeed) 113 | if mol.GetConformers() == (): 114 | return None, 0 115 | try: 116 | result = AllChem.MMFFOptimizeMoleculeConfs(mol, maxIters=maxIters, numThreads=0) 117 | except Exception as e: 118 | print(str(e)) 119 | return None, 0 120 | mol = Chem.RemoveHs(mol) 121 | result = [tuple((result[i][0], result[i][1], i)) for i in range(len(result)) if result[i][0] == 0] 122 | if result == []: # no local minimum on energy surface is found 123 | return None, 0 124 | result.sort() 125 | return mol.GetConformers()[result[0][-1]].GetPositions(), 1 126 | 127 | 128 | def compute_3d_coors_frags(mol, numConfs=20, maxIters=400, randomSeed=1): 129 | du = Chem.MolFromSmiles('*') 130 | clean_frag = Chem.RemoveHs(AllChem.ReplaceSubstructs(Chem.MolFromSmiles(Chem.MolToSmiles(mol)),du,Chem.MolFromSmiles('[H]'),True)[0]) 131 | frag = Chem.CombineMols(clean_frag, Chem.MolFromSmiles("*.*")) 132 | mol_to_link_carbon = AllChem.ReplaceSubstructs(mol, du, Chem.MolFromSmiles('C'), True)[0] 133 | pos, _ = compute_3d_coors_multiple(mol_to_link_carbon, numConfs, maxIters, randomSeed) 134 | return pos 135 | 136 | 137 | # ----- 138 | def set_rdmol_positions_(mol, pos): 139 | """ 140 | Args: 141 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object. 142 | pos: (N_atoms, 3) 143 | """ 144 | assert mol.GetNumAtoms() == pos.shape[0] 145 | conf = Chem.Conformer(mol.GetNumAtoms()) 146 | for i in range(pos.shape[0]): 147 | conf.SetAtomPosition(i, pos[i].tolist()) 148 | mol.AddConformer(conf, assignId=True) 149 | 150 | # for i in range(pos.shape[0]): 151 | # mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist()) 152 | return mol 153 | 154 | 155 | def set_rdmol_positions(rdkit_mol, pos, reset=True): 156 | """ 157 | Args: 158 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object. 159 | pos: (N_atoms, 3) 160 | """ 161 | mol = copy.deepcopy(rdkit_mol) 162 | if reset: 163 | mol.RemoveAllConformers() 164 | set_rdmol_positions_(mol, pos) 165 | return mol 166 | -------------------------------------------------------------------------------- /utils/eval_bond.py: -------------------------------------------------------------------------------- 1 | """Utils for evaluating bond length.""" 2 | 3 | import collections 4 | from typing import Tuple, Sequence, Dict 5 | from rdkit.Chem.rdMolTransforms import GetAngleDeg 6 | from rdkit import Chem 7 | import numpy as np 8 | from utils.data import BOND_TYPES 9 | 10 | BondType = Tuple[int, int, int] # (atomic_num, atomic_num, bond_type) 11 | BondLengthData = Tuple[BondType, float] # (bond_type, bond_length) 12 | BondLengthProfile = Dict[BondType, np.ndarray] # bond_type -> empirical distribution 13 | DISTANCE_BINS = np.arange(1.1, 1.7, 0.005)[:-1] 14 | ANGLE_BINS = np.arange(0., 180., 1.)[:-1] 15 | DIHEDRAL_BINS = np.arange(-180., 180., 1.)[:-1] 16 | 17 | # bond distance 18 | BOND_DISTS = ( 19 | (6, 6, 1), 20 | (6, 6, 2), 21 | (6, 7, 1), 22 | (6, 7, 2), 23 | (6, 8, 1), 24 | (6, 8, 2), 25 | (6, 6, 4), 26 | (6, 7, 4), 27 | (6, 8, 4), 28 | ) 29 | BOND_ANGLES = [ 30 | 'CCC', 31 | 'CCO', 32 | 'CNC', 33 | 'OPO', 34 | 'NCC', 35 | 'CC=O', 36 | 'COC' 37 | ] 38 | DIHEDRAL_ANGLES = [ 39 | # 'C1C-C1C-C1C', 40 | # 'C12C-C12C-C12C', 41 | # 'C1C-C1C-C1O', 42 | # 'O1C-C1C-C1O', 43 | # 'C1c-c12c-c12c', 44 | # 'C1C-C2C-C1C' 45 | ] 46 | 47 | 48 | def get_bond_angle(mol, fragment_mask, bond_smi='CCC'): 49 | """ 50 | Find bond pairs (defined by bond_smi) in mol and return the angle of the bond pair 51 | bond_smi: bond pair smiles, e.g. 'CCC' 52 | """ 53 | deg_list = [] 54 | substructure = Chem.MolFromSmiles(bond_smi) 55 | bond_pairs = mol.GetSubstructMatches(substructure) 56 | for pair in bond_pairs: 57 | if (fragment_mask[pair[0]] == 0) | (fragment_mask[pair[1]] == 0) | (fragment_mask[pair[2]] == 0): 58 | deg_list += [GetAngleDeg(mol.GetConformer(), *pair)] 59 | assert mol.GetBondBetweenAtoms(pair[0], pair[1]) is not None 60 | assert mol.GetBondBetweenAtoms(pair[2], pair[1]) is not None 61 | return deg_list 62 | 63 | 64 | def get_bond_symbol(bond): 65 | """ 66 | Return the symbol representation of a bond 67 | """ 68 | a0 = bond.GetBeginAtom().GetSymbol() 69 | a1 = bond.GetEndAtom().GetSymbol() 70 | b = str(int(bond.GetBondType())) # single: 1, double: 2, triple: 3, aromatic: 12 71 | return ''.join([a0, b, a1]) 72 | 73 | 74 | def get_triple_bonds(mol): 75 | """ 76 | Get all the bond triplets in a molecule 77 | """ 78 | valid_triple_bonds = [] 79 | for idx_bond, bond in enumerate(mol.GetBonds()): 80 | idx_begin_atom = bond.GetBeginAtomIdx() 81 | idx_end_atom = bond.GetEndAtomIdx() 82 | begin_atom = mol.GetAtomWithIdx(idx_begin_atom) 83 | end_atom = mol.GetAtomWithIdx(idx_end_atom) 84 | begin_bonds = begin_atom.GetBonds() 85 | valid_left_bonds = [] 86 | for begin_bond in begin_bonds: 87 | if begin_bond.GetIdx() == idx_bond: 88 | continue 89 | else: 90 | valid_left_bonds.append(begin_bond) 91 | if len(valid_left_bonds) == 0: 92 | continue 93 | 94 | end_bonds = end_atom.GetBonds() 95 | for end_bond in end_bonds: 96 | if end_bond.GetIdx() == idx_bond: 97 | continue 98 | else: 99 | for left_bond in valid_left_bonds: 100 | valid_triple_bonds.append([left_bond, bond, end_bond]) 101 | return valid_triple_bonds 102 | 103 | 104 | def get_dihedral_angle(mol, bonds_ref_sym): 105 | """ 106 | find bond triplets (defined by bonds_ref_sym) in mol and return the dihedral angle of the bond triplet 107 | bonds_ref_sym: a symbol string of bond triplet, e.g. 'C1C-C1C-C1C' 108 | """ 109 | # bonds_ref_sym = '-'.join(get_bond_symbol(bonds_ref)) 110 | bonds_list = get_triple_bonds(mol) 111 | angles_list = [] 112 | for bonds in bonds_list: 113 | sym = '-'.join([get_bond_symbol(b) for b in bonds]) 114 | sym1 = '-'.join([get_bond_symbol(b) for b in bonds][::-1]) 115 | atoms = [] 116 | if (sym == bonds_ref_sym) or (sym1 == bonds_ref_sym): 117 | if sym1 == bonds_ref_sym: 118 | bonds = bonds[::-1] 119 | bond0 = bonds[0] 120 | atom0 = bond0.GetBeginAtomIdx() 121 | atom1 = bond0.GetEndAtomIdx() 122 | 123 | bond1 = bonds[1] 124 | atom1_0 = bond1.GetBeginAtomIdx() 125 | atom1_1 = bond1.GetEndAtomIdx() 126 | if atom0 == atom1_0: 127 | i, j, k = atom1, atom0, atom1_1 128 | elif atom0 == atom1_1: 129 | i, j, k = atom1, atom0, atom1_0 130 | elif atom1 == atom1_0: 131 | i, j, k = atom0, atom1, atom1_1 132 | elif atom1 == atom1_1: 133 | i, j, k = atom0, atom1, atom1_0 134 | 135 | bond2 = bonds[2] 136 | atom2_0 = bond2.GetBeginAtomIdx() 137 | atom2_1 = bond2.GetEndAtomIdx() 138 | if atom2_0 == k: 139 | l = atom2_1 140 | elif atom2_1 == k: 141 | l = atom2_0 142 | # print(i,j,k,l) 143 | angle = Chem.rdMolTransforms.GetDihedralDeg(mol.GetConformer(), i, j, k, l) 144 | angles_list.append(angle) 145 | return angles_list 146 | 147 | 148 | def bond_distance_from_mol(mol, fragment_mask): 149 | # only consider linker-related distance 150 | pos = mol.GetConformer().GetPositions() 151 | pdist = pos[None, :] - pos[:, None] 152 | pdist = np.sqrt(np.sum(pdist ** 2, axis=-1)) 153 | all_distances = [] 154 | for bond in mol.GetBonds(): 155 | s_sym = bond.GetBeginAtom().GetAtomicNum() 156 | e_sym = bond.GetEndAtom().GetAtomicNum() 157 | s_idx, e_idx = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() 158 | if (fragment_mask[s_idx] == 0) | (fragment_mask[e_idx] == 0): 159 | bond_type = BOND_TYPES[bond.GetBondType()] 160 | distance = pdist[s_idx, e_idx] 161 | all_distances.append(((s_sym, e_sym, bond_type), distance)) 162 | return all_distances 163 | 164 | 165 | def _format_bond_type(bond_type: BondType) -> BondType: 166 | atom1, atom2, bond_category = bond_type 167 | if atom1 > atom2: 168 | atom1, atom2 = atom2, atom1 169 | return atom1, atom2, bond_category 170 | 171 | 172 | def get_distribution(distances: Sequence[float], bins=DISTANCE_BINS) -> np.ndarray: 173 | """Get the distribution of distances. 174 | 175 | Args: 176 | distances (list): List of distances. 177 | bins (list): bins of distances 178 | Returns: 179 | np.array: empirical distribution of distances with length equals to DISTANCE_BINS. 180 | """ 181 | bin_counts = collections.Counter(np.searchsorted(bins, distances)) 182 | bin_counts = [bin_counts[i] if i in bin_counts else 0 for i in range(len(bins) + 1)] 183 | bin_counts = np.array(bin_counts) / np.sum(bin_counts) 184 | return bin_counts 185 | 186 | 187 | def get_bond_length_profile(mol_list, fragment_mask_list) -> BondLengthProfile: 188 | bond_lengths = [] 189 | for mol, mask in zip(mol_list, fragment_mask_list): 190 | mol = Chem.RemoveAllHs(mol) 191 | bond_lengths += bond_distance_from_mol(mol, mask) 192 | 193 | bond_length_profile = collections.defaultdict(list) 194 | for bond_type, bond_length in bond_lengths: 195 | bond_type = _format_bond_type(bond_type) 196 | bond_length_profile[bond_type].append(bond_length) 197 | bond_length_profile = {k: get_distribution(v) for k, v in bond_length_profile.items()} 198 | return bond_length_profile 199 | 200 | 201 | def get_bond_angles_dict(mol_list, fragment_mask_list): 202 | bond_angles = collections.defaultdict(list) 203 | dihedral_angles = collections.defaultdict(list) 204 | for mol, mask in zip(mol_list, fragment_mask_list): 205 | mol = Chem.RemoveAllHs(mol) 206 | for angle_type in BOND_ANGLES: 207 | bond_angles[angle_type] += get_bond_angle(mol, mask, bond_smi=angle_type) 208 | for angle_type in DIHEDRAL_ANGLES: 209 | dihedral_angles[angle_type] += get_dihedral_angle(mol, angle_type) 210 | return bond_angles, dihedral_angles 211 | 212 | 213 | def get_bond_angles_profile(mol_list, fragment_mask_list): 214 | angles, dihedrals = get_bond_angles_dict(mol_list, fragment_mask_list) 215 | angles_profile = {} 216 | for k, v in angles.items(): 217 | angles_profile[k] = get_distribution(v, ANGLE_BINS) 218 | for k, v in dihedrals.items(): 219 | angles_profile[k] = get_distribution(v, DIHEDRAL_BINS) 220 | return angles_profile 221 | -------------------------------------------------------------------------------- /utils/fpscores.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanjq/LinkerNet/280759c16ccecece0d81ab9cebe3f44041b80e51/utils/fpscores.pkl.gz -------------------------------------------------------------------------------- /utils/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | 7 | def rotation_matrix_from_vectors(vec1, vec2): 8 | """ Find the rotation matrix that aligns vec1 to vec2 9 | :param vec1: A 3d "source" vector 10 | :param vec2: A 3d "destination" vector 11 | :return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2. 12 | """ 13 | a, b = (vec1 / torch.norm(vec1)).view(3), (vec2 / torch.norm(vec2)).view(3) 14 | v = torch.cross(a, b) 15 | c = torch.dot(a, b) 16 | s = torch.norm(v) 17 | kmat = torch.tensor([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 18 | if s == 0: 19 | rotation_matrix = torch.eye(3) 20 | else: 21 | rotation_matrix = torch.eye(3) + kmat + kmat @ kmat * ((1 - c) / (s ** 2)) 22 | return rotation_matrix 23 | 24 | 25 | def find_rigid_transform(true_points, mapping_points): 26 | # align one fragment to the other and find rigid transformation with Kabsch algorithm 27 | # mapping_points @ R.T + t = true_points 28 | t1 = true_points.mean(0) 29 | t2 = mapping_points.mean(0) 30 | x1 = true_points - t1 31 | x2 = mapping_points - t2 32 | 33 | h = x2.T @ x1 34 | u, s, vt = np.linalg.svd(h) 35 | v = vt.T 36 | d = np.linalg.det(v @ u.T) 37 | e = np.array([[1, 0, 0], [0, 1, 0], [0, 0, d]]) 38 | R = v @ e @ u.T 39 | t = -R @ t2.T + t1.T 40 | return R, t 41 | 42 | 43 | def get_pca_axes(coords, weights=None): 44 | """Computes the (weighted) PCA of the given coordinates. 45 | Args: 46 | coords (np.ndarray): (N ,D) 47 | weights (np.ndarray, optional): (N). Defaults to None. 48 | Returns: 49 | Tuple[np.ndarray, np.ndarray]: (D), (D, D) 50 | """ 51 | if weights is None: 52 | weights = np.ones((*coords.shape[:-1],)) 53 | weights /= weights.sum() 54 | 55 | mean = (weights[..., None] * coords).mean() 56 | centered = coords - mean 57 | cov = centered.T @ np.diag(weights) @ centered 58 | 59 | s, vecs = np.linalg.eigh(cov) 60 | return s, vecs 61 | 62 | 63 | def find_axes(atoms, charges): 64 | """Generates equivariant axes based on PCA. 65 | Args: 66 | atoms (np.ndarray): (..., M, 3) 67 | charges (np.ndarray): (M) 68 | Returns: 69 | np.ndarray: (3, 3) 70 | """ 71 | atoms, charges = deepcopy(atoms), deepcopy(charges) 72 | # First compute the axes by PCA 73 | atoms = atoms - atoms.mean(-2, keepdims=True) 74 | s, axes = get_pca_axes(atoms, charges) 75 | # Let's check whether we have identical eigenvalues 76 | # if that's the case we need to work with soem pseudo positions 77 | # to get unique atoms. 78 | unique_values = np.zeros_like(s) 79 | v, uindex = np.unique(s, return_index=True) 80 | unique_values[uindex] = v 81 | is_ambiguous = np.count_nonzero(unique_values) < np.count_nonzero(s) 82 | # We always compute the pseudo coordinates because it causes some compile errors 83 | # for some unknown reason on A100 cards with jax.lax.cond. 84 | # Compute pseudo coordiantes based on the vector inducing the largest coulomb energy. 85 | distances = atoms[None] - atoms[..., None, :] 86 | dist_norm = np.linalg.norm(distances, axis=-1) 87 | coulomb = charges[None] * charges[:, None] / (dist_norm + 1e-20) 88 | off_diag_mask = ~np.eye(atoms.shape[0], dtype=bool) 89 | coulomb, distances = coulomb[off_diag_mask], distances[off_diag_mask] 90 | idx = np.argmax(coulomb) 91 | scale_vec = distances[idx] 92 | scale_vec /= np.linalg.norm(scale_vec) 93 | # Projected atom positions 94 | proj = atoms @ scale_vec[..., None] * scale_vec 95 | diff = atoms - proj 96 | pseudo_atoms = proj * (1 + 1e-4) + diff 97 | 98 | pseudo_s, pseudo_axes = get_pca_axes(pseudo_atoms, charges) 99 | 100 | # Select pseudo axes if it is ambiguous 101 | s = np.where(is_ambiguous, pseudo_s, s) 102 | axes = np.where(is_ambiguous, pseudo_axes, axes) 103 | 104 | order = np.argsort(s)[::-1] 105 | axes = axes[:, order] 106 | 107 | # Compute an equivariant vector 108 | distances = np.linalg.norm(atoms[None] - atoms[..., None, :], axis=-1) 109 | weights = distances.sum(-1) 110 | equi_vec = ((weights * charges)[..., None] * atoms).mean(0) 111 | 112 | ve = equi_vec @ axes 113 | flips = ve < 0 114 | axes = np.where(flips[None], -axes, axes) 115 | 116 | right_hand = np.stack( 117 | [axes[:, 0], axes[:, 1], np.cross(axes[:, 0], axes[:, 1])], axis=1) 118 | # axes = np.where(np.abs(ve[-1]) < 1e-7, right_hand, axes) 119 | return right_hand 120 | 121 | 122 | def local_to_global(R, t, p): 123 | """ 124 | Description: 125 | Convert local (internal) coordinates to global (external) coordinates q. 126 | q <- Rp + t 127 | Args: 128 | R: (F, 3, 3). 129 | t: (F, 3). 130 | p: Local coordinates, (F, ..., 3). 131 | Returns: 132 | q: Global coordinates, (F, ..., 3). 133 | """ 134 | assert p.size(-1) == 3 135 | assert R.ndim - 1 == t.ndim 136 | squeeze_dim = False 137 | if R.ndim == 2: 138 | R = R.unsqueeze(0) 139 | t = t.unsqueeze(0) 140 | p = p.unsqueeze(0) 141 | squeeze_dim = True 142 | 143 | p_size = p.size() 144 | num_frags = p_size[0] 145 | 146 | p = p.view(num_frags, -1, 3).transpose(-1, -2) # (F, *, 3) -> (F, 3, *) 147 | q = torch.matmul(R, p) + t.unsqueeze(-1) # (F, 3, *) 148 | q = q.transpose(-1, -2).reshape(p_size) # (F, 3, *) -> (F, *, 3) -> (F, ..., 3) 149 | if squeeze_dim: 150 | q = q.squeeze(0) 151 | return q 152 | 153 | 154 | def global_to_local(R, t, q): 155 | """ 156 | Description: 157 | Convert global (external) coordinates q to local (internal) coordinates p. 158 | p <- R^{T}(q - t) 159 | Args: 160 | R: (F, 3, 3). 161 | t: (F, 3). 162 | q: Global coordinates, (F, ..., 3). 163 | Returns: 164 | p: Local coordinates, (F, ..., 3). 165 | """ 166 | assert q.size(-1) == 3 167 | assert R.ndim - 1 == t.ndim 168 | squeeze_dim = False 169 | if R.ndim == 2: 170 | R = R.unsqueeze(0) 171 | t = t.unsqueeze(0) 172 | q = q.unsqueeze(0) 173 | squeeze_dim = True 174 | 175 | q_size = q.size() 176 | num_frags = q_size[0] 177 | 178 | q = q.reshape(num_frags, -1, 3).transpose(-1, -2) # (F, *, 3) -> (F, 3, *) 179 | p = torch.matmul(R.transpose(-1, -2), (q - t.unsqueeze(-1))) # (F, 3, *) 180 | p = p.transpose(-1, -2).reshape(q_size) # (F, 3, *) -> (F, *, 3) -> (F, ..., 3) 181 | if squeeze_dim: 182 | p = p.squeeze(0) 183 | return p 184 | 185 | 186 | 187 | # Copyright (c) Meta Platforms, Inc. and affiliates. 188 | # All rights reserved. 189 | # 190 | # This source code is licensed under the BSD-style license found in the 191 | # LICENSE file in the root directory of this source tree. 192 | def quaternion_to_rotation_matrix(quaternions): 193 | """ 194 | Convert rotations given as quaternions to rotation matrices. 195 | Args: 196 | quaternions: quaternions with real part first, 197 | as tensor of shape (..., 4). 198 | Returns: 199 | Rotation matrices as tensor of shape (..., 3, 3). 200 | """ 201 | quaternions = F.normalize(quaternions, dim=-1) 202 | r, i, j, k = torch.unbind(quaternions, -1) 203 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 204 | 205 | o = torch.stack( 206 | ( 207 | 1 - two_s * (j * j + k * k), 208 | two_s * (i * j - k * r), 209 | two_s * (i * k + j * r), 210 | two_s * (i * j + k * r), 211 | 1 - two_s * (i * i + k * k), 212 | two_s * (j * k - i * r), 213 | two_s * (i * k - j * r), 214 | two_s * (j * k + i * r), 215 | 1 - two_s * (i * i + j * j), 216 | ), 217 | -1, 218 | ) 219 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 220 | 221 | 222 | # Copyright (c) Meta Platforms, Inc. and affiliates. 223 | # All rights reserved. 224 | # 225 | # This source code is licensed under the BSD-style license found in the 226 | # LICENSE file in the root directory of this source tree. 227 | """ 228 | BSD License 229 | 230 | For PyTorch3D software 231 | 232 | Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. 233 | 234 | Redistribution and use in source and binary forms, with or without modification, 235 | are permitted provided that the following conditions are met: 236 | 237 | * Redistributions of source code must retain the above copyright notice, this 238 | list of conditions and the following disclaimer. 239 | 240 | * Redistributions in binary form must reproduce the above copyright notice, 241 | this list of conditions and the following disclaimer in the documentation 242 | and/or other materials provided with the distribution. 243 | 244 | * Neither the name Meta nor the names of its contributors may be used to 245 | endorse or promote products derived from this software without specific 246 | prior written permission. 247 | 248 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 249 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 250 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 251 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 252 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 253 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 254 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 255 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 256 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 257 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 258 | """ 259 | 260 | 261 | def quaternion_1ijk_to_rotation_matrix(q): 262 | """ 263 | (1 + ai + bj + ck) -> R 264 | Args: 265 | q: (..., 3) 266 | """ 267 | b, c, d = torch.unbind(q, dim=-1) 268 | s = torch.sqrt(1 + b**2 + c**2 + d**2) 269 | a, b, c, d = 1/s, b/s, c/s, d/s 270 | 271 | o = torch.stack( 272 | ( 273 | a**2 + b**2 - c**2 - d**2, 2*b*c - 2*a*d, 2*b*d + 2*a*c, 274 | 2*b*c + 2*a*d, a**2 - b**2 + c**2 - d**2, 2*c*d - 2*a*b, 275 | 2*b*d - 2*a*c, 2*c*d + 2*a*b, a**2 - b**2 - c**2 + d**2, 276 | ), 277 | -1, 278 | ) 279 | return o.reshape(q.shape[:-1] + (3, 3)) 280 | 281 | 282 | def apply_rotation_to_vector(R, p): 283 | return local_to_global(R, torch.zeros_like(p), p) 284 | 285 | 286 | def axis_angle_to_matrix(axis_angle): 287 | """ 288 | From https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html 289 | Convert rotations given as axis/angle to rotation matrices. 290 | 291 | Args: 292 | axis_angle: Rotations given as a vector in axis angle form, 293 | as a tensor of shape (..., 3), where the magnitude is 294 | the angle turned anticlockwise in radians around the 295 | vector's direction. 296 | 297 | Returns: 298 | Rotation matrices as tensor of shape (..., 3, 3). 299 | """ 300 | return quaternion_to_rotation_matrix(axis_angle_to_quaternion(axis_angle)) 301 | 302 | 303 | def axis_angle_to_quaternion(axis_angle): 304 | """ 305 | From https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html 306 | Convert rotations given as axis/angle to quaternions. 307 | 308 | Args: 309 | axis_angle: Rotations given as a vector in axis angle form, 310 | as a tensor of shape (..., 3), where the magnitude is 311 | the angle turned anticlockwise in radians around the 312 | vector's direction. 313 | 314 | Returns: 315 | quaternions with real part first, as tensor of shape (..., 4). 316 | """ 317 | angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) 318 | half_angles = 0.5 * angles 319 | eps = 1e-6 320 | small_angles = angles.abs() < eps 321 | sin_half_angles_over_angles = torch.empty_like(angles) 322 | sin_half_angles_over_angles[~small_angles] = ( 323 | torch.sin(half_angles[~small_angles]) / angles[~small_angles] 324 | ) 325 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6 326 | # so sin(x/2)/x is about 1/2 - (x*x)/48 327 | sin_half_angles_over_angles[small_angles] = ( 328 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48 329 | ) 330 | quaternions = torch.cat( 331 | [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 332 | ) 333 | return quaternions 334 | -------------------------------------------------------------------------------- /utils/guidance_funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils import so3 4 | from utils.geometry import local_to_global 5 | 6 | 7 | def compute_frag_distance_loss(p1, p2, min_d, max_d, mode='frag_center_distance'): 8 | """ 9 | :param p1: (B, 3). center of fragment 1. 10 | :param p2: (B, 3). center of fragment 2. 11 | :param min_d: min distance 12 | :param max_d: max distance 13 | :param mode: constrained distance mode 14 | :return 15 | """ 16 | if mode == 'frag_center_distance': 17 | dist = torch.norm(p1 - p2, p=2, dim=-1) # (B, ) 18 | loss = torch.mean(torch.clamp(min_d - dist, min=0) ** 2 + torch.clamp(dist - max_d, min=0) ** 2) 19 | else: 20 | raise ValueError(mode) 21 | return loss 22 | 23 | 24 | def compute_anchor_prox_loss(x_linker, v1, v2, p1, p2, 25 | frags_local_pos_filled, frags_local_pos_mask, 26 | fragment_mask, cand_anchors_mask, batch, min_d=1.2, max_d=1.9): 27 | """ 28 | :param x_linker: (num_linker_atoms, 3) 29 | :param v1: (B, 3) 30 | :param v2: (B, 3) 31 | :param p1: (B, 3) 32 | :param p2: (B, 3) 33 | :param frags_local_pos_filled: (B, max_num_atoms, 3) 34 | :param frags_local_pos_mask: (B, max_num_atoms) of BoolTensor. 35 | :param fragment_mask: (N, ) 36 | :param cand_anchors_mask: (N, ) of BoolTensor. 37 | :param batch: (N, ) 38 | :param min_d 39 | :param max_d 40 | :return 41 | """ 42 | 43 | # transform rotation and translation to positions 44 | R1 = so3.so3vec_to_rotation(v1) 45 | R2 = so3.so3vec_to_rotation(v2) 46 | f1_local_pos = frags_local_pos_filled[::2] 47 | f2_local_pos = frags_local_pos_filled[1::2] 48 | f1_local_pos_mask = frags_local_pos_mask[::2] 49 | f2_local_pos_mask = frags_local_pos_mask[1::2] 50 | x_f1 = local_to_global(R1, p1, f1_local_pos)[f1_local_pos_mask] 51 | x_f2 = local_to_global(R2, p2, f2_local_pos)[f2_local_pos_mask] 52 | 53 | f1_anchor_mask = cand_anchors_mask[fragment_mask == 1] 54 | f2_anchor_mask = cand_anchors_mask[fragment_mask == 2] 55 | f1_batch, f2_batch = batch[fragment_mask == 1], batch[fragment_mask == 2] 56 | 57 | # approach 1: distance constraints on fragments only (unreasonable) 58 | # c_a1 = scatter_mean(x_f1[f1_anchor_mask], f1_batch[f1_anchor_mask], dim=0) # (B, 3) 59 | # c_a2 = scatter_mean(x_f2[f2_anchor_mask], f2_batch[f2_anchor_mask], dim=0) 60 | # c_na1 = scatter_mean(x_f1[~f1_anchor_mask], f1_batch[~f1_anchor_mask], dim=0) 61 | # c_na2 = scatter_mean(x_f2[~f2_anchor_mask], f2_batch[~f2_anchor_mask], dim=0) 62 | # loss = 0. 63 | # d_a1_a2 = torch.norm(c_a1 - c_a2, p=2, dim=-1) 64 | # if c_na1.size(0) > 0: 65 | # d_na1_a2 = torch.norm(c_na1 - c_a2, p=2, dim=-1) 66 | # loss += torch.mean(torch.clamp(d_a1_a2 - d_na1_a2, min=0)) 67 | # if c_na2.size(0) > 0: 68 | # d_a1_na2 = torch.norm(c_a1 - c_na2, p=2, dim=-1) 69 | # loss += torch.mean(torch.clamp(d_a1_a2 - d_a1_na2, min=0)) 70 | 71 | # approach 2: min dist of (linker, anchor) can form bond, (linker, non-anchor) cannot form bond 72 | linker_batch = batch[fragment_mask == 0] 73 | 74 | num_graphs = batch.max().item() + 1 75 | batch_losses = 0. 76 | for idx in range(num_graphs): 77 | linker_pos = x_linker[linker_batch == idx] 78 | loss_f1 = compute_prox_loss(x_f1[f1_batch == idx], linker_pos, f1_anchor_mask[f1_batch == idx], min_d, max_d) 79 | loss_f2 = compute_prox_loss(x_f2[f2_batch == idx], linker_pos, f2_anchor_mask[f2_batch == idx], min_d, max_d) 80 | batch_losses += loss_f1 + loss_f2 81 | 82 | return batch_losses / num_graphs 83 | 84 | 85 | def compute_prox_loss(frags_pos, linker_pos, anchor_mask, min_d=1.2, max_d=1.9): 86 | pairwise_dist = torch.norm(frags_pos[anchor_mask].unsqueeze(1) - linker_pos.unsqueeze(0), p=2, dim=-1) 87 | min_dist = pairwise_dist.min() 88 | # 1.2 < min dist < 1.9 89 | loss_anchor = torch.mean(torch.clamp(min_d - min_dist, min=0) ** 2 + torch.clamp(min_dist - max_d, min=0) ** 2) 90 | 91 | # non anchor min dist > 1.9 92 | loss_non_anchor = 0. 93 | non_anchor_pairwise_dist = torch.norm(frags_pos[~anchor_mask].unsqueeze(1) - linker_pos.unsqueeze(0), p=2, dim=-1) 94 | if non_anchor_pairwise_dist.size(0) > 0: 95 | non_anchor_min_dist = non_anchor_pairwise_dist.min() 96 | loss_non_anchor = torch.mean(torch.clamp(max_d - non_anchor_min_dist, min=0) ** 2) 97 | 98 | loss = loss_anchor + loss_non_anchor 99 | # print(f'loss anchor: {loss_anchor} loss non anchor: {loss_non_anchor}') 100 | return loss 101 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import logging 5 | import torch 6 | import numpy as np 7 | import yaml 8 | from easydict import EasyDict 9 | import sys 10 | import shutil 11 | import torch.utils.tensorboard 12 | 13 | 14 | class BlackHole(object): 15 | def __setattr__(self, name, value): 16 | pass 17 | def __call__(self, *args, **kwargs): 18 | return self 19 | def __getattr__(self, name): 20 | return self 21 | 22 | 23 | def load_config(path): 24 | with open(path, 'r') as f: 25 | return EasyDict(yaml.safe_load(f)) 26 | 27 | 28 | def get_logger(name, log_dir=None): 29 | logger = logging.getLogger(name) 30 | logger.setLevel(logging.DEBUG) 31 | formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s') 32 | 33 | stream_handler = logging.StreamHandler(sys.stdout) 34 | stream_handler.setLevel(logging.DEBUG) 35 | stream_handler.setFormatter(formatter) 36 | logger.addHandler(stream_handler) 37 | 38 | if log_dir is not None: 39 | file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt')) 40 | file_handler.setLevel(logging.DEBUG) 41 | file_handler.setFormatter(formatter) 42 | logger.addHandler(file_handler) 43 | 44 | return logger 45 | 46 | 47 | def get_new_log_dir(root='./logs', prefix='', tag=''): 48 | fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()) 49 | if prefix != '': 50 | fn = prefix + '_' + fn 51 | if tag != '': 52 | fn = fn + '_' + tag 53 | log_dir = os.path.join(root, fn) 54 | os.makedirs(log_dir) 55 | return log_dir 56 | 57 | 58 | def seed_all(seed): 59 | torch.manual_seed(seed) 60 | np.random.seed(seed) 61 | random.seed(seed) 62 | 63 | 64 | def log_hyperparams(writer, args): 65 | from torch.utils.tensorboard.summary import hparams 66 | vars_args = {k:v if isinstance(v, str) else repr(v) for k, v in vars(args).items()} 67 | exp, ssi, sei = hparams(vars_args, {}) 68 | writer.file_writer.add_summary(exp) 69 | writer.file_writer.add_summary(ssi) 70 | writer.file_writer.add_summary(sei) 71 | 72 | 73 | def int_tuple(argstr): 74 | return tuple(map(int, argstr.split(','))) 75 | 76 | 77 | def str_tuple(argstr): 78 | return tuple(argstr.split(',')) 79 | 80 | 81 | def count_parameters(model): 82 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 83 | 84 | 85 | def unique(x, dim=None): 86 | """Unique elements of x and indices of those unique elements 87 | https://github.com/pytorch/pytorch/issues/36748#issuecomment-619514810 88 | 89 | e.g. 90 | 91 | unique(tensor([ 92 | [1, 2, 3], 93 | [1, 2, 4], 94 | [1, 2, 3], 95 | [1, 2, 5] 96 | ]), dim=0) 97 | => (tensor([[1, 2, 3], 98 | [1, 2, 4], 99 | [1, 2, 5]]), 100 | tensor([0, 1, 3])) 101 | """ 102 | unique, inverse = torch.unique( 103 | x, sorted=True, return_inverse=True, dim=dim) 104 | perm = torch.arange(inverse.size(0), dtype=inverse.dtype, 105 | device=inverse.device) 106 | inverse, perm = inverse.flip([0]), perm.flip([0]) 107 | return unique, inverse.new_empty(unique.size(dim)).scatter_(0, inverse, perm) 108 | 109 | 110 | def setup_logdir(config, logdir, mode='train', tag='', create_dir=True): 111 | # Logging 112 | config_name = os.path.basename(config)[:os.path.basename(config).rfind('.')] 113 | log_dir = get_new_log_dir(logdir, tag=tag, prefix=config_name) if create_dir else logdir 114 | if mode == 'train': 115 | ckpt_dir = os.path.join(log_dir, 'checkpoints') 116 | vis_dir = os.path.join(log_dir, 'vis') 117 | os.makedirs(ckpt_dir, exist_ok=True) 118 | os.makedirs(vis_dir, exist_ok=True) 119 | logger = get_logger('train', log_dir) 120 | writer = torch.utils.tensorboard.SummaryWriter(log_dir) 121 | shutil.copytree('./models', os.path.join(log_dir, 'models')) 122 | elif mode == 'eval': 123 | logger = get_logger('eval', log_dir) 124 | writer, ckpt_dir, vis_dir = None, None, None 125 | else: 126 | raise ValueError 127 | logger.info(config) 128 | shutil.copyfile(config, os.path.join(log_dir, os.path.basename(config))) 129 | return logger, writer, log_dir, ckpt_dir, vis_dir 130 | -------------------------------------------------------------------------------- /utils/prior_num_atoms.py: -------------------------------------------------------------------------------- 1 | """Utils for sampling size of a linker.""" 2 | 3 | import numpy as np 4 | import pickle 5 | from collections import Counter 6 | 7 | 8 | def setup_configs(meta_path='utils/prior_num_atoms.pkl', mode='frag_center_distance'): 9 | with open(meta_path, 'rb') as f: 10 | prior_meta = pickle.load(f) 11 | all_dist = prior_meta[mode] 12 | all_n_atoms = prior_meta['num_linker_atoms'] 13 | bin_min, bin_max = np.floor(all_dist.min()), np.ceil(all_dist.max()) 14 | BINS = np.arange(bin_min, bin_max, 1.) 15 | CONFIGS = {'bounds': BINS, 'distributions': []} 16 | 17 | for min_d, max_d in zip(BINS[:-1], BINS[1:]): 18 | valid_idx = (min_d < all_dist) & (all_dist < max_d) 19 | c = Counter(all_n_atoms[valid_idx]) 20 | num_atoms_list, prob_list = list(c.keys()), np.array(list(c.values())) / np.sum(list(c.values())) 21 | CONFIGS['distributions'].append((num_atoms_list, prob_list)) 22 | return CONFIGS 23 | 24 | 25 | def _get_bin_idx(distance, config_dict): 26 | bounds = config_dict['bounds'] 27 | for i in range(len(bounds) - 1): 28 | if distance < bounds[i + 1]: 29 | return i 30 | return len(bounds) - 2 31 | 32 | 33 | def sample_atom_num(distance, config_dict): 34 | bin_idx = _get_bin_idx(distance, config_dict) 35 | num_atom_list, prob_list = config_dict['distributions'][bin_idx] 36 | return np.random.choice(num_atom_list, p=prob_list) 37 | -------------------------------------------------------------------------------- /utils/reconstruct_linker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.spatial.distance import cdist 4 | 5 | from rdkit import Chem, Geometry 6 | 7 | from utils import const 8 | from copy import deepcopy 9 | import re 10 | import itertools 11 | 12 | 13 | class MolReconsError(Exception): 14 | pass 15 | 16 | 17 | def get_bond_order(atom1, atom2, distance, check_exists=True, margins=const.MARGINS_EDM): 18 | distance = 100 * distance # We change the metric 19 | 20 | # Check exists for large molecules where some atom pairs do not have a 21 | # typical bond length. 22 | if check_exists: 23 | if atom1 not in const.BONDS_1: 24 | return 0 25 | if atom2 not in const.BONDS_1[atom1]: 26 | return 0 27 | 28 | # margin1, margin2 and margin3 have been tuned to maximize the stability of the QM9 true samples 29 | if distance < const.BONDS_1[atom1][atom2] + margins[0]: 30 | 31 | # Check if atoms in bonds2 dictionary. 32 | if atom1 in const.BONDS_2 and atom2 in const.BONDS_2[atom1]: 33 | thr_bond2 = const.BONDS_2[atom1][atom2] + margins[1] 34 | if distance < thr_bond2: 35 | if atom1 in const.BONDS_3 and atom2 in const.BONDS_3[atom1]: 36 | thr_bond3 = const.BONDS_3[atom1][atom2] + margins[2] 37 | if distance < thr_bond3: 38 | return 3 # Triple 39 | return 2 # Double 40 | return 1 # Single 41 | return 0 # No bond 42 | 43 | 44 | def create_conformer(coords): 45 | conformer = Chem.Conformer() 46 | for i, (x, y, z) in enumerate(coords): 47 | conformer.SetAtomPosition(i, Geometry.Point3D(x, y, z)) 48 | return conformer 49 | 50 | 51 | def build_linker_mol(linker_x, linker_c, add_bonds=False): 52 | mol = Chem.RWMol() 53 | for atom in linker_c: 54 | a = Chem.Atom(int(atom)) 55 | mol.AddAtom(a) 56 | 57 | if add_bonds: 58 | # predict bond order 59 | n = len(linker_x) 60 | ptable = Chem.GetPeriodicTable() 61 | dists = cdist(linker_x, linker_x, 'euclidean') 62 | E = torch.zeros((n, n), dtype=torch.int) 63 | A = torch.zeros((n, n), dtype=torch.bool) 64 | for i in range(n): 65 | for j in range(i): 66 | atom_i = ptable.GetElementSymbol(int(linker_c[i])) 67 | atom_j = ptable.GetElementSymbol(int(linker_c[j])) 68 | order = get_bond_order(atom_i, atom_j, dists[i, j], margins=const.MARGINS_EDM) 69 | if order > 0: 70 | A[i, j] = 1 71 | E[i, j] = order 72 | all_bonds = torch.nonzero(A) 73 | for bond in all_bonds: 74 | mol.AddBond(bond[0].item(), bond[1].item(), const.BOND_DICT[E[bond[0], bond[1]].item()]) 75 | return mol 76 | 77 | 78 | def reconstruct_mol(frag_mol, frag_x, linker_x, linker_c): 79 | # construct linker mol 80 | linker_mol = build_linker_mol(linker_x, linker_c, add_bonds=True) 81 | 82 | # combine mol and assign conformer 83 | mol = Chem.CombineMols(frag_mol, linker_mol) 84 | mol = Chem.RWMol(mol) 85 | # frag_x = frag_mol.GetConformer().GetPositions() 86 | all_x = np.concatenate([frag_x, linker_x], axis=0) 87 | mol.RemoveAllConformers() 88 | mol.AddConformer(create_conformer(all_x)) 89 | 90 | # add frag-linker bond 91 | n_frag_atoms = frag_mol.GetNumAtoms() 92 | n = mol.GetNumAtoms() 93 | 94 | dists = cdist(all_x, all_x, 'euclidean') 95 | E = torch.zeros((n, n), dtype=torch.int) 96 | A = torch.zeros((n, n), dtype=torch.bool) 97 | for i in range(n_frag_atoms): 98 | for j in range(n_frag_atoms, n): 99 | atom_i = mol.GetAtomWithIdx(i).GetSymbol() 100 | atom_j = mol.GetAtomWithIdx(j).GetSymbol() 101 | order = get_bond_order(atom_i, atom_j, dists[i, j], margins=const.MARGINS_EDM) 102 | if order > 0: 103 | A[i, j] = 1 104 | E[i, j] = order 105 | all_bonds = torch.nonzero(A) 106 | for bond in all_bonds: 107 | mol.AddBond(bond[0].item(), bond[1].item(), const.BOND_DICT[E[bond[0], bond[1]].item()]) 108 | 109 | # frag_c = [atom.GetAtomicNum() for atom in frag_mol.GetAtoms()] 110 | # all_x = np.concatenate([frag_x, linker_x], axis=0) 111 | # all_c = np.concatenate([frag_c, linker_c], axis=0) 112 | # print('all c: ', all_c) 113 | # mol = build_linker_mol(all_x, all_c) 114 | # 115 | # try: 116 | # Chem.SanitizeMol(mol) 117 | # fixed = True 118 | # except Exception as e: 119 | # fixed = False 120 | # 121 | # if not fixed: 122 | # mol, fixed = fix_valence(mol) 123 | return mol 124 | 125 | 126 | def reconstruct_mol_with_bond(frag_mol, frag_x, linker_x, linker_c, 127 | linker_bond_index, linker_bond_type, known_linker_bonds=True, check_validity=True): 128 | # construct linker mol 129 | linker_mol = build_linker_mol(linker_x, linker_c, add_bonds=known_linker_bonds) 130 | 131 | # combine mol and assign conformer 132 | mol = Chem.CombineMols(frag_mol, linker_mol) 133 | linker_atom_idx = list(range(mol.GetNumAtoms() - linker_mol.GetNumAtoms(), mol.GetNumAtoms())) 134 | mol = Chem.RWMol(mol) 135 | # frag_x = frag_mol.GetConformer().GetPositions() 136 | all_x = np.concatenate([frag_x, linker_x], axis=0) 137 | mol.RemoveAllConformers() 138 | mol.AddConformer(create_conformer(all_x)) 139 | 140 | linker_bond_index, linker_bond_type = linker_bond_index.tolist(), linker_bond_type.tolist() 141 | anchor_indices = set() 142 | # add bonds 143 | for i, type_this in enumerate(linker_bond_type): 144 | node_i, node_j = linker_bond_index[0][i], linker_bond_index[1][i] 145 | if node_i < node_j: 146 | if type_this == 0: 147 | continue 148 | else: 149 | if node_i in linker_atom_idx and node_j not in linker_atom_idx: 150 | anchor_indices.add(int(node_j)) 151 | elif node_j in linker_atom_idx and node_i not in linker_atom_idx: 152 | anchor_indices.add(int(node_i)) 153 | 154 | if type_this == 1: 155 | mol.AddBond(node_i, node_j, Chem.BondType.SINGLE) 156 | elif type_this == 2: 157 | mol.AddBond(node_i, node_j, Chem.BondType.DOUBLE) 158 | elif type_this == 3: 159 | mol.AddBond(node_i, node_j, Chem.BondType.TRIPLE) 160 | elif type_this == 4: 161 | mol.AddBond(node_i, node_j, Chem.BondType.AROMATIC) 162 | else: 163 | raise Exception('unknown bond order {}'.format(type_this)) 164 | 165 | mol = mol.GetMol() 166 | for anchor_idx in anchor_indices: 167 | atom = mol.GetAtomWithIdx(anchor_idx) 168 | atom.SetNumExplicitHs(0) 169 | 170 | if check_validity: 171 | mol = fix_validity(mol, linker_atom_idx) 172 | 173 | # check valid 174 | # rd_mol_check = Chem.MolFromSmiles(Chem.MolToSmiles(mol)) 175 | # if (rd_mol_check is None) and check_validity: 176 | # raise MolReconsError() 177 | return mol 178 | 179 | 180 | def fix_validity(mol, linker_atom_idx): 181 | try: 182 | Chem.SanitizeMol(mol) 183 | fixed = True 184 | except Exception as e: 185 | fixed = False 186 | 187 | if not fixed: 188 | try: 189 | Chem.Kekulize(deepcopy(mol)) 190 | except Chem.rdchem.KekulizeException as e: 191 | err = e 192 | if 'Unkekulized' in err.args[0]: 193 | mol, fixed = fix_aromatic(mol) 194 | 195 | # valence error for N 196 | if not fixed: 197 | mol, fixed = fix_valence(mol, linker_atom_idx) 198 | 199 | # print('s2') 200 | if not fixed: 201 | mol, fixed = fix_aromatic(mol, True, linker_atom_idx) 202 | 203 | try: 204 | Chem.SanitizeMol(mol) 205 | except Exception as e: 206 | # raise MolReconsError() 207 | return None 208 | return mol 209 | 210 | 211 | def fix_valence(mol, linker_atom_idx=None): 212 | mol = deepcopy(mol) 213 | fixed = False 214 | cnt_loop = 0 215 | while True: 216 | try: 217 | Chem.SanitizeMol(mol) 218 | fixed = True 219 | break 220 | except Chem.rdchem.AtomValenceException as e: 221 | err = e 222 | except Exception as e: 223 | return mol, False # from HERE: rerun sample 224 | cnt_loop += 1 225 | if cnt_loop > 100: 226 | break 227 | N4_valence = re.compile(u"Explicit valence for atom # ([0-9]{1,}) N, 4, is greater than permitted") 228 | index = N4_valence.findall(err.args[0]) 229 | if len(index) > 0: 230 | if linker_atom_idx is None or int(index[0]) in linker_atom_idx: 231 | mol.GetAtomWithIdx(int(index[0])).SetFormalCharge(1) 232 | return mol, fixed 233 | 234 | 235 | def get_ring_sys(mol): 236 | all_rings = Chem.GetSymmSSSR(mol) 237 | if len(all_rings) == 0: 238 | ring_sys_list = [] 239 | else: 240 | ring_sys_list = [all_rings[0]] 241 | for ring in all_rings[1:]: 242 | form_prev = False 243 | for prev_ring in ring_sys_list: 244 | if set(ring).intersection(set(prev_ring)): 245 | prev_ring.extend(ring) 246 | form_prev = True 247 | break 248 | if not form_prev: 249 | ring_sys_list.append(ring) 250 | ring_sys_list = [list(set(x)) for x in ring_sys_list] 251 | return ring_sys_list 252 | 253 | 254 | def get_all_subsets(ring_list): 255 | all_sub_list = [] 256 | for n_sub in range(len(ring_list)+1): 257 | all_sub_list.extend(itertools.combinations(ring_list, n_sub)) 258 | return all_sub_list 259 | 260 | 261 | def fix_aromatic(mol, strict=False, linker_atom_idx=None): 262 | mol_orig = mol 263 | atomatic_list = [a.GetIdx() for a in mol.GetAromaticAtoms()] 264 | N_ring_list = [] 265 | S_ring_list = [] 266 | for ring_sys in get_ring_sys(mol): 267 | if set(ring_sys).intersection(set(atomatic_list)): 268 | idx_N = [atom for atom in ring_sys if mol.GetAtomWithIdx(atom).GetSymbol() == 'N'] 269 | if len(idx_N) > 0: 270 | idx_N.append(-1) # -1 for not add to this loop 271 | N_ring_list.append(idx_N) 272 | idx_S = [atom for atom in ring_sys if mol.GetAtomWithIdx(atom).GetSymbol() == 'S'] 273 | if len(idx_S) > 0: 274 | idx_S.append(-1) # -1 for not add to this loop 275 | S_ring_list.append(idx_S) 276 | # enumerate S 277 | fixed = False 278 | if strict: 279 | S_ring_list = [s for ring in S_ring_list for s in ring if s != -1] 280 | permutation = get_all_subsets(S_ring_list) 281 | else: 282 | permutation = list(itertools.product(*S_ring_list)) 283 | for perm in permutation: 284 | mol = deepcopy(mol_orig) 285 | perm = [x for x in perm if x != -1] 286 | for idx in perm: 287 | if linker_atom_idx is None or idx in linker_atom_idx: 288 | mol.GetAtomWithIdx(idx).SetFormalCharge(1) 289 | try: 290 | if strict: 291 | mol, fixed = fix_valence(mol, linker_atom_idx) 292 | Chem.SanitizeMol(mol) 293 | fixed = True 294 | break 295 | except: 296 | continue 297 | # enumerate N 298 | if not fixed: 299 | if strict: 300 | N_ring_list = [s for ring in N_ring_list for s in ring if s != -1] 301 | permutation = get_all_subsets(N_ring_list) 302 | else: 303 | permutation = list(itertools.product(*N_ring_list)) 304 | for perm in permutation: # each ring select one atom 305 | perm = [x for x in perm if x != -1] 306 | # print(perm) 307 | actions = itertools.product([0, 1], repeat=len(perm)) 308 | for action in actions: # add H or charge 309 | mol = deepcopy(mol_orig) 310 | for idx, act_atom in zip(perm, action): 311 | if linker_atom_idx is None or idx in linker_atom_idx: 312 | if act_atom == 0: 313 | mol.GetAtomWithIdx(idx).SetNumExplicitHs(1) 314 | else: 315 | mol.GetAtomWithIdx(idx).SetFormalCharge(1) 316 | try: 317 | if strict: 318 | mol, fixed = fix_valence(mol, linker_atom_idx) 319 | Chem.SanitizeMol(mol) 320 | fixed = True 321 | break 322 | except: 323 | continue 324 | if fixed: 325 | break 326 | return mol, fixed 327 | 328 | 329 | def parse_sampling_result(data_list, final_x, final_c, atom_featurizer): 330 | all_mols = [] 331 | for data, x_gen, c_gen in zip(data_list, final_x, final_c): 332 | frag_pos = x_gen[data.fragment_mask > 0].cpu().numpy().astype(np.float64) 333 | linker_pos = x_gen[data.fragment_mask == 0].cpu().numpy().astype(np.float64) 334 | linker_ele = [atom_featurizer.get_element_from_index(int(c)) for c in c_gen[data.linker_mask]] 335 | full_mol = reconstruct_mol(data.frag_mol, frag_pos, linker_pos, linker_ele) 336 | all_mols.append(full_mol) 337 | return all_mols 338 | 339 | 340 | def parse_sampling_result_with_bond(data_list, final_x, final_c, final_bond, atom_featurizer, 341 | known_linker_bonds=True, check_validity=False): 342 | all_mols = [] 343 | if not isinstance(data_list, list): 344 | data_list = [data_list] * len(final_x) 345 | for data, x_gen, c_gen, b_gen in zip(data_list, final_x, final_c, final_bond): 346 | frag_pos = x_gen[data.fragment_mask > 0].cpu().numpy().astype(np.float64) 347 | linker_pos = x_gen[data.fragment_mask == 0].cpu().numpy().astype(np.float64) 348 | linker_ele = [atom_featurizer.get_element_from_index(int(c)) for c in c_gen[data.linker_mask]] 349 | linker_bond_type = b_gen[data.linker_bond_mask] 350 | linker_bond_index = data.edge_index[:, data.linker_bond_mask] 351 | full_mol = reconstruct_mol_with_bond( 352 | data.frag_mol, frag_pos, linker_pos, linker_ele, linker_bond_index, linker_bond_type, 353 | known_linker_bonds, check_validity) 354 | all_mols.append(full_mol) 355 | return all_mols 356 | -------------------------------------------------------------------------------- /utils/sascorer.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of synthetic accessibility score as described in: 3 | # 4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions 5 | # Peter Ertl and Ansgar Schuffenhauer 6 | # Journal of Cheminformatics 1:8 (2009) 7 | # http://www.jcheminf.com/content/1/1/8 8 | # 9 | # several small modifications to the original paper are included 10 | # particularly slightly different formula for marocyclic penalty 11 | # and taking into account also molecule symmetry (fingerprint density) 12 | # 13 | # for a set of 10k diverse molecules the agreement between the original method 14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97 15 | # 16 | # peter ertl & greg landrum, september 2013 17 | # 18 | from __future__ import print_function 19 | 20 | from rdkit import Chem 21 | from rdkit.Chem import rdMolDescriptors 22 | from rdkit.six.moves import cPickle 23 | from rdkit.six import iteritems 24 | 25 | import math 26 | from collections import defaultdict 27 | 28 | import os.path as op 29 | 30 | _fscores = None 31 | 32 | 33 | def readFragmentScores(name='fpscores'): 34 | import gzip 35 | global _fscores 36 | # generate the full path filename: 37 | if name == "fpscores": 38 | name = op.join(op.dirname(__file__), name) 39 | _fscores = cPickle.load(gzip.open('%s.pkl.gz' % name)) 40 | outDict = {} 41 | for i in _fscores: 42 | for j in range(1, len(i)): 43 | outDict[i[j]] = float(i[0]) 44 | _fscores = outDict 45 | 46 | 47 | def numBridgeheadsAndSpiro(mol, ri=None): 48 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 49 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 50 | return nBridgehead, nSpiro 51 | 52 | 53 | def calculateScore(m): 54 | if _fscores is None: 55 | readFragmentScores() 56 | 57 | # fragment score 58 | fp = rdMolDescriptors.GetMorganFingerprint(m, 59 | 2) #<- 2 is the *radius* of the circular fingerprint 60 | fps = fp.GetNonzeroElements() 61 | score1 = 0. 62 | nf = 0 63 | for bitId, v in iteritems(fps): 64 | nf += v 65 | sfp = bitId 66 | score1 += _fscores.get(sfp, -4) * v 67 | score1 /= nf 68 | 69 | # features score 70 | nAtoms = m.GetNumAtoms() 71 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 72 | ri = m.GetRingInfo() 73 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 74 | nMacrocycles = 0 75 | for x in ri.AtomRings(): 76 | if len(x) > 8: 77 | nMacrocycles += 1 78 | 79 | sizePenalty = nAtoms**1.005 - nAtoms 80 | stereoPenalty = math.log10(nChiralCenters + 1) 81 | spiroPenalty = math.log10(nSpiro + 1) 82 | bridgePenalty = math.log10(nBridgeheads + 1) 83 | macrocyclePenalty = 0. 84 | # --------------------------------------- 85 | # This differs from the paper, which defines: 86 | # macrocyclePenalty = math.log10(nMacrocycles+1) 87 | # This form generates better results when 2 or more macrocycles are present 88 | if nMacrocycles > 0: 89 | macrocyclePenalty = math.log10(2) 90 | 91 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty 92 | 93 | # correction for the fingerprint density 94 | # not in the original publication, added in version 1.1 95 | # to make highly symmetrical molecules easier to synthetise 96 | score3 = 0. 97 | if nAtoms > len(fps): 98 | score3 = math.log(float(nAtoms) / len(fps)) * .5 99 | 100 | sascore = score1 + score2 + score3 101 | 102 | # need to transform "raw" value into scale between 1 and 10 103 | min = -4.0 104 | max = 2.5 105 | sascore = 11. - (sascore - min + 1) / (max - min) * 9. 106 | # smooth the 10-end 107 | if sascore > 8.: 108 | sascore = 8. + math.log(sascore + 1. - 9.) 109 | if sascore > 10.: 110 | sascore = 10.0 111 | elif sascore < 1.: 112 | sascore = 1.0 113 | 114 | return sascore 115 | 116 | 117 | def processMols(mols): 118 | print('smiles\tName\tsa_score') 119 | for i, m in enumerate(mols): 120 | if m is None: 121 | continue 122 | 123 | s = calculateScore(m) 124 | 125 | smiles = Chem.MolToSmiles(m) 126 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 127 | 128 | 129 | if __name__ == '__main__': 130 | import sys, time 131 | 132 | t1 = time.time() 133 | readFragmentScores("fpscores") 134 | t2 = time.time() 135 | 136 | suppl = Chem.SmilesMolSupplier(sys.argv[1]) 137 | t3 = time.time() 138 | processMols(suppl) 139 | t4 = time.time() 140 | 141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)), 142 | file=sys.stderr) 143 | 144 | # 145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 146 | # All rights reserved. 147 | # 148 | # Redistribution and use in source and binary forms, with or without 149 | # modification, are permitted provided that the following conditions are 150 | # met: 151 | # 152 | # * Redistributions of source code must retain the above copyright 153 | # notice, this list of conditions and the following disclaimer. 154 | # * Redistributions in binary form must reproduce the above 155 | # copyright notice, this list of conditions and the following 156 | # disclaimer in the documentation and/or other materials provided 157 | # with the distribution. 158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 159 | # nor the names of its contributors may be used to endorse or promote 160 | # products derived from this software without specific prior written permission. 161 | # 162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 173 | # 174 | -------------------------------------------------------------------------------- /utils/so3.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from utils.geometry import quaternion_to_rotation_matrix 8 | 9 | 10 | def log_rotation(R): 11 | trace = R[..., range(3), range(3)].sum(-1) 12 | if torch.is_grad_enabled(): 13 | # The derivative of acos at -1.0 is -inf, so to stablize the gradient, we use -0.9999 14 | min_cos = -0.999 15 | else: 16 | min_cos = -1.0 17 | cos_theta = ((trace - 1) / 2).clamp_min(min=min_cos) 18 | sin_theta = torch.sqrt(1 - cos_theta ** 2) 19 | theta = torch.acos(cos_theta) 20 | coef = ((theta + 1e-8) / (2 * sin_theta + 2e-8))[..., None, None] 21 | logR = coef * (R - R.transpose(-1, -2)) 22 | return logR 23 | 24 | 25 | def skewsym_to_so3vec(S): 26 | x = S[..., 1, 2] 27 | y = S[..., 2, 0] 28 | z = S[..., 0, 1] 29 | w = torch.stack([x, y, z], dim=-1) 30 | return w 31 | 32 | 33 | def so3vec_to_skewsym(w): 34 | x, y, z = torch.unbind(w, dim=-1) 35 | o = torch.zeros_like(x) 36 | S = torch.stack([ 37 | o, z, -y, 38 | -z, o, x, 39 | y, -x, o, 40 | ], dim=-1).reshape(w.shape[:-1] + (3, 3)) 41 | return S 42 | 43 | 44 | def exp_skewsym(S): 45 | x = torch.linalg.norm(skewsym_to_so3vec(S), dim=-1) 46 | I = torch.eye(3).to(S).view([1 for _ in range(S.dim() - 2)] + [3, 3]) 47 | 48 | sinx, cosx = torch.sin(x), torch.cos(x) 49 | b = (sinx + 1e-8) / (x + 1e-8) 50 | c = (1 - cosx + 1e-8) / (x ** 2 + 2e-8) # lim_{x->0} (1-cosx)/(x^2) = 0.5 51 | 52 | S2 = S @ S 53 | return I + b[..., None, None] * S + c[..., None, None] * S2 54 | 55 | 56 | def so3vec_to_rotation(w): 57 | return exp_skewsym(so3vec_to_skewsym(w)) 58 | 59 | 60 | def rotation_to_so3vec(R): 61 | logR = log_rotation(R) 62 | w = skewsym_to_so3vec(logR) 63 | return w 64 | 65 | 66 | def random_uniform_so3(size, device='cpu'): 67 | q = F.normalize(torch.randn(list(size) + [4, ], device=device), dim=-1) # (..., 4) 68 | return rotation_to_so3vec(quaternion_to_rotation_matrix(q)) 69 | 70 | 71 | class ApproxAngularDistribution(nn.Module): 72 | # todo: interpolation 73 | def __init__(self, stddevs, std_threshold=0.1, num_bins=8192, num_iters=1024): 74 | super().__init__() 75 | self.std_threshold = std_threshold 76 | self.num_bins = num_bins 77 | self.num_iters = num_iters 78 | self.register_buffer('stddevs', torch.FloatTensor(stddevs)) 79 | self.register_buffer('approx_flag', self.stddevs <= std_threshold) 80 | self._precompute_histograms() 81 | 82 | @staticmethod 83 | def _pdf(x, e, L): 84 | """ 85 | Args: 86 | x: (N, ) 87 | e: Float 88 | L: Integer 89 | """ 90 | x = x[:, None] # (N, *) 91 | c = ((1 - torch.cos(x)) / math.pi) # (N, *) 92 | l = torch.arange(0, L)[None, :] # (*, L) 93 | a = (2 * l + 1) * torch.exp(-l * (l + 1) * (e ** 2)) # (*, L) 94 | b = (torch.sin((l + 0.5) * x) + 1e-6) / (torch.sin(x / 2) + 1e-6) # (N, L) 95 | 96 | f = (c * a * b).sum(dim=1) 97 | return f 98 | 99 | def _precompute_histograms(self): 100 | X, Y = [], [] 101 | for std in self.stddevs: 102 | std = std.item() 103 | x = torch.linspace(0, math.pi, self.num_bins) # (n_bins,) 104 | y = self._pdf(x, std, self.num_iters) # (n_bins,) 105 | y = torch.nan_to_num(y).clamp_min(0) 106 | X.append(x) 107 | Y.append(y) 108 | self.register_buffer('X', torch.stack(X, dim=0)) # (n_stddevs, n_bins) 109 | self.register_buffer('Y', torch.stack(Y, dim=0)) # (n_stddevs, n_bins) 110 | 111 | def sample(self, std_idx): 112 | """ 113 | Args: 114 | std_idx: Indices of standard deviation. 115 | Returns: 116 | samples: Angular samples [0, PI), same size as std. 117 | """ 118 | size = std_idx.size() 119 | std_idx = std_idx.flatten() # (N,) 120 | 121 | # Samples from histogram 122 | prob = self.Y[std_idx] # (N, n_bins) 123 | bin_idx = torch.multinomial(prob[:, :-1], num_samples=1).squeeze(-1) # (N,) 124 | bin_start = self.X[std_idx, bin_idx] # (N,) 125 | bin_width = self.X[std_idx, bin_idx + 1] - self.X[std_idx, bin_idx] 126 | samples_hist = bin_start + torch.rand_like(bin_start) * bin_width # (N,) 127 | 128 | # Samples from Gaussian approximation 129 | mean_gaussian = self.stddevs[std_idx] * 2 130 | std_gaussian = self.stddevs[std_idx] 131 | samples_gaussian = mean_gaussian + torch.randn_like(mean_gaussian) * std_gaussian 132 | samples_gaussian = samples_gaussian.abs() % math.pi 133 | 134 | # Choose from histogram or Gaussian 135 | gaussian_flag = self.approx_flag[std_idx] 136 | samples = torch.where(gaussian_flag, samples_gaussian, samples_hist) 137 | 138 | return samples.reshape(size) 139 | 140 | 141 | def random_normal_so3(std_idx, angular_distrib, device='cpu'): 142 | size = std_idx.size() 143 | u = F.normalize(torch.randn(list(size) + [3, ], device=device), dim=-1) 144 | theta = angular_distrib.sample(std_idx) 145 | w = u * theta[..., None] 146 | return w 147 | -------------------------------------------------------------------------------- /utils/train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch_geometric.data import Data, Batch 4 | from utils.misc import BlackHole 5 | from easydict import EasyDict 6 | 7 | 8 | def repeat_data(data: Data, num_repeat) -> Batch: 9 | datas = [copy.deepcopy(data) for i in range(num_repeat)] 10 | return Batch.from_data_list(datas) 11 | 12 | 13 | def repeat_batch(batch: Batch, num_repeat) -> Batch: 14 | datas = batch.to_data_list() 15 | new_data = [] 16 | for i in range(num_repeat): 17 | new_data += copy.deepcopy(datas) 18 | return Batch.from_data_list(new_data) 19 | 20 | 21 | def inf_iterator(iterable): 22 | iterator = iterable.__iter__() 23 | while True: 24 | try: 25 | yield iterator.__next__() 26 | except StopIteration: 27 | iterator = iterable.__iter__() 28 | 29 | 30 | def get_optimizer(cfg, model): 31 | if cfg.type == 'adam': 32 | return torch.optim.Adam( 33 | model.parameters(), 34 | lr=cfg.lr, 35 | weight_decay=cfg.weight_decay, 36 | betas=(cfg.beta1, cfg.beta2, ) 37 | ) 38 | elif cfg.type == 'adamw': 39 | return torch.optim.AdamW( 40 | model.parameters(), 41 | lr=cfg.lr, 42 | weight_decay=cfg.weight_decay, 43 | betas=(cfg.beta1, cfg.beta2, ) 44 | ) 45 | else: 46 | raise NotImplementedError('Optimizer not supported: %s' % cfg.type) 47 | 48 | 49 | def get_scheduler(cfg, optimizer): 50 | if cfg.type == 'plateau': 51 | return torch.optim.lr_scheduler.ReduceLROnPlateau( 52 | optimizer, 53 | factor=cfg.factor, 54 | patience=cfg.patience, 55 | min_lr=cfg.min_lr 56 | ) 57 | else: 58 | raise NotImplementedError('Scheduler not supported: %s' % cfg.type) 59 | 60 | 61 | def sum_weighted_losses(losses, weights): 62 | """ 63 | Args: 64 | losses: Dict of scalar tensors. 65 | weights: Dict of weights. 66 | """ 67 | loss = 0 68 | for k in losses.keys(): 69 | if weights is None: 70 | loss = loss + losses[k] 71 | else: 72 | loss = loss + weights[k] * losses[k] 73 | return loss 74 | 75 | 76 | def log_losses(out, it, tag, train_report_iter=1, logger=BlackHole(), writer=BlackHole(), others={}): 77 | if it % train_report_iter == 0: 78 | logstr = '[%s] Iter %05d' % (tag, it) 79 | logstr += ' | loss %.4f' % out['overall'].item() 80 | for k, v in out.items(): 81 | if k == 'overall': continue 82 | logstr += ' | loss(%s) %.4f' % (k, v.item()) 83 | for k, v in others.items(): 84 | if k == 'lr': 85 | logstr += ' | %s %2.6f' % (k, v) 86 | else: 87 | logstr += ' | %s %2.4f' % (k, v) 88 | logger.info(logstr) 89 | 90 | for k, v in out.items(): 91 | if k == 'overall': 92 | writer.add_scalar('%s/loss' % tag, v, it) 93 | else: 94 | writer.add_scalar('%s/loss_%s' % (tag, k), v, it) 95 | for k, v in others.items(): 96 | writer.add_scalar('%s/%s' % (tag, k), v, it) 97 | writer.flush() 98 | 99 | 100 | class ValidationLossTape(object): 101 | 102 | def __init__(self): 103 | super().__init__() 104 | self.accumulate = {} 105 | self.others = {} 106 | self.total = 0 107 | 108 | def update(self, out, n, others={}): 109 | self.total += n 110 | for k, v in out.items(): 111 | if k not in self.accumulate: 112 | self.accumulate[k] = v.clone().detach() 113 | else: 114 | self.accumulate[k] += v.clone().detach() 115 | 116 | for k, v in others.items(): 117 | if k not in self.others: 118 | self.others[k] = v.clone().detach() 119 | else: 120 | self.others[k] += v.clone().detach() 121 | 122 | def log(self, it, logger=BlackHole(), writer=BlackHole(), tag='val'): 123 | avg = EasyDict({k: v / self.total for k, v in self.accumulate.items()}) 124 | avg_others = EasyDict({k: v / self.total for k, v in self.others.items()}) 125 | log_losses(avg, it, tag, logger=logger, writer=writer, others=avg_others) 126 | return avg['overall'] 127 | -------------------------------------------------------------------------------- /utils/train_linker_smiles.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guanjq/LinkerNet/280759c16ccecece0d81ab9cebe3f44041b80e51/utils/train_linker_smiles.pkl -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from rdkit import Chem 5 | from torch_geometric.transforms import BaseTransform 6 | 7 | from utils import so3 8 | from utils.geometry import local_to_global, global_to_local, rotation_matrix_from_vectors 9 | 10 | 11 | # required by model: 12 | # x_noisy, c_noisy, atom_feat, edge_index, edge_feat, 13 | # R1_noisy, R2_noisy, R1, R2, eps_p1, eps_p2, x_0, c_0, t, fragment_mask, batch 14 | 15 | 16 | def modify_frags_conformer(frags_local_pos, frag_idx_mask, v_frags, p_frags): 17 | R_frags = so3.so3vec_to_rotation(v_frags) 18 | x_frags = torch.zeros_like(frags_local_pos) 19 | for i in range(2): 20 | noisy_pos = local_to_global(R_frags[i], p_frags[i], 21 | frags_local_pos[frag_idx_mask == i + 1]) 22 | x_frags[frag_idx_mask == i + 1] = noisy_pos 23 | return x_frags 24 | 25 | 26 | def dataset_info(dataset): # qm9, zinc, cep 27 | if dataset == 'qm9': 28 | return {'atom_types': ["H", "C", "N", "O", "F"], 29 | 'maximum_valence': {0: 1, 1: 4, 2: 3, 3: 2, 4: 1}, 30 | 'number_to_atom': {0: "H", 1: "C", 2: "N", 3: "O", 4: "F"}, 31 | 'bucket_sizes': np.array(list(range(4, 28, 2)) + [29]) 32 | } 33 | elif dataset == 'zinc' or dataset == 'protac': 34 | return {'atom_types': ['Br1(0)', 'C4(0)', 'Cl1(0)', 'F1(0)', 'H1(0)', 'I1(0)', 35 | 'N2(-1)', 'N3(0)', 'N4(1)', 'O1(-1)', 'O2(0)', 'S2(0)', 'S4(0)', 'S6(0)'], 36 | 'maximum_valence': {0: 1, 1: 4, 2: 1, 3: 1, 4: 1, 5: 1, 6: 2, 7: 3, 8: 4, 9: 1, 10: 2, 11: 2, 12: 4, 37 | 13: 6, 14: 3}, 38 | 'number_to_atom': {0: 'Br', 1: 'C', 2: 'Cl', 3: 'F', 4: 'H', 5: 'I', 6: 'N', 7: 'N', 8: 'N', 9: 'O', 39 | 10: 'O', 11: 'S', 12: 'S', 13: 'S'}, 40 | 'bucket_sizes': np.array( 41 | [28, 31, 33, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 55, 58, 84]) 42 | } 43 | 44 | elif dataset == "cep": 45 | return {'atom_types': ["C", "S", "N", "O", "Se", "Si"], 46 | 'maximum_valence': {0: 4, 1: 2, 2: 3, 3: 2, 4: 2, 5: 4}, 47 | 'number_to_atom': {0: "C", 1: "S", 2: "N", 3: "O", 4: "Se", 5: "Si"}, 48 | 'bucket_sizes': np.array([25, 28, 29, 30, 32, 33, 34, 35, 36, 37, 38, 39, 43, 46]) 49 | } 50 | else: 51 | print("the datasets in use are qm9|zinc|cep") 52 | exit(1) 53 | 54 | 55 | class FeaturizeAtom(BaseTransform): 56 | 57 | def __init__(self, dataset_name, known_anchor=False, 58 | add_atom_type=True, add_atom_feat=True): 59 | super().__init__() 60 | self.dataset_name = dataset_name 61 | self.dataset_info = dataset_info(dataset_name) 62 | self.known_anchor = known_anchor 63 | self.add_atom_type = add_atom_type 64 | self.add_atom_feat = add_atom_feat 65 | 66 | @property 67 | def num_classes(self): 68 | return len(self.dataset_info['atom_types']) 69 | 70 | @property 71 | def feature_dim(self): 72 | n_feat_dim = 2 73 | if self.known_anchor: 74 | n_feat_dim += 2 75 | return n_feat_dim 76 | 77 | def get_index(self, atom_num, valence, charge): 78 | if self.dataset_name in ['zinc', 'protac']: 79 | pt = Chem.GetPeriodicTable() 80 | atom_str = "%s%i(%i)" % (pt.GetElementSymbol(atom_num), valence, charge) 81 | return self.dataset_info['atom_types'].index(atom_str) 82 | else: 83 | raise ValueError 84 | 85 | def get_element_from_index(self, index): 86 | pt = Chem.GetPeriodicTable() 87 | symb = self.dataset_info['number_to_atom'][index] 88 | return pt.GetAtomicNumber(symb) 89 | 90 | def __call__(self, data): 91 | if self.add_atom_type: 92 | x = [self.get_index(int(e), int(v), int(c)) for e, v, c in zip(data.element, data.valence, data.charge)] 93 | data.atom_type = torch.tensor(x) 94 | if self.add_atom_feat: 95 | # fragment / linker indicator, independent with atom types 96 | linker_flag = F.one_hot((data.fragment_mask == 0).long(), 2) 97 | all_feats = [linker_flag] 98 | # fragment anchor flag 99 | if self.known_anchor: 100 | anchor_flag = F.one_hot((data.anchor_mask == 1).long(), 2) 101 | all_feats.append(anchor_flag) 102 | data.atom_feat = torch.cat(all_feats, -1) 103 | return data 104 | 105 | 106 | class BuildCompleteGraph(BaseTransform): 107 | 108 | def __init__(self, known_linker_bond=False, known_cand_anchors=False): 109 | super().__init__() 110 | self.known_linker_bond = known_linker_bond 111 | self.known_cand_anchors = known_cand_anchors 112 | 113 | @property 114 | def num_bond_classes(self): 115 | return 5 116 | 117 | @property 118 | def bond_feature_dim(self): 119 | return 4 120 | 121 | @staticmethod 122 | def _get_interleave_edge_index(edge_index): 123 | edge_index_sym = torch.stack([edge_index[1], edge_index[0]]) 124 | e = torch.zeros_like(torch.cat([edge_index, edge_index_sym], dim=-1)) 125 | e[:, ::2] = edge_index 126 | e[:, 1::2] = edge_index_sym 127 | return e 128 | 129 | def _build_interleave_fc(self, n1_atoms, n2_atoms): 130 | eij = torch.triu_indices(n1_atoms, n2_atoms, offset=1) 131 | e = self._get_interleave_edge_index(eij) 132 | return e 133 | 134 | def __call__(self, data): 135 | # fully connected graph 136 | num_nodes = len(data.pos) 137 | fc_edge_index = self._build_interleave_fc(num_nodes, num_nodes) 138 | data.edge_index = fc_edge_index 139 | 140 | # (ll, lf, fl, ff) indicator 141 | src, dst = data.edge_index 142 | num_edges = len(fc_edge_index[0]) 143 | edge_type = torch.zeros(num_edges).long() 144 | l_ind_src = data.fragment_mask[src] == 0 145 | l_ind_dst = data.fragment_mask[dst] == 0 146 | edge_type[l_ind_src & l_ind_dst] = 0 147 | edge_type[l_ind_src & ~l_ind_dst] = 1 148 | edge_type[~l_ind_src & l_ind_dst] = 2 149 | edge_type[~l_ind_src & ~l_ind_dst] = 3 150 | edge_type = F.one_hot(edge_type, num_classes=4) 151 | data.edge_feat = edge_type 152 | 153 | # bond type 0: none 1: singe 2: double 3: triple 4: aromatic 154 | bond_type = torch.zeros(num_edges).long() 155 | 156 | id_fc_edge = fc_edge_index[0] * num_nodes + fc_edge_index[1] 157 | id_frag_bond = data.bond_index[0] * num_nodes + data.bond_index[1] 158 | idx_edge = torch.tensor([torch.nonzero(id_fc_edge == id_).squeeze() for id_ in id_frag_bond]) 159 | bond_type[idx_edge] = data.bond_type 160 | # data.edge_type = F.one_hot(bond_type, num_classes=5) 161 | data.edge_type = bond_type 162 | if self.known_linker_bond: 163 | data.linker_bond_mask = (data.fragment_mask[src] == 0) ^ (data.fragment_mask[dst] == 0) 164 | elif self.known_cand_anchors: 165 | ll_bond = (data.fragment_mask[src] == 0) & (data.fragment_mask[dst] == 0) 166 | fl_bond = (data.cand_anchors_mask[src] == 1) & (data.fragment_mask[dst] == 0) 167 | lf_bond = (data.cand_anchors_mask[dst] == 1) & (data.fragment_mask[src] == 0) 168 | data.linker_bond_mask = ll_bond | fl_bond | lf_bond 169 | else: 170 | data.linker_bond_mask = (data.fragment_mask[src] == 0) | (data.fragment_mask[dst] == 0) 171 | 172 | data.inner_edge_mask = (data.fragment_mask[src] == data.fragment_mask[dst]) 173 | return data 174 | 175 | 176 | class SelectCandAnchors(BaseTransform): 177 | 178 | def __init__(self, mode='exact', k=2): 179 | super().__init__() 180 | self.mode = mode 181 | assert mode in ['exact', 'k-hop'] 182 | self.k = k 183 | 184 | @staticmethod 185 | def bfs(nbh_list, node, k=2, valid_list=[]): 186 | visited = [node] 187 | queue = [node] 188 | level = [0] 189 | bfs_perm = [] 190 | 191 | while len(queue) > 0: 192 | m = queue.pop(0) 193 | l = level.pop(0) 194 | if l > k: 195 | break 196 | bfs_perm.append(m) 197 | 198 | for neighbour in nbh_list[m]: 199 | if neighbour not in visited and neighbour in valid_list: 200 | visited.append(neighbour) 201 | queue.append(neighbour) 202 | level.append(l + 1) 203 | return bfs_perm 204 | 205 | def __call__(self, data): 206 | # link_indices = (data.linker_mask == 1).nonzero()[:, 0].tolist() 207 | # frag_indices = (data.linker_mask == 0).nonzero()[:, 0].tolist() 208 | # anchor_indices = [j for i, j in zip(*data.bond_index.tolist()) if i in link_indices and j in frag_indices] 209 | # data.anchor_indices = anchor_indices 210 | cand_anchors_mask = torch.zeros_like(data.fragment_mask).bool() 211 | if self.mode == 'exact': 212 | cand_anchors_mask[data.anchor_indices] = True 213 | data.cand_anchors_mask = cand_anchors_mask 214 | 215 | elif self.mode == 'k-hop': 216 | # data.nbh_list = {i.item(): [j.item() for k, j in enumerate(data.bond_index[1]) 217 | # if data.bond_index[0, k].item() == i] for i in data.bond_index[0]} 218 | # all_cand = [] 219 | for anchor in data.anchor_indices: 220 | a_frag_id = data.fragment_mask[anchor] 221 | a_valid_list = (data.fragment_mask == a_frag_id).nonzero(as_tuple=True)[0].tolist() 222 | a_cand = self.bfs(data.nbh_list, anchor, k=self.k, valid_list=a_valid_list) 223 | a_cand = [a for a in a_cand if data.frag_mol.GetAtomWithIdx(a).GetTotalNumHs() > 0] 224 | cand_anchors_mask[a_cand] = True 225 | # all_cand.append(a_cand) 226 | data.cand_anchors_mask = cand_anchors_mask 227 | else: 228 | raise ValueError(self.mode) 229 | return data 230 | 231 | 232 | class StackFragLocalPos(BaseTransform): 233 | def __init__(self, max_num_atoms=30): 234 | super().__init__() 235 | self.max_num_atoms = max_num_atoms 236 | 237 | def __call__(self, data): 238 | frag_idx_mask = data.fragment_mask[data.fragment_mask > 0] 239 | f1_pos = data.frags_local_pos[frag_idx_mask == 1] 240 | f2_pos = data.frags_local_pos[frag_idx_mask == 2] 241 | assert len(f1_pos) <= self.max_num_atoms 242 | assert len(f2_pos) <= self.max_num_atoms 243 | # todo: use F.pad 244 | f1_fill_pos = torch.cat([f1_pos, torch.zeros(self.max_num_atoms - len(f1_pos), 3)], dim=0) 245 | f1_mask = torch.cat([torch.ones(len(f1_pos)), torch.zeros(self.max_num_atoms - len(f1_pos))], dim=0) 246 | f2_fill_pos = torch.cat([f2_pos, torch.zeros(self.max_num_atoms - len(f2_pos), 3)], dim=0) 247 | f2_mask = torch.cat([torch.ones(len(f2_pos)), torch.zeros(self.max_num_atoms - len(f2_pos))], dim=0) 248 | data.frags_local_pos_filled = torch.stack([f1_fill_pos, f2_fill_pos], dim=0) 249 | data.frags_local_pos_mask = torch.stack([f1_mask, f2_mask], dim=0).bool() 250 | return data 251 | 252 | 253 | class RelativeGeometry(BaseTransform): 254 | def __init__(self, mode): 255 | super().__init__() 256 | self.mode = mode 257 | 258 | def __call__(self, data): 259 | if self.mode == 'relative_pos_and_rot': 260 | # randomly take first / second fragment as the reference 261 | idx = torch.randint(0, 2, [1])[0] 262 | pos = (data.pos - data.frags_t[idx]) @ data.frags_R[idx] 263 | frags_R = data.frags_R[idx].T @ data.frags_R 264 | frags_t = (data.frags_t - data.frags_t[idx]) @ data.frags_R[idx] 265 | # frags_d doesn't change 266 | data.frags_rel_mask = torch.tensor([True, True]) 267 | data.frags_rel_mask[idx] = False # the reference fragment will not be added noise later 268 | data.frags_atom_rel_mask = data.fragment_mask == (2 - idx) 269 | 270 | elif self.mode == 'two_pos_and_rot': 271 | # still guarantee the center of two fragments' centers is the origin 272 | rand_rot = get_random_rot() 273 | pos = data.pos @ rand_rot 274 | frags_R = rand_rot.T @ data.frags_R 275 | frags_t = data.frags_t @ rand_rot 276 | data.frags_rel_mask = torch.tensor([True, True]) 277 | 278 | elif self.mode == 'distance_and_two_rot_aug': 279 | # only the first row of frags_R unchanged 280 | rand_rot = get_random_rot() 281 | tmp_pos = data.pos @ rand_rot 282 | tmp_frags_R = rand_rot.T @ data.frags_R 283 | tmp_frags_t = data.frags_t @ rand_rot 284 | 285 | rot = rotation_matrix_from_vectors(tmp_frags_t[1] - tmp_frags_t[0], torch.tensor([1., 0., 0.])) 286 | tr = -rot @ ((tmp_frags_t[0] + tmp_frags_t[1]) / 2) 287 | pos = tmp_pos @ rot.T + tr 288 | frags_R = rot @ tmp_frags_R 289 | frags_t = tmp_frags_t @ rot.T + tr 290 | data.frags_rel_mask = torch.tensor([True, True]) 291 | 292 | elif self.mode == 'distance_and_two_rot': 293 | # unchanged 294 | frags_R = data.frags_R 295 | frags_t = data.frags_t 296 | pos = data.pos 297 | data.frags_rel_mask = torch.tensor([True, True]) 298 | 299 | else: 300 | raise ValueError(self.mode) 301 | 302 | data.frags_R = frags_R 303 | data.frags_t = frags_t 304 | data.pos = pos 305 | # print('frags_R: ', data.frags_R, 'frags_t: ', frags_t) 306 | return data 307 | 308 | 309 | def get_random_rot(): 310 | M = np.random.randn(3, 3) 311 | Q, __ = np.linalg.qr(M) 312 | rand_rot = torch.from_numpy(Q.astype(np.float32)) 313 | return rand_rot 314 | 315 | 316 | class ReplaceLocalFrame(BaseTransform): 317 | def __init__(self): 318 | super().__init__() 319 | 320 | def __call__(self, data): 321 | frags_R1 = get_random_rot() 322 | frags_R2 = get_random_rot() 323 | f1_local_pos = global_to_local(frags_R1, data.frags_t[0], data.pos[data.fragment_mask == 1]) 324 | f2_local_pos = global_to_local(frags_R2, data.frags_t[1], data.pos[data.fragment_mask == 2]) 325 | data.frags_local_pos = torch.cat([f1_local_pos, f2_local_pos], dim=0) 326 | return data 327 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import py3Dmol 2 | import os 3 | import copy 4 | from rdkit import Chem 5 | from rdkit.Chem import Draw 6 | 7 | 8 | def visualize_complex(pdb_block, sdf_block, show_protein_surface=True, show_ligand=True, show_ligand_surface=True): 9 | view = py3Dmol.view() 10 | 11 | # Add protein to the canvas 12 | view.addModel(pdb_block, 'pdb') 13 | if show_protein_surface: 14 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'white'}, {'model': -1}) 15 | else: 16 | view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}}) 17 | view.setStyle({'model': -1}, {"cartoon": {"style": "edged", 'opacity': 0}}) 18 | 19 | # Add ligand to the canvas 20 | if show_ligand: 21 | view.addModel(sdf_block, 'sdf') 22 | view.setStyle({'model': -1}, {'stick': {}}) 23 | # view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}}) 24 | if show_ligand_surface: 25 | view.addSurface(py3Dmol.VDW, {'opacity': 0.8}, {'model': -1}) 26 | 27 | view.zoomTo() 28 | return view 29 | 30 | 31 | def visualize_complex_with_frags(pdb_block, all_frags, show_protein_surface=True, show_ligand=True, show_ligand_surface=True): 32 | view = py3Dmol.view() 33 | 34 | # Add protein to the canvas 35 | view.addModel(pdb_block, 'pdb') 36 | if show_protein_surface: 37 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'white'}, {'model': -1}) 38 | else: 39 | view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}}) 40 | view.setStyle({'model': -1}, {"cartoon": {"style": "edged", 'opacity': 0}}) 41 | 42 | # Add ligand to the canvas 43 | if show_ligand: 44 | for frag in all_frags: 45 | sdf_block = Chem.MolToMolBlock(frag) 46 | view.addModel(sdf_block, 'sdf') 47 | view.setStyle({'model': -1}, {'stick': {}}) 48 | # view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}}) 49 | if show_ligand_surface: 50 | view.addSurface(py3Dmol.VDW, {'opacity': 0.8}, {'model': -1}) 51 | 52 | view.zoomTo() 53 | return view 54 | 55 | def visualize_complex_highlight_pocket(pdb_block, sdf_block, 56 | pocket_atom_idx, pocket_res_idx=None, pocket_chain=None, 57 | show_ligand=True, show_ligand_surface=True): 58 | view = py3Dmol.view() 59 | 60 | # Add protein to the canvas 61 | view.addModel(pdb_block, 'pdb') 62 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'white'}, {'model': -1}) 63 | if pocket_atom_idx: 64 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'red'}, {'model': -1, 'serial': pocket_atom_idx}) 65 | if pocket_res_idx: 66 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'red'}, 67 | {'model': -1, 'chain': pocket_chain, 'resi': list(set(pocket_res_idx))}) 68 | # color_map = ['red', 'yellow', 'blue', 'green'] 69 | # for idx, pocket_atom_idx in enumerate(all_pocket_atom_idx): 70 | # print(pocket_atom_idx) 71 | # view.addSurface(py3Dmol.VDW, {'opacity':0.7, 'color':color_map[idx]}, {'model': -1, 'serial': pocket_atom_idx}) 72 | # view.addSurface(py3Dmol.VDW, {'opacity':0.7,'color':'red'}, {'model': -1, 'resi': list(set(pocket_residue))}) 73 | 74 | # view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}}) 75 | view.setStyle({'model': -1}, {"cartoon": {"style": "edged", 'opacity': 0}}) 76 | # view.setStyle({'model': -1, 'serial': atom_idx}, {'cartoon': {'color': 'red'}}) 77 | # view.setStyle({'model': -1, 'resi': [482, 484]}, {'cartoon': {'color': 'green'}}) 78 | 79 | # Add ligand to the canvas 80 | if show_ligand: 81 | view.addModel(sdf_block, 'sdf') 82 | view.setStyle({'model': -1}, {'stick': {}}) 83 | # view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}}) 84 | if show_ligand_surface: 85 | view.addSurface(py3Dmol.VDW, {'opacity': 0.8}, {'model': -1}) 86 | 87 | view.zoomTo() 88 | return view 89 | 90 | 91 | def visualize_mol_highlight_fragments(mol, match_list): 92 | all_target_atm = [] 93 | for match in match_list: 94 | target_atm = [] 95 | for atom in mol.GetAtoms(): 96 | if atom.GetIdx() in match: 97 | target_atm.append(atom.GetIdx()) 98 | all_target_atm.append(target_atm) 99 | 100 | return Draw.MolsToGridImage([mol for _ in range(len(match_list))], highlightAtomLists=all_target_atm, 101 | subImgSize=(400, 400), molsPerRow=4) 102 | 103 | 104 | def visualize_generated_xyz_v2(atom_pos, atom_type, protein_path, ligand_path=None, show_ligand=False, show_protein_surface=True): 105 | ptable = Chem.GetPeriodicTable() 106 | 107 | num_atoms = len(atom_pos) 108 | xyz = "%d\n\n" % (num_atoms,) 109 | for i in range(num_atoms): 110 | symb = ptable.GetElementSymbol(atom_type[i]) 111 | x, y, z = atom_pos[i] 112 | xyz += "%s %.8f %.8f %.8f\n" % (symb, x, y, z) 113 | 114 | # print(xyz) 115 | 116 | with open(protein_path, 'r') as f: 117 | pdb_block = f.read() 118 | 119 | view = py3Dmol.view() 120 | # Generated molecule 121 | view.addModel(xyz, 'xyz') 122 | view.setStyle({'model': -1}, {'sphere': {'radius': 0.3}, 'stick': {}}) 123 | 124 | # Pocket 125 | view.addModel(pdb_block, 'pdb') 126 | if show_protein_surface: 127 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'white'}, {'model': -1}) 128 | else: 129 | view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}}) 130 | view.setStyle({'model': -1}, {"cartoon": {"style": "edged", 'opacity': 0}}) 131 | 132 | # Focus on the generated 133 | view.zoomTo() 134 | 135 | # Ligand 136 | if show_ligand: 137 | with open(ligand_path, 'r') as f: 138 | sdf_block = f.read() 139 | view.addModel(sdf_block, 'sdf') 140 | view.setStyle({'model': -1}, {'stick': {}}) 141 | 142 | return view 143 | 144 | 145 | def visualize_generated_xyz(data, root, show_ligand=False): 146 | ptable = Chem.GetPeriodicTable() 147 | 148 | num_atoms = data.ligand_context_element.size(0) 149 | xyz = "%d\n\n" % (num_atoms,) 150 | for i in range(num_atoms): 151 | symb = ptable.GetElementSymbol(data.ligand_context_element[i].item()) 152 | x, y, z = data.ligand_context_pos[i].clone().cpu().tolist() 153 | xyz += "%s %.8f %.8f %.8f\n" % (symb, x, y, z) 154 | 155 | # print(xyz) 156 | 157 | protein_path = os.path.join(root, data.protein_filename) 158 | ligand_path = os.path.join(root, data.ligand_filename) 159 | 160 | with open(protein_path, 'r') as f: 161 | pdb_block = f.read() 162 | with open(ligand_path, 'r') as f: 163 | sdf_block = f.read() 164 | 165 | view = py3Dmol.view() 166 | # Generated molecule 167 | view.addModel(xyz, 'xyz') 168 | view.setStyle({'model': -1}, {'sphere': {'radius': 0.3}, 'stick': {}}) 169 | # Focus on the generated 170 | view.zoomTo() 171 | 172 | # Pocket 173 | view.addModel(pdb_block, 'pdb') 174 | view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}}) 175 | # Ligand 176 | if show_ligand: 177 | view.addModel(sdf_block, 'sdf') 178 | view.setStyle({'model': -1}, {'stick': {}}) 179 | 180 | return view 181 | 182 | 183 | def visualize_generated_sdf(data, protein_path, ligand_path, show_ligand=False, show_protein_surface=True): 184 | # protein_path = os.path.join(root, data.protein_filename) 185 | # ligand_path = os.path.join(root, data.ligand_filename) 186 | 187 | with open(protein_path, 'r') as f: 188 | pdb_block = f.read() 189 | 190 | view = py3Dmol.view() 191 | # Generated molecule 192 | mol_block = Chem.MolToMolBlock(data.rdmol) 193 | view.addModel(mol_block, 'sdf') 194 | view.setStyle({'model': -1}, {'sphere': {'radius': 0.3}, 'stick': {}}) 195 | # Focus on the generated 196 | # view.zoomTo() 197 | 198 | # Pocket 199 | view.addModel(pdb_block, 'pdb') 200 | if show_protein_surface: 201 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'white'}, {'model': -1}) 202 | view.setStyle({'model': -1}, {"cartoon": {"style": "edged", 'opacity': 0}}) 203 | else: 204 | view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}}) 205 | # Ligand 206 | if show_ligand: 207 | with open(ligand_path, 'r') as f: 208 | sdf_block = f.read() 209 | view.addModel(sdf_block, 'sdf') 210 | view.setStyle({'model': -1}, {'stick': {}}) 211 | view.zoomTo() 212 | return view 213 | 214 | 215 | def visualize_generated_arms(data_list, protein_path, ligand_path, show_ligand=False, show_protein_surface=True): 216 | # protein_path = os.path.join(root, data.protein_filename) 217 | # ligand_path = os.path.join(root, data.ligand_filename) 218 | 219 | with open(protein_path, 'r') as f: 220 | pdb_block = f.read() 221 | 222 | view = py3Dmol.view() 223 | # Generated molecule 224 | for data in data_list: 225 | mol_block = Chem.MolToMolBlock(data.rdmol) 226 | view.addModel(mol_block, 'sdf') 227 | view.setStyle({'model': -1}, {'sphere': {'radius': 0.3}, 'stick': {}}) 228 | # Focus on the generated 229 | # view.zoomTo() 230 | 231 | # Pocket 232 | view.addModel(pdb_block, 'pdb') 233 | if show_protein_surface: 234 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'white'}, {'model': -1}) 235 | view.setStyle({'model': -1}, {"cartoon": {"style": "edged", 'opacity': 0}}) 236 | else: 237 | view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}}) 238 | # Ligand 239 | if show_ligand: 240 | with open(ligand_path, 'r') as f: 241 | sdf_block = f.read() 242 | view.addModel(sdf_block, 'sdf') 243 | view.setStyle({'model': -1}, {'stick': {}}) 244 | view.zoomTo() 245 | return view 246 | 247 | 248 | def visualize_ligand(mol, size=(300, 300), style="stick", surface=False, opacity=0.5, viewer=None): 249 | """Draw molecule in 3D 250 | 251 | Args: 252 | ---- 253 | mol: rdMol, molecule to show 254 | size: tuple(int, int), canvas size 255 | style: str, type of drawing molecule 256 | style can be 'line', 'stick', 'sphere', 'carton' 257 | surface, bool, display SAS 258 | opacity, float, opacity of surface, range 0.0-1.0 259 | Return: 260 | ---- 261 | viewer: py3Dmol.view, a class for constructing embedded 3Dmol.js views in ipython notebooks. 262 | """ 263 | assert style in ('line', 'stick', 'sphere', 'carton') 264 | if viewer is None: 265 | viewer = py3Dmol.view(width=size[0], height=size[1]) 266 | if isinstance(mol, list): 267 | for i, m in enumerate(mol): 268 | mblock = Chem.MolToMolBlock(m) 269 | viewer.addModel(mblock, 'mol' + str(i)) 270 | elif len(mol.GetConformers()) > 1: 271 | for i in range(len(mol.GetConformers())): 272 | mblock = Chem.MolToMolBlock(mol, confId=i) 273 | viewer.addModel(mblock, 'mol' + str(i)) 274 | else: 275 | mblock = Chem.MolToMolBlock(mol) 276 | viewer.addModel(mblock, 'mol') 277 | viewer.setStyle({style: {}}) 278 | if surface: 279 | viewer.addSurface(py3Dmol.SAS, {'opacity': opacity}) 280 | viewer.zoomTo() 281 | return viewer 282 | 283 | 284 | def visualize_full_mol(frags_mol, linker_pos, linker_type): 285 | ptable = Chem.GetPeriodicTable() 286 | 287 | num_atoms = len(linker_pos) 288 | xyz = "%d\n\n" % (num_atoms,) 289 | for i in range(num_atoms): 290 | symb = ptable.GetElementSymbol(linker_type[i]) 291 | x, y, z = linker_pos[i] 292 | xyz += "%s %.8f %.8f %.8f\n" % (symb, x, y, z) 293 | 294 | view = py3Dmol.view() 295 | # Generated molecule 296 | view.addModel(xyz, 'xyz') 297 | view.setStyle({'model': -1}, {'sphere': {'radius': 0.3}, 'stick': {}}) 298 | 299 | mblock = Chem.MolToMolBlock(frags_mol) 300 | view.addModel(mblock, 'sdf') 301 | view.setStyle({'model': -1}, {'stick': {}}) 302 | view.zoomTo() 303 | return view 304 | 305 | 306 | def mol_with_atom_index(mol): 307 | mol = copy.deepcopy(mol) 308 | mol.RemoveAllConformers() 309 | for atom in mol.GetAtoms(): 310 | atom.SetAtomMapNum(atom.GetIdx()) 311 | return mol 312 | --------------------------------------------------------------------------------