├── .gitignore
├── LICENSE
├── README.md
├── assets
└── overview.png
├── configs
├── sampling
│ ├── protac
│ │ ├── protac_all_gui.yml
│ │ ├── protac_anchor.yml
│ │ ├── protac_dist.yml
│ │ └── protac_no_gui.yml
│ └── zinc.yml
└── training
│ ├── nn_variants
│ ├── pos_eps_rot_eps.yml
│ ├── pos_eps_rot_euler.yml
│ └── pos_newton_rot_eps.yml
│ ├── warhead_protac.yml
│ └── zinc.yml
├── data
├── .gitignore
└── protac
│ ├── 3d_index.pkl
│ ├── e3_ligand.csv
│ ├── index.pkl
│ ├── linker.csv
│ ├── protac.csv
│ ├── smi_protac.txt
│ └── warhead.csv
├── datasets
├── __init__.py
├── linker_data.py
└── linker_dataset.py
├── models
├── common.py
├── diff_protac_bond.py
├── encoders
│ ├── __init__.py
│ └── node_edge_net.py
├── eps_net.py
└── transition.py
├── playground
└── check_data.ipynb
├── scripts
├── baselines
│ ├── eval_3dlinker.py
│ ├── eval_delinker.py
│ └── eval_difflinker.py
├── eval_protac.py
├── prepare_data.py
├── sample_protac.py
└── train_protac.py
└── utils
├── calc_SC_RDKit.py
├── const.py
├── data.py
├── eval_bond.py
├── evaluation.py
├── fpscores.pkl.gz
├── frag_utils.py
├── geometry.py
├── guidance_funcs.py
├── misc.py
├── prior_num_atoms.py
├── reconstruct_linker.py
├── sascorer.py
├── so3.py
├── train.py
├── train_linker_smiles.pkl
├── transforms.py
├── visualize.py
└── wehi_pains.csv
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # IDE
132 | .idea/
133 |
134 | # OS FILE
135 | .DS_Store
136 | test_*
137 | ckpts
138 | output*
139 | logs/
140 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Jiaqi Guan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LinkerNet: Fragment Poses and Linker Co-Design with 3D Equivariant Diffusion
2 |
3 | [](https://github.com/guanjq/targetdiff/blob/main/LICIENCE)
4 |
5 |
6 | This repository is the official implementation of LinkerNet: Fragment Poses and Linker Co-Design with 3D Equivariant Diffusion (NeurIPS 2023). [[PDF]](https://openreview.net/forum?id=6EaLIw3W7c)
7 |
8 |
9 |
10 |
11 |
12 | ## Installation
13 |
14 | ### Dependency
15 |
16 | The code has been tested in the following environment:
17 |
18 |
19 | | Package | Version |
20 | |-------------------|-----------|
21 | | Python | 3.8 |
22 | | PyTorch | 1.13.1 |
23 | | CUDA | 11.6 |
24 | | PyTorch Geometric | 2.2.0 |
25 | | RDKit | 2022.03.2 |
26 |
27 | ### Install via Conda and Pip
28 | ```bash
29 | conda create -n targetdiff python=3.8
30 | conda activate targetdiff
31 | conda install pytorch pytorch-cuda=11.6 -c pytorch -c nvidia
32 | conda install pyg -c pyg
33 | conda install rdkit openbabel tensorboard pyyaml easydict python-lmdb -c conda-forge
34 | ```
35 |
36 | ---
37 | ## Data Preprocess
38 |
39 | ### PROTAC-DB
40 |
41 | We have provided all data files related to PROTAC-DB dataset in this repo.
42 | * The raw data (.csv files in the data/protac folder) are downloaded from [PROTAC-DB](http://cadd.zju.edu.cn/protacdb/).
43 | * The index.pkl file is obtained in playground/check_data.ipynb
44 | * The 3d_index.pkl file containing the conformation generated by RDKit, which is obtained by running the following command:
45 |
46 | ```bash
47 | python scripts/prepare_data.py --raw_path data/protac/index.pkl --dest data/protac/3d_index.pkl
48 | ```
49 |
50 | Note that RDKit version may influence the PROTAC-DB dataset processing and splitting. We provided the processed data and split file [here](https://drive.google.com/drive/folders/1Nt37DO1PYwPNM0_uF2Zzz4QxWD3v8pF6?usp=drive_link)
51 |
52 | ### ZINC
53 |
54 | The raw ZINC data are same as [DiffLinker](https://zenodo.org/records/7121271).
55 | We preprocess ZINC data to output an index file by running:
56 | ```bash
57 | python scripts/prepare_data.py \
58 | --raw_path data/zinc_difflinker \
59 | --dest data/zinc_difflinker/index_full.pkl \
60 | --dataset zinc_difflinker --mode full
61 | ```
62 | We also provided the preprocessed index file [here](https://drive.google.com/drive/folders/1C1srELCCNJLk8v1smjvmbE-xYvnog5jU?usp=sharing).
63 |
64 | ---
65 | ## Training
66 | python scripts/train_protac.py configs/training/zinc.yml
67 |
68 | We have provided the [pretrained checkpoints](https://drive.google.com/drive/folders/1C1srELCCNJLk8v1smjvmbE-xYvnog5jU?usp=sharing) on ZINC / PROTAC.
69 |
70 | ## Sampling
71 | python scripts/sample_protac.py configs/sampling/zinc.yml --subset test --start_id 0 --end_id -1 --num_samples 250 --outdir outputs/zinc
72 |
73 | We have also provided the sampling results in the same link.
74 |
75 |
76 | ## Evaluation
77 | python scripts/eval_protac.py {SAMPLE_RESULT_DIR}
78 |
79 |
80 | ## Citation
81 | ```
82 | @inproceedings{guan2023linkernet,
83 | title={LinkerNet: Fragment Poses and Linker Co-Design with 3D Equivariant Diffusion},
84 | author={Guan, Jiaqi and Peng, Xingang and Jiang, PeiQi and Luo, Yunan and Peng, Jian and Ma, Jianzhu},
85 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
86 | year={2023}
87 | }
88 | ```
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guanjq/LinkerNet/280759c16ccecece0d81ab9cebe3f44041b80e51/assets/overview.png
--------------------------------------------------------------------------------
/configs/sampling/protac/protac_all_gui.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | name: protac
3 | path: ./data/protac
4 | version: v1
5 | split_mode: warhead # [warhead, ligase, random]
6 | index_name: 3d_index.pkl
7 | max_num_atoms: 70
8 |
9 | model:
10 | checkpoint: ckpts/protac_model.pt
11 |
12 | sample:
13 | seed: 2022
14 | num_samples: 100
15 | num_atoms: prior # [ref, prior]
16 | cand_bond_mask: True
17 | guidance_opt:
18 | - type: anchor_prox
19 | update: frag_rot
20 | min_d: 1.2
21 | max_d: 1.9
22 | decay: False
23 | - type: frag_distance
24 | mode: frag_center_distance
25 | constraint_mode: dynamic # [dynamic, const]
26 | sigma: 0.2
27 | min_d:
28 | max_d:
29 |
--------------------------------------------------------------------------------
/configs/sampling/protac/protac_anchor.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | name: protac
3 | path: ./data/protac
4 | version: v1
5 | split_mode: warhead # [warhead, ligase, random]
6 | index_name: 3d_index.pkl
7 | max_num_atoms: 70
8 |
9 | model:
10 | checkpoint: ckpts/protac_model.pt
11 |
12 | sample:
13 | seed: 2022
14 | num_samples: 100
15 | num_atoms: prior # [ref, prior]
16 | cand_bond_mask: True
17 | guidance_opt:
18 | - type: anchor_prox
19 | update: frag_rot
20 | min_d: 1.2
21 | max_d: 1.9
22 | decay: False
23 | # - type: frag_distance
24 | # mode: frag_center_distance
25 | # constraint_mode: dynamic # [dynamic, const]
26 | # sigma: 0.2
27 | # min_d:
28 | # max_d:
29 |
--------------------------------------------------------------------------------
/configs/sampling/protac/protac_dist.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | name: protac
3 | path: ./data/protac
4 | version: v1
5 | split_mode: warhead # [warhead, ligase, random]
6 | index_name: 3d_index.pkl
7 | max_num_atoms: 70
8 |
9 | model:
10 | checkpoint: ckpts/protac_model.pt
11 |
12 | sample:
13 | seed: 2022
14 | num_samples: 100
15 | num_atoms: prior # [ref, prior]
16 | cand_bond_mask: True
17 | guidance_opt:
18 | # - type: anchor_prox
19 | # update: frag_rot
20 | # min_d: 1.2
21 | # max_d: 1.9
22 | # decay: False
23 | - type: frag_distance
24 | mode: frag_center_distance
25 | constraint_mode: dynamic # [dynamic, const]
26 | sigma: 0.2
27 | min_d:
28 | max_d:
29 |
--------------------------------------------------------------------------------
/configs/sampling/protac/protac_no_gui.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | name: protac
3 | path: ./data/protac
4 | version: v1
5 | split_mode: warhead # [warhead, ligase, random]
6 | index_name: 3d_index.pkl
7 | max_num_atoms: 70
8 |
9 | model:
10 | checkpoint: ckpts/protac_model.pt
11 |
12 | sample:
13 | seed: 2022
14 | num_samples: 100
15 | num_atoms: prior # [ref, prior]
16 | cand_bond_mask: True
17 | guidance_opt:
18 | # - type: anchor_prox
19 | # update: frag_rot
20 | # min_d: 1.2
21 | # max_d: 1.9
22 | # decay: False
23 | # - type: frag_distance
24 | # mode: frag_center_distance
25 | # constraint_mode: dynamic # [dynamic, const]
26 | # sigma: 0.2
27 | # min_d:
28 | # max_d:
29 |
--------------------------------------------------------------------------------
/configs/sampling/zinc.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | name: zinc
3 | path: ./data/zinc_difflinker
4 | version: full
5 | index_name: index_full.pkl
6 | max_num_atoms: 30
7 |
8 | model:
9 | checkpoint: ckpts/zinc_model.pt
10 |
11 | sample:
12 | seed: 2022
13 | num_samples: 100
14 | num_atoms: ref
15 | guidance_opt:
16 | # num_steps: 100
17 | # energy_drift:
18 | # - type: frag_link_prox
19 | # min_d: 1.2
20 | # max_d: 1.9
21 |
--------------------------------------------------------------------------------
/configs/training/nn_variants/pos_eps_rot_eps.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | name: zinc
3 | path: ./data/zinc_difflinker
4 | version: full # [tiny, full]
5 | index_name: index_full.pkl
6 |
7 | model:
8 | num_steps: 500
9 | node_emb_dim: 256
10 | edge_emb_dim: 64
11 | time_emb_type: plain
12 | time_emb_dim: 1
13 | time_emb_scale: 1000
14 |
15 | train_frag_rot: True
16 | train_frag_pos: True
17 | train_link: True
18 | train_bond: True
19 |
20 | frag_pos_prior:
21 | known_anchor: False
22 | known_linker_bond: False
23 | rel_geometry: two_pos_and_rot
24 |
25 | diffusion:
26 | trans_rot_opt:
27 | sche_type: cosine
28 | s: 0.01
29 | trans_pos_opt:
30 | sche_type: cosine
31 | s: 0.01
32 | trans_link_pos_opt:
33 | sche_type: cosine
34 | s: 0.01
35 | trans_link_cls_opt:
36 | sche_type: cosine
37 | s: 0.01
38 | trans_link_bond_opt:
39 | sche_type: cosine
40 | s: 0.01
41 |
42 | eps_net:
43 | net_type: node_edge_net
44 | encoder:
45 | num_blocks: 6
46 | cutoff: 15.
47 | use_gate: True
48 | num_gaussians: 20
49 | expansion_mode: exp
50 | tr_output_type: invariant_eps
51 | rot_output_type: invariant_eps
52 | output_n_heads: 8
53 | separate_att: True
54 | sym_force: False
55 |
56 | train:
57 | seed: 2023
58 | loss_weights:
59 | frag_rot: 1.0
60 | frag_pos: 1.0
61 | link_pos: 1.0
62 | link_cls: 100.0
63 | link_bond: 100.0
64 |
65 | batch_size: 64
66 | num_workers: 8
67 | n_acc_batch: 1
68 | max_iters: 500000
69 | val_freq: 2000
70 | pos_noise_std: 0.05
71 | max_grad_norm: 50.0
72 | optimizer:
73 | type: adamw
74 | lr: 5.e-4
75 | weight_decay: 1.e-8
76 | beta1: 0.99
77 | beta2: 0.999
78 | scheduler:
79 | type: plateau
80 | factor: 0.6
81 | patience: 10
82 | min_lr: 1.e-6
--------------------------------------------------------------------------------
/configs/training/nn_variants/pos_eps_rot_euler.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | name: zinc
3 | path: ./data/zinc_difflinker
4 | version: full # [tiny, full]
5 | index_name: index_full.pkl
6 |
7 | model:
8 | num_steps: 500
9 | node_emb_dim: 256
10 | edge_emb_dim: 64
11 | time_emb_type: plain
12 | time_emb_dim: 1
13 | time_emb_scale: 1000
14 |
15 | train_frag_rot: True
16 | train_frag_pos: True
17 | train_link: True
18 | train_bond: True
19 |
20 | frag_pos_prior:
21 | known_anchor: False
22 | known_linker_bond: False
23 | rel_geometry: two_pos_and_rot
24 |
25 | diffusion:
26 | trans_rot_opt:
27 | sche_type: cosine
28 | s: 0.01
29 | trans_pos_opt:
30 | sche_type: cosine
31 | s: 0.01
32 | trans_link_pos_opt:
33 | sche_type: cosine
34 | s: 0.01
35 | trans_link_cls_opt:
36 | sche_type: cosine
37 | s: 0.01
38 | trans_link_bond_opt:
39 | sche_type: cosine
40 | s: 0.01
41 |
42 | eps_net:
43 | net_type: node_edge_net
44 | encoder:
45 | num_blocks: 6
46 | cutoff: 15.
47 | use_gate: True
48 | num_gaussians: 20
49 | expansion_mode: exp
50 | tr_output_type: invariant_eps
51 | rot_output_type: euler_equation
52 | output_n_heads: 8
53 | separate_att: True
54 | sym_force: False
55 |
56 | train:
57 | seed: 2023
58 | loss_weights:
59 | frag_rot: 1.0
60 | frag_pos: 1.0
61 | link_pos: 1.0
62 | link_cls: 100.0
63 | link_bond: 100.0
64 |
65 | batch_size: 64
66 | num_workers: 8
67 | n_acc_batch: 1
68 | max_iters: 500000
69 | val_freq: 2000
70 | pos_noise_std: 0.05
71 | max_grad_norm: 50.0
72 | optimizer:
73 | type: adamw
74 | lr: 5.e-4
75 | weight_decay: 1.e-8
76 | beta1: 0.99
77 | beta2: 0.999
78 | scheduler:
79 | type: plateau
80 | factor: 0.6
81 | patience: 10
82 | min_lr: 1.e-6
--------------------------------------------------------------------------------
/configs/training/nn_variants/pos_newton_rot_eps.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | name: zinc
3 | path: ./data/zinc_difflinker
4 | version: full # [tiny, full]
5 | index_name: index_full.pkl
6 |
7 | model:
8 | num_steps: 500
9 | node_emb_dim: 256
10 | edge_emb_dim: 64
11 | time_emb_type: plain
12 | time_emb_dim: 1
13 | time_emb_scale: 1000
14 |
15 | train_frag_rot: True
16 | train_frag_pos: True
17 | train_link: True
18 | train_bond: True
19 |
20 | frag_pos_prior:
21 | known_anchor: False
22 | known_linker_bond: False
23 | rel_geometry: two_pos_and_rot
24 |
25 | diffusion:
26 | trans_rot_opt:
27 | sche_type: cosine
28 | s: 0.01
29 | trans_pos_opt:
30 | sche_type: cosine
31 | s: 0.01
32 | trans_link_pos_opt:
33 | sche_type: cosine
34 | s: 0.01
35 | trans_link_cls_opt:
36 | sche_type: cosine
37 | s: 0.01
38 | trans_link_bond_opt:
39 | sche_type: cosine
40 | s: 0.01
41 |
42 | eps_net:
43 | net_type: node_edge_net
44 | encoder:
45 | num_blocks: 6
46 | cutoff: 15.
47 | use_gate: True
48 | num_gaussians: 20
49 | expansion_mode: exp
50 | tr_output_type: newton_equation
51 | rot_output_type: invariant_eps
52 | output_n_heads: 8
53 | separate_att: True
54 | sym_force: False
55 |
56 | train:
57 | seed: 2023
58 | loss_weights:
59 | frag_rot: 1.0
60 | frag_pos: 1.0
61 | link_pos: 1.0
62 | link_cls: 100.0
63 | link_bond: 100.0
64 |
65 | batch_size: 64
66 | num_workers: 8
67 | n_acc_batch: 1
68 | max_iters: 500000
69 | val_freq: 2000
70 | pos_noise_std: 0.05
71 | max_grad_norm: 50.0
72 | optimizer:
73 | type: adamw
74 | lr: 5.e-4
75 | weight_decay: 1.e-8
76 | beta1: 0.99
77 | beta2: 0.999
78 | scheduler:
79 | type: plateau
80 | factor: 0.6
81 | patience: 10
82 | min_lr: 1.e-6
--------------------------------------------------------------------------------
/configs/training/warhead_protac.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | name: protac
3 | path: ./data/protac
4 | version: v1
5 | split_mode: warhead # [warhead, ligase, random]
6 | index_name: 3d_index.pkl
7 | max_num_atoms: 70
8 |
9 | model:
10 | num_steps: 500
11 | node_emb_dim: 256
12 | edge_emb_dim: 64
13 | time_emb_type: plain
14 | time_emb_dim: 1
15 | time_emb_scale: 1000
16 |
17 | train_frag_rot: True
18 | train_frag_pos: True
19 | train_link: True
20 | train_bond: True
21 |
22 | frag_pos_prior:
23 | known_anchor: False
24 | known_linker_bond: False
25 | rel_geometry: two_pos_and_rot
26 |
27 | diffusion:
28 | trans_rot_opt:
29 | sche_type: cosine
30 | s: 0.01
31 | trans_pos_opt:
32 | sche_type: cosine
33 | s: 0.01
34 | trans_link_pos_opt:
35 | sche_type: cosine
36 | s: 0.01
37 | trans_link_cls_opt:
38 | sche_type: cosine
39 | s: 0.01
40 | trans_link_bond_opt:
41 | sche_type: cosine
42 | s: 0.01
43 |
44 | eps_net:
45 | net_type: node_edge_net
46 | encoder:
47 | num_blocks: 6
48 | cutoff: 15.
49 | use_gate: True
50 | num_gaussians: 20
51 | expansion_mode: exp
52 | tr_output_type: newton_equation
53 | rot_output_type: euler_equation
54 | output_n_heads: 8
55 |
56 | train:
57 | ckpt_path: ckpts/zinc_model.pt
58 | seed: 2023
59 | loss_weights:
60 | frag_rot: 1.0
61 | frag_pos: 1.0
62 | link_pos: 1.0
63 | link_cls: 100.0
64 | link_bond: 100.0
65 |
66 | batch_size: 2
67 | num_workers: 8
68 | n_acc_batch: 1
69 | max_iters: 500000
70 | val_freq: 2000
71 | pos_noise_std: 0.05
72 | max_grad_norm: 50.0
73 | optimizer:
74 | type: adamw
75 | lr: 1.e-4
76 | weight_decay: 1.e-8
77 | beta1: 0.99
78 | beta2: 0.999
79 | scheduler:
80 | type: plateau
81 | factor: 0.6
82 | patience: 10
83 | min_lr: 1.e-6
--------------------------------------------------------------------------------
/configs/training/zinc.yml:
--------------------------------------------------------------------------------
1 | dataset:
2 | name: zinc
3 | path: ./data/zinc_difflinker
4 | version: full # [tiny, full]
5 | index_name: index_full.pkl
6 |
7 | model:
8 | num_steps: 500
9 | node_emb_dim: 256
10 | edge_emb_dim: 64
11 | time_emb_type: plain
12 | time_emb_dim: 1
13 | time_emb_scale: 1000
14 |
15 | train_frag_rot: True
16 | train_frag_pos: True
17 | train_link: True
18 | train_bond: True
19 |
20 | frag_pos_prior:
21 | known_anchor: False
22 | known_linker_bond: False
23 | rel_geometry: two_pos_and_rot
24 |
25 | diffusion:
26 | trans_rot_opt:
27 | sche_type: cosine
28 | s: 0.01
29 | trans_pos_opt:
30 | sche_type: cosine
31 | s: 0.01
32 | trans_link_pos_opt:
33 | sche_type: cosine
34 | s: 0.01
35 | trans_link_cls_opt:
36 | sche_type: cosine
37 | s: 0.01
38 | trans_link_bond_opt:
39 | sche_type: cosine
40 | s: 0.01
41 |
42 | eps_net:
43 | net_type: node_edge_net
44 | encoder:
45 | num_blocks: 6
46 | cutoff: 15.
47 | use_gate: True
48 | num_gaussians: 20
49 | expansion_mode: exp
50 | tr_output_type: newton_equation_outer
51 | rot_output_type: euler_equation_outer
52 | output_n_heads: 8
53 | separate_att: True
54 | sym_force: False
55 |
56 | train:
57 | seed: 2023
58 | loss_weights:
59 | frag_rot: 1.0
60 | frag_pos: 1.0
61 | link_pos: 1.0
62 | link_cls: 100.0
63 | link_bond: 100.0
64 |
65 | batch_size: 64
66 | num_workers: 8
67 | n_acc_batch: 1
68 | max_iters: 500000
69 | val_freq: 2000
70 | pos_noise_std: 0.05
71 | max_grad_norm: 50.0
72 | optimizer:
73 | type: adamw
74 | lr: 5.e-4
75 | weight_decay: 1.e-8
76 | beta1: 0.99
77 | beta2: 0.999
78 | scheduler:
79 | type: plateau
80 | factor: 0.6
81 | patience: 10
82 | min_lr: 1.e-6
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | #zinc_3dlinker/*.json
2 | #zinc_3dlinker/smi_train.txt
3 | #protac/protac.sdf
4 | zinc_difflinker/index_full.pkl
5 | !.gitignore
--------------------------------------------------------------------------------
/data/protac/3d_index.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guanjq/LinkerNet/280759c16ccecece0d81ab9cebe3f44041b80e51/data/protac/3d_index.pkl
--------------------------------------------------------------------------------
/data/protac/index.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guanjq/LinkerNet/280759c16ccecece0d81ab9cebe3f44041b80e51/data/protac/index.pkl
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guanjq/LinkerNet/280759c16ccecece0d81ab9cebe3f44041b80e51/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/linker_data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch_scatter
3 | import numpy as np
4 | from torch_geometric.data import Data, Batch
5 | from torch_geometric.loader import DataLoader
6 |
7 | FOLLOW_BATCH = ()
8 |
9 |
10 | class FragLinkerData(Data):
11 |
12 | def __init__(self, *args, **kwargs):
13 | super().__init__(*args, **kwargs)
14 |
15 | # @staticmethod
16 | # def from_protein_ligand_dicts(protein_dict=None, ligand_dict=None, **kwargs):
17 | # instance = FragLinkerData(**kwargs)
18 | #
19 | # if protein_dict is not None:
20 | # for key, item in protein_dict.items():
21 | # instance['protein_' + key] = item
22 | #
23 | # if ligand_dict is not None:
24 | # for key, item in ligand_dict.items():
25 | # instance['ligand_' + key] = item
26 | #
27 | # instance['ligand_nbh_list'] = {i.item(): [j.item() for k, j in enumerate(instance.ligand_bond_index[1])
28 | # if instance.ligand_bond_index[0, k].item() == i]
29 | # for i in instance.ligand_bond_index[0]}
30 | # return instance
31 |
32 | # def __inc__(self, key, value, *args, **kwargs):
33 | # if key == 'ligand_bond_index':
34 | # return self['ligand_element'].size(0)
35 | # # elif key == 'ligand_context_bond_index':
36 | # # return self['ligand_context_element'].size(0)
37 | # else:
38 | # return super().__inc__(key, value)
39 |
40 |
41 | def batch_from_data_list(data_list):
42 | return Batch.from_data_list(data_list, follow_batch=FOLLOW_BATCH)
43 |
44 |
45 | def torchify_dict(data):
46 | output = {}
47 | for k, v in data.items():
48 | if isinstance(v, np.ndarray):
49 | output[k] = torch.from_numpy(v)
50 | else:
51 | output[k] = v
52 | return output
53 |
--------------------------------------------------------------------------------
/models/common.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import time
6 | from torch_geometric.nn import radius_graph, knn_graph
7 | import math
8 |
9 |
10 | class GaussianSmearing(nn.Module):
11 | # used to embed the edge distances
12 | def __init__(self, start=0.0, stop=5.0, num_gaussians=50, expansion_mode='linear'):
13 | super().__init__()
14 | self.start, self.stop, self.num_gaussians = start, stop, num_gaussians
15 | self.expansion_mode = expansion_mode
16 | if expansion_mode == 'exp':
17 | offset = torch.exp(torch.linspace(start=np.log(start+1), end=np.log(stop+1), steps=num_gaussians)) - 1
18 | diff = torch.diff(offset)
19 | diff = torch.cat([diff[:1], diff])
20 | coeff = -0.5 / (diff ** 2)
21 | self.register_buffer('coeff', coeff)
22 | elif expansion_mode == 'linear':
23 | offset = torch.linspace(start=start, end=stop, steps=num_gaussians)
24 | self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
25 | else:
26 | raise NotImplementedError('type_ must be either exp or linear')
27 |
28 | self.register_buffer('offset', offset)
29 |
30 | def __repr__(self):
31 | return f'GaussianSmearing(start={self.start}, stop={self.stop}, ' \
32 | f'num_gaussians={self.num_gaussians}, expansion_mode={self.expansion_mode})'
33 |
34 | def forward(self, dist):
35 | dist = dist.view(-1, 1) - self.offset.view(1, -1)
36 | return torch.exp(self.coeff * torch.pow(dist, 2))
37 |
38 |
39 | class AngleExpansion(nn.Module):
40 | def __init__(self, start=1.0, stop=5.0, half_expansion=10):
41 | super(AngleExpansion, self).__init__()
42 | l_mul = 1. / torch.linspace(stop, start, half_expansion)
43 | r_mul = torch.linspace(start, stop, half_expansion)
44 | coeff = torch.cat([l_mul, r_mul], dim=-1)
45 | self.register_buffer('coeff', coeff)
46 |
47 | def forward(self, angle):
48 | return torch.cos(angle.view(-1, 1) * self.coeff.view(1, -1))
49 |
50 |
51 | class Swish(nn.Module):
52 | def __init__(self):
53 | super(Swish, self).__init__()
54 | self.beta = nn.Parameter(torch.tensor(1.0))
55 |
56 | def forward(self, x):
57 | return x * torch.sigmoid(self.beta * x)
58 |
59 |
60 | class ShiftedSoftplus(nn.Module):
61 | def __init__(self):
62 | super().__init__()
63 | self.shift = torch.log(torch.tensor(2.0)).item()
64 |
65 | def forward(self, x):
66 | return F.softplus(x) - self.shift
67 |
68 |
69 | NONLINEARITIES = {
70 | "tanh": nn.Tanh(),
71 | "relu": nn.ReLU(),
72 | "softplus": nn.Softplus(),
73 | "elu": nn.ELU(),
74 | "swish": Swish(),
75 | 'silu': nn.SiLU()
76 | }
77 |
78 |
79 | class MLP(nn.Module):
80 | """MLP with the same hidden dim across all layers."""
81 |
82 | def __init__(self, in_dim, out_dim, hidden_dim, num_layer=2, norm=True, act_fn='relu', act_last=False):
83 | super().__init__()
84 | layers = []
85 | for layer_idx in range(num_layer):
86 | if layer_idx == 0:
87 | layers.append(nn.Linear(in_dim, hidden_dim))
88 | elif layer_idx == num_layer - 1:
89 | layers.append(nn.Linear(hidden_dim, out_dim))
90 | else:
91 | layers.append(nn.Linear(hidden_dim, hidden_dim))
92 | if layer_idx < num_layer - 1 or act_last:
93 | if norm:
94 | layers.append(nn.LayerNorm(hidden_dim))
95 | layers.append(NONLINEARITIES[act_fn])
96 | self.net = nn.Sequential(*layers)
97 |
98 | def forward(self, x):
99 | return self.net(x)
100 |
101 |
102 | def outer_product(*vectors):
103 | for index, vector in enumerate(vectors):
104 | if index == 0:
105 | out = vector.unsqueeze(-1)
106 | else:
107 | out = out * vector.unsqueeze(1)
108 | out = out.view(out.shape[0], -1).unsqueeze(-1)
109 | return out.squeeze()
110 |
111 |
112 | def get_h_dist(dist_metric, hi, hj):
113 | if dist_metric == 'euclidean':
114 | h_dist = torch.sum((hi - hj) ** 2, -1, keepdim=True)
115 | return h_dist
116 | elif dist_metric == 'cos_sim':
117 | hi_norm = torch.norm(hi, p=2, dim=-1, keepdim=True)
118 | hj_norm = torch.norm(hj, p=2, dim=-1, keepdim=True)
119 | h_dist = torch.sum(hi * hj, -1, keepdim=True) / (hi_norm * hj_norm)
120 | return h_dist, hj_norm
121 |
122 |
123 | def compose_context(h_protein, h_ligand, pos_protein, pos_ligand, batch_protein, batch_ligand):
124 | # previous version has problems when ligand atom types are fixed
125 | # (due to sorting randomly in case of same element)
126 |
127 | batch_ctx = torch.cat([batch_protein, batch_ligand], dim=0)
128 | # sort_idx = batch_ctx.argsort()
129 | sort_idx = torch.sort(batch_ctx, stable=True).indices
130 |
131 | mask_ligand = torch.cat([
132 | torch.zeros([batch_protein.size(0)], device=batch_protein.device).bool(),
133 | torch.ones([batch_ligand.size(0)], device=batch_ligand.device).bool(),
134 | ], dim=0)[sort_idx]
135 |
136 | batch_ctx = batch_ctx[sort_idx]
137 | if isinstance(h_protein, list):
138 | h_ctx_sca = torch.cat([h_protein[0], h_ligand[0]], dim=0)[sort_idx] # (N_protein+N_ligand, H)
139 | h_ctx_vec = torch.cat([h_protein[1], h_ligand[1]], dim=0)[sort_idx] # (N_protein+N_ligand, H, 3)
140 | h_ctx = [h_ctx_sca, h_ctx_vec]
141 | else:
142 | h_ctx = torch.cat([h_protein, h_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, H)
143 | pos_ctx = torch.cat([pos_protein, pos_ligand], dim=0)[sort_idx] # (N_protein+N_ligand, 3)
144 |
145 | return h_ctx, pos_ctx, batch_ctx, mask_ligand
146 |
147 |
148 | def hybrid_edge_connection(ligand_pos, protein_pos, k, ligand_index, protein_index):
149 | # fully-connected for ligand atoms
150 | dst = torch.repeat_interleave(ligand_index, len(ligand_index))
151 | src = ligand_index.repeat(len(ligand_index))
152 | mask = dst != src
153 | dst, src = dst[mask], src[mask]
154 | ll_edge_index = torch.stack([src, dst])
155 |
156 | # knn for ligand-protein edges
157 | ligand_protein_pos_dist = torch.unsqueeze(ligand_pos, 1) - torch.unsqueeze(protein_pos, 0)
158 | ligand_protein_pos_dist = torch.norm(ligand_protein_pos_dist, p=2, dim=-1)
159 | knn_p_idx = torch.topk(ligand_protein_pos_dist, k=k, largest=False, dim=1).indices
160 | knn_p_idx = protein_index[knn_p_idx]
161 | knn_l_idx = torch.unsqueeze(ligand_index, 1)
162 | knn_l_idx = knn_l_idx.repeat(1, k)
163 | pl_edge_index = torch.stack([knn_p_idx, knn_l_idx], dim=0)
164 | pl_edge_index = pl_edge_index.view(2, -1)
165 | return ll_edge_index, pl_edge_index
166 |
167 |
168 | def batch_hybrid_edge_connection(x, k, mask_ligand, batch, add_p_index=False):
169 | batch_size = batch.max().item() + 1
170 | batch_ll_edge_index, batch_pl_edge_index, batch_p_edge_index = [], [], []
171 | with torch.no_grad():
172 | for i in range(batch_size):
173 | ligand_index = ((batch == i) & (mask_ligand == 1)).nonzero()[:, 0]
174 | protein_index = ((batch == i) & (mask_ligand == 0)).nonzero()[:, 0]
175 | # print(f'batch: {i}, ligand_index: {ligand_index} {len(ligand_index)}')
176 | ligand_pos, protein_pos = x[ligand_index], x[protein_index]
177 | ll_edge_index, pl_edge_index = hybrid_edge_connection(
178 | ligand_pos, protein_pos, k, ligand_index, protein_index)
179 | batch_ll_edge_index.append(ll_edge_index)
180 | batch_pl_edge_index.append(pl_edge_index)
181 | if add_p_index:
182 | all_pos = torch.cat([protein_pos, ligand_pos], 0)
183 | p_edge_index = knn_graph(all_pos, k=k, flow='source_to_target')
184 | p_edge_index = p_edge_index[:, p_edge_index[1] < len(protein_pos)]
185 | p_src, p_dst = p_edge_index
186 | all_index = torch.cat([protein_index, ligand_index], 0)
187 | # print('len protein index: ', len(protein_index))
188 | # print('max index: ', max(p_src), max(p_dst))
189 | p_edge_index = torch.stack([all_index[p_src], all_index[p_dst]], 0)
190 | batch_p_edge_index.append(p_edge_index)
191 | # edge_index.append(torch.cat([ll_edge_index, pl_edge_index], -1))
192 |
193 | if add_p_index:
194 | edge_index = [torch.cat([ll, pl, p], -1) for ll, pl, p in zip(
195 | batch_ll_edge_index, batch_pl_edge_index, batch_p_edge_index)]
196 | else:
197 | edge_index = [torch.cat([ll, pl], -1) for ll, pl in zip(batch_ll_edge_index, batch_pl_edge_index)]
198 | edge_index = torch.cat(edge_index, -1)
199 | return edge_index
200 |
201 |
202 | # Time Embedding
203 | def basic_time_embedding(beta):
204 | emb = torch.stack([beta, torch.sin(beta), torch.cos(beta)], dim=-1) # (N, 3)
205 | return emb
206 |
207 |
208 | def sinusoidal_embedding(timesteps, embedding_dim, max_positions=10000):
209 | """ from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py """
210 | assert len(timesteps.shape) == 1
211 | half_dim = embedding_dim // 2
212 | emb = math.log(max_positions) / (half_dim - 1)
213 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
214 | emb = timesteps.float()[:, None] * emb[None, :]
215 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
216 | if embedding_dim % 2 == 1: # zero pad
217 | emb = F.pad(emb, (0, 1), mode='constant')
218 | assert emb.shape == (timesteps.shape[0], embedding_dim)
219 | return emb
220 |
221 |
222 | class GaussianFourierProjection(nn.Module):
223 | """Gaussian Fourier embeddings for noise levels.
224 | from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/models/layerspp.py#L32
225 | """
226 |
227 | def __init__(self, embedding_size=256, scale=1.0):
228 | super().__init__()
229 | self.W = nn.Parameter(torch.randn(embedding_size // 2) * scale, requires_grad=False)
230 |
231 | def forward(self, x):
232 | x_proj = x[:, None] * self.W[None, :] * 2 * math.pi
233 | emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
234 | return emb
235 |
236 |
237 | def get_timestep_embedding(embedding_type, embedding_dim, embedding_scale=10000):
238 | if embedding_type == 'basic':
239 | emb_func = (lambda x: basic_time_embedding(x))
240 | elif embedding_type == 'sin':
241 | emb_func = (lambda x: sinusoidal_embedding(embedding_scale * x, embedding_dim))
242 | elif embedding_type == 'fourier':
243 | emb_func = GaussianFourierProjection(embedding_size=embedding_dim, scale=embedding_scale)
244 | else:
245 | raise NotImplemented
246 | return emb_func
247 |
--------------------------------------------------------------------------------
/models/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | from .node_edge_net import NodeEdgeNet
2 |
3 |
4 | def get_refine_net(config, net_type, node_hidden_dim, edge_hidden_dim, train_link):
5 | if net_type == 'node_edge_net':
6 | refine_net = NodeEdgeNet(
7 | node_dim=node_hidden_dim,
8 | edge_dim=edge_hidden_dim,
9 | num_blocks=config.num_blocks,
10 | cutoff=config.cutoff,
11 | use_gate=config.use_gate,
12 | update_pos=train_link,
13 | expansion_mode=config.get('expansion_mode', 'linear')
14 | )
15 | else:
16 | raise ValueError(net_type)
17 | return refine_net
18 |
--------------------------------------------------------------------------------
/models/encoders/node_edge_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import Module, Linear, ModuleList
4 | from torch_scatter import scatter_sum
5 | from models.common import GaussianSmearing, MLP
6 |
7 |
8 | class NodeBlock(Module):
9 |
10 | def __init__(self, node_dim, edge_dim, hidden_dim, use_gate):
11 | super().__init__()
12 | self.use_gate = use_gate
13 | self.node_dim = node_dim
14 |
15 | self.node_net = MLP(node_dim, hidden_dim, hidden_dim)
16 | self.edge_net = MLP(edge_dim, hidden_dim, hidden_dim)
17 | self.msg_net = Linear(hidden_dim, hidden_dim)
18 |
19 | if self.use_gate:
20 | self.gate = MLP(edge_dim+node_dim+1, hidden_dim, hidden_dim) # add 1 for time
21 |
22 | self.centroid_lin = Linear(node_dim, hidden_dim)
23 | self.layer_norm = nn.LayerNorm(hidden_dim)
24 | self.act = nn.ReLU()
25 | self.out_transform = Linear(hidden_dim, node_dim)
26 |
27 | def forward(self, x, edge_index, edge_attr, node_time):
28 | """
29 | Args:
30 | x: Node features, (N, H).
31 | edge_index: (2, E).
32 | edge_attr: (E, H)
33 | """
34 | N = x.size(0)
35 | row, col = edge_index # (E,) , (E,)
36 |
37 | h_node = self.node_net(x) # (N, H)
38 |
39 | # Compose messages
40 | h_edge = self.edge_net(edge_attr) # (E, H_per_head)
41 | msg_j = self.msg_net(h_edge * h_node[col])
42 |
43 | if self.use_gate:
44 | gate = self.gate(torch.cat([edge_attr, x[col], node_time[col]], dim=-1))
45 | msg_j = msg_j * torch.sigmoid(gate)
46 |
47 | # Aggregate messages
48 | aggr_msg = scatter_sum(msg_j, row, dim=0, dim_size=N)
49 | out = self.centroid_lin(x) + aggr_msg
50 |
51 | out = self.layer_norm(out)
52 | out = self.out_transform(self.act(out))
53 | return out
54 |
55 |
56 | class BondFFN(Module):
57 | def __init__(self, bond_dim, node_dim, inter_dim, use_gate, out_dim=None):
58 | super().__init__()
59 | out_dim = bond_dim if out_dim is None else out_dim
60 | self.use_gate = use_gate
61 | self.bond_linear = Linear(bond_dim, inter_dim, bias=False)
62 | self.node_linear = Linear(node_dim, inter_dim, bias=False)
63 | self.inter_module = MLP(inter_dim, out_dim, inter_dim)
64 | if self.use_gate:
65 | self.gate = MLP(bond_dim+node_dim+1, out_dim, 32) # +1 for time
66 |
67 | def forward(self, bond_feat_input, node_feat_input, time):
68 | bond_feat = self.bond_linear(bond_feat_input)
69 | node_feat = self.node_linear(node_feat_input)
70 | inter_feat = bond_feat * node_feat
71 | inter_feat = self.inter_module(inter_feat)
72 | if self.use_gate:
73 | gate = self.gate(torch.cat([bond_feat_input, node_feat_input, time], dim=-1))
74 | inter_feat = inter_feat * torch.sigmoid(gate)
75 | return inter_feat
76 |
77 |
78 | class QKVLin(Module):
79 | def __init__(self, h_dim, key_dim, num_heads):
80 | super().__init__()
81 | self.num_heads = num_heads
82 | self.q_lin = Linear(h_dim, key_dim)
83 | self.k_lin = Linear(h_dim, key_dim)
84 | self.v_lin = Linear(h_dim, h_dim)
85 |
86 | def forward(self, inputs):
87 | n = inputs.size(0)
88 | return [
89 | self.q_lin(inputs).view(n, self.num_heads, -1),
90 | self.k_lin(inputs).view(n, self.num_heads, -1),
91 | self.v_lin(inputs).view(n, self.num_heads, -1),
92 | ]
93 |
94 |
95 | class EdgeBlock(Module):
96 | def __init__(self, edge_dim, node_dim, hidden_dim=None, use_gate=True):
97 | super().__init__()
98 | self.use_gate = use_gate
99 | inter_dim = edge_dim * 2 if hidden_dim is None else hidden_dim
100 |
101 | self.bond_ffn_left = BondFFN(edge_dim, node_dim, inter_dim=inter_dim, use_gate=use_gate)
102 | self.bond_ffn_right = BondFFN(edge_dim, node_dim, inter_dim=inter_dim, use_gate=use_gate)
103 |
104 | self.node_ffn_left = Linear(node_dim, edge_dim)
105 | self.node_ffn_right = Linear(node_dim, edge_dim)
106 |
107 | self.self_ffn = Linear(edge_dim, edge_dim)
108 | self.layer_norm = nn.LayerNorm(edge_dim)
109 | self.out_transform = Linear(edge_dim, edge_dim)
110 | self.act = nn.ReLU()
111 |
112 | def forward(self, h_bond, bond_index, h_node, bond_time):
113 | """
114 | h_bond: (b, bond_dim)
115 | bond_index: (2, b)
116 | h_node: (n, node_dim)
117 | """
118 | N = h_node.size(0)
119 | left_node, right_node = bond_index
120 |
121 | # message from neighbor bonds
122 | msg_bond_left = self.bond_ffn_left(h_bond, h_node[left_node], bond_time)
123 | msg_bond_left = scatter_sum(msg_bond_left, right_node, dim=0, dim_size=N)
124 | msg_bond_left = msg_bond_left[left_node]
125 |
126 | msg_bond_right = self.bond_ffn_right(h_bond, h_node[right_node], bond_time)
127 | msg_bond_right = scatter_sum(msg_bond_right, left_node, dim=0, dim_size=N)
128 | msg_bond_right = msg_bond_right[right_node]
129 |
130 | h_bond = (
131 | msg_bond_left + msg_bond_right
132 | + self.node_ffn_left(h_node[left_node])
133 | + self.node_ffn_right(h_node[right_node])
134 | + self.self_ffn(h_bond)
135 | )
136 | h_bond = self.layer_norm(h_bond)
137 |
138 | h_bond = self.out_transform(self.act(h_bond))
139 | return h_bond
140 |
141 |
142 | class NodeEdgeNet(Module):
143 | def __init__(self, node_dim, edge_dim, num_blocks, cutoff, use_gate, **kwargs):
144 | super().__init__()
145 | self.node_dim = node_dim
146 | self.edge_dim = edge_dim
147 | self.num_blocks = num_blocks
148 | self.cutoff = cutoff
149 | self.use_gate = use_gate
150 | self.kwargs = kwargs
151 |
152 | if 'num_gaussians' not in kwargs:
153 | num_gaussians = 16
154 | else:
155 | num_gaussians = kwargs['num_gaussians']
156 | if 'start' not in kwargs:
157 | start = 0
158 | else:
159 | start = kwargs['start']
160 | self.distance_expansion = GaussianSmearing(
161 | start=start, stop=cutoff, num_gaussians=num_gaussians, expansion_mode=kwargs['expansion_mode'])
162 | print('distance expansion: ', self.distance_expansion)
163 | if ('update_edge' in kwargs) and (not kwargs['update_edge']):
164 | self.update_edge = False
165 | input_edge_dim = num_gaussians
166 | else:
167 | self.update_edge = True # default update edge
168 | input_edge_dim = edge_dim + num_gaussians
169 |
170 | if ('update_pos' in kwargs) and (not kwargs['update_pos']):
171 | self.update_pos = False
172 | else:
173 | self.update_pos = True # default update pos
174 | print('update pos: ', self.update_pos)
175 | # node network
176 | self.node_blocks_with_edge = ModuleList()
177 | self.edge_embs = ModuleList()
178 | self.edge_blocks = ModuleList()
179 | self.pos_blocks = ModuleList()
180 | for _ in range(num_blocks):
181 | self.node_blocks_with_edge.append(NodeBlock(
182 | node_dim=node_dim, edge_dim=edge_dim, hidden_dim=node_dim, use_gate=use_gate,
183 | ))
184 | self.edge_embs.append(Linear(input_edge_dim, edge_dim))
185 | if self.update_edge:
186 | self.edge_blocks.append(EdgeBlock(
187 | edge_dim=edge_dim, node_dim=node_dim, use_gate=use_gate,
188 | ))
189 | if self.update_pos:
190 | self.pos_blocks.append(PosUpdate(
191 | node_dim=node_dim, edge_dim=edge_dim, hidden_dim=edge_dim, use_gate=use_gate,
192 | ))
193 |
194 | @property
195 | def node_hidden_dim(self):
196 | return self.node_dim
197 |
198 | @property
199 | def edge_hidden_dim(self):
200 | return self.edge_dim
201 |
202 | def forward(self, pos_node, h_node, h_edge, edge_index, linker_mask, node_time, edge_time):
203 | for i in range(self.num_blocks):
204 | # edge fetures before each block
205 | if self.update_pos or (i==0):
206 | h_edge_dist, relative_vec, distance = self._build_edges_dist(pos_node, edge_index)
207 | if self.update_edge:
208 | h_edge = torch.cat([h_edge, h_edge_dist], dim=-1)
209 | else:
210 | h_edge = h_edge_dist
211 | h_edge = self.edge_embs[i](h_edge)
212 |
213 | # node and edge feature updates
214 | h_node_with_edge = self.node_blocks_with_edge[i](h_node, edge_index, h_edge, node_time)
215 | if self.update_edge:
216 | h_edge = h_edge + self.edge_blocks[i](h_edge, edge_index, h_node, edge_time)
217 | h_node = h_node + h_node_with_edge
218 | # pos updates
219 | if self.update_pos:
220 | delta_x = self.pos_blocks[i](h_node, h_edge, edge_index, relative_vec, distance, edge_time)
221 | pos_node = pos_node + delta_x * linker_mask[:, None] # only linker positions will be updated
222 | return pos_node, h_node, h_edge
223 |
224 | def _build_edges_dist(self, pos, edge_index):
225 | # distance
226 | relative_vec = pos[edge_index[0]] - pos[edge_index[1]]
227 | distance = torch.norm(relative_vec, dim=-1, p=2)
228 | edge_dist = self.distance_expansion(distance)
229 | return edge_dist, relative_vec, distance
230 |
231 |
232 | class PosUpdate(Module):
233 | def __init__(self, node_dim, edge_dim, hidden_dim, use_gate):
234 | super().__init__()
235 | self.left_lin_edge = MLP(node_dim, edge_dim, hidden_dim)
236 | self.right_lin_edge = MLP(node_dim, edge_dim, hidden_dim)
237 | self.edge_lin = BondFFN(edge_dim, edge_dim, node_dim, use_gate, out_dim=1)
238 |
239 | def forward(self, h_node, h_edge, edge_index, relative_vec, distance, edge_time):
240 | edge_index_left, edge_index_right = edge_index
241 |
242 | left_feat = self.left_lin_edge(h_node[edge_index_left])
243 | right_feat = self.right_lin_edge(h_node[edge_index_right])
244 | weight_edge = self.edge_lin(h_edge, left_feat * right_feat, edge_time)
245 |
246 | # relative_vec = pos_node[edge_index_left] - pos_node[edge_index_right]
247 | # distance = torch.norm(relative_vec, dim=-1, keepdim=True)
248 | force_edge = weight_edge * relative_vec / distance.unsqueeze(-1) / (distance.unsqueeze(-1) + 1.)
249 | delta_pos = scatter_sum(force_edge, edge_index_left, dim=0, dim_size=h_node.shape[0])
250 |
251 | return delta_pos
252 |
253 |
254 | class PosPredictor(Module):
255 | def __init__(self, node_dim, edge_dim, bond_dim, use_gate):
256 | super().__init__()
257 | self.left_lin_edge = MLP(node_dim, edge_dim, hidden_dim=edge_dim)
258 | self.right_lin_edge = MLP(node_dim, edge_dim, hidden_dim=edge_dim)
259 | self.edge_lin = BondFFN(edge_dim, edge_dim, node_dim, use_gate, out_dim=1)
260 |
261 | self.bond_dim = bond_dim
262 | if bond_dim > 0:
263 | self.left_lin_bond = MLP(node_dim, bond_dim, hidden_dim=bond_dim)
264 | self.right_lin_bond = MLP(node_dim, bond_dim, hidden_dim=bond_dim)
265 | self.bond_lin = BondFFN(bond_dim, bond_dim, node_dim, use_gate, out_dim=1)
266 |
267 | def forward(self, h_node, pos_node, h_bond, bond_index, h_edge, edge_index, is_frag):
268 | # 1 pos update through edges
269 | is_left_frag = is_frag[edge_index[0]]
270 | edge_index_left, edge_index_right = edge_index[:, is_left_frag]
271 |
272 | left_feat = self.left_lin_edge(h_node[edge_index_left])
273 | right_feat = self.right_lin_edge(h_node[edge_index_right])
274 | weight_edge = self.edge_lin(h_edge[is_left_frag], left_feat * right_feat)
275 | force_edge = weight_edge * (pos_node[edge_index_left] - pos_node[edge_index_right])
276 | delta_pos = scatter_sum(force_edge, edge_index_left, dim=0, dim_size=h_node.shape[0])
277 |
278 | # 2 pos update through bonds
279 | if self.bond_dim > 0:
280 | is_left_frag = is_frag[bond_index[0]]
281 | bond_index_left, bond_index_right = bond_index[:, is_left_frag]
282 |
283 | left_feat = self.left_lin_bond(h_node[bond_index_left])
284 | right_feat = self.right_lin_bond(h_node[bond_index_right])
285 | weight_bond = self.bond_lin(h_bond[is_left_frag], left_feat * right_feat)
286 | force_bond = weight_bond * (pos_node[bond_index_left] - pos_node[bond_index_right])
287 | delta_pos = delta_pos + scatter_sum(force_bond, bond_index_left, dim=0, dim_size=h_node.shape[0])
288 |
289 | pos_update = pos_node + delta_pos / 10.
290 | return pos_update
291 |
--------------------------------------------------------------------------------
/models/eps_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.encoders import get_refine_net
4 | from utils.geometry import apply_rotation_to_vector, quaternion_1ijk_to_rotation_matrix
5 | from utils.so3 import so3vec_to_rotation, rotation_to_so3vec
6 | from torch_scatter import scatter_mean, scatter_softmax, scatter_sum
7 | from models.common import MLP
8 | import numpy as np
9 |
10 |
11 | class ForceLayer(nn.Module):
12 | def __init__(self, input_dim, hidden_dim, output_dim, n_heads, edge_feat_dim, time_dim,
13 | separate_att=False, act_fn='relu', norm=True):
14 | super().__init__()
15 | self.input_dim = input_dim
16 | self.hidden_dim = hidden_dim
17 | self.output_dim = output_dim
18 | self.n_heads = n_heads
19 | self.edge_feat_dim = edge_feat_dim
20 | self.time_dim = time_dim
21 | self.act_fn = act_fn
22 | self.separate_att = separate_att
23 |
24 | kv_input_dim = input_dim * 2 + edge_feat_dim
25 | self.xk_func = MLP(kv_input_dim, output_dim, hidden_dim, norm=norm, act_fn=act_fn)
26 | self.xv_func = MLP(kv_input_dim + time_dim, self.n_heads, hidden_dim, norm=norm, act_fn=act_fn)
27 | self.xq_func = MLP(input_dim, output_dim, hidden_dim, norm=norm, act_fn=act_fn)
28 |
29 | def forward(self, h, rel_x, edge_feat, edge_index, inner_edge_mask, t, e_w=None):
30 | N = h.size(0)
31 | src, dst = edge_index
32 | hi, hj = h[dst], h[src]
33 |
34 | # multi-head attention
35 | kv_input = torch.cat([edge_feat, hi, hj], -1)
36 |
37 | k = self.xk_func(kv_input).view(-1, self.n_heads, self.output_dim // self.n_heads)
38 | v = self.xv_func(torch.cat([kv_input, t[dst]], -1))
39 | e_w = e_w.view(-1, 1) if e_w is not None else 1.
40 | v = v * e_w
41 |
42 | v = v.unsqueeze(-1) * rel_x.unsqueeze(1) # (xi - xj) [n_edges, n_heads, 3]
43 | q = self.xq_func(h).view(-1, self.n_heads, self.output_dim // self.n_heads)
44 |
45 | if self.separate_att:
46 | k_in, v_in = k[inner_edge_mask], v[inner_edge_mask]
47 | alpha_in = scatter_softmax((q[dst[inner_edge_mask]] * k_in / np.sqrt(k_in.shape[-1])).sum(-1),
48 | dst[inner_edge_mask], dim=0) # (E, heads)
49 | m_in = alpha_in.unsqueeze(-1) * v_in # (E, heads, 3)
50 | inner_forces = scatter_sum(m_in, dst[inner_edge_mask], dim=0, dim_size=N).mean(1)
51 |
52 | k_out, v_out = k[~inner_edge_mask], v[~inner_edge_mask]
53 | alpha_out = scatter_softmax((q[dst[~inner_edge_mask]] * k_out / np.sqrt(k_out.shape[-1])).sum(-1),
54 | dst[~inner_edge_mask], dim=0) # (E, heads)
55 | m_out = alpha_out.unsqueeze(-1) * v_out # (E, heads, 3)
56 | outer_forces = scatter_sum(m_out, dst[~inner_edge_mask], dim=0, dim_size=N).mean(1)
57 |
58 | else:
59 | # Compute attention weights
60 | alpha = scatter_softmax((q[dst] * k / np.sqrt(k.shape[-1])).sum(-1), dst, dim=0) # (E, heads)
61 |
62 | # Perform attention-weighted message-passing
63 | m = alpha.unsqueeze(-1) * v # (E, heads, 3)
64 | inner_forces = scatter_sum(m[inner_edge_mask], dst[inner_edge_mask], dim=0, dim_size=N).mean(1)
65 | outer_forces = scatter_sum(m[~inner_edge_mask], dst[~inner_edge_mask], dim=0, dim_size=N).mean(1)
66 | # output = scatter_sum(m, dst, dim=0, dim_size=N) # (N, heads, 3)
67 | return inner_forces, outer_forces # [num_nodes, 3]
68 |
69 |
70 | class SymForceLayer(nn.Module):
71 | def __init__(self, input_dim, hidden_dim, edge_feat_dim, time_dim,
72 | act_fn='relu', norm=True):
73 | super().__init__()
74 | self.input_dim = input_dim
75 | self.hidden_dim = hidden_dim
76 | self.output_dim = 1
77 | self.edge_feat_dim = edge_feat_dim
78 | self.time_dim = time_dim
79 | self.act_fn = act_fn
80 | self.pred_layer = MLP(input_dim + edge_feat_dim + time_dim, 1, hidden_dim,
81 | num_layer=3, norm=norm, act_fn=act_fn)
82 |
83 | def forward(self, h, rel_x, edge_feat, edge_index, inner_edge_mask, t):
84 | N = h.size(0)
85 | src, dst = edge_index[:, ~inner_edge_mask]
86 | rel_x = rel_x[~inner_edge_mask]
87 | distance = torch.norm(rel_x, p=2, dim=-1)
88 | hi, hj = h[dst], h[src]
89 |
90 | feat = torch.cat([edge_feat[~inner_edge_mask], (hi + hj) / 2, t[dst]], -1)
91 | forces = self.pred_layer(feat) * rel_x / distance.unsqueeze(-1) / (distance.unsqueeze(-1) + 1.)
92 | outer_forces = scatter_sum(forces, dst, dim=0, dim_size=N)
93 | return None, outer_forces # [num_nodes, 3]
94 |
95 |
96 | class EpsilonNet(nn.Module):
97 |
98 | def __init__(self, cfg, node_emb_dim, edge_emb_dim, time_emb_dim, num_classes, num_bond_classes,
99 | train_frag_rot, train_frag_pos, train_link, train_bond, pred_frag_dist=False,
100 | use_rel_geometry=False, softmax_last=True):
101 | super().__init__()
102 | self.encoder_type = cfg.net_type
103 | self.encoder = get_refine_net(cfg.encoder, cfg.net_type, node_emb_dim, edge_emb_dim, train_link)
104 | self.num_frags = 2
105 | self.pred_frag_dist = pred_frag_dist
106 | self.use_rel_geometry = use_rel_geometry
107 | self.train_frag_rot = train_frag_rot
108 | self.train_frag_pos = train_frag_pos
109 | self.train_link = train_link
110 | self.train_bond = train_bond
111 | self.tr_output_type = cfg.get('tr_output_type', 'invariant_eps')
112 | self.rot_output_type = cfg.rot_output_type
113 | print('EpsNet Softmax Last: ', softmax_last)
114 |
115 | if 'newton_equation' in self.tr_output_type or 'euler_equation' in self.rot_output_type:
116 | if cfg.get('sym_force', False):
117 | self.force_layer = SymForceLayer(
118 | input_dim=self.encoder.node_hidden_dim,
119 | hidden_dim=self.encoder.node_hidden_dim,
120 | edge_feat_dim=self.encoder.edge_hidden_dim,
121 | time_dim=time_emb_dim
122 | )
123 | else:
124 | self.force_layer = ForceLayer(
125 | input_dim=self.encoder.node_hidden_dim,
126 | hidden_dim=self.encoder.node_hidden_dim,
127 | output_dim=self.encoder.node_hidden_dim,
128 | n_heads=cfg.output_n_heads,
129 | edge_feat_dim=self.encoder.edge_hidden_dim,
130 | time_dim=time_emb_dim,
131 | separate_att=cfg.get('separate_att', False),
132 | )
133 | print(self.force_layer)
134 | if self.tr_output_type == 'invariant_eps' or self.rot_output_type == 'invariant_eps':
135 | self.frag_aggr = nn.Sequential(
136 | nn.Linear(node_emb_dim, node_emb_dim * 2), nn.ReLU(),
137 | nn.Linear(node_emb_dim * 2, node_emb_dim), nn.ReLU(),
138 | nn.Linear(node_emb_dim, node_emb_dim)
139 | )
140 | if self.tr_output_type == 'invariant_eps':
141 | self.eps_crd_net = nn.Sequential(
142 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(),
143 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(),
144 | nn.Linear(node_emb_dim, 1 if pred_frag_dist else 3)
145 | )
146 | if self.rot_output_type == 'invariant_eps':
147 | self.eps_rot_net = nn.Sequential(
148 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(),
149 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(),
150 | nn.Linear(node_emb_dim, 3)
151 | )
152 |
153 | if softmax_last:
154 | self.eps_cls_net = nn.Sequential(
155 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(),
156 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(),
157 | nn.Linear(node_emb_dim, num_classes), nn.Softmax(dim=-1)
158 | )
159 | else:
160 | self.eps_cls_net = nn.Sequential(
161 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(),
162 | nn.Linear(node_emb_dim, node_emb_dim), nn.ReLU(),
163 | nn.Linear(node_emb_dim, num_classes)
164 | )
165 | if self.train_bond:
166 | if softmax_last:
167 | self.bond_pred_net = nn.Sequential(
168 | nn.Linear(edge_emb_dim, edge_emb_dim), nn.ReLU(),
169 | nn.Linear(edge_emb_dim, num_bond_classes), nn.Softmax(dim=-1)
170 | )
171 | else:
172 | self.bond_pred_net = nn.Sequential(
173 | nn.Linear(edge_emb_dim, edge_emb_dim), nn.ReLU(),
174 | nn.Linear(edge_emb_dim, num_bond_classes)
175 | )
176 |
177 | def update_frag_geometry(self, final_x, final_h, final_h_bond,
178 | edge_index, inner_edge_mask, batch_node, mask, node_t_emb, R):
179 |
180 | eps_d, eps_pos, recon_x = None, None, None
181 | v_next, R_next = None, None
182 | if self.tr_output_type == 'invariant_eps' or self.rot_output_type == 'invariant_eps':
183 | # Aggregate features for fragment update
184 | frag_h = final_h[mask]
185 | frag_batch = batch_node[mask]
186 | frag_h_mean = scatter_mean(frag_h, frag_batch, dim=0)
187 | frag_h_mean = self.frag_aggr(frag_h_mean) # (G, H)
188 |
189 | # Position / Distance changes
190 | if self.tr_output_type == 'invariant_eps':
191 | if self.pred_frag_dist:
192 | eps_d = self.eps_crd_net(frag_h_mean) # (G, 1)
193 | else:
194 | eps_crd = self.eps_crd_net(frag_h_mean) # local coordinates (G, 3)
195 | eps_pos = apply_rotation_to_vector(R, eps_crd) # (G, 3)
196 |
197 | if self.rot_output_type == 'invariant_eps':
198 | # New orientation
199 | eps_rot = self.eps_rot_net(frag_h_mean) # (G, 3)
200 | U = quaternion_1ijk_to_rotation_matrix(eps_rot) # (G, 3, 3)
201 | R_next = R @ U
202 | v_next = rotation_to_so3vec(R_next) # (G, 3)
203 |
204 | if 'newton_equation' in self.tr_output_type or 'euler_equation' in self.rot_output_type:
205 | src, dst = edge_index
206 | rel_x = final_x[dst] - final_x[src]
207 | inner_forces, outer_forces = self.force_layer(
208 | final_h, rel_x, final_h_bond, edge_index, inner_edge_mask, t=node_t_emb) # equivariant
209 | if 'outer' in self.tr_output_type:
210 | assert 'outer' in self.rot_output_type
211 | forces = outer_forces
212 | else:
213 | forces = inner_forces + outer_forces
214 |
215 | if 'newton_equation' in self.tr_output_type:
216 | frag_batch = batch_node[mask]
217 | frag_center = scatter_mean(final_x[mask], frag_batch, dim=0) # (G, 3)
218 | recon_x = frag_center + scatter_mean(forces[mask], frag_batch, dim=0)
219 |
220 | if 'euler_equation' in self.rot_output_type:
221 | x_f1, force_1, batch_node_f1 = final_x[mask], forces[mask], batch_node[mask]
222 | mu_1 = scatter_mean(x_f1, batch_node_f1, dim=0)[batch_node_f1]
223 | tau = scatter_sum(torch.cross(x_f1 - mu_1, force_1), batch_node_f1, dim=0) # (num_graphs, 3)
224 | inertia_mat = scatter_sum(
225 | torch.sum((x_f1 - mu_1) ** 2, dim=-1)[:, None, None] * torch.eye(3)[None].to(x_f1) -
226 | (x_f1 - mu_1).unsqueeze(-1) @ (x_f1 - mu_1).unsqueeze(-2), batch_node_f1,
227 | dim=0) # (num_graphs, 3, 3)
228 | omega = torch.linalg.solve(inertia_mat, tau.unsqueeze(-1)).squeeze(-1) # (num_graphs, 3)
229 | R_next = so3vec_to_rotation(-omega) @ R
230 | v_next = rotation_to_so3vec(R_next)
231 |
232 | assert (eps_pos is not None) or (eps_d is not None) or (recon_x is not None)
233 | return eps_pos, eps_d, recon_x, v_next, R_next
234 |
235 | def forward(self,
236 | x_noisy, node_attr, R1_noisy, R2_noisy, edge_index, edge_attr,
237 | fragment_mask, edge_mask, inner_edge_mask, batch_node,
238 | node_t_emb, edge_t_emb, graph_t_emb, rel_fragment_mask, rel_R_noisy
239 | ):
240 | """
241 | Args:
242 | x_noisy: (N, 3)
243 | node_attr: (N, H)
244 | R1_noisy: (num_graphs, 3, 3)
245 | f_mask: (N, )
246 |
247 | v_t: (F, 3).
248 | p_t: (F, 3).
249 | node_feat: (N, Hn).
250 | edge_feat: (E, He).
251 | edge_index: (2, E)
252 | mask_ligand: (N, ).
253 | mask_ll_edge: (E, ).
254 | beta: (F, ).
255 | Returns:
256 | v_next: UPDATED (not epsilon) SO3-vector of orientations, (F, 3).
257 | R_next: (F, 3, 3).
258 | eps_pos: (F, 3).
259 | c_denoised: (F, C).
260 | """
261 | linker_mask = (fragment_mask == 0)
262 |
263 | final_x, final_h, final_h_bond = self.encoder(
264 | pos_node=x_noisy, h_node=node_attr, h_edge=edge_attr, edge_index=edge_index,
265 | linker_mask=linker_mask, node_time=node_t_emb, edge_time=edge_t_emb)
266 |
267 | # Update linker
268 | outputs = {}
269 | if self.train_link:
270 | x_denoised = final_x[linker_mask] # (L, 3)
271 | c_denoised = self.eps_cls_net(final_h[linker_mask]) # may have softmax-ed, (L, K)
272 | outputs.update({
273 | 'linker_x': x_denoised,
274 | 'linker_c': c_denoised
275 | })
276 |
277 | if self.train_bond:
278 | in_bond_feat = final_h_bond[edge_mask]
279 | in_bond_feat = (in_bond_feat[::2] + in_bond_feat[1::2]).repeat_interleave(2, dim=0)
280 | bond_denoised = self.bond_pred_net(in_bond_feat)
281 | outputs['linker_bond'] = bond_denoised
282 |
283 | # Update fragment geometry
284 | if self.train_frag_rot or self.train_frag_pos:
285 | if self.use_rel_geometry:
286 | eps_pos, _, recon_x, v_next, R_next = self.update_frag_geometry(
287 | final_x, final_h, final_h_bond,
288 | edge_index, inner_edge_mask, batch_node, rel_fragment_mask, node_t_emb, rel_R_noisy)
289 | outputs.update({
290 | 'frag_eps_pos': eps_pos,
291 | 'frag_v_next': v_next,
292 | 'frag_R_next': R_next
293 | })
294 | else:
295 | eps_pos1, eps_d1, recon_x1, v_next1, R_next1 = self.update_frag_geometry(
296 | final_x, final_h, final_h_bond,
297 | edge_index, inner_edge_mask, batch_node, (fragment_mask == 1), node_t_emb, R1_noisy)
298 | eps_pos2, eps_d2, recon_x2, v_next2, R_next2 = self.update_frag_geometry(
299 | final_x, final_h, final_h_bond,
300 | edge_index, inner_edge_mask, batch_node, (fragment_mask == 2), node_t_emb, R2_noisy)
301 |
302 | # zero center
303 | if not self.pred_frag_dist:
304 | if self.tr_output_type == 'invariant_eps':
305 | center = (eps_pos1 + eps_pos2) / 2
306 | eps_pos1, eps_pos2 = eps_pos1 - center, eps_pos2 - center
307 | elif 'newton_equation' in self.tr_output_type:
308 | center = (recon_x1 + recon_x2) / 2
309 | recon_x1, recon_x2 = recon_x1 - center, recon_x2 - center
310 | else:
311 | raise ValueError(self.tr_output_type)
312 |
313 | outputs.update({
314 | 'frag_eps_pos': (eps_pos1, eps_pos2),
315 | 'frag_eps_d': eps_d1 + eps_d2 if self.pred_frag_dist else None,
316 | 'frag_recon_x': (recon_x1, recon_x2),
317 | 'frag_v_next': (v_next1, v_next2),
318 | 'frag_R_next': (R_next1, R_next2)
319 | })
320 |
321 | return outputs
322 |
--------------------------------------------------------------------------------
/scripts/baselines/eval_3dlinker.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 | import torch
5 | from rdkit import RDLogger
6 | from torch_geometric.transforms import Compose
7 | from tqdm.auto import tqdm
8 |
9 | import utils.misc as misc
10 | import utils.transforms as trans
11 | from datasets.linker_dataset import get_linker_dataset
12 | from utils.evaluation import LinkerEvaluator
13 | from utils.visualize import *
14 |
15 |
16 | if __name__ == '__main__':
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('sample_path', type=str)
19 | parser.add_argument('--num_eval', type=int, default=400)
20 | parser.add_argument('--save_path', type=str)
21 | parser.add_argument('--config_path', type=str, default='configs/sampling/zinc.yml')
22 |
23 | args = parser.parse_args()
24 | RDLogger.DisableLog('rdApp.*')
25 |
26 | mols = []
27 | num_valid_mols = 0
28 | for fidx in range(args.num_eval):
29 | sdf_path = os.path.join(args.sample_path, f'data_{fidx}.sdf')
30 | if os.path.exists(sdf_path):
31 | m = Chem.SDMolSupplier(sdf_path, sanitize=False)
32 | mols.append(m)
33 | num_valid_mols += 1
34 | else:
35 | mols.append([])
36 | print('Num datapoints: ', args.num_eval)
37 | print('Num valid mols: ', num_valid_mols)
38 |
39 | config = misc.load_config(args.config_path)
40 | # Transforms
41 | atom_featurizer = trans.FeaturizeAtom(
42 | config.dataset.name, known_anchor=False, add_atom_type=True, add_atom_feat=True)
43 | graph_builder = trans.BuildCompleteGraph(known_linker_bond=False)
44 | test_transform = Compose([
45 | atom_featurizer,
46 | trans.SelectCandAnchors(mode='k-hop', k=1),
47 | graph_builder,
48 | trans.StackFragLocalPos(max_num_atoms=config.dataset.get('max_num_atoms', 30)),
49 | # trans.RelativeGeometry(mode=cfg_model.get('rel_geometry', 'distance_and_two_rot'))
50 | ])
51 | dataset, subsets = get_linker_dataset(
52 | cfg=config.dataset,
53 | transform_map={'train': None, 'val': None, 'test': test_transform}
54 | )
55 | test_set = subsets['test']
56 |
57 | all_results = []
58 | for i in range(args.num_eval):
59 | ref_data = test_set[i]
60 | gen_mols = mols[i]
61 | if len(gen_mols) == 0:
62 | continue
63 |
64 | num_frag_atoms = sum(ref_data.fragment_mask > 0)
65 | gen_data_list = []
66 | for mol in gen_mols:
67 | if mol is None:
68 | gen_data_list.append(None)
69 | continue
70 | gen_data = ref_data.clone()
71 | gen_data['pos'] = torch.from_numpy(mol.GetConformer().GetPositions().astype(np.float32))
72 | num_linker_atoms = len(gen_data['pos']) - num_frag_atoms
73 | if num_linker_atoms < 0:
74 | gen_data_list.append(None)
75 | continue
76 | gen_data['fragment_mask'] = torch.cat(
77 | [ref_data.fragment_mask[:num_frag_atoms], torch.zeros(num_linker_atoms).long()])
78 | gen_data['linker_mask'] = (gen_data['fragment_mask'] == 0)
79 | gen_data_list.append(gen_data)
80 |
81 | results = {
82 | 'gen_mols': gen_mols,
83 | 'ref_data': ref_data,
84 | 'data_list': gen_data_list
85 | }
86 | all_results.append(results)
87 |
88 | evaluator = LinkerEvaluator(all_results, reconstruct=False)
89 | evaluator.evaluate()
90 | if args.save_path is not None:
91 | evaluator.save_metrics(args.save_path)
92 |
--------------------------------------------------------------------------------
/scripts/baselines/eval_delinker.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 | import torch
5 | from rdkit import RDLogger
6 | from torch_geometric.transforms import Compose
7 | from tqdm.auto import tqdm
8 |
9 | import utils.misc as misc
10 | import utils.transforms as trans
11 | from datasets.linker_dataset import get_linker_dataset
12 | from utils.evaluation import LinkerEvaluator, standardise_linker, remove_dummys_mol
13 | from utils.visualize import *
14 | from utils import frag_utils
15 |
16 |
17 | if __name__ == '__main__':
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('sample_path', type=str)
20 | parser.add_argument('--num_eval', type=int, default=400)
21 | parser.add_argument('--num_samples', type=int, default=250)
22 | parser.add_argument('--save_path', type=str)
23 | parser.add_argument('--config_path', type=str, default='configs/sampling/zinc.yml')
24 |
25 | args = parser.parse_args()
26 | RDLogger.DisableLog('rdApp.*')
27 |
28 | mols = []
29 | with open(args.sample_path, 'r') as f:
30 | for line in tqdm(f.readlines()):
31 | parts = line.strip().split(' ')
32 | data = {
33 | 'fragments': parts[0],
34 | 'true_molecule': parts[1],
35 | 'pred_molecule': parts[2],
36 | 'pred_linker': parts[3] if len(parts) > 3 else '',
37 | }
38 | mols.append(Chem.MolFromSmiles(data['pred_molecule']))
39 | print('Num data: ', len(mols))
40 |
41 | config = misc.load_config(args.config_path)
42 | # Transforms
43 | atom_featurizer = trans.FeaturizeAtom(
44 | config.dataset.name, known_anchor=False, add_atom_type=True, add_atom_feat=True)
45 | graph_builder = trans.BuildCompleteGraph(known_linker_bond=False)
46 | test_transform = Compose([
47 | atom_featurizer,
48 | trans.SelectCandAnchors(mode='k-hop', k=1),
49 | graph_builder,
50 | trans.StackFragLocalPos(max_num_atoms=config.dataset.get('max_num_atoms', 30)),
51 | # trans.RelativeGeometry(mode=cfg_model.get('rel_geometry', 'distance_and_two_rot'))
52 | ])
53 | dataset, subsets = get_linker_dataset(
54 | cfg=config.dataset,
55 | transform_map={'train': None, 'val': None, 'test': test_transform}
56 | )
57 | test_set = subsets['test']
58 |
59 | all_results = []
60 | for i in range(args.num_eval):
61 | ref_data = test_set[i]
62 | gen_mols = [mols[idx] for idx in range(args.num_samples * i, args.num_samples * (i + 1))]
63 |
64 | num_frag_atoms = sum(ref_data.fragment_mask > 0)
65 | gen_data_list = []
66 | for mol in gen_mols:
67 | if mol is None:
68 | gen_data_list.append(None)
69 | continue
70 | gen_data = ref_data.clone()
71 | gen_num_atoms = mol.GetNumAtoms()
72 | # gen_data['pos'] = torch.from_numpy(mol.GetConformer().GetPositions().astype(np.float32))
73 | num_linker_atoms = gen_num_atoms - num_frag_atoms
74 | if num_linker_atoms < 0:
75 | gen_data_list.append(None)
76 | continue
77 | # gen_data['fragment_mask'] = torch.cat(
78 | # [ref_data.fragment_mask[:num_frag_atoms], torch.zeros(num_linker_atoms).long()])
79 | # gen_data['linker_mask'] = (gen_data['fragment_mask'] == 0)
80 | gen_data_list.append(gen_data)
81 |
82 | results = {
83 | 'gen_mols': gen_mols,
84 | 'ref_data': ref_data,
85 | 'data_list': gen_data_list
86 | }
87 | all_results.append(results)
88 |
89 | # fix the mismatch between gen mols and ref mol
90 | data_list = []
91 | with open(args.sample_path, 'r') as f:
92 | for line in tqdm(f.readlines()):
93 | parts = line.strip().split(' ')
94 | data = {
95 | 'fragments': parts[0],
96 | 'true_molecule': parts[1],
97 | 'pred_molecule': parts[2],
98 | 'pred_linker': parts[3] if len(parts) > 3 else '',
99 | }
100 | data_list.append(data)
101 |
102 | valid_all_results = []
103 | for i in range(args.num_eval):
104 | valid_results = []
105 | gen_data_list = [data_list[idx] for idx in range(args.num_samples * i, args.num_samples * (i + 1))]
106 | for gen_data in gen_data_list:
107 | gen_smi = gen_data['pred_molecule']
108 | ref_smi = gen_data['true_molecule']
109 | raw_frag_smi = gen_data['fragments']
110 | frag_smi = Chem.MolToSmiles(remove_dummys_mol(Chem.MolFromSmiles(raw_frag_smi)))
111 | try:
112 | # gen_mols is chemically valid
113 | Chem.SanitizeMol(Chem.MolFromSmiles(gen_smi), sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
114 | except:
115 | continue
116 |
117 | # gen_mols should contain both fragments
118 | if len(Chem.MolFromSmiles(gen_smi).GetSubstructMatch(Chem.MolFromSmiles(frag_smi))) != Chem.MolFromSmiles(
119 | frag_smi).GetNumAtoms():
120 | continue
121 |
122 | # Determine linkers of generated molecules
123 | try:
124 | linker = frag_utils.get_linker(Chem.MolFromSmiles(gen_smi), Chem.MolFromSmiles(frag_smi), frag_smi)
125 | linker_smi = standardise_linker(linker)
126 | except:
127 | continue
128 |
129 | valid_results.append({
130 | 'ref_smi': ref_smi,
131 | 'frag_smi': frag_smi,
132 | 'gen_smi': gen_smi,
133 | 'linker_smi': linker_smi,
134 | 'metrics': {}
135 | })
136 | valid_all_results.append(valid_results)
137 |
138 | evaluator = LinkerEvaluator(all_results, reconstruct=False)
139 | evaluator.gen_all_results = valid_all_results
140 | evaluator.validity(args.num_samples)
141 |
142 | evaluator.evaluate(eval_3d=False)
143 | if args.save_path is not None:
144 | evaluator.save_metrics(args.save_path)
145 |
146 |
147 |
--------------------------------------------------------------------------------
/scripts/baselines/eval_difflinker.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 | import torch
5 | from rdkit import RDLogger
6 | from torch_geometric.transforms import Compose
7 | from tqdm.auto import tqdm
8 |
9 | import utils.misc as misc
10 | import utils.transforms as trans
11 | from datasets.linker_dataset import get_linker_dataset
12 | from utils.evaluation import LinkerEvaluator
13 | from utils.visualize import *
14 |
15 |
16 | if __name__ == '__main__':
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('sample_path', type=str)
19 | parser.add_argument('--num_eval', type=int, default=None)
20 | parser.add_argument('--num_samples', type=int, default=100)
21 | parser.add_argument('--save_path', type=str)
22 | parser.add_argument('--config_path', type=str, default='configs/sampling/zinc.yml')
23 |
24 | args = parser.parse_args()
25 | RDLogger.DisableLog('rdApp.*')
26 |
27 | if args.sample_path.endswith('.sdf'):
28 | mols = Chem.SDMolSupplier(args.sample_path)
29 | if args.num_eval is None:
30 | args.num_eval = len(mols) // args.num_samples
31 | print('Load sdf done! Total mols: ', len(mols))
32 | else:
33 | mols = []
34 | folders = sorted(os.listdir(args.sample_path), key=lambda x: int(x))
35 | if args.num_eval:
36 | folders = folders[:args.num_eval]
37 | else:
38 | args.num_eval = len(folders)
39 | num_valid_mols = 0
40 | for folder in tqdm(folders, desc='Load molecules'):
41 | sdf_files = [fn for fn in os.listdir(os.path.join(args.sample_path, folder)) if fn.endswith('.sdf')]
42 | # for fname in sorted(sdf_files, key=lambda x: int(x[:-5])):
43 | for fname in range(args.num_samples):
44 | sdf_path = os.path.join(args.sample_path, folder, str(fname) + '_.sdf')
45 | if os.path.exists(sdf_path):
46 | supp = Chem.SDMolSupplier(sdf_path, sanitize=False)
47 | mol = list(supp)[0]
48 | mols.append(mol)
49 | num_valid_mols += 1
50 | else:
51 | mols.append(None)
52 | print('Num datapoints: ', args.num_eval)
53 | print('Num valid mols: ', num_valid_mols)
54 |
55 | config = misc.load_config(args.config_path)
56 | # Transforms
57 | atom_featurizer = trans.FeaturizeAtom(
58 | config.dataset.name, known_anchor=False, add_atom_type=True, add_atom_feat=True)
59 | graph_builder = trans.BuildCompleteGraph(known_linker_bond=False)
60 | test_transform = Compose([
61 | atom_featurizer,
62 | trans.SelectCandAnchors(mode='k-hop', k=1),
63 | graph_builder,
64 | trans.StackFragLocalPos(max_num_atoms=config.dataset.get('max_num_atoms', 30)),
65 | # trans.RelativeGeometry(mode=cfg_model.get('rel_geometry', 'distance_and_two_rot'))
66 | ])
67 | dataset, subsets = get_linker_dataset(
68 | cfg=config.dataset,
69 | transform_map={'train': None, 'val': None, 'test': test_transform}
70 | )
71 | test_set = subsets['test']
72 |
73 | all_results = []
74 | for i in range(args.num_eval):
75 | ref_data = test_set[i]
76 | gen_mols = [mols[idx] for idx in range(args.num_samples * i, args.num_samples * (i + 1))]
77 | num_frag_atoms = sum(ref_data.fragment_mask > 0)
78 | gen_data_list = []
79 | for mol in gen_mols:
80 | if mol is None:
81 | gen_data_list.append(None)
82 | continue
83 | gen_data = ref_data.clone()
84 | gen_data['pos'] = torch.from_numpy(mol.GetConformer().GetPositions().astype(np.float32))
85 | num_linker_atoms = len(gen_data['pos']) - num_frag_atoms
86 | if num_linker_atoms < 0:
87 | gen_data_list.append(None)
88 | continue
89 | gen_data['fragment_mask'] = torch.cat(
90 | [ref_data.fragment_mask[:num_frag_atoms], torch.zeros(num_linker_atoms).long()])
91 | gen_data['linker_mask'] = (gen_data['fragment_mask'] == 0)
92 | gen_data_list.append(gen_data)
93 |
94 | results = {
95 | 'gen_mols': gen_mols,
96 | 'ref_data': ref_data,
97 | 'data_list': gen_data_list
98 | }
99 | all_results.append(results)
100 |
101 | evaluator = LinkerEvaluator(all_results, reconstruct=False)
102 | evaluator.evaluate()
103 | if args.save_path is not None:
104 | evaluator.save_metrics(args.save_path)
105 |
--------------------------------------------------------------------------------
/scripts/eval_protac.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from rdkit import RDLogger
3 | import torch
4 | from utils.evaluation import LinkerEvaluator
5 | from glob import glob
6 | import os
7 | from tqdm.auto import tqdm
8 | from torch_geometric.transforms import Compose
9 | import utils.misc as misc
10 | import utils.transforms as trans
11 | from datasets.linker_dataset import get_linker_dataset
12 | from utils.evaluation import LinkerEvaluator
13 | from utils.visualize import *
14 |
15 |
16 | if __name__ == '__main__':
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('sample_path', type=str)
19 | parser.add_argument('--num_eval', type=int, default=None)
20 | parser.add_argument('--recon', type=eval, default=False)
21 | parser.add_argument('--save', type=eval, default=True)
22 | args = parser.parse_args()
23 | RDLogger.DisableLog('rdApp.*')
24 |
25 | if args.sample_path.endswith('.pt'):
26 | print(f'There is only one sampling file to evaluate')
27 | results = [torch.load(args.sample_path)]
28 | else:
29 | results_file_list = sorted(glob(os.path.join(args.sample_path, 'sampling_*.pt')))
30 | if args.num_eval:
31 | results_file_list = results_file_list[:args.num_eval]
32 | print(f'There are {len(results_file_list)} files to evaluate')
33 | results = []
34 | for f in tqdm(results_file_list, desc='Load sampling files'):
35 | results.append(torch.load(f))
36 |
37 | if 'data_list' not in results[0].keys():
38 | print('Can not find data_list in result keys -- add data_list based on the test set')
39 | config = misc.load_config('configs/sampling/zinc.yml')
40 | # Transforms
41 | atom_featurizer = trans.FeaturizeAtom(
42 | config.dataset.name, known_anchor=False, add_atom_type=True, add_atom_feat=True)
43 | graph_builder = trans.BuildCompleteGraph(known_linker_bond=False)
44 | test_transform = Compose([
45 | atom_featurizer,
46 | trans.SelectCandAnchors(mode='k-hop', k=2),
47 | graph_builder,
48 | trans.StackFragLocalPos(max_num_atoms=config.dataset.get('max_num_atoms', 30)),
49 | # trans.RelativeGeometry(mode=cfg_model.get('rel_geometry', 'distance_and_two_rot'))
50 | ])
51 | dataset, subsets = get_linker_dataset(
52 | cfg=config.dataset,
53 | transform_map={'train': None, 'val': None, 'test': test_transform}
54 | )
55 | test_set = subsets['test']
56 |
57 | all_results = []
58 | for i in range(len(results)):
59 | ref_data = test_set[i]
60 | gen_mols = results[i]['gen_mols']
61 | num_frag_atoms = sum(ref_data.fragment_mask > 0)
62 | gen_data_list = [ref_data.clone() for _ in range(100)]
63 |
64 | new_results = {
65 | 'gen_mols': gen_mols,
66 | 'ref_data': ref_data,
67 | 'data_list': gen_data_list,
68 | 'final_x': results[i]['final_x'],
69 | 'final_c': results[i]['final_c'],
70 | 'final_bond': results[i]['final_bond']
71 | }
72 | all_results.append(new_results)
73 | results = all_results
74 |
75 | evaluator = LinkerEvaluator(results, reconstruct=args.recon)
76 | evaluator.evaluate()
77 | save_path = os.path.join(args.sample_path, 'summary.csv')
78 | if args.save:
79 | evaluator.save_metrics(save_path)
80 |
--------------------------------------------------------------------------------
/scripts/prepare_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 | from tqdm.auto import tqdm
4 | from rdkit import Chem
5 | import torch
6 | import os
7 | import numpy as np
8 | from utils.data import compute_3d_coors_multiple, set_rdmol_positions
9 |
10 |
11 | def read_protac_file(file_path):
12 | with open(file_path, 'r') as f:
13 | lines = f.readlines()
14 | data = []
15 | for i, line in enumerate(lines):
16 | toks = line.strip().split(' ')
17 | if len(toks) == 4:
18 | smi_protac, smi_linker, smi_warhead, smi_ligase = toks
19 | else:
20 | raise ValueError("Incorrect input format.")
21 | data.append({'smi_protac': smi_protac,
22 | 'smi_linker': smi_linker,
23 | 'smi_warhead': smi_warhead,
24 | 'smi_ligase': smi_ligase})
25 | return data
26 |
27 |
28 | def validity_check(num_atoms, m1, m2, m3):
29 | if len(set(m1).intersection(set(m2))) != 0 or len(set(m2).intersection(set(m3))) != 0 or \
30 | len(set(m1).intersection(set(m3))) != 0 or len(m1) + len(m2) + len(m3) != num_atoms:
31 | return False
32 | else:
33 | return True
34 |
35 |
36 | def preprocess_protac(raw_file_path, save_path):
37 | if raw_file_path.endswith('.txt'):
38 | raw_data = read_protac_file(raw_file_path)
39 | else:
40 | with open(raw_file_path, 'rb') as f:
41 | raw_data = pickle.load(f)
42 | processed_data = []
43 | total = len(raw_data)
44 | for i, d in enumerate(tqdm(raw_data, desc='Generate 3D conformer')):
45 | smi_protac, smi_linker, smi_warhead, smi_ligase = d['smi_protac'], d['smi_linker'], \
46 | d['smi_warhead'], d['smi_ligase']
47 |
48 | # frag_0 (3d pos), frag_1 (3d pos), linker (3d pos)
49 | mol = Chem.MolFromSmiles(smi_protac)
50 | # generate 3d coordinates of mols
51 | pos, _ = compute_3d_coors_multiple(mol, maxIters=1000)
52 | if pos is None:
53 | print('Generate conformer fail!')
54 | continue
55 | mol = set_rdmol_positions(mol, pos)
56 |
57 | if raw_file_path.endswith('.txt'):
58 | warhead_m = mol.GetSubstructMatch(Chem.MolFromSmiles(smi_warhead))
59 | ligase_m = mol.GetSubstructMatch(Chem.MolFromSmiles(smi_ligase))
60 | linker_m = mol.GetSubstructMatch(Chem.MolFromSmiles(smi_linker))
61 | d.update({
62 | 'atom_indices_warhead': warhead_m,
63 | 'atom_indices_ligase': ligase_m,
64 | 'atom_indices_linker': linker_m
65 | })
66 | else:
67 | warhead_m, ligase_m, linker_m = d['atom_indices_warhead'], d['atom_indices_ligase'], d['atom_indices_linker']
68 |
69 | valid = validity_check(mol.GetNumAtoms(), warhead_m, ligase_m, linker_m)
70 | if not valid:
71 | print('Validity check fail!')
72 | continue
73 |
74 | processed_data.append({
75 | 'mol': mol,
76 | **d
77 | })
78 |
79 | print('Saving data')
80 | with open(save_path, 'wb') as f:
81 | pickle.dump(processed_data, f)
82 | print('Length raw data: \t%d' % total)
83 | print('Length processed data: \t%d' % len(processed_data))
84 |
85 |
86 | def preprocess_zinc_from_difflinker(raw_file_dir, save_path, mode):
87 | datasets = {}
88 | if mode == 'full':
89 | all_subsets = ['train', 'val', 'test']
90 | else:
91 | all_subsets = ['val', 'test']
92 | for subset in all_subsets:
93 | data_list = []
94 | n_fail = 0
95 | all_data = torch.load(os.path.join(raw_file_dir, f'zinc_final_{subset}.pt'), map_location='cpu')
96 | full_rdmols = Chem.SDMolSupplier(os.path.join(raw_file_dir, f'zinc_final_{subset}_mol.sdf'), sanitize=False)
97 | frag_rdmols = Chem.SDMolSupplier(os.path.join(raw_file_dir, f'zinc_final_{subset}_frag.sdf'), sanitize=False)
98 | link_rdmols = Chem.SDMolSupplier(os.path.join(raw_file_dir, f'zinc_final_{subset}_link.sdf'), sanitize=False)
99 | assert len(all_data) == len(frag_rdmols) == len(link_rdmols)
100 | for i in tqdm(range(len(all_data)), desc=subset):
101 | data = all_data[i]
102 | mol, frag_mol, link_mol = full_rdmols[i], frag_rdmols[i], link_rdmols[i]
103 | # if mol is None or frag_mol is None or link_mol is None:
104 | # print('Fail i: ', i)
105 | # n_fail += 1
106 | # continue
107 | pos = data['positions']
108 | fragment_mask, linker_mask = data['fragment_mask'].bool(), data['linker_mask'].bool()
109 | # align full mol with positions, etc.
110 | mol_pos = mol.GetConformer().GetPositions()
111 | mapping = np.linalg.norm(mol_pos[None] - data['positions'].numpy()[:, None], axis=-1).argmin(axis=1)
112 | assert len(np.unique(mapping)) == len(mapping)
113 | new_mol = Chem.RenumberAtoms(mol, mapping.tolist())
114 | Chem.SanitizeMol(new_mol)
115 | # check frag mol and link mol are aligned
116 | assert np.allclose(frag_mol.GetConformer().GetPositions(), pos[fragment_mask].numpy())
117 | assert np.allclose(link_mol.GetConformer().GetPositions(), pos[linker_mask].numpy())
118 | # print(mapping)
119 |
120 | # print(data['anchors'].nonzero(as_tuple=True)[0].tolist())
121 | # Note: anchor atom index may be wrong!
122 | frag_mols = Chem.GetMolFrags(frag_mol, asMols=True, sanitizeFrags=False)
123 | assert len(frag_mols) == 2
124 |
125 | all_frag_atom_idx = set((fragment_mask == 1).nonzero()[:, 0].tolist())
126 | frag_atom_idx_list = []
127 | for m1 in new_mol.GetSubstructMatches(frag_mols[0]):
128 | for m2 in new_mol.GetSubstructMatches(frag_mols[1]):
129 | if len(set(m1).intersection(set(m2))) == 0 and set(m1).union(set(m2)) == all_frag_atom_idx:
130 | frag_atom_idx_list = [m1, m2]
131 | break
132 |
133 | try:
134 | assert len(frag_atom_idx_list) == 2 and all([x is not None and len(x) > 0 for x in frag_atom_idx_list])
135 | except:
136 | print('Fail i: ', i)
137 | n_fail += 1
138 | continue
139 | new_fragment_mask = torch.zeros_like(fragment_mask).long()
140 | new_fragment_mask[list(frag_atom_idx_list[0])] = 1
141 | new_fragment_mask[list(frag_atom_idx_list[1])] = 2
142 |
143 | # extract frag mol directly from new_mol, in case the Kekulize error
144 | bond_ids = []
145 | for bond_idx, bond in enumerate(new_mol.GetBonds()):
146 | start = bond.GetBeginAtomIdx()
147 | end = bond.GetEndAtomIdx()
148 | if (new_fragment_mask[start] > 0) == (new_fragment_mask[end] == 0):
149 | bond_ids.append(bond_idx)
150 | assert len(bond_ids) == 2
151 | break_mol = Chem.FragmentOnBonds(new_mol, bond_ids, addDummies=False)
152 | frags = [f for f in Chem.GetMolFrags(break_mol, asMols=True)
153 | if f.GetNumAtoms() != link_mol.GetNumAtoms()
154 | or not np.allclose(f.GetConformer().GetPositions(), link_mol.GetConformer().GetPositions())]
155 | assert len(frags) == 2
156 | new_frag_mol = Chem.CombineMols(*frags)
157 | assert np.allclose(new_frag_mol.GetConformer().GetPositions(), frag_mol.GetConformer().GetPositions())
158 |
159 | data_list.append({
160 | 'id': data['uuid'],
161 | 'smiles': data['name'],
162 | 'mol': new_mol,
163 | 'frag_mol': new_frag_mol,
164 | 'link_mol': link_mol,
165 | 'fragment_mask': new_fragment_mask,
166 | 'atom_indices_f1': list(frag_atom_idx_list[0]),
167 | 'atom_indices_f2': list(frag_atom_idx_list[1]),
168 | 'linker_mask': linker_mask,
169 | # 'anchors': data['anchors'].bool()
170 | })
171 | print('n fail: ', n_fail)
172 | datasets[subset] = data_list
173 |
174 | print('Saving data')
175 | with open(save_path, 'wb') as f:
176 | pickle.dump(datasets, f)
177 | print('Length processed data: ', [f'{x}: {len(datasets[x])}' for x in datasets])
178 |
179 |
180 | if __name__ == "__main__":
181 | parser = argparse.ArgumentParser()
182 | parser.add_argument('--raw_path', type=str, required=True)
183 | parser.add_argument('--dataset', type=str, default='protac', choices=['protac', 'zinc_difflinker'])
184 | parser.add_argument('--dest', type=str, required=True)
185 | parser.add_argument('--mode', type=str, default='full', choices=['tiny', 'full'])
186 | args = parser.parse_args()
187 |
188 | if args.dataset == 'protac':
189 | preprocess_protac(args.raw_path, args.dest)
190 | elif args.dataset == 'zinc_difflinker':
191 | preprocess_zinc_from_difflinker(args.raw_path, args.dest, args.mode)
192 | else:
193 | raise NotImplementedError
194 |
--------------------------------------------------------------------------------
/scripts/sample_protac.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 | from rdkit import RDLogger
7 | from torch_geometric.transforms import Compose
8 | from tqdm.auto import tqdm
9 |
10 | import utils.misc as misc
11 | import utils.transforms as trans
12 | from datasets.linker_dataset import get_linker_dataset
13 | from models.diff_protac_bond import DiffPROTACModel
14 | from utils.reconstruct_linker import parse_sampling_result, parse_sampling_result_with_bond
15 | from torch_geometric.data import Batch
16 | from utils.evaluation import eval_success_rate
17 | from utils.prior_num_atoms import setup_configs, sample_atom_num
18 |
19 |
20 | if __name__ == '__main__':
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('config', type=str)
23 | parser.add_argument('--ckpt_path', type=str)
24 | parser.add_argument('--subset', choices=['val', 'test'], default='val')
25 | parser.add_argument('--device', type=str, default='cuda:0')
26 | parser.add_argument('--start_id', type=int, default=0)
27 | parser.add_argument('--end_id', type=int, default=-1)
28 | parser.add_argument('--num_samples', type=int, default=100)
29 | parser.add_argument('--outdir', type=str, default='./outputs_test')
30 | parser.add_argument('--tag', type=str, default='')
31 | parser.add_argument('--save_traj', type=eval, default=False)
32 | parser.add_argument('--cand_anchors_mode', type=str, default='k-hop', choices=['k-hop', 'exact'])
33 | parser.add_argument('--cand_anchors_k', type=int, default=2)
34 | args = parser.parse_args()
35 | RDLogger.DisableLog('rdApp.*')
36 |
37 | # Load configs
38 | config = misc.load_config(args.config)
39 | config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
40 | misc.seed_all(config.sample.seed)
41 | logger, writer, log_dir, ckpt_dir, vis_dir = misc.setup_logdir(
42 | args.config, args.outdir, mode='eval', tag=args.tag, create_dir=True)
43 | logger.info(args)
44 |
45 | # Load checkpoint
46 | ckpt_path = config.model.checkpoint if args.ckpt_path is None else args.ckpt_path
47 | ckpt = torch.load(ckpt_path, map_location=args.device)
48 | logger.info(f'Successfully load the model! {ckpt_path}')
49 | cfg_model = ckpt['configs'].model
50 |
51 | # Transforms
52 | test_transform = Compose([
53 | trans.FeaturizeAtom(config.dataset.name, add_atom_feat=False)
54 | ])
55 | atom_featurizer = trans.FeaturizeAtom(
56 | config.dataset.name, known_anchor=cfg_model.known_anchor, add_atom_type=False, add_atom_feat=True)
57 | graph_builder = trans.BuildCompleteGraph(known_linker_bond=cfg_model.known_linker_bond,
58 | known_cand_anchors=config.sample.get('cand_bond_mask', False))
59 | init_transform = Compose([
60 | atom_featurizer,
61 | trans.SelectCandAnchors(mode=args.cand_anchors_mode, k=args.cand_anchors_k),
62 | graph_builder,
63 | trans.StackFragLocalPos(max_num_atoms=config.dataset.get('max_num_atoms', 30)),
64 | trans.RelativeGeometry(mode=cfg_model.get('rel_geometry', 'distance_and_two_rot'))
65 | ])
66 | logger.info(f'Init transform: {init_transform}')
67 |
68 | # Datasets and loaders
69 | logger.info('Loading dataset...')
70 | if config.dataset.split_mode == 'full':
71 | dataset = get_linker_dataset(
72 | cfg=config.dataset,
73 | transform_map=test_transform
74 | )
75 | test_set = dataset
76 | else:
77 | dataset, subsets = get_linker_dataset(
78 | cfg=config.dataset,
79 | transform_map={'train': None, 'val': test_transform, 'test': test_transform}
80 | )
81 | test_set = subsets[args.subset]
82 | logger.info(f'Test: {len(test_set)}')
83 |
84 | FOLLOW_BATCH = ['edge_type']
85 | COLLATE_EXCLUDE_KEYS = ['nbh_list']
86 |
87 | # Model
88 | logger.info('Building model...')
89 | model = DiffPROTACModel(
90 | cfg_model,
91 | num_classes=atom_featurizer.num_classes,
92 | num_bond_classes=graph_builder.num_bond_classes,
93 | atom_feature_dim=atom_featurizer.feature_dim,
94 | edge_feature_dim=graph_builder.bond_feature_dim
95 | ).to(args.device)
96 | logger.info('Num of parameters is %.2f M' % (np.sum([p.numel() for p in model.parameters()]) / 1e6))
97 | model.load_state_dict(ckpt['model'])
98 | logger.info(f'Load model weights done!')
99 |
100 | # Sampling
101 | logger.info(f'Begin sampling [{args.start_id}, {args.end_id})...')
102 | assert config.sample.num_atoms in ['ref', 'prior']
103 | if config.sample.num_atoms == 'prior':
104 | num_atoms_config = setup_configs(mode='frag_center_distance')
105 | else:
106 | num_atoms_config = None
107 |
108 | if args.end_id == -1:
109 | args.end_id = len(test_set)
110 | for idx in tqdm(range(args.start_id, args.end_id)):
111 | raw_data = test_set[idx]
112 | data_list = [raw_data.clone() for _ in range(args.num_samples)]
113 | # modify data list
114 | if num_atoms_config is None:
115 | new_data_list = [init_transform(data) for data in data_list]
116 | else:
117 | new_data_list = []
118 | for data in data_list:
119 | # sample num atoms
120 | dist = torch.floor(data.frags_d) # (B, )
121 | num_linker_atoms = sample_atom_num(dist, num_atoms_config).astype(int)
122 | num_f1_atoms = len(data.atom_indices_f1)
123 | num_f2_atoms = len(data.atom_indices_f2)
124 | num_f_atoms = num_f1_atoms + num_f2_atoms
125 | frag_pos = data.pos[data.fragment_mask > 0]
126 | frag_atom_type = data.atom_type[data.fragment_mask > 0]
127 | frag_bond_idx = (data.bond_index[0] < num_f_atoms) & (data.bond_index[1] < num_f_atoms)
128 |
129 | data.fragment_mask = torch.LongTensor([1] * num_f1_atoms + [2] * num_f2_atoms + [0] * num_linker_atoms)
130 | data.linker_mask = (data.fragment_mask == 0)
131 | data.anchor_mask = torch.cat([data.anchor_mask[:num_f_atoms], torch.zeros([num_linker_atoms]).long()])
132 | data.bond_index = data.bond_index[:, frag_bond_idx]
133 | data.bond_type = data.bond_type[frag_bond_idx]
134 | data.pos = torch.cat([frag_pos, torch.zeros([num_linker_atoms, 3])], dim=0)
135 | data.atom_type = torch.cat([frag_atom_type, torch.zeros([num_linker_atoms]).long()])
136 | new_data = init_transform(data)
137 | new_data_list.append(new_data)
138 |
139 | batch = Batch.from_data_list(
140 | new_data_list, follow_batch=FOLLOW_BATCH, exclude_keys=COLLATE_EXCLUDE_KEYS).to(args.device)
141 | traj_batch, final_x, final_c, final_bond = model.sample(
142 | batch,
143 | p_init_mode=cfg_model.frag_pos_prior,
144 | guidance_opt=config.sample.guidance_opt
145 | )
146 | if model.train_bond:
147 | gen_mols = parse_sampling_result_with_bond(
148 | new_data_list, final_x, final_c, final_bond, atom_featurizer,
149 | known_linker_bonds=cfg_model.known_linker_bond, check_validity=True)
150 | else:
151 | gen_mols = parse_sampling_result(new_data_list, final_x, final_c, atom_featurizer)
152 | save_path = os.path.join(log_dir, f'sampling_{idx:06d}.pt')
153 | save_dict = {
154 | 'ref_data': init_transform(raw_data),
155 | 'data_list': new_data_list, # don't save it to reduce the size of outputs
156 | 'final_x': final_x, 'final_c': final_c, 'final_bond': final_bond,
157 | 'gen_mols': gen_mols
158 | }
159 | if args.save_traj:
160 | save_dict['traj'] = traj_batch
161 |
162 | torch.save(save_dict, save_path)
163 | logger.info('Sample done!')
164 |
165 | # Quick Eval
166 | recon_rate, complete_rate = [], []
167 | anchor_dists = []
168 | for idx in range(args.start_id, args.end_id):
169 | load_path = os.path.join(log_dir, f'sampling_{idx:06d}.pt')
170 | results = torch.load(load_path)
171 | rr, cr = eval_success_rate(results['gen_mols'])
172 | logger.info(f'idx: {idx} recon rate: {rr} complete rate: {cr}')
173 | recon_rate.append(rr)
174 | complete_rate.append(cr)
175 | logger.info(f'recon rate: mean: {np.mean(recon_rate):.4f} median: {np.median(recon_rate):.4f}')
176 | logger.info(f'complete rate: mean: {np.mean(complete_rate):.4f} median: {np.median(complete_rate):.4f}')
177 |
--------------------------------------------------------------------------------
/scripts/train_protac.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 |
5 | import numpy as np
6 | import torch.utils.tensorboard
7 | from rdkit import RDLogger
8 | from torch.nn.utils import clip_grad_norm_
9 | from torch_geometric.loader import DataLoader
10 | from torch_geometric.transforms import Compose
11 | from tqdm.auto import tqdm
12 |
13 | import utils.misc as misc
14 | import utils.transforms as trans
15 | from datasets.linker_dataset import get_linker_dataset
16 | from models.diff_protac_bond import DiffPROTACModel
17 | from utils.evaluation import eval_success_rate
18 | from utils.reconstruct_linker import parse_sampling_result_with_bond, parse_sampling_result
19 | from utils.train import *
20 |
21 | torch.multiprocessing.set_sharing_strategy('file_system')
22 |
23 |
24 | if __name__ == '__main__':
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('config', type=str, default='./configs/training/zinc.yml')
27 | parser.add_argument('--overfit_one', type=eval, default=False)
28 | parser.add_argument('--device', type=str, default='cuda')
29 | parser.add_argument('--logdir', type=str, default='./logs')
30 | parser.add_argument('--train_report_iter', type=int, default=200)
31 | parser.add_argument('--sampling_iter', type=int, default=10000)
32 | args = parser.parse_args()
33 | RDLogger.DisableLog('rdApp.*')
34 |
35 | # Load configs
36 | config = misc.load_config(args.config)
37 | misc.seed_all(config.train.seed)
38 | logger, writer, log_dir, ckpt_dir, vis_dir = misc.setup_logdir(args.config, args.logdir)
39 | logger.info(args)
40 |
41 | # Transforms
42 | cfg_model = config.model
43 | atom_featurizer = trans.FeaturizeAtom(config.dataset.name, known_anchor=cfg_model.known_anchor)
44 | graph_builder = trans.BuildCompleteGraph(known_linker_bond=cfg_model.known_linker_bond)
45 | max_num_atoms = config.dataset.get('max_num_atoms', 30)
46 | train_transform = Compose([
47 | atom_featurizer,
48 | graph_builder,
49 | trans.StackFragLocalPos(max_num_atoms=max_num_atoms),
50 | trans.RelativeGeometry(mode=cfg_model.rel_geometry)
51 | ])
52 | test_transform = Compose([
53 | atom_featurizer,
54 | graph_builder,
55 | trans.StackFragLocalPos(max_num_atoms=max_num_atoms),
56 | trans.RelativeGeometry(mode=cfg_model.rel_geometry)
57 | ])
58 | logger.info(f'Relative Geometry: {cfg_model.rel_geometry}')
59 | logger.info(f'Train transform: {train_transform}')
60 | logger.info(f'Test transform: {test_transform}')
61 | if cfg_model.rel_geometry in ['two_pos_and_rot', 'relative_pos_and_rot']:
62 | assert cfg_model.frag_pos_prior is None
63 |
64 | # Datasets and loaders
65 | logger.info('Loading dataset...')
66 | dataset, subsets = get_linker_dataset(
67 | cfg=config.dataset,
68 | transform_map={'train': train_transform, 'val': test_transform, 'test': test_transform}
69 | )
70 |
71 | if config.dataset.version == 'tiny':
72 | if args.overfit_one:
73 | train_set, val_set, test_set = [subsets['val'][0].clone() for _ in range(config.train.batch_size)], \
74 | [subsets['val'][0]], [subsets['val'][0]]
75 | else:
76 | train_set, val_set, test_set = subsets['val'], subsets['val'], subsets['val']
77 | else:
78 | train_set, val_set, test_set = subsets['train'], subsets['val'], subsets['test']
79 | # train_set, val_set, test_set = [subsets['train'][697].clone() for _ in range(config.train.batch_size)], \
80 | # [subsets['val'][0]], [subsets['val'][0]]
81 | logger.info(f'Training: {len(train_set)} Validation: {len(val_set)} Test: {len(test_set)}')
82 |
83 | FOLLOW_BATCH = ['edge_type']
84 | COLLATE_EXCLUDE_KEYS = ['nbh_list']
85 | train_iterator = inf_iterator(DataLoader(
86 | train_set,
87 | batch_size=config.train.batch_size,
88 | shuffle=True,
89 | num_workers=config.train.num_workers,
90 | prefetch_factor=8,
91 | persistent_workers=True,
92 | follow_batch=FOLLOW_BATCH,
93 | exclude_keys=COLLATE_EXCLUDE_KEYS
94 | ))
95 | val_loader = DataLoader(val_set, config.train.batch_size, shuffle=False,
96 | follow_batch=FOLLOW_BATCH, exclude_keys=COLLATE_EXCLUDE_KEYS)
97 | test_loader = DataLoader(test_set, config.train.batch_size, shuffle=False,
98 | follow_batch=FOLLOW_BATCH, exclude_keys=COLLATE_EXCLUDE_KEYS)
99 |
100 | # Model
101 | logger.info('Building model...')
102 | model = DiffPROTACModel(
103 | config.model,
104 | num_classes=atom_featurizer.num_classes,
105 | num_bond_classes=graph_builder.num_bond_classes,
106 | atom_feature_dim=atom_featurizer.feature_dim,
107 | edge_feature_dim=graph_builder.bond_feature_dim
108 | ).to(args.device)
109 | logger.info('Num of parameters is %.2f M' % (np.sum([p.numel() for p in model.parameters()]) / 1e6))
110 | if config.train.get('ckpt_path', None):
111 | ckpt = torch.load(config.train.ckpt_path, map_location=args.device)
112 | model.load_state_dict(ckpt['model'])
113 | logger.info(f'load checkpoint from {config.train.ckpt_path}!')
114 |
115 | # Optimizer and scheduler
116 | optimizer = get_optimizer(config.train.optimizer, model)
117 | scheduler = get_scheduler(config.train.scheduler, optimizer)
118 | # if config.train.get('ckpt_path', None):
119 | # logger.info('Resuming optimizer states...')
120 | # optimizer.load_state_dict(ckpt['optimizer'])
121 | # logger.info('Resuming scheduler states...')
122 | # scheduler.load_state_dict(ckpt['scheduler'])
123 |
124 | def train(it, batch):
125 | optimizer.zero_grad()
126 | loss_dict = model.get_loss(batch, pos_noise_std=config.train.get('pos_noise_std', 0.))
127 | loss = sum_weighted_losses(loss_dict, config.train.loss_weights)
128 | loss_dict['overall'] = loss
129 | loss.backward()
130 | orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm,
131 | error_if_nonfinite=True) # 5% running time
132 | optimizer.step()
133 |
134 | # Logging
135 | log_losses(loss_dict, it, 'train', args.train_report_iter, logger, writer, others={
136 | 'grad': orig_grad_norm,
137 | 'lr': optimizer.param_groups[0]['lr']
138 | })
139 |
140 | def validate(it):
141 | loss_tape = ValidationLossTape()
142 | with torch.no_grad():
143 | model.eval()
144 | for i, batch in enumerate(tqdm(val_loader, desc='Validate', dynamic_ncols=True)):
145 | for t in np.linspace(0, model.num_steps - 1, 10).astype(int):
146 | time_step = torch.tensor([t] * batch.num_graphs).to(args.device)
147 | batch = batch.to(args.device)
148 | loss_dict = model.get_loss(batch, t=time_step)
149 | loss = sum_weighted_losses(loss_dict, config.train.loss_weights)
150 | loss_dict['overall'] = loss
151 | loss_tape.update(loss_dict, 1)
152 |
153 | avg_loss = loss_tape.log(it, logger, writer, 'val')
154 | # Trigger scheduler
155 | if config.train.scheduler.type == 'plateau':
156 | scheduler.step(avg_loss)
157 | elif config.train.scheduler.type == 'warmup_plateau':
158 | scheduler.step_ReduceLROnPlateau(avg_loss)
159 | else:
160 | scheduler.step()
161 | return avg_loss
162 |
163 | def test(it):
164 | loss_tape = ValidationLossTape()
165 | with torch.no_grad():
166 | model.eval()
167 | for i, batch in enumerate(tqdm(test_loader, desc='Test', dynamic_ncols=True)):
168 | for t in np.linspace(0, model.num_steps - 1, 10).astype(int):
169 | time_step = torch.tensor([t] * batch.num_graphs).to(args.device)
170 | batch = batch.to(args.device)
171 | loss_dict = model.get_loss(batch, t=time_step)
172 | loss = sum_weighted_losses(loss_dict, config.train.loss_weights)
173 | loss_dict['overall'] = loss
174 | loss_tape.update(loss_dict, 1)
175 | avg_loss = loss_tape.log(it, logger, writer, 'test')
176 | return avg_loss
177 |
178 |
179 | try:
180 | model.train()
181 | best_loss, best_iter = None, None
182 | for it in range(1, config.train.max_iters + 1):
183 | # try:
184 | t1 = time.time()
185 | batch = next(train_iterator).to(args.device)
186 | t2 = time.time()
187 | # print('data processing time: ', t2 - t1)
188 | train(it, batch)
189 | if it % args.sampling_iter == 0:
190 | data = test_set[0]
191 | data_list = [data.clone() for _ in range(100)]
192 | batch = Batch.from_data_list(
193 | data_list, follow_batch=FOLLOW_BATCH, exclude_keys=COLLATE_EXCLUDE_KEYS).to(args.device)
194 | traj_batch, final_x, final_c, final_bond = model.sample(batch, p_init_mode=cfg_model.frag_pos_prior)
195 | if model.train_bond:
196 | gen_mols = parse_sampling_result_with_bond(
197 | data_list, final_x, final_c, final_bond, atom_featurizer,
198 | known_linker_bonds=cfg_model.known_linker_bond, check_validity=True)
199 | else:
200 | gen_mols = parse_sampling_result(data_list, final_x, final_c, atom_featurizer)
201 | save_path = os.path.join(vis_dir, f'sampling_results_{it}.pt')
202 | torch.save({
203 | 'data': data,
204 | # 'traj': traj_batch,
205 | 'final_x': final_x, 'final_c': final_c, 'final_bond': final_bond,
206 | 'gen_mols': gen_mols
207 | }, save_path)
208 | logger.info(f'dump sampling vis to {save_path}!')
209 | recon_rate, complete_rate = eval_success_rate(gen_mols)
210 | logger.info(f'recon rate: {recon_rate:.4f}')
211 | logger.info(f'complete rate: {complete_rate:.4f}')
212 | writer.add_scalar('sampling/recon_rate', recon_rate, it)
213 | writer.add_scalar('sampling/complete_rate', complete_rate, it)
214 | writer.flush()
215 |
216 | if it % config.train.val_freq == 0 or it == config.train.max_iters:
217 | val_loss = validate(it)
218 | ckpt_path = os.path.join(ckpt_dir, '%d.pt' % it)
219 | torch.save({
220 | 'configs': config,
221 | 'model': model.state_dict(),
222 | 'optimizer': optimizer.state_dict(),
223 | 'scheduler': scheduler.state_dict(),
224 | 'iteration': it,
225 | }, ckpt_path)
226 | if best_loss is None or val_loss < best_loss:
227 | logger.info(f'[Validate] Best val loss achieved: {val_loss:.6f}')
228 | best_loss, best_iter = val_loss, it
229 | test(it)
230 | model.train()
231 |
232 | except KeyboardInterrupt:
233 | logger.info('Terminating...')
234 |
--------------------------------------------------------------------------------
/utils/calc_SC_RDKit.py:
--------------------------------------------------------------------------------
1 | import os
2 | from rdkit import Chem
3 | from rdkit.Chem import AllChem, rdShapeHelpers
4 | from rdkit.Chem.FeatMaps import FeatMaps
5 | from rdkit import RDConfig
6 |
7 | # Set up features to use in FeatureMap
8 | fdefName = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
9 | fdef = AllChem.BuildFeatureFactory(fdefName)
10 |
11 | fmParams = {}
12 | for k in fdef.GetFeatureFamilies():
13 | fparams = FeatMaps.FeatMapParams()
14 | fmParams[k] = fparams
15 |
16 | keep = ('Donor', 'Acceptor', 'NegIonizable', 'PosIonizable',
17 | 'ZnBinder', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe')
18 |
19 |
20 | def get_FeatureMapScore(query_mol, ref_mol):
21 | featLists = []
22 | for m in [query_mol, ref_mol]:
23 | rawFeats = fdef.GetFeaturesForMol(m)
24 | # filter that list down to only include the ones we're intereted in
25 | featLists.append([f for f in rawFeats if f.GetFamily() in keep])
26 | fms = [FeatMaps.FeatMap(feats=x, weights=[1] * len(x), params=fmParams) for x in featLists]
27 | fms[0].scoreMode=FeatMaps.FeatMapScoreMode.Best
28 | fm_score = fms[0].ScoreFeats(featLists[1]) / min(fms[0].GetNumFeatures(), len(featLists[1]))
29 |
30 | return fm_score
31 |
32 |
33 | def calc_SC_RDKit_score(query_mol, ref_mol):
34 | fm_score = get_FeatureMapScore(query_mol, ref_mol)
35 |
36 | protrude_dist = rdShapeHelpers.ShapeProtrudeDist(query_mol, ref_mol,
37 | allowReordering=False)
38 | SC_RDKit_score = 0.5*fm_score + 0.5*(1 - protrude_dist)
39 |
40 | return SC_RDKit_score
41 |
--------------------------------------------------------------------------------
/utils/const.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from rdkit import Chem
4 |
5 |
6 | TORCH_FLOAT = torch.float32
7 | TORCH_INT = torch.int8
8 |
9 | # #################################################################################### #
10 | # ####################################### ZINC ####################################### #
11 | # #################################################################################### #
12 |
13 | # Atom idx for one-hot encoding
14 | ATOM2IDX = {'C': 0, 'O': 1, 'N': 2, 'F': 3, 'S': 4, 'Cl': 5, 'Br': 6, 'I': 7}
15 | IDX2ATOM = {0: 'C', 1: 'O', 2: 'N', 3: 'F', 4: 'S', 5: 'Cl', 6: 'Br', 7: 'I'}
16 |
17 | # Atomic numbers (Z)
18 | CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53}
19 |
20 | # One-hot atom types
21 | NUMBER_OF_ATOM_TYPES = len(ATOM2IDX)
22 |
23 |
24 | # #################################################################################### #
25 | # ####################################### GEOM ####################################### #
26 | # #################################################################################### #
27 |
28 | # Atom idx for one-hot encoding
29 | GEOM_ATOM2IDX = {'C': 0, 'O': 1, 'N': 2, 'F': 3, 'S': 4, 'Cl': 5, 'Br': 6, 'I': 7, 'P': 8}
30 | GEOM_IDX2ATOM = {0: 'C', 1: 'O', 2: 'N', 3: 'F', 4: 'S', 5: 'Cl', 6: 'Br', 7: 'I', 8: 'P'}
31 |
32 | # Atomic numbers (Z)
33 | GEOM_CHARGES = {'C': 6, 'O': 8, 'N': 7, 'F': 9, 'S': 16, 'Cl': 17, 'Br': 35, 'I': 53, 'P': 15}
34 |
35 | # One-hot atom types
36 | GEOM_NUMBER_OF_ATOM_TYPES = len(GEOM_ATOM2IDX)
37 |
38 | # Dataset keys
39 | DATA_LIST_ATTRS = {
40 | 'uuid', 'name', 'fragments_smi', 'linker_smi', 'num_atoms'
41 | }
42 | DATA_ATTRS_TO_PAD = {
43 | 'positions', 'one_hot', 'charges', 'anchors', 'fragment_mask', 'linker_mask', 'pocket_mask', 'fragment_only_mask'
44 | }
45 | DATA_ATTRS_TO_ADD_LAST_DIM = {
46 | 'charges', 'anchors', 'fragment_mask', 'linker_mask', 'pocket_mask', 'fragment_only_mask'
47 | }
48 |
49 | # Distribution of linker size in train data
50 | LINKER_SIZE_DIST = {
51 | 4: 85540,
52 | 3: 113928,
53 | 6: 70946,
54 | 7: 30408,
55 | 5: 77671,
56 | 9: 5177,
57 | 10: 1214,
58 | 8: 12712,
59 | 11: 158,
60 | 12: 7,
61 | }
62 |
63 |
64 | # Bond lengths from:
65 | # http://www.wiredchemist.com/chemistry/data/bond_energies_lengths.html
66 | # And:
67 | # http://chemistry-reference.com/tables/Bond%20Lengths%20and%20Enthalpies.pdf
68 | BONDS_1 = {
69 | 'H': {
70 | 'H': 74, 'C': 109, 'N': 101, 'O': 96, 'F': 92,
71 | 'B': 119, 'Si': 148, 'P': 144, 'As': 152, 'S': 134,
72 | 'Cl': 127, 'Br': 141, 'I': 161
73 | },
74 | 'C': {
75 | 'H': 109, 'C': 154, 'N': 147, 'O': 143, 'F': 135,
76 | 'Si': 185, 'P': 184, 'S': 182, 'Cl': 177, 'Br': 194,
77 | 'I': 214
78 | },
79 | 'N': {
80 | 'H': 101, 'C': 147, 'N': 145, 'O': 140, 'F': 136,
81 | 'Cl': 175, 'Br': 214, 'S': 168, 'I': 222, 'P': 177
82 | },
83 | 'O': {
84 | 'H': 96, 'C': 143, 'N': 140, 'O': 148, 'F': 142,
85 | 'Br': 172, 'S': 151, 'P': 163, 'Si': 163, 'Cl': 164,
86 | 'I': 194
87 | },
88 | 'F': {
89 | 'H': 92, 'C': 135, 'N': 136, 'O': 142, 'F': 142,
90 | 'S': 158, 'Si': 160, 'Cl': 166, 'Br': 178, 'P': 156,
91 | 'I': 187
92 | },
93 | 'B': {
94 | 'H': 119, 'Cl': 175
95 | },
96 | 'Si': {
97 | 'Si': 233, 'H': 148, 'C': 185, 'O': 163, 'S': 200,
98 | 'F': 160, 'Cl': 202, 'Br': 215, 'I': 243,
99 | },
100 | 'Cl': {
101 | 'Cl': 199, 'H': 127, 'C': 177, 'N': 175, 'O': 164,
102 | 'P': 203, 'S': 207, 'B': 175, 'Si': 202, 'F': 166,
103 | 'Br': 214
104 | },
105 | 'S': {
106 | 'H': 134, 'C': 182, 'N': 168, 'O': 151, 'S': 204,
107 | 'F': 158, 'Cl': 207, 'Br': 225, 'Si': 200, 'P': 210,
108 | 'I': 234
109 | },
110 | 'Br': {
111 | 'Br': 228, 'H': 141, 'C': 194, 'O': 172, 'N': 214,
112 | 'Si': 215, 'S': 225, 'F': 178, 'Cl': 214, 'P': 222
113 | },
114 | 'P': {
115 | 'P': 221, 'H': 144, 'C': 184, 'O': 163, 'Cl': 203,
116 | 'S': 210, 'F': 156, 'N': 177, 'Br': 222
117 | },
118 | 'I': {
119 | 'H': 161, 'C': 214, 'Si': 243, 'N': 222, 'O': 194,
120 | 'S': 234, 'F': 187, 'I': 266
121 | },
122 | 'As': {
123 | 'H': 152
124 | }
125 | }
126 |
127 | BONDS_2 = {
128 | 'C': {'C': 134, 'N': 129, 'O': 120, 'S': 160},
129 | 'N': {'C': 129, 'N': 125, 'O': 121},
130 | 'O': {'C': 120, 'N': 121, 'O': 121, 'P': 150},
131 | 'P': {'O': 150, 'S': 186},
132 | 'S': {'P': 186}
133 | }
134 |
135 | BONDS_3 = {
136 | 'C': {'C': 120, 'N': 116, 'O': 113},
137 | 'N': {'C': 116, 'N': 110},
138 | 'O': {'C': 113}
139 | }
140 |
141 | BOND_DICT = [
142 | None,
143 | Chem.rdchem.BondType.SINGLE,
144 | Chem.rdchem.BondType.DOUBLE,
145 | Chem.rdchem.BondType.TRIPLE,
146 | Chem.rdchem.BondType.AROMATIC,
147 | ]
148 |
149 | BOND2IDX = {
150 | Chem.rdchem.BondType.SINGLE: 1,
151 | Chem.rdchem.BondType.DOUBLE: 2,
152 | Chem.rdchem.BondType.TRIPLE: 3,
153 | Chem.rdchem.BondType.AROMATIC: 4,
154 | }
155 |
156 | ALLOWED_BONDS = {
157 | 'H': 1,
158 | 'C': 4,
159 | 'N': 3,
160 | 'O': 2,
161 | 'F': 1,
162 | 'B': 3,
163 | 'Al': 3,
164 | 'Si': 4,
165 | 'P': [3, 5],
166 | 'S': 4,
167 | 'Cl': 1,
168 | 'As': 3,
169 | 'Br': 1,
170 | 'I': 1,
171 | 'Hg': [1, 2],
172 | 'Bi': [3, 5]
173 | }
174 |
175 | MARGINS_EDM = [10, 5, 2]
176 |
177 | COLORS = ['C0', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8']
178 | # RADII = [0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3]
179 | RADII = [0.77, 0.77, 0.77, 0.77, 0.77, 0.77, 0.77, 0.77, 0.77]
180 |
181 | ZINC_TRAIN_LINKER_ID2SIZE = [3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
182 | ZINC_TRAIN_LINKER_SIZE2ID = {
183 | size: idx
184 | for idx, size in enumerate(ZINC_TRAIN_LINKER_ID2SIZE)
185 | }
186 | ZINC_TRAIN_LINKER_SIZE_WEIGHTS = [
187 | 3.47347831e-01,
188 | 4.63079100e-01,
189 | 5.12370917e-01,
190 | 5.62392614e-01,
191 | 1.30294388e+00,
192 | 3.24247801e+00,
193 | 8.12391184e+00,
194 | 3.45634358e+01,
195 | 2.72428571e+02,
196 | 6.26585714e+03
197 | ]
198 |
199 |
200 | GEOM_TRAIN_LINKER_ID2SIZE = [
201 | 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
202 | 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 36, 38, 41
203 | ]
204 | GEOM_TRAIN_LINKER_SIZE2ID = {
205 | size: idx
206 | for idx, size in enumerate(GEOM_TRAIN_LINKER_ID2SIZE)
207 | }
208 | GEOM_TRAIN_LINKER_SIZE_WEIGHTS = [
209 | 1.07790681e+00, 4.54693604e-01, 3.62575713e-01, 3.75199484e-01,
210 | 3.67812588e-01, 3.92388528e-01, 3.83421054e-01, 4.26924670e-01,
211 | 4.92768040e-01, 4.99761944e-01, 4.92342726e-01, 5.71456905e-01,
212 | 7.30631393e-01, 8.45412928e-01, 9.97252243e-01, 1.25423985e+00,
213 | 1.57316129e+00, 2.19902962e+00, 3.22640431e+00, 4.25481066e+00,
214 | 6.34749573e+00, 9.00676236e+00, 1.43084017e+01, 2.25763173e+01,
215 | 3.36867096e+01, 9.50713805e+01, 2.08693274e+02, 2.51659537e+02,
216 | 7.77856749e+02, 8.55642424e+03, 8.55642424e+03, 4.27821212e+03,
217 | 4.27821212e+03
218 | ]
219 |
--------------------------------------------------------------------------------
/utils/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import numpy as np
4 | from rdkit import Chem
5 | from rdkit.Chem import AllChem
6 | from rdkit.Chem.rdchem import BondType
7 | from rdkit.Chem import ChemicalFeatures
8 | from rdkit import RDConfig
9 | from openbabel import openbabel as ob
10 | import torch
11 |
12 |
13 | ATOM_FAMILIES = ['Acceptor', 'Donor', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe', 'NegIonizable', 'PosIonizable',
14 | 'ZnBinder']
15 | ATOM_FAMILIES_ID = {s: i for i, s in enumerate(ATOM_FAMILIES)}
16 | BOND_TYPES = {
17 | BondType.UNSPECIFIED: 0,
18 | BondType.SINGLE: 1,
19 | BondType.DOUBLE: 2,
20 | BondType.TRIPLE: 3,
21 | BondType.AROMATIC: 4,
22 | }
23 | BOND_NAMES = {v: str(k) for k, v in BOND_TYPES.items()}
24 | HYBRIDIZATION_TYPE = ['S', 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2']
25 | HYBRIDIZATION_TYPE_ID = {s: i for i, s in enumerate(HYBRIDIZATION_TYPE)}
26 |
27 |
28 | def process_from_mol(rdmol):
29 | fdefName = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef')
30 | factory = ChemicalFeatures.BuildFeatureFactory(fdefName)
31 |
32 | rd_num_atoms = rdmol.GetNumAtoms()
33 | feat_mat = np.zeros([rd_num_atoms, len(ATOM_FAMILIES)], dtype=np.compat.long)
34 | for feat in factory.GetFeaturesForMol(rdmol):
35 | feat_mat[feat.GetAtomIds(), ATOM_FAMILIES_ID[feat.GetFamily()]] = 1
36 |
37 | # Get hybridization in the order of atom idx.
38 | hybridization = []
39 | for atom in rdmol.GetAtoms():
40 | hybr = str(atom.GetHybridization())
41 | idx = atom.GetIdx()
42 | hybridization.append((idx, hybr))
43 | hybridization = sorted(hybridization)
44 | hybridization = [v[1] for v in hybridization]
45 |
46 | ptable = Chem.GetPeriodicTable()
47 |
48 | pos = np.array(rdmol.GetConformers()[0].GetPositions(), dtype=np.float32)
49 | element, valence, charge = [], [], []
50 | accum_pos = 0
51 | accum_mass = 0
52 | for atom_idx in range(rd_num_atoms):
53 | atom = rdmol.GetAtomWithIdx(atom_idx)
54 | atom_num = atom.GetAtomicNum()
55 | element.append(atom_num)
56 | valence.append(atom.GetTotalValence())
57 | charge.append(atom.GetFormalCharge())
58 | atom_weight = ptable.GetAtomicWeight(atom_num)
59 | accum_pos += pos[atom_idx] * atom_weight
60 | accum_mass += atom_weight
61 | center_of_mass = accum_pos / accum_mass
62 | element = np.array(element, dtype=np.int)
63 | valence = np.array(valence, dtype=np.int)
64 | charge = np.array(charge, dtype=np.int)
65 |
66 | # in edge_type, we have 1 for single bond, 2 for double bond, 3 for triple bond, and 4 for aromatic bond.
67 | row, col, edge_type = [], [], []
68 | for bond in rdmol.GetBonds():
69 | start = bond.GetBeginAtomIdx()
70 | end = bond.GetEndAtomIdx()
71 | row += [start, end]
72 | col += [end, start]
73 | edge_type += 2 * [BOND_TYPES[bond.GetBondType()]]
74 |
75 | edge_index = np.array([row, col], dtype=np.long)
76 | edge_type = np.array(edge_type, dtype=np.long)
77 |
78 | perm = (edge_index[0] * rd_num_atoms + edge_index[1]).argsort()
79 | edge_index = edge_index[:, perm]
80 | edge_type = edge_type[perm]
81 |
82 | data = {
83 | 'rdmol': rdmol,
84 | 'element': element,
85 | 'pos': pos,
86 | 'bond_index': edge_index,
87 | 'bond_type': edge_type,
88 | 'center_of_mass': center_of_mass,
89 | 'atom_feature': feat_mat,
90 | 'hybridization': hybridization,
91 | 'valence': valence,
92 | 'charge': charge
93 | }
94 | return data
95 |
96 |
97 | # rdmol conformer
98 | def compute_3d_coors(mol, random_seed=0):
99 | mol = Chem.AddHs(mol)
100 | success = AllChem.EmbedMolecule(mol, randomSeed=random_seed)
101 | if success == -1:
102 | return 0, 0
103 | mol = Chem.RemoveHs(mol)
104 | c = mol.GetConformer(0)
105 | pos = c.GetPositions()
106 | return pos, 1
107 |
108 |
109 | def compute_3d_coors_multiple(mol, numConfs=20, maxIters=400, randomSeed=1):
110 | # mol = Chem.MolFromSmiles(smi)
111 | mol = Chem.AddHs(mol)
112 | AllChem.EmbedMultipleConfs(mol, numConfs=numConfs, numThreads=0, randomSeed=randomSeed)
113 | if mol.GetConformers() == ():
114 | return None, 0
115 | try:
116 | result = AllChem.MMFFOptimizeMoleculeConfs(mol, maxIters=maxIters, numThreads=0)
117 | except Exception as e:
118 | print(str(e))
119 | return None, 0
120 | mol = Chem.RemoveHs(mol)
121 | result = [tuple((result[i][0], result[i][1], i)) for i in range(len(result)) if result[i][0] == 0]
122 | if result == []: # no local minimum on energy surface is found
123 | return None, 0
124 | result.sort()
125 | return mol.GetConformers()[result[0][-1]].GetPositions(), 1
126 |
127 |
128 | def compute_3d_coors_frags(mol, numConfs=20, maxIters=400, randomSeed=1):
129 | du = Chem.MolFromSmiles('*')
130 | clean_frag = Chem.RemoveHs(AllChem.ReplaceSubstructs(Chem.MolFromSmiles(Chem.MolToSmiles(mol)),du,Chem.MolFromSmiles('[H]'),True)[0])
131 | frag = Chem.CombineMols(clean_frag, Chem.MolFromSmiles("*.*"))
132 | mol_to_link_carbon = AllChem.ReplaceSubstructs(mol, du, Chem.MolFromSmiles('C'), True)[0]
133 | pos, _ = compute_3d_coors_multiple(mol_to_link_carbon, numConfs, maxIters, randomSeed)
134 | return pos
135 |
136 |
137 | # -----
138 | def set_rdmol_positions_(mol, pos):
139 | """
140 | Args:
141 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.
142 | pos: (N_atoms, 3)
143 | """
144 | assert mol.GetNumAtoms() == pos.shape[0]
145 | conf = Chem.Conformer(mol.GetNumAtoms())
146 | for i in range(pos.shape[0]):
147 | conf.SetAtomPosition(i, pos[i].tolist())
148 | mol.AddConformer(conf, assignId=True)
149 |
150 | # for i in range(pos.shape[0]):
151 | # mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())
152 | return mol
153 |
154 |
155 | def set_rdmol_positions(rdkit_mol, pos, reset=True):
156 | """
157 | Args:
158 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.
159 | pos: (N_atoms, 3)
160 | """
161 | mol = copy.deepcopy(rdkit_mol)
162 | if reset:
163 | mol.RemoveAllConformers()
164 | set_rdmol_positions_(mol, pos)
165 | return mol
166 |
--------------------------------------------------------------------------------
/utils/eval_bond.py:
--------------------------------------------------------------------------------
1 | """Utils for evaluating bond length."""
2 |
3 | import collections
4 | from typing import Tuple, Sequence, Dict
5 | from rdkit.Chem.rdMolTransforms import GetAngleDeg
6 | from rdkit import Chem
7 | import numpy as np
8 | from utils.data import BOND_TYPES
9 |
10 | BondType = Tuple[int, int, int] # (atomic_num, atomic_num, bond_type)
11 | BondLengthData = Tuple[BondType, float] # (bond_type, bond_length)
12 | BondLengthProfile = Dict[BondType, np.ndarray] # bond_type -> empirical distribution
13 | DISTANCE_BINS = np.arange(1.1, 1.7, 0.005)[:-1]
14 | ANGLE_BINS = np.arange(0., 180., 1.)[:-1]
15 | DIHEDRAL_BINS = np.arange(-180., 180., 1.)[:-1]
16 |
17 | # bond distance
18 | BOND_DISTS = (
19 | (6, 6, 1),
20 | (6, 6, 2),
21 | (6, 7, 1),
22 | (6, 7, 2),
23 | (6, 8, 1),
24 | (6, 8, 2),
25 | (6, 6, 4),
26 | (6, 7, 4),
27 | (6, 8, 4),
28 | )
29 | BOND_ANGLES = [
30 | 'CCC',
31 | 'CCO',
32 | 'CNC',
33 | 'OPO',
34 | 'NCC',
35 | 'CC=O',
36 | 'COC'
37 | ]
38 | DIHEDRAL_ANGLES = [
39 | # 'C1C-C1C-C1C',
40 | # 'C12C-C12C-C12C',
41 | # 'C1C-C1C-C1O',
42 | # 'O1C-C1C-C1O',
43 | # 'C1c-c12c-c12c',
44 | # 'C1C-C2C-C1C'
45 | ]
46 |
47 |
48 | def get_bond_angle(mol, fragment_mask, bond_smi='CCC'):
49 | """
50 | Find bond pairs (defined by bond_smi) in mol and return the angle of the bond pair
51 | bond_smi: bond pair smiles, e.g. 'CCC'
52 | """
53 | deg_list = []
54 | substructure = Chem.MolFromSmiles(bond_smi)
55 | bond_pairs = mol.GetSubstructMatches(substructure)
56 | for pair in bond_pairs:
57 | if (fragment_mask[pair[0]] == 0) | (fragment_mask[pair[1]] == 0) | (fragment_mask[pair[2]] == 0):
58 | deg_list += [GetAngleDeg(mol.GetConformer(), *pair)]
59 | assert mol.GetBondBetweenAtoms(pair[0], pair[1]) is not None
60 | assert mol.GetBondBetweenAtoms(pair[2], pair[1]) is not None
61 | return deg_list
62 |
63 |
64 | def get_bond_symbol(bond):
65 | """
66 | Return the symbol representation of a bond
67 | """
68 | a0 = bond.GetBeginAtom().GetSymbol()
69 | a1 = bond.GetEndAtom().GetSymbol()
70 | b = str(int(bond.GetBondType())) # single: 1, double: 2, triple: 3, aromatic: 12
71 | return ''.join([a0, b, a1])
72 |
73 |
74 | def get_triple_bonds(mol):
75 | """
76 | Get all the bond triplets in a molecule
77 | """
78 | valid_triple_bonds = []
79 | for idx_bond, bond in enumerate(mol.GetBonds()):
80 | idx_begin_atom = bond.GetBeginAtomIdx()
81 | idx_end_atom = bond.GetEndAtomIdx()
82 | begin_atom = mol.GetAtomWithIdx(idx_begin_atom)
83 | end_atom = mol.GetAtomWithIdx(idx_end_atom)
84 | begin_bonds = begin_atom.GetBonds()
85 | valid_left_bonds = []
86 | for begin_bond in begin_bonds:
87 | if begin_bond.GetIdx() == idx_bond:
88 | continue
89 | else:
90 | valid_left_bonds.append(begin_bond)
91 | if len(valid_left_bonds) == 0:
92 | continue
93 |
94 | end_bonds = end_atom.GetBonds()
95 | for end_bond in end_bonds:
96 | if end_bond.GetIdx() == idx_bond:
97 | continue
98 | else:
99 | for left_bond in valid_left_bonds:
100 | valid_triple_bonds.append([left_bond, bond, end_bond])
101 | return valid_triple_bonds
102 |
103 |
104 | def get_dihedral_angle(mol, bonds_ref_sym):
105 | """
106 | find bond triplets (defined by bonds_ref_sym) in mol and return the dihedral angle of the bond triplet
107 | bonds_ref_sym: a symbol string of bond triplet, e.g. 'C1C-C1C-C1C'
108 | """
109 | # bonds_ref_sym = '-'.join(get_bond_symbol(bonds_ref))
110 | bonds_list = get_triple_bonds(mol)
111 | angles_list = []
112 | for bonds in bonds_list:
113 | sym = '-'.join([get_bond_symbol(b) for b in bonds])
114 | sym1 = '-'.join([get_bond_symbol(b) for b in bonds][::-1])
115 | atoms = []
116 | if (sym == bonds_ref_sym) or (sym1 == bonds_ref_sym):
117 | if sym1 == bonds_ref_sym:
118 | bonds = bonds[::-1]
119 | bond0 = bonds[0]
120 | atom0 = bond0.GetBeginAtomIdx()
121 | atom1 = bond0.GetEndAtomIdx()
122 |
123 | bond1 = bonds[1]
124 | atom1_0 = bond1.GetBeginAtomIdx()
125 | atom1_1 = bond1.GetEndAtomIdx()
126 | if atom0 == atom1_0:
127 | i, j, k = atom1, atom0, atom1_1
128 | elif atom0 == atom1_1:
129 | i, j, k = atom1, atom0, atom1_0
130 | elif atom1 == atom1_0:
131 | i, j, k = atom0, atom1, atom1_1
132 | elif atom1 == atom1_1:
133 | i, j, k = atom0, atom1, atom1_0
134 |
135 | bond2 = bonds[2]
136 | atom2_0 = bond2.GetBeginAtomIdx()
137 | atom2_1 = bond2.GetEndAtomIdx()
138 | if atom2_0 == k:
139 | l = atom2_1
140 | elif atom2_1 == k:
141 | l = atom2_0
142 | # print(i,j,k,l)
143 | angle = Chem.rdMolTransforms.GetDihedralDeg(mol.GetConformer(), i, j, k, l)
144 | angles_list.append(angle)
145 | return angles_list
146 |
147 |
148 | def bond_distance_from_mol(mol, fragment_mask):
149 | # only consider linker-related distance
150 | pos = mol.GetConformer().GetPositions()
151 | pdist = pos[None, :] - pos[:, None]
152 | pdist = np.sqrt(np.sum(pdist ** 2, axis=-1))
153 | all_distances = []
154 | for bond in mol.GetBonds():
155 | s_sym = bond.GetBeginAtom().GetAtomicNum()
156 | e_sym = bond.GetEndAtom().GetAtomicNum()
157 | s_idx, e_idx = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
158 | if (fragment_mask[s_idx] == 0) | (fragment_mask[e_idx] == 0):
159 | bond_type = BOND_TYPES[bond.GetBondType()]
160 | distance = pdist[s_idx, e_idx]
161 | all_distances.append(((s_sym, e_sym, bond_type), distance))
162 | return all_distances
163 |
164 |
165 | def _format_bond_type(bond_type: BondType) -> BondType:
166 | atom1, atom2, bond_category = bond_type
167 | if atom1 > atom2:
168 | atom1, atom2 = atom2, atom1
169 | return atom1, atom2, bond_category
170 |
171 |
172 | def get_distribution(distances: Sequence[float], bins=DISTANCE_BINS) -> np.ndarray:
173 | """Get the distribution of distances.
174 |
175 | Args:
176 | distances (list): List of distances.
177 | bins (list): bins of distances
178 | Returns:
179 | np.array: empirical distribution of distances with length equals to DISTANCE_BINS.
180 | """
181 | bin_counts = collections.Counter(np.searchsorted(bins, distances))
182 | bin_counts = [bin_counts[i] if i in bin_counts else 0 for i in range(len(bins) + 1)]
183 | bin_counts = np.array(bin_counts) / np.sum(bin_counts)
184 | return bin_counts
185 |
186 |
187 | def get_bond_length_profile(mol_list, fragment_mask_list) -> BondLengthProfile:
188 | bond_lengths = []
189 | for mol, mask in zip(mol_list, fragment_mask_list):
190 | mol = Chem.RemoveAllHs(mol)
191 | bond_lengths += bond_distance_from_mol(mol, mask)
192 |
193 | bond_length_profile = collections.defaultdict(list)
194 | for bond_type, bond_length in bond_lengths:
195 | bond_type = _format_bond_type(bond_type)
196 | bond_length_profile[bond_type].append(bond_length)
197 | bond_length_profile = {k: get_distribution(v) for k, v in bond_length_profile.items()}
198 | return bond_length_profile
199 |
200 |
201 | def get_bond_angles_dict(mol_list, fragment_mask_list):
202 | bond_angles = collections.defaultdict(list)
203 | dihedral_angles = collections.defaultdict(list)
204 | for mol, mask in zip(mol_list, fragment_mask_list):
205 | mol = Chem.RemoveAllHs(mol)
206 | for angle_type in BOND_ANGLES:
207 | bond_angles[angle_type] += get_bond_angle(mol, mask, bond_smi=angle_type)
208 | for angle_type in DIHEDRAL_ANGLES:
209 | dihedral_angles[angle_type] += get_dihedral_angle(mol, angle_type)
210 | return bond_angles, dihedral_angles
211 |
212 |
213 | def get_bond_angles_profile(mol_list, fragment_mask_list):
214 | angles, dihedrals = get_bond_angles_dict(mol_list, fragment_mask_list)
215 | angles_profile = {}
216 | for k, v in angles.items():
217 | angles_profile[k] = get_distribution(v, ANGLE_BINS)
218 | for k, v in dihedrals.items():
219 | angles_profile[k] = get_distribution(v, DIHEDRAL_BINS)
220 | return angles_profile
221 |
--------------------------------------------------------------------------------
/utils/fpscores.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guanjq/LinkerNet/280759c16ccecece0d81ab9cebe3f44041b80e51/utils/fpscores.pkl.gz
--------------------------------------------------------------------------------
/utils/geometry.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from copy import deepcopy
5 |
6 |
7 | def rotation_matrix_from_vectors(vec1, vec2):
8 | """ Find the rotation matrix that aligns vec1 to vec2
9 | :param vec1: A 3d "source" vector
10 | :param vec2: A 3d "destination" vector
11 | :return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2.
12 | """
13 | a, b = (vec1 / torch.norm(vec1)).view(3), (vec2 / torch.norm(vec2)).view(3)
14 | v = torch.cross(a, b)
15 | c = torch.dot(a, b)
16 | s = torch.norm(v)
17 | kmat = torch.tensor([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
18 | if s == 0:
19 | rotation_matrix = torch.eye(3)
20 | else:
21 | rotation_matrix = torch.eye(3) + kmat + kmat @ kmat * ((1 - c) / (s ** 2))
22 | return rotation_matrix
23 |
24 |
25 | def find_rigid_transform(true_points, mapping_points):
26 | # align one fragment to the other and find rigid transformation with Kabsch algorithm
27 | # mapping_points @ R.T + t = true_points
28 | t1 = true_points.mean(0)
29 | t2 = mapping_points.mean(0)
30 | x1 = true_points - t1
31 | x2 = mapping_points - t2
32 |
33 | h = x2.T @ x1
34 | u, s, vt = np.linalg.svd(h)
35 | v = vt.T
36 | d = np.linalg.det(v @ u.T)
37 | e = np.array([[1, 0, 0], [0, 1, 0], [0, 0, d]])
38 | R = v @ e @ u.T
39 | t = -R @ t2.T + t1.T
40 | return R, t
41 |
42 |
43 | def get_pca_axes(coords, weights=None):
44 | """Computes the (weighted) PCA of the given coordinates.
45 | Args:
46 | coords (np.ndarray): (N ,D)
47 | weights (np.ndarray, optional): (N). Defaults to None.
48 | Returns:
49 | Tuple[np.ndarray, np.ndarray]: (D), (D, D)
50 | """
51 | if weights is None:
52 | weights = np.ones((*coords.shape[:-1],))
53 | weights /= weights.sum()
54 |
55 | mean = (weights[..., None] * coords).mean()
56 | centered = coords - mean
57 | cov = centered.T @ np.diag(weights) @ centered
58 |
59 | s, vecs = np.linalg.eigh(cov)
60 | return s, vecs
61 |
62 |
63 | def find_axes(atoms, charges):
64 | """Generates equivariant axes based on PCA.
65 | Args:
66 | atoms (np.ndarray): (..., M, 3)
67 | charges (np.ndarray): (M)
68 | Returns:
69 | np.ndarray: (3, 3)
70 | """
71 | atoms, charges = deepcopy(atoms), deepcopy(charges)
72 | # First compute the axes by PCA
73 | atoms = atoms - atoms.mean(-2, keepdims=True)
74 | s, axes = get_pca_axes(atoms, charges)
75 | # Let's check whether we have identical eigenvalues
76 | # if that's the case we need to work with soem pseudo positions
77 | # to get unique atoms.
78 | unique_values = np.zeros_like(s)
79 | v, uindex = np.unique(s, return_index=True)
80 | unique_values[uindex] = v
81 | is_ambiguous = np.count_nonzero(unique_values) < np.count_nonzero(s)
82 | # We always compute the pseudo coordinates because it causes some compile errors
83 | # for some unknown reason on A100 cards with jax.lax.cond.
84 | # Compute pseudo coordiantes based on the vector inducing the largest coulomb energy.
85 | distances = atoms[None] - atoms[..., None, :]
86 | dist_norm = np.linalg.norm(distances, axis=-1)
87 | coulomb = charges[None] * charges[:, None] / (dist_norm + 1e-20)
88 | off_diag_mask = ~np.eye(atoms.shape[0], dtype=bool)
89 | coulomb, distances = coulomb[off_diag_mask], distances[off_diag_mask]
90 | idx = np.argmax(coulomb)
91 | scale_vec = distances[idx]
92 | scale_vec /= np.linalg.norm(scale_vec)
93 | # Projected atom positions
94 | proj = atoms @ scale_vec[..., None] * scale_vec
95 | diff = atoms - proj
96 | pseudo_atoms = proj * (1 + 1e-4) + diff
97 |
98 | pseudo_s, pseudo_axes = get_pca_axes(pseudo_atoms, charges)
99 |
100 | # Select pseudo axes if it is ambiguous
101 | s = np.where(is_ambiguous, pseudo_s, s)
102 | axes = np.where(is_ambiguous, pseudo_axes, axes)
103 |
104 | order = np.argsort(s)[::-1]
105 | axes = axes[:, order]
106 |
107 | # Compute an equivariant vector
108 | distances = np.linalg.norm(atoms[None] - atoms[..., None, :], axis=-1)
109 | weights = distances.sum(-1)
110 | equi_vec = ((weights * charges)[..., None] * atoms).mean(0)
111 |
112 | ve = equi_vec @ axes
113 | flips = ve < 0
114 | axes = np.where(flips[None], -axes, axes)
115 |
116 | right_hand = np.stack(
117 | [axes[:, 0], axes[:, 1], np.cross(axes[:, 0], axes[:, 1])], axis=1)
118 | # axes = np.where(np.abs(ve[-1]) < 1e-7, right_hand, axes)
119 | return right_hand
120 |
121 |
122 | def local_to_global(R, t, p):
123 | """
124 | Description:
125 | Convert local (internal) coordinates to global (external) coordinates q.
126 | q <- Rp + t
127 | Args:
128 | R: (F, 3, 3).
129 | t: (F, 3).
130 | p: Local coordinates, (F, ..., 3).
131 | Returns:
132 | q: Global coordinates, (F, ..., 3).
133 | """
134 | assert p.size(-1) == 3
135 | assert R.ndim - 1 == t.ndim
136 | squeeze_dim = False
137 | if R.ndim == 2:
138 | R = R.unsqueeze(0)
139 | t = t.unsqueeze(0)
140 | p = p.unsqueeze(0)
141 | squeeze_dim = True
142 |
143 | p_size = p.size()
144 | num_frags = p_size[0]
145 |
146 | p = p.view(num_frags, -1, 3).transpose(-1, -2) # (F, *, 3) -> (F, 3, *)
147 | q = torch.matmul(R, p) + t.unsqueeze(-1) # (F, 3, *)
148 | q = q.transpose(-1, -2).reshape(p_size) # (F, 3, *) -> (F, *, 3) -> (F, ..., 3)
149 | if squeeze_dim:
150 | q = q.squeeze(0)
151 | return q
152 |
153 |
154 | def global_to_local(R, t, q):
155 | """
156 | Description:
157 | Convert global (external) coordinates q to local (internal) coordinates p.
158 | p <- R^{T}(q - t)
159 | Args:
160 | R: (F, 3, 3).
161 | t: (F, 3).
162 | q: Global coordinates, (F, ..., 3).
163 | Returns:
164 | p: Local coordinates, (F, ..., 3).
165 | """
166 | assert q.size(-1) == 3
167 | assert R.ndim - 1 == t.ndim
168 | squeeze_dim = False
169 | if R.ndim == 2:
170 | R = R.unsqueeze(0)
171 | t = t.unsqueeze(0)
172 | q = q.unsqueeze(0)
173 | squeeze_dim = True
174 |
175 | q_size = q.size()
176 | num_frags = q_size[0]
177 |
178 | q = q.reshape(num_frags, -1, 3).transpose(-1, -2) # (F, *, 3) -> (F, 3, *)
179 | p = torch.matmul(R.transpose(-1, -2), (q - t.unsqueeze(-1))) # (F, 3, *)
180 | p = p.transpose(-1, -2).reshape(q_size) # (F, 3, *) -> (F, *, 3) -> (F, ..., 3)
181 | if squeeze_dim:
182 | p = p.squeeze(0)
183 | return p
184 |
185 |
186 |
187 | # Copyright (c) Meta Platforms, Inc. and affiliates.
188 | # All rights reserved.
189 | #
190 | # This source code is licensed under the BSD-style license found in the
191 | # LICENSE file in the root directory of this source tree.
192 | def quaternion_to_rotation_matrix(quaternions):
193 | """
194 | Convert rotations given as quaternions to rotation matrices.
195 | Args:
196 | quaternions: quaternions with real part first,
197 | as tensor of shape (..., 4).
198 | Returns:
199 | Rotation matrices as tensor of shape (..., 3, 3).
200 | """
201 | quaternions = F.normalize(quaternions, dim=-1)
202 | r, i, j, k = torch.unbind(quaternions, -1)
203 | two_s = 2.0 / (quaternions * quaternions).sum(-1)
204 |
205 | o = torch.stack(
206 | (
207 | 1 - two_s * (j * j + k * k),
208 | two_s * (i * j - k * r),
209 | two_s * (i * k + j * r),
210 | two_s * (i * j + k * r),
211 | 1 - two_s * (i * i + k * k),
212 | two_s * (j * k - i * r),
213 | two_s * (i * k - j * r),
214 | two_s * (j * k + i * r),
215 | 1 - two_s * (i * i + j * j),
216 | ),
217 | -1,
218 | )
219 | return o.reshape(quaternions.shape[:-1] + (3, 3))
220 |
221 |
222 | # Copyright (c) Meta Platforms, Inc. and affiliates.
223 | # All rights reserved.
224 | #
225 | # This source code is licensed under the BSD-style license found in the
226 | # LICENSE file in the root directory of this source tree.
227 | """
228 | BSD License
229 |
230 | For PyTorch3D software
231 |
232 | Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
233 |
234 | Redistribution and use in source and binary forms, with or without modification,
235 | are permitted provided that the following conditions are met:
236 |
237 | * Redistributions of source code must retain the above copyright notice, this
238 | list of conditions and the following disclaimer.
239 |
240 | * Redistributions in binary form must reproduce the above copyright notice,
241 | this list of conditions and the following disclaimer in the documentation
242 | and/or other materials provided with the distribution.
243 |
244 | * Neither the name Meta nor the names of its contributors may be used to
245 | endorse or promote products derived from this software without specific
246 | prior written permission.
247 |
248 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
249 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
250 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
251 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
252 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
253 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
254 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
255 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
256 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
257 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
258 | """
259 |
260 |
261 | def quaternion_1ijk_to_rotation_matrix(q):
262 | """
263 | (1 + ai + bj + ck) -> R
264 | Args:
265 | q: (..., 3)
266 | """
267 | b, c, d = torch.unbind(q, dim=-1)
268 | s = torch.sqrt(1 + b**2 + c**2 + d**2)
269 | a, b, c, d = 1/s, b/s, c/s, d/s
270 |
271 | o = torch.stack(
272 | (
273 | a**2 + b**2 - c**2 - d**2, 2*b*c - 2*a*d, 2*b*d + 2*a*c,
274 | 2*b*c + 2*a*d, a**2 - b**2 + c**2 - d**2, 2*c*d - 2*a*b,
275 | 2*b*d - 2*a*c, 2*c*d + 2*a*b, a**2 - b**2 - c**2 + d**2,
276 | ),
277 | -1,
278 | )
279 | return o.reshape(q.shape[:-1] + (3, 3))
280 |
281 |
282 | def apply_rotation_to_vector(R, p):
283 | return local_to_global(R, torch.zeros_like(p), p)
284 |
285 |
286 | def axis_angle_to_matrix(axis_angle):
287 | """
288 | From https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html
289 | Convert rotations given as axis/angle to rotation matrices.
290 |
291 | Args:
292 | axis_angle: Rotations given as a vector in axis angle form,
293 | as a tensor of shape (..., 3), where the magnitude is
294 | the angle turned anticlockwise in radians around the
295 | vector's direction.
296 |
297 | Returns:
298 | Rotation matrices as tensor of shape (..., 3, 3).
299 | """
300 | return quaternion_to_rotation_matrix(axis_angle_to_quaternion(axis_angle))
301 |
302 |
303 | def axis_angle_to_quaternion(axis_angle):
304 | """
305 | From https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html
306 | Convert rotations given as axis/angle to quaternions.
307 |
308 | Args:
309 | axis_angle: Rotations given as a vector in axis angle form,
310 | as a tensor of shape (..., 3), where the magnitude is
311 | the angle turned anticlockwise in radians around the
312 | vector's direction.
313 |
314 | Returns:
315 | quaternions with real part first, as tensor of shape (..., 4).
316 | """
317 | angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
318 | half_angles = 0.5 * angles
319 | eps = 1e-6
320 | small_angles = angles.abs() < eps
321 | sin_half_angles_over_angles = torch.empty_like(angles)
322 | sin_half_angles_over_angles[~small_angles] = (
323 | torch.sin(half_angles[~small_angles]) / angles[~small_angles]
324 | )
325 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6
326 | # so sin(x/2)/x is about 1/2 - (x*x)/48
327 | sin_half_angles_over_angles[small_angles] = (
328 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48
329 | )
330 | quaternions = torch.cat(
331 | [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
332 | )
333 | return quaternions
334 |
--------------------------------------------------------------------------------
/utils/guidance_funcs.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from utils import so3
4 | from utils.geometry import local_to_global
5 |
6 |
7 | def compute_frag_distance_loss(p1, p2, min_d, max_d, mode='frag_center_distance'):
8 | """
9 | :param p1: (B, 3). center of fragment 1.
10 | :param p2: (B, 3). center of fragment 2.
11 | :param min_d: min distance
12 | :param max_d: max distance
13 | :param mode: constrained distance mode
14 | :return
15 | """
16 | if mode == 'frag_center_distance':
17 | dist = torch.norm(p1 - p2, p=2, dim=-1) # (B, )
18 | loss = torch.mean(torch.clamp(min_d - dist, min=0) ** 2 + torch.clamp(dist - max_d, min=0) ** 2)
19 | else:
20 | raise ValueError(mode)
21 | return loss
22 |
23 |
24 | def compute_anchor_prox_loss(x_linker, v1, v2, p1, p2,
25 | frags_local_pos_filled, frags_local_pos_mask,
26 | fragment_mask, cand_anchors_mask, batch, min_d=1.2, max_d=1.9):
27 | """
28 | :param x_linker: (num_linker_atoms, 3)
29 | :param v1: (B, 3)
30 | :param v2: (B, 3)
31 | :param p1: (B, 3)
32 | :param p2: (B, 3)
33 | :param frags_local_pos_filled: (B, max_num_atoms, 3)
34 | :param frags_local_pos_mask: (B, max_num_atoms) of BoolTensor.
35 | :param fragment_mask: (N, )
36 | :param cand_anchors_mask: (N, ) of BoolTensor.
37 | :param batch: (N, )
38 | :param min_d
39 | :param max_d
40 | :return
41 | """
42 |
43 | # transform rotation and translation to positions
44 | R1 = so3.so3vec_to_rotation(v1)
45 | R2 = so3.so3vec_to_rotation(v2)
46 | f1_local_pos = frags_local_pos_filled[::2]
47 | f2_local_pos = frags_local_pos_filled[1::2]
48 | f1_local_pos_mask = frags_local_pos_mask[::2]
49 | f2_local_pos_mask = frags_local_pos_mask[1::2]
50 | x_f1 = local_to_global(R1, p1, f1_local_pos)[f1_local_pos_mask]
51 | x_f2 = local_to_global(R2, p2, f2_local_pos)[f2_local_pos_mask]
52 |
53 | f1_anchor_mask = cand_anchors_mask[fragment_mask == 1]
54 | f2_anchor_mask = cand_anchors_mask[fragment_mask == 2]
55 | f1_batch, f2_batch = batch[fragment_mask == 1], batch[fragment_mask == 2]
56 |
57 | # approach 1: distance constraints on fragments only (unreasonable)
58 | # c_a1 = scatter_mean(x_f1[f1_anchor_mask], f1_batch[f1_anchor_mask], dim=0) # (B, 3)
59 | # c_a2 = scatter_mean(x_f2[f2_anchor_mask], f2_batch[f2_anchor_mask], dim=0)
60 | # c_na1 = scatter_mean(x_f1[~f1_anchor_mask], f1_batch[~f1_anchor_mask], dim=0)
61 | # c_na2 = scatter_mean(x_f2[~f2_anchor_mask], f2_batch[~f2_anchor_mask], dim=0)
62 | # loss = 0.
63 | # d_a1_a2 = torch.norm(c_a1 - c_a2, p=2, dim=-1)
64 | # if c_na1.size(0) > 0:
65 | # d_na1_a2 = torch.norm(c_na1 - c_a2, p=2, dim=-1)
66 | # loss += torch.mean(torch.clamp(d_a1_a2 - d_na1_a2, min=0))
67 | # if c_na2.size(0) > 0:
68 | # d_a1_na2 = torch.norm(c_a1 - c_na2, p=2, dim=-1)
69 | # loss += torch.mean(torch.clamp(d_a1_a2 - d_a1_na2, min=0))
70 |
71 | # approach 2: min dist of (linker, anchor) can form bond, (linker, non-anchor) cannot form bond
72 | linker_batch = batch[fragment_mask == 0]
73 |
74 | num_graphs = batch.max().item() + 1
75 | batch_losses = 0.
76 | for idx in range(num_graphs):
77 | linker_pos = x_linker[linker_batch == idx]
78 | loss_f1 = compute_prox_loss(x_f1[f1_batch == idx], linker_pos, f1_anchor_mask[f1_batch == idx], min_d, max_d)
79 | loss_f2 = compute_prox_loss(x_f2[f2_batch == idx], linker_pos, f2_anchor_mask[f2_batch == idx], min_d, max_d)
80 | batch_losses += loss_f1 + loss_f2
81 |
82 | return batch_losses / num_graphs
83 |
84 |
85 | def compute_prox_loss(frags_pos, linker_pos, anchor_mask, min_d=1.2, max_d=1.9):
86 | pairwise_dist = torch.norm(frags_pos[anchor_mask].unsqueeze(1) - linker_pos.unsqueeze(0), p=2, dim=-1)
87 | min_dist = pairwise_dist.min()
88 | # 1.2 < min dist < 1.9
89 | loss_anchor = torch.mean(torch.clamp(min_d - min_dist, min=0) ** 2 + torch.clamp(min_dist - max_d, min=0) ** 2)
90 |
91 | # non anchor min dist > 1.9
92 | loss_non_anchor = 0.
93 | non_anchor_pairwise_dist = torch.norm(frags_pos[~anchor_mask].unsqueeze(1) - linker_pos.unsqueeze(0), p=2, dim=-1)
94 | if non_anchor_pairwise_dist.size(0) > 0:
95 | non_anchor_min_dist = non_anchor_pairwise_dist.min()
96 | loss_non_anchor = torch.mean(torch.clamp(max_d - non_anchor_min_dist, min=0) ** 2)
97 |
98 | loss = loss_anchor + loss_non_anchor
99 | # print(f'loss anchor: {loss_anchor} loss non anchor: {loss_non_anchor}')
100 | return loss
101 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import random
4 | import logging
5 | import torch
6 | import numpy as np
7 | import yaml
8 | from easydict import EasyDict
9 | import sys
10 | import shutil
11 | import torch.utils.tensorboard
12 |
13 |
14 | class BlackHole(object):
15 | def __setattr__(self, name, value):
16 | pass
17 | def __call__(self, *args, **kwargs):
18 | return self
19 | def __getattr__(self, name):
20 | return self
21 |
22 |
23 | def load_config(path):
24 | with open(path, 'r') as f:
25 | return EasyDict(yaml.safe_load(f))
26 |
27 |
28 | def get_logger(name, log_dir=None):
29 | logger = logging.getLogger(name)
30 | logger.setLevel(logging.DEBUG)
31 | formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s')
32 |
33 | stream_handler = logging.StreamHandler(sys.stdout)
34 | stream_handler.setLevel(logging.DEBUG)
35 | stream_handler.setFormatter(formatter)
36 | logger.addHandler(stream_handler)
37 |
38 | if log_dir is not None:
39 | file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
40 | file_handler.setLevel(logging.DEBUG)
41 | file_handler.setFormatter(formatter)
42 | logger.addHandler(file_handler)
43 |
44 | return logger
45 |
46 |
47 | def get_new_log_dir(root='./logs', prefix='', tag=''):
48 | fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
49 | if prefix != '':
50 | fn = prefix + '_' + fn
51 | if tag != '':
52 | fn = fn + '_' + tag
53 | log_dir = os.path.join(root, fn)
54 | os.makedirs(log_dir)
55 | return log_dir
56 |
57 |
58 | def seed_all(seed):
59 | torch.manual_seed(seed)
60 | np.random.seed(seed)
61 | random.seed(seed)
62 |
63 |
64 | def log_hyperparams(writer, args):
65 | from torch.utils.tensorboard.summary import hparams
66 | vars_args = {k:v if isinstance(v, str) else repr(v) for k, v in vars(args).items()}
67 | exp, ssi, sei = hparams(vars_args, {})
68 | writer.file_writer.add_summary(exp)
69 | writer.file_writer.add_summary(ssi)
70 | writer.file_writer.add_summary(sei)
71 |
72 |
73 | def int_tuple(argstr):
74 | return tuple(map(int, argstr.split(',')))
75 |
76 |
77 | def str_tuple(argstr):
78 | return tuple(argstr.split(','))
79 |
80 |
81 | def count_parameters(model):
82 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
83 |
84 |
85 | def unique(x, dim=None):
86 | """Unique elements of x and indices of those unique elements
87 | https://github.com/pytorch/pytorch/issues/36748#issuecomment-619514810
88 |
89 | e.g.
90 |
91 | unique(tensor([
92 | [1, 2, 3],
93 | [1, 2, 4],
94 | [1, 2, 3],
95 | [1, 2, 5]
96 | ]), dim=0)
97 | => (tensor([[1, 2, 3],
98 | [1, 2, 4],
99 | [1, 2, 5]]),
100 | tensor([0, 1, 3]))
101 | """
102 | unique, inverse = torch.unique(
103 | x, sorted=True, return_inverse=True, dim=dim)
104 | perm = torch.arange(inverse.size(0), dtype=inverse.dtype,
105 | device=inverse.device)
106 | inverse, perm = inverse.flip([0]), perm.flip([0])
107 | return unique, inverse.new_empty(unique.size(dim)).scatter_(0, inverse, perm)
108 |
109 |
110 | def setup_logdir(config, logdir, mode='train', tag='', create_dir=True):
111 | # Logging
112 | config_name = os.path.basename(config)[:os.path.basename(config).rfind('.')]
113 | log_dir = get_new_log_dir(logdir, tag=tag, prefix=config_name) if create_dir else logdir
114 | if mode == 'train':
115 | ckpt_dir = os.path.join(log_dir, 'checkpoints')
116 | vis_dir = os.path.join(log_dir, 'vis')
117 | os.makedirs(ckpt_dir, exist_ok=True)
118 | os.makedirs(vis_dir, exist_ok=True)
119 | logger = get_logger('train', log_dir)
120 | writer = torch.utils.tensorboard.SummaryWriter(log_dir)
121 | shutil.copytree('./models', os.path.join(log_dir, 'models'))
122 | elif mode == 'eval':
123 | logger = get_logger('eval', log_dir)
124 | writer, ckpt_dir, vis_dir = None, None, None
125 | else:
126 | raise ValueError
127 | logger.info(config)
128 | shutil.copyfile(config, os.path.join(log_dir, os.path.basename(config)))
129 | return logger, writer, log_dir, ckpt_dir, vis_dir
130 |
--------------------------------------------------------------------------------
/utils/prior_num_atoms.py:
--------------------------------------------------------------------------------
1 | """Utils for sampling size of a linker."""
2 |
3 | import numpy as np
4 | import pickle
5 | from collections import Counter
6 |
7 |
8 | def setup_configs(meta_path='utils/prior_num_atoms.pkl', mode='frag_center_distance'):
9 | with open(meta_path, 'rb') as f:
10 | prior_meta = pickle.load(f)
11 | all_dist = prior_meta[mode]
12 | all_n_atoms = prior_meta['num_linker_atoms']
13 | bin_min, bin_max = np.floor(all_dist.min()), np.ceil(all_dist.max())
14 | BINS = np.arange(bin_min, bin_max, 1.)
15 | CONFIGS = {'bounds': BINS, 'distributions': []}
16 |
17 | for min_d, max_d in zip(BINS[:-1], BINS[1:]):
18 | valid_idx = (min_d < all_dist) & (all_dist < max_d)
19 | c = Counter(all_n_atoms[valid_idx])
20 | num_atoms_list, prob_list = list(c.keys()), np.array(list(c.values())) / np.sum(list(c.values()))
21 | CONFIGS['distributions'].append((num_atoms_list, prob_list))
22 | return CONFIGS
23 |
24 |
25 | def _get_bin_idx(distance, config_dict):
26 | bounds = config_dict['bounds']
27 | for i in range(len(bounds) - 1):
28 | if distance < bounds[i + 1]:
29 | return i
30 | return len(bounds) - 2
31 |
32 |
33 | def sample_atom_num(distance, config_dict):
34 | bin_idx = _get_bin_idx(distance, config_dict)
35 | num_atom_list, prob_list = config_dict['distributions'][bin_idx]
36 | return np.random.choice(num_atom_list, p=prob_list)
37 |
--------------------------------------------------------------------------------
/utils/reconstruct_linker.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.spatial.distance import cdist
4 |
5 | from rdkit import Chem, Geometry
6 |
7 | from utils import const
8 | from copy import deepcopy
9 | import re
10 | import itertools
11 |
12 |
13 | class MolReconsError(Exception):
14 | pass
15 |
16 |
17 | def get_bond_order(atom1, atom2, distance, check_exists=True, margins=const.MARGINS_EDM):
18 | distance = 100 * distance # We change the metric
19 |
20 | # Check exists for large molecules where some atom pairs do not have a
21 | # typical bond length.
22 | if check_exists:
23 | if atom1 not in const.BONDS_1:
24 | return 0
25 | if atom2 not in const.BONDS_1[atom1]:
26 | return 0
27 |
28 | # margin1, margin2 and margin3 have been tuned to maximize the stability of the QM9 true samples
29 | if distance < const.BONDS_1[atom1][atom2] + margins[0]:
30 |
31 | # Check if atoms in bonds2 dictionary.
32 | if atom1 in const.BONDS_2 and atom2 in const.BONDS_2[atom1]:
33 | thr_bond2 = const.BONDS_2[atom1][atom2] + margins[1]
34 | if distance < thr_bond2:
35 | if atom1 in const.BONDS_3 and atom2 in const.BONDS_3[atom1]:
36 | thr_bond3 = const.BONDS_3[atom1][atom2] + margins[2]
37 | if distance < thr_bond3:
38 | return 3 # Triple
39 | return 2 # Double
40 | return 1 # Single
41 | return 0 # No bond
42 |
43 |
44 | def create_conformer(coords):
45 | conformer = Chem.Conformer()
46 | for i, (x, y, z) in enumerate(coords):
47 | conformer.SetAtomPosition(i, Geometry.Point3D(x, y, z))
48 | return conformer
49 |
50 |
51 | def build_linker_mol(linker_x, linker_c, add_bonds=False):
52 | mol = Chem.RWMol()
53 | for atom in linker_c:
54 | a = Chem.Atom(int(atom))
55 | mol.AddAtom(a)
56 |
57 | if add_bonds:
58 | # predict bond order
59 | n = len(linker_x)
60 | ptable = Chem.GetPeriodicTable()
61 | dists = cdist(linker_x, linker_x, 'euclidean')
62 | E = torch.zeros((n, n), dtype=torch.int)
63 | A = torch.zeros((n, n), dtype=torch.bool)
64 | for i in range(n):
65 | for j in range(i):
66 | atom_i = ptable.GetElementSymbol(int(linker_c[i]))
67 | atom_j = ptable.GetElementSymbol(int(linker_c[j]))
68 | order = get_bond_order(atom_i, atom_j, dists[i, j], margins=const.MARGINS_EDM)
69 | if order > 0:
70 | A[i, j] = 1
71 | E[i, j] = order
72 | all_bonds = torch.nonzero(A)
73 | for bond in all_bonds:
74 | mol.AddBond(bond[0].item(), bond[1].item(), const.BOND_DICT[E[bond[0], bond[1]].item()])
75 | return mol
76 |
77 |
78 | def reconstruct_mol(frag_mol, frag_x, linker_x, linker_c):
79 | # construct linker mol
80 | linker_mol = build_linker_mol(linker_x, linker_c, add_bonds=True)
81 |
82 | # combine mol and assign conformer
83 | mol = Chem.CombineMols(frag_mol, linker_mol)
84 | mol = Chem.RWMol(mol)
85 | # frag_x = frag_mol.GetConformer().GetPositions()
86 | all_x = np.concatenate([frag_x, linker_x], axis=0)
87 | mol.RemoveAllConformers()
88 | mol.AddConformer(create_conformer(all_x))
89 |
90 | # add frag-linker bond
91 | n_frag_atoms = frag_mol.GetNumAtoms()
92 | n = mol.GetNumAtoms()
93 |
94 | dists = cdist(all_x, all_x, 'euclidean')
95 | E = torch.zeros((n, n), dtype=torch.int)
96 | A = torch.zeros((n, n), dtype=torch.bool)
97 | for i in range(n_frag_atoms):
98 | for j in range(n_frag_atoms, n):
99 | atom_i = mol.GetAtomWithIdx(i).GetSymbol()
100 | atom_j = mol.GetAtomWithIdx(j).GetSymbol()
101 | order = get_bond_order(atom_i, atom_j, dists[i, j], margins=const.MARGINS_EDM)
102 | if order > 0:
103 | A[i, j] = 1
104 | E[i, j] = order
105 | all_bonds = torch.nonzero(A)
106 | for bond in all_bonds:
107 | mol.AddBond(bond[0].item(), bond[1].item(), const.BOND_DICT[E[bond[0], bond[1]].item()])
108 |
109 | # frag_c = [atom.GetAtomicNum() for atom in frag_mol.GetAtoms()]
110 | # all_x = np.concatenate([frag_x, linker_x], axis=0)
111 | # all_c = np.concatenate([frag_c, linker_c], axis=0)
112 | # print('all c: ', all_c)
113 | # mol = build_linker_mol(all_x, all_c)
114 | #
115 | # try:
116 | # Chem.SanitizeMol(mol)
117 | # fixed = True
118 | # except Exception as e:
119 | # fixed = False
120 | #
121 | # if not fixed:
122 | # mol, fixed = fix_valence(mol)
123 | return mol
124 |
125 |
126 | def reconstruct_mol_with_bond(frag_mol, frag_x, linker_x, linker_c,
127 | linker_bond_index, linker_bond_type, known_linker_bonds=True, check_validity=True):
128 | # construct linker mol
129 | linker_mol = build_linker_mol(linker_x, linker_c, add_bonds=known_linker_bonds)
130 |
131 | # combine mol and assign conformer
132 | mol = Chem.CombineMols(frag_mol, linker_mol)
133 | linker_atom_idx = list(range(mol.GetNumAtoms() - linker_mol.GetNumAtoms(), mol.GetNumAtoms()))
134 | mol = Chem.RWMol(mol)
135 | # frag_x = frag_mol.GetConformer().GetPositions()
136 | all_x = np.concatenate([frag_x, linker_x], axis=0)
137 | mol.RemoveAllConformers()
138 | mol.AddConformer(create_conformer(all_x))
139 |
140 | linker_bond_index, linker_bond_type = linker_bond_index.tolist(), linker_bond_type.tolist()
141 | anchor_indices = set()
142 | # add bonds
143 | for i, type_this in enumerate(linker_bond_type):
144 | node_i, node_j = linker_bond_index[0][i], linker_bond_index[1][i]
145 | if node_i < node_j:
146 | if type_this == 0:
147 | continue
148 | else:
149 | if node_i in linker_atom_idx and node_j not in linker_atom_idx:
150 | anchor_indices.add(int(node_j))
151 | elif node_j in linker_atom_idx and node_i not in linker_atom_idx:
152 | anchor_indices.add(int(node_i))
153 |
154 | if type_this == 1:
155 | mol.AddBond(node_i, node_j, Chem.BondType.SINGLE)
156 | elif type_this == 2:
157 | mol.AddBond(node_i, node_j, Chem.BondType.DOUBLE)
158 | elif type_this == 3:
159 | mol.AddBond(node_i, node_j, Chem.BondType.TRIPLE)
160 | elif type_this == 4:
161 | mol.AddBond(node_i, node_j, Chem.BondType.AROMATIC)
162 | else:
163 | raise Exception('unknown bond order {}'.format(type_this))
164 |
165 | mol = mol.GetMol()
166 | for anchor_idx in anchor_indices:
167 | atom = mol.GetAtomWithIdx(anchor_idx)
168 | atom.SetNumExplicitHs(0)
169 |
170 | if check_validity:
171 | mol = fix_validity(mol, linker_atom_idx)
172 |
173 | # check valid
174 | # rd_mol_check = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
175 | # if (rd_mol_check is None) and check_validity:
176 | # raise MolReconsError()
177 | return mol
178 |
179 |
180 | def fix_validity(mol, linker_atom_idx):
181 | try:
182 | Chem.SanitizeMol(mol)
183 | fixed = True
184 | except Exception as e:
185 | fixed = False
186 |
187 | if not fixed:
188 | try:
189 | Chem.Kekulize(deepcopy(mol))
190 | except Chem.rdchem.KekulizeException as e:
191 | err = e
192 | if 'Unkekulized' in err.args[0]:
193 | mol, fixed = fix_aromatic(mol)
194 |
195 | # valence error for N
196 | if not fixed:
197 | mol, fixed = fix_valence(mol, linker_atom_idx)
198 |
199 | # print('s2')
200 | if not fixed:
201 | mol, fixed = fix_aromatic(mol, True, linker_atom_idx)
202 |
203 | try:
204 | Chem.SanitizeMol(mol)
205 | except Exception as e:
206 | # raise MolReconsError()
207 | return None
208 | return mol
209 |
210 |
211 | def fix_valence(mol, linker_atom_idx=None):
212 | mol = deepcopy(mol)
213 | fixed = False
214 | cnt_loop = 0
215 | while True:
216 | try:
217 | Chem.SanitizeMol(mol)
218 | fixed = True
219 | break
220 | except Chem.rdchem.AtomValenceException as e:
221 | err = e
222 | except Exception as e:
223 | return mol, False # from HERE: rerun sample
224 | cnt_loop += 1
225 | if cnt_loop > 100:
226 | break
227 | N4_valence = re.compile(u"Explicit valence for atom # ([0-9]{1,}) N, 4, is greater than permitted")
228 | index = N4_valence.findall(err.args[0])
229 | if len(index) > 0:
230 | if linker_atom_idx is None or int(index[0]) in linker_atom_idx:
231 | mol.GetAtomWithIdx(int(index[0])).SetFormalCharge(1)
232 | return mol, fixed
233 |
234 |
235 | def get_ring_sys(mol):
236 | all_rings = Chem.GetSymmSSSR(mol)
237 | if len(all_rings) == 0:
238 | ring_sys_list = []
239 | else:
240 | ring_sys_list = [all_rings[0]]
241 | for ring in all_rings[1:]:
242 | form_prev = False
243 | for prev_ring in ring_sys_list:
244 | if set(ring).intersection(set(prev_ring)):
245 | prev_ring.extend(ring)
246 | form_prev = True
247 | break
248 | if not form_prev:
249 | ring_sys_list.append(ring)
250 | ring_sys_list = [list(set(x)) for x in ring_sys_list]
251 | return ring_sys_list
252 |
253 |
254 | def get_all_subsets(ring_list):
255 | all_sub_list = []
256 | for n_sub in range(len(ring_list)+1):
257 | all_sub_list.extend(itertools.combinations(ring_list, n_sub))
258 | return all_sub_list
259 |
260 |
261 | def fix_aromatic(mol, strict=False, linker_atom_idx=None):
262 | mol_orig = mol
263 | atomatic_list = [a.GetIdx() for a in mol.GetAromaticAtoms()]
264 | N_ring_list = []
265 | S_ring_list = []
266 | for ring_sys in get_ring_sys(mol):
267 | if set(ring_sys).intersection(set(atomatic_list)):
268 | idx_N = [atom for atom in ring_sys if mol.GetAtomWithIdx(atom).GetSymbol() == 'N']
269 | if len(idx_N) > 0:
270 | idx_N.append(-1) # -1 for not add to this loop
271 | N_ring_list.append(idx_N)
272 | idx_S = [atom for atom in ring_sys if mol.GetAtomWithIdx(atom).GetSymbol() == 'S']
273 | if len(idx_S) > 0:
274 | idx_S.append(-1) # -1 for not add to this loop
275 | S_ring_list.append(idx_S)
276 | # enumerate S
277 | fixed = False
278 | if strict:
279 | S_ring_list = [s for ring in S_ring_list for s in ring if s != -1]
280 | permutation = get_all_subsets(S_ring_list)
281 | else:
282 | permutation = list(itertools.product(*S_ring_list))
283 | for perm in permutation:
284 | mol = deepcopy(mol_orig)
285 | perm = [x for x in perm if x != -1]
286 | for idx in perm:
287 | if linker_atom_idx is None or idx in linker_atom_idx:
288 | mol.GetAtomWithIdx(idx).SetFormalCharge(1)
289 | try:
290 | if strict:
291 | mol, fixed = fix_valence(mol, linker_atom_idx)
292 | Chem.SanitizeMol(mol)
293 | fixed = True
294 | break
295 | except:
296 | continue
297 | # enumerate N
298 | if not fixed:
299 | if strict:
300 | N_ring_list = [s for ring in N_ring_list for s in ring if s != -1]
301 | permutation = get_all_subsets(N_ring_list)
302 | else:
303 | permutation = list(itertools.product(*N_ring_list))
304 | for perm in permutation: # each ring select one atom
305 | perm = [x for x in perm if x != -1]
306 | # print(perm)
307 | actions = itertools.product([0, 1], repeat=len(perm))
308 | for action in actions: # add H or charge
309 | mol = deepcopy(mol_orig)
310 | for idx, act_atom in zip(perm, action):
311 | if linker_atom_idx is None or idx in linker_atom_idx:
312 | if act_atom == 0:
313 | mol.GetAtomWithIdx(idx).SetNumExplicitHs(1)
314 | else:
315 | mol.GetAtomWithIdx(idx).SetFormalCharge(1)
316 | try:
317 | if strict:
318 | mol, fixed = fix_valence(mol, linker_atom_idx)
319 | Chem.SanitizeMol(mol)
320 | fixed = True
321 | break
322 | except:
323 | continue
324 | if fixed:
325 | break
326 | return mol, fixed
327 |
328 |
329 | def parse_sampling_result(data_list, final_x, final_c, atom_featurizer):
330 | all_mols = []
331 | for data, x_gen, c_gen in zip(data_list, final_x, final_c):
332 | frag_pos = x_gen[data.fragment_mask > 0].cpu().numpy().astype(np.float64)
333 | linker_pos = x_gen[data.fragment_mask == 0].cpu().numpy().astype(np.float64)
334 | linker_ele = [atom_featurizer.get_element_from_index(int(c)) for c in c_gen[data.linker_mask]]
335 | full_mol = reconstruct_mol(data.frag_mol, frag_pos, linker_pos, linker_ele)
336 | all_mols.append(full_mol)
337 | return all_mols
338 |
339 |
340 | def parse_sampling_result_with_bond(data_list, final_x, final_c, final_bond, atom_featurizer,
341 | known_linker_bonds=True, check_validity=False):
342 | all_mols = []
343 | if not isinstance(data_list, list):
344 | data_list = [data_list] * len(final_x)
345 | for data, x_gen, c_gen, b_gen in zip(data_list, final_x, final_c, final_bond):
346 | frag_pos = x_gen[data.fragment_mask > 0].cpu().numpy().astype(np.float64)
347 | linker_pos = x_gen[data.fragment_mask == 0].cpu().numpy().astype(np.float64)
348 | linker_ele = [atom_featurizer.get_element_from_index(int(c)) for c in c_gen[data.linker_mask]]
349 | linker_bond_type = b_gen[data.linker_bond_mask]
350 | linker_bond_index = data.edge_index[:, data.linker_bond_mask]
351 | full_mol = reconstruct_mol_with_bond(
352 | data.frag_mol, frag_pos, linker_pos, linker_ele, linker_bond_index, linker_bond_type,
353 | known_linker_bonds, check_validity)
354 | all_mols.append(full_mol)
355 | return all_mols
356 |
--------------------------------------------------------------------------------
/utils/sascorer.py:
--------------------------------------------------------------------------------
1 | #
2 | # calculation of synthetic accessibility score as described in:
3 | #
4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
5 | # Peter Ertl and Ansgar Schuffenhauer
6 | # Journal of Cheminformatics 1:8 (2009)
7 | # http://www.jcheminf.com/content/1/1/8
8 | #
9 | # several small modifications to the original paper are included
10 | # particularly slightly different formula for marocyclic penalty
11 | # and taking into account also molecule symmetry (fingerprint density)
12 | #
13 | # for a set of 10k diverse molecules the agreement between the original method
14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97
15 | #
16 | # peter ertl & greg landrum, september 2013
17 | #
18 | from __future__ import print_function
19 |
20 | from rdkit import Chem
21 | from rdkit.Chem import rdMolDescriptors
22 | from rdkit.six.moves import cPickle
23 | from rdkit.six import iteritems
24 |
25 | import math
26 | from collections import defaultdict
27 |
28 | import os.path as op
29 |
30 | _fscores = None
31 |
32 |
33 | def readFragmentScores(name='fpscores'):
34 | import gzip
35 | global _fscores
36 | # generate the full path filename:
37 | if name == "fpscores":
38 | name = op.join(op.dirname(__file__), name)
39 | _fscores = cPickle.load(gzip.open('%s.pkl.gz' % name))
40 | outDict = {}
41 | for i in _fscores:
42 | for j in range(1, len(i)):
43 | outDict[i[j]] = float(i[0])
44 | _fscores = outDict
45 |
46 |
47 | def numBridgeheadsAndSpiro(mol, ri=None):
48 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
49 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
50 | return nBridgehead, nSpiro
51 |
52 |
53 | def calculateScore(m):
54 | if _fscores is None:
55 | readFragmentScores()
56 |
57 | # fragment score
58 | fp = rdMolDescriptors.GetMorganFingerprint(m,
59 | 2) #<- 2 is the *radius* of the circular fingerprint
60 | fps = fp.GetNonzeroElements()
61 | score1 = 0.
62 | nf = 0
63 | for bitId, v in iteritems(fps):
64 | nf += v
65 | sfp = bitId
66 | score1 += _fscores.get(sfp, -4) * v
67 | score1 /= nf
68 |
69 | # features score
70 | nAtoms = m.GetNumAtoms()
71 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
72 | ri = m.GetRingInfo()
73 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
74 | nMacrocycles = 0
75 | for x in ri.AtomRings():
76 | if len(x) > 8:
77 | nMacrocycles += 1
78 |
79 | sizePenalty = nAtoms**1.005 - nAtoms
80 | stereoPenalty = math.log10(nChiralCenters + 1)
81 | spiroPenalty = math.log10(nSpiro + 1)
82 | bridgePenalty = math.log10(nBridgeheads + 1)
83 | macrocyclePenalty = 0.
84 | # ---------------------------------------
85 | # This differs from the paper, which defines:
86 | # macrocyclePenalty = math.log10(nMacrocycles+1)
87 | # This form generates better results when 2 or more macrocycles are present
88 | if nMacrocycles > 0:
89 | macrocyclePenalty = math.log10(2)
90 |
91 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
92 |
93 | # correction for the fingerprint density
94 | # not in the original publication, added in version 1.1
95 | # to make highly symmetrical molecules easier to synthetise
96 | score3 = 0.
97 | if nAtoms > len(fps):
98 | score3 = math.log(float(nAtoms) / len(fps)) * .5
99 |
100 | sascore = score1 + score2 + score3
101 |
102 | # need to transform "raw" value into scale between 1 and 10
103 | min = -4.0
104 | max = 2.5
105 | sascore = 11. - (sascore - min + 1) / (max - min) * 9.
106 | # smooth the 10-end
107 | if sascore > 8.:
108 | sascore = 8. + math.log(sascore + 1. - 9.)
109 | if sascore > 10.:
110 | sascore = 10.0
111 | elif sascore < 1.:
112 | sascore = 1.0
113 |
114 | return sascore
115 |
116 |
117 | def processMols(mols):
118 | print('smiles\tName\tsa_score')
119 | for i, m in enumerate(mols):
120 | if m is None:
121 | continue
122 |
123 | s = calculateScore(m)
124 |
125 | smiles = Chem.MolToSmiles(m)
126 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
127 |
128 |
129 | if __name__ == '__main__':
130 | import sys, time
131 |
132 | t1 = time.time()
133 | readFragmentScores("fpscores")
134 | t2 = time.time()
135 |
136 | suppl = Chem.SmilesMolSupplier(sys.argv[1])
137 | t3 = time.time()
138 | processMols(suppl)
139 | t4 = time.time()
140 |
141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
142 | file=sys.stderr)
143 |
144 | #
145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
146 | # All rights reserved.
147 | #
148 | # Redistribution and use in source and binary forms, with or without
149 | # modification, are permitted provided that the following conditions are
150 | # met:
151 | #
152 | # * Redistributions of source code must retain the above copyright
153 | # notice, this list of conditions and the following disclaimer.
154 | # * Redistributions in binary form must reproduce the above
155 | # copyright notice, this list of conditions and the following
156 | # disclaimer in the documentation and/or other materials provided
157 | # with the distribution.
158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc.
159 | # nor the names of its contributors may be used to endorse or promote
160 | # products derived from this software without specific prior written permission.
161 | #
162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
173 | #
174 |
--------------------------------------------------------------------------------
/utils/so3.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from utils.geometry import quaternion_to_rotation_matrix
8 |
9 |
10 | def log_rotation(R):
11 | trace = R[..., range(3), range(3)].sum(-1)
12 | if torch.is_grad_enabled():
13 | # The derivative of acos at -1.0 is -inf, so to stablize the gradient, we use -0.9999
14 | min_cos = -0.999
15 | else:
16 | min_cos = -1.0
17 | cos_theta = ((trace - 1) / 2).clamp_min(min=min_cos)
18 | sin_theta = torch.sqrt(1 - cos_theta ** 2)
19 | theta = torch.acos(cos_theta)
20 | coef = ((theta + 1e-8) / (2 * sin_theta + 2e-8))[..., None, None]
21 | logR = coef * (R - R.transpose(-1, -2))
22 | return logR
23 |
24 |
25 | def skewsym_to_so3vec(S):
26 | x = S[..., 1, 2]
27 | y = S[..., 2, 0]
28 | z = S[..., 0, 1]
29 | w = torch.stack([x, y, z], dim=-1)
30 | return w
31 |
32 |
33 | def so3vec_to_skewsym(w):
34 | x, y, z = torch.unbind(w, dim=-1)
35 | o = torch.zeros_like(x)
36 | S = torch.stack([
37 | o, z, -y,
38 | -z, o, x,
39 | y, -x, o,
40 | ], dim=-1).reshape(w.shape[:-1] + (3, 3))
41 | return S
42 |
43 |
44 | def exp_skewsym(S):
45 | x = torch.linalg.norm(skewsym_to_so3vec(S), dim=-1)
46 | I = torch.eye(3).to(S).view([1 for _ in range(S.dim() - 2)] + [3, 3])
47 |
48 | sinx, cosx = torch.sin(x), torch.cos(x)
49 | b = (sinx + 1e-8) / (x + 1e-8)
50 | c = (1 - cosx + 1e-8) / (x ** 2 + 2e-8) # lim_{x->0} (1-cosx)/(x^2) = 0.5
51 |
52 | S2 = S @ S
53 | return I + b[..., None, None] * S + c[..., None, None] * S2
54 |
55 |
56 | def so3vec_to_rotation(w):
57 | return exp_skewsym(so3vec_to_skewsym(w))
58 |
59 |
60 | def rotation_to_so3vec(R):
61 | logR = log_rotation(R)
62 | w = skewsym_to_so3vec(logR)
63 | return w
64 |
65 |
66 | def random_uniform_so3(size, device='cpu'):
67 | q = F.normalize(torch.randn(list(size) + [4, ], device=device), dim=-1) # (..., 4)
68 | return rotation_to_so3vec(quaternion_to_rotation_matrix(q))
69 |
70 |
71 | class ApproxAngularDistribution(nn.Module):
72 | # todo: interpolation
73 | def __init__(self, stddevs, std_threshold=0.1, num_bins=8192, num_iters=1024):
74 | super().__init__()
75 | self.std_threshold = std_threshold
76 | self.num_bins = num_bins
77 | self.num_iters = num_iters
78 | self.register_buffer('stddevs', torch.FloatTensor(stddevs))
79 | self.register_buffer('approx_flag', self.stddevs <= std_threshold)
80 | self._precompute_histograms()
81 |
82 | @staticmethod
83 | def _pdf(x, e, L):
84 | """
85 | Args:
86 | x: (N, )
87 | e: Float
88 | L: Integer
89 | """
90 | x = x[:, None] # (N, *)
91 | c = ((1 - torch.cos(x)) / math.pi) # (N, *)
92 | l = torch.arange(0, L)[None, :] # (*, L)
93 | a = (2 * l + 1) * torch.exp(-l * (l + 1) * (e ** 2)) # (*, L)
94 | b = (torch.sin((l + 0.5) * x) + 1e-6) / (torch.sin(x / 2) + 1e-6) # (N, L)
95 |
96 | f = (c * a * b).sum(dim=1)
97 | return f
98 |
99 | def _precompute_histograms(self):
100 | X, Y = [], []
101 | for std in self.stddevs:
102 | std = std.item()
103 | x = torch.linspace(0, math.pi, self.num_bins) # (n_bins,)
104 | y = self._pdf(x, std, self.num_iters) # (n_bins,)
105 | y = torch.nan_to_num(y).clamp_min(0)
106 | X.append(x)
107 | Y.append(y)
108 | self.register_buffer('X', torch.stack(X, dim=0)) # (n_stddevs, n_bins)
109 | self.register_buffer('Y', torch.stack(Y, dim=0)) # (n_stddevs, n_bins)
110 |
111 | def sample(self, std_idx):
112 | """
113 | Args:
114 | std_idx: Indices of standard deviation.
115 | Returns:
116 | samples: Angular samples [0, PI), same size as std.
117 | """
118 | size = std_idx.size()
119 | std_idx = std_idx.flatten() # (N,)
120 |
121 | # Samples from histogram
122 | prob = self.Y[std_idx] # (N, n_bins)
123 | bin_idx = torch.multinomial(prob[:, :-1], num_samples=1).squeeze(-1) # (N,)
124 | bin_start = self.X[std_idx, bin_idx] # (N,)
125 | bin_width = self.X[std_idx, bin_idx + 1] - self.X[std_idx, bin_idx]
126 | samples_hist = bin_start + torch.rand_like(bin_start) * bin_width # (N,)
127 |
128 | # Samples from Gaussian approximation
129 | mean_gaussian = self.stddevs[std_idx] * 2
130 | std_gaussian = self.stddevs[std_idx]
131 | samples_gaussian = mean_gaussian + torch.randn_like(mean_gaussian) * std_gaussian
132 | samples_gaussian = samples_gaussian.abs() % math.pi
133 |
134 | # Choose from histogram or Gaussian
135 | gaussian_flag = self.approx_flag[std_idx]
136 | samples = torch.where(gaussian_flag, samples_gaussian, samples_hist)
137 |
138 | return samples.reshape(size)
139 |
140 |
141 | def random_normal_so3(std_idx, angular_distrib, device='cpu'):
142 | size = std_idx.size()
143 | u = F.normalize(torch.randn(list(size) + [3, ], device=device), dim=-1)
144 | theta = angular_distrib.sample(std_idx)
145 | w = u * theta[..., None]
146 | return w
147 |
--------------------------------------------------------------------------------
/utils/train.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from torch_geometric.data import Data, Batch
4 | from utils.misc import BlackHole
5 | from easydict import EasyDict
6 |
7 |
8 | def repeat_data(data: Data, num_repeat) -> Batch:
9 | datas = [copy.deepcopy(data) for i in range(num_repeat)]
10 | return Batch.from_data_list(datas)
11 |
12 |
13 | def repeat_batch(batch: Batch, num_repeat) -> Batch:
14 | datas = batch.to_data_list()
15 | new_data = []
16 | for i in range(num_repeat):
17 | new_data += copy.deepcopy(datas)
18 | return Batch.from_data_list(new_data)
19 |
20 |
21 | def inf_iterator(iterable):
22 | iterator = iterable.__iter__()
23 | while True:
24 | try:
25 | yield iterator.__next__()
26 | except StopIteration:
27 | iterator = iterable.__iter__()
28 |
29 |
30 | def get_optimizer(cfg, model):
31 | if cfg.type == 'adam':
32 | return torch.optim.Adam(
33 | model.parameters(),
34 | lr=cfg.lr,
35 | weight_decay=cfg.weight_decay,
36 | betas=(cfg.beta1, cfg.beta2, )
37 | )
38 | elif cfg.type == 'adamw':
39 | return torch.optim.AdamW(
40 | model.parameters(),
41 | lr=cfg.lr,
42 | weight_decay=cfg.weight_decay,
43 | betas=(cfg.beta1, cfg.beta2, )
44 | )
45 | else:
46 | raise NotImplementedError('Optimizer not supported: %s' % cfg.type)
47 |
48 |
49 | def get_scheduler(cfg, optimizer):
50 | if cfg.type == 'plateau':
51 | return torch.optim.lr_scheduler.ReduceLROnPlateau(
52 | optimizer,
53 | factor=cfg.factor,
54 | patience=cfg.patience,
55 | min_lr=cfg.min_lr
56 | )
57 | else:
58 | raise NotImplementedError('Scheduler not supported: %s' % cfg.type)
59 |
60 |
61 | def sum_weighted_losses(losses, weights):
62 | """
63 | Args:
64 | losses: Dict of scalar tensors.
65 | weights: Dict of weights.
66 | """
67 | loss = 0
68 | for k in losses.keys():
69 | if weights is None:
70 | loss = loss + losses[k]
71 | else:
72 | loss = loss + weights[k] * losses[k]
73 | return loss
74 |
75 |
76 | def log_losses(out, it, tag, train_report_iter=1, logger=BlackHole(), writer=BlackHole(), others={}):
77 | if it % train_report_iter == 0:
78 | logstr = '[%s] Iter %05d' % (tag, it)
79 | logstr += ' | loss %.4f' % out['overall'].item()
80 | for k, v in out.items():
81 | if k == 'overall': continue
82 | logstr += ' | loss(%s) %.4f' % (k, v.item())
83 | for k, v in others.items():
84 | if k == 'lr':
85 | logstr += ' | %s %2.6f' % (k, v)
86 | else:
87 | logstr += ' | %s %2.4f' % (k, v)
88 | logger.info(logstr)
89 |
90 | for k, v in out.items():
91 | if k == 'overall':
92 | writer.add_scalar('%s/loss' % tag, v, it)
93 | else:
94 | writer.add_scalar('%s/loss_%s' % (tag, k), v, it)
95 | for k, v in others.items():
96 | writer.add_scalar('%s/%s' % (tag, k), v, it)
97 | writer.flush()
98 |
99 |
100 | class ValidationLossTape(object):
101 |
102 | def __init__(self):
103 | super().__init__()
104 | self.accumulate = {}
105 | self.others = {}
106 | self.total = 0
107 |
108 | def update(self, out, n, others={}):
109 | self.total += n
110 | for k, v in out.items():
111 | if k not in self.accumulate:
112 | self.accumulate[k] = v.clone().detach()
113 | else:
114 | self.accumulate[k] += v.clone().detach()
115 |
116 | for k, v in others.items():
117 | if k not in self.others:
118 | self.others[k] = v.clone().detach()
119 | else:
120 | self.others[k] += v.clone().detach()
121 |
122 | def log(self, it, logger=BlackHole(), writer=BlackHole(), tag='val'):
123 | avg = EasyDict({k: v / self.total for k, v in self.accumulate.items()})
124 | avg_others = EasyDict({k: v / self.total for k, v in self.others.items()})
125 | log_losses(avg, it, tag, logger=logger, writer=writer, others=avg_others)
126 | return avg['overall']
127 |
--------------------------------------------------------------------------------
/utils/train_linker_smiles.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guanjq/LinkerNet/280759c16ccecece0d81ab9cebe3f44041b80e51/utils/train_linker_smiles.pkl
--------------------------------------------------------------------------------
/utils/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from rdkit import Chem
5 | from torch_geometric.transforms import BaseTransform
6 |
7 | from utils import so3
8 | from utils.geometry import local_to_global, global_to_local, rotation_matrix_from_vectors
9 |
10 |
11 | # required by model:
12 | # x_noisy, c_noisy, atom_feat, edge_index, edge_feat,
13 | # R1_noisy, R2_noisy, R1, R2, eps_p1, eps_p2, x_0, c_0, t, fragment_mask, batch
14 |
15 |
16 | def modify_frags_conformer(frags_local_pos, frag_idx_mask, v_frags, p_frags):
17 | R_frags = so3.so3vec_to_rotation(v_frags)
18 | x_frags = torch.zeros_like(frags_local_pos)
19 | for i in range(2):
20 | noisy_pos = local_to_global(R_frags[i], p_frags[i],
21 | frags_local_pos[frag_idx_mask == i + 1])
22 | x_frags[frag_idx_mask == i + 1] = noisy_pos
23 | return x_frags
24 |
25 |
26 | def dataset_info(dataset): # qm9, zinc, cep
27 | if dataset == 'qm9':
28 | return {'atom_types': ["H", "C", "N", "O", "F"],
29 | 'maximum_valence': {0: 1, 1: 4, 2: 3, 3: 2, 4: 1},
30 | 'number_to_atom': {0: "H", 1: "C", 2: "N", 3: "O", 4: "F"},
31 | 'bucket_sizes': np.array(list(range(4, 28, 2)) + [29])
32 | }
33 | elif dataset == 'zinc' or dataset == 'protac':
34 | return {'atom_types': ['Br1(0)', 'C4(0)', 'Cl1(0)', 'F1(0)', 'H1(0)', 'I1(0)',
35 | 'N2(-1)', 'N3(0)', 'N4(1)', 'O1(-1)', 'O2(0)', 'S2(0)', 'S4(0)', 'S6(0)'],
36 | 'maximum_valence': {0: 1, 1: 4, 2: 1, 3: 1, 4: 1, 5: 1, 6: 2, 7: 3, 8: 4, 9: 1, 10: 2, 11: 2, 12: 4,
37 | 13: 6, 14: 3},
38 | 'number_to_atom': {0: 'Br', 1: 'C', 2: 'Cl', 3: 'F', 4: 'H', 5: 'I', 6: 'N', 7: 'N', 8: 'N', 9: 'O',
39 | 10: 'O', 11: 'S', 12: 'S', 13: 'S'},
40 | 'bucket_sizes': np.array(
41 | [28, 31, 33, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 55, 58, 84])
42 | }
43 |
44 | elif dataset == "cep":
45 | return {'atom_types': ["C", "S", "N", "O", "Se", "Si"],
46 | 'maximum_valence': {0: 4, 1: 2, 2: 3, 3: 2, 4: 2, 5: 4},
47 | 'number_to_atom': {0: "C", 1: "S", 2: "N", 3: "O", 4: "Se", 5: "Si"},
48 | 'bucket_sizes': np.array([25, 28, 29, 30, 32, 33, 34, 35, 36, 37, 38, 39, 43, 46])
49 | }
50 | else:
51 | print("the datasets in use are qm9|zinc|cep")
52 | exit(1)
53 |
54 |
55 | class FeaturizeAtom(BaseTransform):
56 |
57 | def __init__(self, dataset_name, known_anchor=False,
58 | add_atom_type=True, add_atom_feat=True):
59 | super().__init__()
60 | self.dataset_name = dataset_name
61 | self.dataset_info = dataset_info(dataset_name)
62 | self.known_anchor = known_anchor
63 | self.add_atom_type = add_atom_type
64 | self.add_atom_feat = add_atom_feat
65 |
66 | @property
67 | def num_classes(self):
68 | return len(self.dataset_info['atom_types'])
69 |
70 | @property
71 | def feature_dim(self):
72 | n_feat_dim = 2
73 | if self.known_anchor:
74 | n_feat_dim += 2
75 | return n_feat_dim
76 |
77 | def get_index(self, atom_num, valence, charge):
78 | if self.dataset_name in ['zinc', 'protac']:
79 | pt = Chem.GetPeriodicTable()
80 | atom_str = "%s%i(%i)" % (pt.GetElementSymbol(atom_num), valence, charge)
81 | return self.dataset_info['atom_types'].index(atom_str)
82 | else:
83 | raise ValueError
84 |
85 | def get_element_from_index(self, index):
86 | pt = Chem.GetPeriodicTable()
87 | symb = self.dataset_info['number_to_atom'][index]
88 | return pt.GetAtomicNumber(symb)
89 |
90 | def __call__(self, data):
91 | if self.add_atom_type:
92 | x = [self.get_index(int(e), int(v), int(c)) for e, v, c in zip(data.element, data.valence, data.charge)]
93 | data.atom_type = torch.tensor(x)
94 | if self.add_atom_feat:
95 | # fragment / linker indicator, independent with atom types
96 | linker_flag = F.one_hot((data.fragment_mask == 0).long(), 2)
97 | all_feats = [linker_flag]
98 | # fragment anchor flag
99 | if self.known_anchor:
100 | anchor_flag = F.one_hot((data.anchor_mask == 1).long(), 2)
101 | all_feats.append(anchor_flag)
102 | data.atom_feat = torch.cat(all_feats, -1)
103 | return data
104 |
105 |
106 | class BuildCompleteGraph(BaseTransform):
107 |
108 | def __init__(self, known_linker_bond=False, known_cand_anchors=False):
109 | super().__init__()
110 | self.known_linker_bond = known_linker_bond
111 | self.known_cand_anchors = known_cand_anchors
112 |
113 | @property
114 | def num_bond_classes(self):
115 | return 5
116 |
117 | @property
118 | def bond_feature_dim(self):
119 | return 4
120 |
121 | @staticmethod
122 | def _get_interleave_edge_index(edge_index):
123 | edge_index_sym = torch.stack([edge_index[1], edge_index[0]])
124 | e = torch.zeros_like(torch.cat([edge_index, edge_index_sym], dim=-1))
125 | e[:, ::2] = edge_index
126 | e[:, 1::2] = edge_index_sym
127 | return e
128 |
129 | def _build_interleave_fc(self, n1_atoms, n2_atoms):
130 | eij = torch.triu_indices(n1_atoms, n2_atoms, offset=1)
131 | e = self._get_interleave_edge_index(eij)
132 | return e
133 |
134 | def __call__(self, data):
135 | # fully connected graph
136 | num_nodes = len(data.pos)
137 | fc_edge_index = self._build_interleave_fc(num_nodes, num_nodes)
138 | data.edge_index = fc_edge_index
139 |
140 | # (ll, lf, fl, ff) indicator
141 | src, dst = data.edge_index
142 | num_edges = len(fc_edge_index[0])
143 | edge_type = torch.zeros(num_edges).long()
144 | l_ind_src = data.fragment_mask[src] == 0
145 | l_ind_dst = data.fragment_mask[dst] == 0
146 | edge_type[l_ind_src & l_ind_dst] = 0
147 | edge_type[l_ind_src & ~l_ind_dst] = 1
148 | edge_type[~l_ind_src & l_ind_dst] = 2
149 | edge_type[~l_ind_src & ~l_ind_dst] = 3
150 | edge_type = F.one_hot(edge_type, num_classes=4)
151 | data.edge_feat = edge_type
152 |
153 | # bond type 0: none 1: singe 2: double 3: triple 4: aromatic
154 | bond_type = torch.zeros(num_edges).long()
155 |
156 | id_fc_edge = fc_edge_index[0] * num_nodes + fc_edge_index[1]
157 | id_frag_bond = data.bond_index[0] * num_nodes + data.bond_index[1]
158 | idx_edge = torch.tensor([torch.nonzero(id_fc_edge == id_).squeeze() for id_ in id_frag_bond])
159 | bond_type[idx_edge] = data.bond_type
160 | # data.edge_type = F.one_hot(bond_type, num_classes=5)
161 | data.edge_type = bond_type
162 | if self.known_linker_bond:
163 | data.linker_bond_mask = (data.fragment_mask[src] == 0) ^ (data.fragment_mask[dst] == 0)
164 | elif self.known_cand_anchors:
165 | ll_bond = (data.fragment_mask[src] == 0) & (data.fragment_mask[dst] == 0)
166 | fl_bond = (data.cand_anchors_mask[src] == 1) & (data.fragment_mask[dst] == 0)
167 | lf_bond = (data.cand_anchors_mask[dst] == 1) & (data.fragment_mask[src] == 0)
168 | data.linker_bond_mask = ll_bond | fl_bond | lf_bond
169 | else:
170 | data.linker_bond_mask = (data.fragment_mask[src] == 0) | (data.fragment_mask[dst] == 0)
171 |
172 | data.inner_edge_mask = (data.fragment_mask[src] == data.fragment_mask[dst])
173 | return data
174 |
175 |
176 | class SelectCandAnchors(BaseTransform):
177 |
178 | def __init__(self, mode='exact', k=2):
179 | super().__init__()
180 | self.mode = mode
181 | assert mode in ['exact', 'k-hop']
182 | self.k = k
183 |
184 | @staticmethod
185 | def bfs(nbh_list, node, k=2, valid_list=[]):
186 | visited = [node]
187 | queue = [node]
188 | level = [0]
189 | bfs_perm = []
190 |
191 | while len(queue) > 0:
192 | m = queue.pop(0)
193 | l = level.pop(0)
194 | if l > k:
195 | break
196 | bfs_perm.append(m)
197 |
198 | for neighbour in nbh_list[m]:
199 | if neighbour not in visited and neighbour in valid_list:
200 | visited.append(neighbour)
201 | queue.append(neighbour)
202 | level.append(l + 1)
203 | return bfs_perm
204 |
205 | def __call__(self, data):
206 | # link_indices = (data.linker_mask == 1).nonzero()[:, 0].tolist()
207 | # frag_indices = (data.linker_mask == 0).nonzero()[:, 0].tolist()
208 | # anchor_indices = [j for i, j in zip(*data.bond_index.tolist()) if i in link_indices and j in frag_indices]
209 | # data.anchor_indices = anchor_indices
210 | cand_anchors_mask = torch.zeros_like(data.fragment_mask).bool()
211 | if self.mode == 'exact':
212 | cand_anchors_mask[data.anchor_indices] = True
213 | data.cand_anchors_mask = cand_anchors_mask
214 |
215 | elif self.mode == 'k-hop':
216 | # data.nbh_list = {i.item(): [j.item() for k, j in enumerate(data.bond_index[1])
217 | # if data.bond_index[0, k].item() == i] for i in data.bond_index[0]}
218 | # all_cand = []
219 | for anchor in data.anchor_indices:
220 | a_frag_id = data.fragment_mask[anchor]
221 | a_valid_list = (data.fragment_mask == a_frag_id).nonzero(as_tuple=True)[0].tolist()
222 | a_cand = self.bfs(data.nbh_list, anchor, k=self.k, valid_list=a_valid_list)
223 | a_cand = [a for a in a_cand if data.frag_mol.GetAtomWithIdx(a).GetTotalNumHs() > 0]
224 | cand_anchors_mask[a_cand] = True
225 | # all_cand.append(a_cand)
226 | data.cand_anchors_mask = cand_anchors_mask
227 | else:
228 | raise ValueError(self.mode)
229 | return data
230 |
231 |
232 | class StackFragLocalPos(BaseTransform):
233 | def __init__(self, max_num_atoms=30):
234 | super().__init__()
235 | self.max_num_atoms = max_num_atoms
236 |
237 | def __call__(self, data):
238 | frag_idx_mask = data.fragment_mask[data.fragment_mask > 0]
239 | f1_pos = data.frags_local_pos[frag_idx_mask == 1]
240 | f2_pos = data.frags_local_pos[frag_idx_mask == 2]
241 | assert len(f1_pos) <= self.max_num_atoms
242 | assert len(f2_pos) <= self.max_num_atoms
243 | # todo: use F.pad
244 | f1_fill_pos = torch.cat([f1_pos, torch.zeros(self.max_num_atoms - len(f1_pos), 3)], dim=0)
245 | f1_mask = torch.cat([torch.ones(len(f1_pos)), torch.zeros(self.max_num_atoms - len(f1_pos))], dim=0)
246 | f2_fill_pos = torch.cat([f2_pos, torch.zeros(self.max_num_atoms - len(f2_pos), 3)], dim=0)
247 | f2_mask = torch.cat([torch.ones(len(f2_pos)), torch.zeros(self.max_num_atoms - len(f2_pos))], dim=0)
248 | data.frags_local_pos_filled = torch.stack([f1_fill_pos, f2_fill_pos], dim=0)
249 | data.frags_local_pos_mask = torch.stack([f1_mask, f2_mask], dim=0).bool()
250 | return data
251 |
252 |
253 | class RelativeGeometry(BaseTransform):
254 | def __init__(self, mode):
255 | super().__init__()
256 | self.mode = mode
257 |
258 | def __call__(self, data):
259 | if self.mode == 'relative_pos_and_rot':
260 | # randomly take first / second fragment as the reference
261 | idx = torch.randint(0, 2, [1])[0]
262 | pos = (data.pos - data.frags_t[idx]) @ data.frags_R[idx]
263 | frags_R = data.frags_R[idx].T @ data.frags_R
264 | frags_t = (data.frags_t - data.frags_t[idx]) @ data.frags_R[idx]
265 | # frags_d doesn't change
266 | data.frags_rel_mask = torch.tensor([True, True])
267 | data.frags_rel_mask[idx] = False # the reference fragment will not be added noise later
268 | data.frags_atom_rel_mask = data.fragment_mask == (2 - idx)
269 |
270 | elif self.mode == 'two_pos_and_rot':
271 | # still guarantee the center of two fragments' centers is the origin
272 | rand_rot = get_random_rot()
273 | pos = data.pos @ rand_rot
274 | frags_R = rand_rot.T @ data.frags_R
275 | frags_t = data.frags_t @ rand_rot
276 | data.frags_rel_mask = torch.tensor([True, True])
277 |
278 | elif self.mode == 'distance_and_two_rot_aug':
279 | # only the first row of frags_R unchanged
280 | rand_rot = get_random_rot()
281 | tmp_pos = data.pos @ rand_rot
282 | tmp_frags_R = rand_rot.T @ data.frags_R
283 | tmp_frags_t = data.frags_t @ rand_rot
284 |
285 | rot = rotation_matrix_from_vectors(tmp_frags_t[1] - tmp_frags_t[0], torch.tensor([1., 0., 0.]))
286 | tr = -rot @ ((tmp_frags_t[0] + tmp_frags_t[1]) / 2)
287 | pos = tmp_pos @ rot.T + tr
288 | frags_R = rot @ tmp_frags_R
289 | frags_t = tmp_frags_t @ rot.T + tr
290 | data.frags_rel_mask = torch.tensor([True, True])
291 |
292 | elif self.mode == 'distance_and_two_rot':
293 | # unchanged
294 | frags_R = data.frags_R
295 | frags_t = data.frags_t
296 | pos = data.pos
297 | data.frags_rel_mask = torch.tensor([True, True])
298 |
299 | else:
300 | raise ValueError(self.mode)
301 |
302 | data.frags_R = frags_R
303 | data.frags_t = frags_t
304 | data.pos = pos
305 | # print('frags_R: ', data.frags_R, 'frags_t: ', frags_t)
306 | return data
307 |
308 |
309 | def get_random_rot():
310 | M = np.random.randn(3, 3)
311 | Q, __ = np.linalg.qr(M)
312 | rand_rot = torch.from_numpy(Q.astype(np.float32))
313 | return rand_rot
314 |
315 |
316 | class ReplaceLocalFrame(BaseTransform):
317 | def __init__(self):
318 | super().__init__()
319 |
320 | def __call__(self, data):
321 | frags_R1 = get_random_rot()
322 | frags_R2 = get_random_rot()
323 | f1_local_pos = global_to_local(frags_R1, data.frags_t[0], data.pos[data.fragment_mask == 1])
324 | f2_local_pos = global_to_local(frags_R2, data.frags_t[1], data.pos[data.fragment_mask == 2])
325 | data.frags_local_pos = torch.cat([f1_local_pos, f2_local_pos], dim=0)
326 | return data
327 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import py3Dmol
2 | import os
3 | import copy
4 | from rdkit import Chem
5 | from rdkit.Chem import Draw
6 |
7 |
8 | def visualize_complex(pdb_block, sdf_block, show_protein_surface=True, show_ligand=True, show_ligand_surface=True):
9 | view = py3Dmol.view()
10 |
11 | # Add protein to the canvas
12 | view.addModel(pdb_block, 'pdb')
13 | if show_protein_surface:
14 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'white'}, {'model': -1})
15 | else:
16 | view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}})
17 | view.setStyle({'model': -1}, {"cartoon": {"style": "edged", 'opacity': 0}})
18 |
19 | # Add ligand to the canvas
20 | if show_ligand:
21 | view.addModel(sdf_block, 'sdf')
22 | view.setStyle({'model': -1}, {'stick': {}})
23 | # view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}})
24 | if show_ligand_surface:
25 | view.addSurface(py3Dmol.VDW, {'opacity': 0.8}, {'model': -1})
26 |
27 | view.zoomTo()
28 | return view
29 |
30 |
31 | def visualize_complex_with_frags(pdb_block, all_frags, show_protein_surface=True, show_ligand=True, show_ligand_surface=True):
32 | view = py3Dmol.view()
33 |
34 | # Add protein to the canvas
35 | view.addModel(pdb_block, 'pdb')
36 | if show_protein_surface:
37 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'white'}, {'model': -1})
38 | else:
39 | view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}})
40 | view.setStyle({'model': -1}, {"cartoon": {"style": "edged", 'opacity': 0}})
41 |
42 | # Add ligand to the canvas
43 | if show_ligand:
44 | for frag in all_frags:
45 | sdf_block = Chem.MolToMolBlock(frag)
46 | view.addModel(sdf_block, 'sdf')
47 | view.setStyle({'model': -1}, {'stick': {}})
48 | # view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}})
49 | if show_ligand_surface:
50 | view.addSurface(py3Dmol.VDW, {'opacity': 0.8}, {'model': -1})
51 |
52 | view.zoomTo()
53 | return view
54 |
55 | def visualize_complex_highlight_pocket(pdb_block, sdf_block,
56 | pocket_atom_idx, pocket_res_idx=None, pocket_chain=None,
57 | show_ligand=True, show_ligand_surface=True):
58 | view = py3Dmol.view()
59 |
60 | # Add protein to the canvas
61 | view.addModel(pdb_block, 'pdb')
62 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'white'}, {'model': -1})
63 | if pocket_atom_idx:
64 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'red'}, {'model': -1, 'serial': pocket_atom_idx})
65 | if pocket_res_idx:
66 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'red'},
67 | {'model': -1, 'chain': pocket_chain, 'resi': list(set(pocket_res_idx))})
68 | # color_map = ['red', 'yellow', 'blue', 'green']
69 | # for idx, pocket_atom_idx in enumerate(all_pocket_atom_idx):
70 | # print(pocket_atom_idx)
71 | # view.addSurface(py3Dmol.VDW, {'opacity':0.7, 'color':color_map[idx]}, {'model': -1, 'serial': pocket_atom_idx})
72 | # view.addSurface(py3Dmol.VDW, {'opacity':0.7,'color':'red'}, {'model': -1, 'resi': list(set(pocket_residue))})
73 |
74 | # view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}})
75 | view.setStyle({'model': -1}, {"cartoon": {"style": "edged", 'opacity': 0}})
76 | # view.setStyle({'model': -1, 'serial': atom_idx}, {'cartoon': {'color': 'red'}})
77 | # view.setStyle({'model': -1, 'resi': [482, 484]}, {'cartoon': {'color': 'green'}})
78 |
79 | # Add ligand to the canvas
80 | if show_ligand:
81 | view.addModel(sdf_block, 'sdf')
82 | view.setStyle({'model': -1}, {'stick': {}})
83 | # view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}})
84 | if show_ligand_surface:
85 | view.addSurface(py3Dmol.VDW, {'opacity': 0.8}, {'model': -1})
86 |
87 | view.zoomTo()
88 | return view
89 |
90 |
91 | def visualize_mol_highlight_fragments(mol, match_list):
92 | all_target_atm = []
93 | for match in match_list:
94 | target_atm = []
95 | for atom in mol.GetAtoms():
96 | if atom.GetIdx() in match:
97 | target_atm.append(atom.GetIdx())
98 | all_target_atm.append(target_atm)
99 |
100 | return Draw.MolsToGridImage([mol for _ in range(len(match_list))], highlightAtomLists=all_target_atm,
101 | subImgSize=(400, 400), molsPerRow=4)
102 |
103 |
104 | def visualize_generated_xyz_v2(atom_pos, atom_type, protein_path, ligand_path=None, show_ligand=False, show_protein_surface=True):
105 | ptable = Chem.GetPeriodicTable()
106 |
107 | num_atoms = len(atom_pos)
108 | xyz = "%d\n\n" % (num_atoms,)
109 | for i in range(num_atoms):
110 | symb = ptable.GetElementSymbol(atom_type[i])
111 | x, y, z = atom_pos[i]
112 | xyz += "%s %.8f %.8f %.8f\n" % (symb, x, y, z)
113 |
114 | # print(xyz)
115 |
116 | with open(protein_path, 'r') as f:
117 | pdb_block = f.read()
118 |
119 | view = py3Dmol.view()
120 | # Generated molecule
121 | view.addModel(xyz, 'xyz')
122 | view.setStyle({'model': -1}, {'sphere': {'radius': 0.3}, 'stick': {}})
123 |
124 | # Pocket
125 | view.addModel(pdb_block, 'pdb')
126 | if show_protein_surface:
127 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'white'}, {'model': -1})
128 | else:
129 | view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}})
130 | view.setStyle({'model': -1}, {"cartoon": {"style": "edged", 'opacity': 0}})
131 |
132 | # Focus on the generated
133 | view.zoomTo()
134 |
135 | # Ligand
136 | if show_ligand:
137 | with open(ligand_path, 'r') as f:
138 | sdf_block = f.read()
139 | view.addModel(sdf_block, 'sdf')
140 | view.setStyle({'model': -1}, {'stick': {}})
141 |
142 | return view
143 |
144 |
145 | def visualize_generated_xyz(data, root, show_ligand=False):
146 | ptable = Chem.GetPeriodicTable()
147 |
148 | num_atoms = data.ligand_context_element.size(0)
149 | xyz = "%d\n\n" % (num_atoms,)
150 | for i in range(num_atoms):
151 | symb = ptable.GetElementSymbol(data.ligand_context_element[i].item())
152 | x, y, z = data.ligand_context_pos[i].clone().cpu().tolist()
153 | xyz += "%s %.8f %.8f %.8f\n" % (symb, x, y, z)
154 |
155 | # print(xyz)
156 |
157 | protein_path = os.path.join(root, data.protein_filename)
158 | ligand_path = os.path.join(root, data.ligand_filename)
159 |
160 | with open(protein_path, 'r') as f:
161 | pdb_block = f.read()
162 | with open(ligand_path, 'r') as f:
163 | sdf_block = f.read()
164 |
165 | view = py3Dmol.view()
166 | # Generated molecule
167 | view.addModel(xyz, 'xyz')
168 | view.setStyle({'model': -1}, {'sphere': {'radius': 0.3}, 'stick': {}})
169 | # Focus on the generated
170 | view.zoomTo()
171 |
172 | # Pocket
173 | view.addModel(pdb_block, 'pdb')
174 | view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}})
175 | # Ligand
176 | if show_ligand:
177 | view.addModel(sdf_block, 'sdf')
178 | view.setStyle({'model': -1}, {'stick': {}})
179 |
180 | return view
181 |
182 |
183 | def visualize_generated_sdf(data, protein_path, ligand_path, show_ligand=False, show_protein_surface=True):
184 | # protein_path = os.path.join(root, data.protein_filename)
185 | # ligand_path = os.path.join(root, data.ligand_filename)
186 |
187 | with open(protein_path, 'r') as f:
188 | pdb_block = f.read()
189 |
190 | view = py3Dmol.view()
191 | # Generated molecule
192 | mol_block = Chem.MolToMolBlock(data.rdmol)
193 | view.addModel(mol_block, 'sdf')
194 | view.setStyle({'model': -1}, {'sphere': {'radius': 0.3}, 'stick': {}})
195 | # Focus on the generated
196 | # view.zoomTo()
197 |
198 | # Pocket
199 | view.addModel(pdb_block, 'pdb')
200 | if show_protein_surface:
201 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'white'}, {'model': -1})
202 | view.setStyle({'model': -1}, {"cartoon": {"style": "edged", 'opacity': 0}})
203 | else:
204 | view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}})
205 | # Ligand
206 | if show_ligand:
207 | with open(ligand_path, 'r') as f:
208 | sdf_block = f.read()
209 | view.addModel(sdf_block, 'sdf')
210 | view.setStyle({'model': -1}, {'stick': {}})
211 | view.zoomTo()
212 | return view
213 |
214 |
215 | def visualize_generated_arms(data_list, protein_path, ligand_path, show_ligand=False, show_protein_surface=True):
216 | # protein_path = os.path.join(root, data.protein_filename)
217 | # ligand_path = os.path.join(root, data.ligand_filename)
218 |
219 | with open(protein_path, 'r') as f:
220 | pdb_block = f.read()
221 |
222 | view = py3Dmol.view()
223 | # Generated molecule
224 | for data in data_list:
225 | mol_block = Chem.MolToMolBlock(data.rdmol)
226 | view.addModel(mol_block, 'sdf')
227 | view.setStyle({'model': -1}, {'sphere': {'radius': 0.3}, 'stick': {}})
228 | # Focus on the generated
229 | # view.zoomTo()
230 |
231 | # Pocket
232 | view.addModel(pdb_block, 'pdb')
233 | if show_protein_surface:
234 | view.addSurface(py3Dmol.VDW, {'opacity': 0.7, 'color': 'white'}, {'model': -1})
235 | view.setStyle({'model': -1}, {"cartoon": {"style": "edged", 'opacity': 0}})
236 | else:
237 | view.setStyle({'model': -1}, {'cartoon': {'color': 'spectrum'}, 'line': {}})
238 | # Ligand
239 | if show_ligand:
240 | with open(ligand_path, 'r') as f:
241 | sdf_block = f.read()
242 | view.addModel(sdf_block, 'sdf')
243 | view.setStyle({'model': -1}, {'stick': {}})
244 | view.zoomTo()
245 | return view
246 |
247 |
248 | def visualize_ligand(mol, size=(300, 300), style="stick", surface=False, opacity=0.5, viewer=None):
249 | """Draw molecule in 3D
250 |
251 | Args:
252 | ----
253 | mol: rdMol, molecule to show
254 | size: tuple(int, int), canvas size
255 | style: str, type of drawing molecule
256 | style can be 'line', 'stick', 'sphere', 'carton'
257 | surface, bool, display SAS
258 | opacity, float, opacity of surface, range 0.0-1.0
259 | Return:
260 | ----
261 | viewer: py3Dmol.view, a class for constructing embedded 3Dmol.js views in ipython notebooks.
262 | """
263 | assert style in ('line', 'stick', 'sphere', 'carton')
264 | if viewer is None:
265 | viewer = py3Dmol.view(width=size[0], height=size[1])
266 | if isinstance(mol, list):
267 | for i, m in enumerate(mol):
268 | mblock = Chem.MolToMolBlock(m)
269 | viewer.addModel(mblock, 'mol' + str(i))
270 | elif len(mol.GetConformers()) > 1:
271 | for i in range(len(mol.GetConformers())):
272 | mblock = Chem.MolToMolBlock(mol, confId=i)
273 | viewer.addModel(mblock, 'mol' + str(i))
274 | else:
275 | mblock = Chem.MolToMolBlock(mol)
276 | viewer.addModel(mblock, 'mol')
277 | viewer.setStyle({style: {}})
278 | if surface:
279 | viewer.addSurface(py3Dmol.SAS, {'opacity': opacity})
280 | viewer.zoomTo()
281 | return viewer
282 |
283 |
284 | def visualize_full_mol(frags_mol, linker_pos, linker_type):
285 | ptable = Chem.GetPeriodicTable()
286 |
287 | num_atoms = len(linker_pos)
288 | xyz = "%d\n\n" % (num_atoms,)
289 | for i in range(num_atoms):
290 | symb = ptable.GetElementSymbol(linker_type[i])
291 | x, y, z = linker_pos[i]
292 | xyz += "%s %.8f %.8f %.8f\n" % (symb, x, y, z)
293 |
294 | view = py3Dmol.view()
295 | # Generated molecule
296 | view.addModel(xyz, 'xyz')
297 | view.setStyle({'model': -1}, {'sphere': {'radius': 0.3}, 'stick': {}})
298 |
299 | mblock = Chem.MolToMolBlock(frags_mol)
300 | view.addModel(mblock, 'sdf')
301 | view.setStyle({'model': -1}, {'stick': {}})
302 | view.zoomTo()
303 | return view
304 |
305 |
306 | def mol_with_atom_index(mol):
307 | mol = copy.deepcopy(mol)
308 | mol.RemoveAllConformers()
309 | for atom in mol.GetAtoms():
310 | atom.SetAtomMapNum(atom.GetIdx())
311 | return mol
312 |
--------------------------------------------------------------------------------