├── .gitignore
├── LICENSE
├── README.md
├── assets
├── CCCCC[C@@H](C)CO_traj.gif
├── Geodiff500-CCCCC[C@@H](C)CO_traj.gif
├── bbbp_sider.png
├── case_study.png
├── clintox_bace.png
├── mask_diff_framework.png
├── subgdiff_framework.jpg
├── subgdiff_framework.png
└── toxcast.png
├── checkpoints
├── qm9_200steps
│ ├── 2000000.pt
│ ├── qm9_200steps.yml
│ └── samples
│ │ ├── log_eval_samples_all.txt
│ │ ├── samples_all.pkl
│ │ ├── samples_all_covmat.csv
│ │ └── samples_all_covmat.pkl
├── qm9_500steps
│ ├── 2000000.pt
│ ├── qm9_500steps.yml
│ └── samples
│ │ ├── log_eval_samples_all.txt
│ │ ├── samples_all.pkl
│ │ └── samples_all_covmat.csv
└── sample_2024_02_05__15_50_46_SubGDiff500
│ └── log.txt
├── configs
├── qm9_200steps.yml
└── qm9_500steps.yml
├── datasets
├── chem.py
├── qm9.py
└── rdmol2data.py
├── env.yaml
├── eval_covmat.py
├── eval_prop.py
├── finetune
├── config.py
├── datasets
│ ├── __init__.py
│ ├── datasets_QM9.py
│ └── rdmol2data.py
└── splitters.py
├── finetune_qm9.py
├── models
├── common.py
├── encoder
│ ├── __init__.py
│ ├── coarse.py
│ ├── edge.py
│ ├── gin.py
│ └── schnet.py
├── epsnet
│ ├── __init__.py
│ ├── attention.py
│ ├── diffusion.py
│ ├── dualenc.py
│ └── dualencoder.py
└── geometry.py
├── test.py
├── test.sh
├── train.py
├── train.sh
├── train_dist.py
├── utils
├── chem.py
├── common.py
├── datasets.py
├── evaluation
│ └── covmat.py
├── misc.py
├── transforms.py
└── visualize.py
└── visualization.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | logs
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 SubGDiff authors.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SUBGDIFF: A Subgraph Diffusion Model to Improve Molecular Representation Learning
2 |
3 | The official implementation of the NeurIPS 2024 paper [SubgDiff: A Subgraph Diffusion Model to Improve Molecular Representation Learning](https://arxiv.org/abs/2405.05665).
4 |
5 |
6 |

7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | ## Environments
16 |
17 |
18 | ### Dependency
19 | ```
20 | python=3.7
21 |
22 | pytorch 1.11.0 py3.7_cuda11.3_cudnn8.2.0_0
23 | torch-cluster 1.6.0 pypi_0
24 | torch-geometric 1.7.2 pypi_0
25 | torch-scatter 2.0.9 pypi_0
26 | torch-sparse 0.6.13 pypi_0
27 |
28 | rdkit 2023.3.2 pypi_0
29 | ```
30 | other useful packages in my enviroment: see `env.txt` / `env.yaml`
31 |
32 |
33 | ## Dataset
34 |
35 | The dataset can directly download from [https://zenodo.org/records/10616999](https://zenodo.org/records/10616999)
36 |
37 | ### Offical Dataset
38 |
39 |
40 | The offical raw GEOM dataset is avaiable [[here]](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/JNGTDF).
41 |
42 |
43 | ### Prepare your own GEOM dataset from scratch (optional)
44 |
45 | You can also download origianl GEOM full dataset and prepare your own data split. A guide is available at previous work GeoDiff's [[github page]](https://github.com/MinkaiXu/GeoDiff).
46 |
47 | ## Training
48 |
49 | All hyper-parameters and training details are provided in config files (`./configs/*.yml`), and free feel to tune these parameters.
50 |
51 | You can train the model with the following commands:
52 |
53 | - QM9 dataset
54 | ```bash
55 | # Default settings
56 |
57 | sh train.sh
58 |
59 | ```
60 |
61 | ## Generation
62 |
63 |
64 |
65 |
66 |
67 | You can generate conformations for entire or part of test sets by:
68 |
69 | `One can use the commands to generate new samples:
70 | ```bash
71 | sh test.sh
72 | ```
73 |
74 | Here `start_idx` and `end_idx` indicate the range of the test set that we want to use. All hyper-parameters related to sampling can be set in `test.py` files. Following GeoDiff, the `start_idx=800` and `end_idx=100` in our experiments.
75 |
76 | Conformations of some drug-like molecules generated by SubGDiff can see `visulization.ipynb`
77 |
78 | ## Evaluation
79 |
80 | After generating conformations following the obove commands, the results of all benchmark tasks can be calculated based on the generated data.
81 |
82 | ### Conformation Generation
83 |
84 | The `COV` and `MAT` scores on the GEOM datasets can be calculated using the following commands:
85 |
86 | Evaluation on generated samples:
87 | ```bash
88 | python eval_covmat.py checkpoints/qm9_500steps/samples/samples_all.pkl
89 | python eval_covmat.py checkpoints/qm9_200steps/samples/samples_all.pkl
90 | ```
91 |
92 |
93 |
94 |
95 |
96 |
97 | ## Visualizing molecules with PyMol
98 |
99 | - molecules visulization:
100 |
101 | Run `visualization.ipynb`
102 |
103 | - molecules trajectory generation (--save_traj)
104 |
105 | ```bash
106 | python test.py --ckpt checkpoints/qm9_500steps/2000000.pt --config checkpoints/qm9_500steps/qm9_500steps.yml --test_set './data/GEOM/QM9/test_data_1k.pkl' --start_idx 800 --end_idx 1000 --sampling_type same_mask_noisy --n_steps 500 --device cuda:1 --w_global 0.1 --clip 1000 --clip_local 20 --global_start_sigma 5 --tag SubGDiff500 --save_traj
107 | ```
108 |
109 | Run `visualization.ipynb`.
110 |
111 |
112 |
113 |
114 |
115 | ## Subgraph diffusion process
116 |
117 | Please see the `Class SubgraphNoiseTransform` in `utils/transforms.py`
118 |
119 |
120 | ## Visualization
121 | ### Sampling trajectory of SubGDiff (ours)
122 | The video suggests that the denoising network will only denoise the atomic coordinates from a subgraph at each timestep during the sampling process.
123 |
124 |
125 |
126 | ### Sampling trajectory of GeoDiff
127 |
128 | The video indicates that the denoising network will denoise all atomic coordinates at each timestep during the sampling process.
129 |
130 |
131 |
132 | ### Visualization of final sampling results
133 | The following figures come from the final step of the sampling trajectory above (SubGDiff and GeoDiff). From the figures, we can see that our SubGDiff can generate a more similar conformation to the ground truth.
134 |
135 |
136 |
137 | ## Citation
138 | ```
139 | @inproceedings{zhang2024subgdiff,
140 | title={SubgDiff: A Subgraph Diffusion Model to Improve Molecular Representation Learning},
141 | author={Zhang, Jiying and Liu, Zijing and Wang, Yu, Feng, Bin and Li, Yu},
142 | booktitle={Advances in Neural Information Processing Systems},
143 | year={2024}
144 | }
145 | ```
146 |
147 |
148 | ## Acknowledgement
149 |
150 | This repo is built upon the previous work Geodiff's [[codebase]](https://github.com/MinkaiXu/GeoDiff).
151 |
--------------------------------------------------------------------------------
/assets/CCCCC[C@@H](C)CO_traj.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/CCCCC[C@@H](C)CO_traj.gif
--------------------------------------------------------------------------------
/assets/Geodiff500-CCCCC[C@@H](C)CO_traj.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/Geodiff500-CCCCC[C@@H](C)CO_traj.gif
--------------------------------------------------------------------------------
/assets/bbbp_sider.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/bbbp_sider.png
--------------------------------------------------------------------------------
/assets/case_study.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/case_study.png
--------------------------------------------------------------------------------
/assets/clintox_bace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/clintox_bace.png
--------------------------------------------------------------------------------
/assets/mask_diff_framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/mask_diff_framework.png
--------------------------------------------------------------------------------
/assets/subgdiff_framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/subgdiff_framework.jpg
--------------------------------------------------------------------------------
/assets/subgdiff_framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/subgdiff_framework.png
--------------------------------------------------------------------------------
/assets/toxcast.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/toxcast.png
--------------------------------------------------------------------------------
/checkpoints/qm9_200steps/2000000.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/checkpoints/qm9_200steps/2000000.pt
--------------------------------------------------------------------------------
/checkpoints/qm9_200steps/qm9_200steps.yml:
--------------------------------------------------------------------------------
1 | model:
2 | type: subgraph_diffusion
3 | network: dualenc
4 | hidden_dim: 128
5 | num_convs: 6
6 | num_convs_local: 4
7 | cutoff: 10.0
8 | mlp_act: relu
9 | beta_schedule: sigmoid
10 | beta_start: 1.e-7
11 | beta_end: 5.e-2
12 | num_diffusion_timesteps: 200
13 | edge_order: 3
14 | edge_encoder: mlp
15 | smooth_conv: false
16 | same_mask_steps: 10
17 | mask_pred: MLP #
18 |
19 | train:
20 | seed: 2021
21 | batch_size: 64
22 | val_freq: 2000
23 | max_iters: 2000000
24 | max_grad_norm: 10000.00
25 | anneal_power: 2.0
26 | optimizer:
27 | type: adam
28 | lr: 1.e-3
29 | weight_decay: 0.
30 | beta1: 0.95
31 | beta2: 0.999
32 | scheduler:
33 | type: plateau
34 | factor: 0.6
35 | patience: 30
36 |
37 | dataset:
38 | train: ./data/GEOM/QM9/train_data_40k_subgraph.pkl
39 | val: ./data/GEOM/QM9/val_data_5k_subgraph.pkl
40 | test: ./data/GEOM/QM9/test_data_1k.pkl
41 |
--------------------------------------------------------------------------------
/checkpoints/qm9_200steps/samples/log_eval_samples_all.txt:
--------------------------------------------------------------------------------
1 | [2024-02-04 16:07:52,500::eval::INFO] Loading results: checkpoints/qm9_200steps/samples/samples_all.pkl
2 | [2024-02-04 16:07:52,671::eval::INFO] Total: 200
3 | [2024-02-04 16:07:52,718::eval::INFO] Filtered: 196 / 200
4 | [2024-02-04 16:11:31,018::eval::INFO]
5 | COV-R_mean COV-R_median COV-R_std COV-P_mean COV-P_median COV-P_std
6 | 0.05 0.000101 0.000000 0.001164 0.000051 0.000000 0.000582
7 | 0.10 0.066095 0.020000 0.121742 0.024568 0.008969 0.047813
8 | 0.15 0.251479 0.171766 0.240313 0.095490 0.062149 0.108542
9 | 0.20 0.394886 0.338929 0.269503 0.161195 0.128679 0.144146
10 | 0.25 0.495578 0.445437 0.270431 0.213917 0.174318 0.162835
11 | 0.30 0.575141 0.536866 0.261140 0.259591 0.215421 0.175224
12 | 0.35 0.642270 0.614450 0.241695 0.299245 0.251838 0.183180
13 | 0.40 0.706840 0.700558 0.218976 0.340382 0.297155 0.190494
14 | 0.45 0.777926 0.811150 0.193104 0.397087 0.385711 0.202436
15 | 0.50 0.855279 0.889899 0.154255 0.477603 0.458864 0.213210
16 | 0.55 0.928285 0.963525 0.109881 0.579527 0.548333 0.220027
17 | 0.60 0.970038 1.000000 0.083130 0.675407 0.706145 0.219565
18 | 0.65 0.986554 1.000000 0.074638 0.755032 0.820000 0.209216
19 | 0.70 0.992985 1.000000 0.071559 0.815814 0.888153 0.196520
20 | 0.75 0.994636 1.000000 0.070967 0.857790 0.939757 0.186596
21 | 0.80 0.994974 1.000000 0.069723 0.888436 0.969891 0.180300
22 | 0.85 0.995051 1.000000 0.069115 0.908737 0.988604 0.175182
23 | 0.90 0.995116 1.000000 0.068201 0.919555 0.992726 0.170642
24 | 0.95 0.995225 1.000000 0.066679 0.925715 0.994135 0.167369
25 | 1.00 0.995378 1.000000 0.064548 0.930677 0.995702 0.163408
26 | 1.05 0.995508 1.000000 0.062721 0.935447 0.997041 0.158606
27 | 1.10 0.995617 1.000000 0.061199 0.939520 1.000000 0.152420
28 | 1.15 0.995726 1.000000 0.059676 0.946509 1.000000 0.141974
29 | 1.20 0.995770 1.000000 0.059067 0.956307 1.000000 0.128556
30 | 1.25 0.995857 1.000000 0.057849 0.963814 1.000000 0.115558
31 | 1.30 0.996054 1.000000 0.055109 0.971617 1.000000 0.100980
32 | 1.35 0.996206 1.000000 0.052978 0.977215 1.000000 0.090921
33 | 1.40 0.996293 1.000000 0.051760 0.980891 1.000000 0.085915
34 | 1.45 0.996359 1.000000 0.050847 0.983148 1.000000 0.083228
35 | 1.50 0.996381 1.000000 0.050542 0.984951 1.000000 0.081133
36 | 1.55 0.996555 1.000000 0.048106 0.986234 1.000000 0.078929
37 | 1.60 0.996708 1.000000 0.045975 0.987412 1.000000 0.075759
38 | 1.65 0.996860 1.000000 0.043844 0.988513 1.000000 0.073259
39 | 1.70 0.996991 1.000000 0.042017 0.989432 1.000000 0.070975
40 | 1.75 0.997035 1.000000 0.041408 0.990098 1.000000 0.069999
41 | 1.80 0.997209 1.000000 0.038972 0.990665 1.000000 0.069282
42 | 1.85 0.997362 1.000000 0.036841 0.990981 1.000000 0.068877
43 | 1.90 0.997449 1.000000 0.035623 0.991103 1.000000 0.068748
44 | 1.95 0.997493 1.000000 0.035014 0.991254 1.000000 0.068592
45 | 2.00 0.997623 1.000000 0.033187 0.991325 1.000000 0.068536
46 | 2.05 0.997820 1.000000 0.030447 0.991533 1.000000 0.068376
47 | 2.10 0.997929 1.000000 0.028925 0.991670 1.000000 0.068318
48 | 2.15 0.998038 1.000000 0.027402 0.991785 1.000000 0.068218
49 | 2.20 0.998147 1.000000 0.025880 0.991888 1.000000 0.068131
50 | 2.25 0.998278 1.000000 0.024053 0.992025 1.000000 0.068077
51 | 2.30 0.998452 1.000000 0.021617 0.992133 1.000000 0.068008
52 | 2.35 0.998517 1.000000 0.020704 0.992277 1.000000 0.067936
53 | 2.40 0.998626 1.000000 0.019182 0.992367 1.000000 0.067877
54 | 2.45 0.998757 1.000000 0.017355 0.992443 1.000000 0.067853
55 | 2.50 0.998823 1.000000 0.016441 0.992569 1.000000 0.067770
56 | 2.55 0.998910 1.000000 0.015224 0.992771 1.000000 0.067659
57 | 2.60 0.998953 1.000000 0.014615 0.992986 1.000000 0.067406
58 | 2.65 0.999019 1.000000 0.013701 0.993105 1.000000 0.067358
59 | 2.70 0.999041 1.000000 0.013397 0.993267 1.000000 0.067290
60 | 2.75 0.999128 1.000000 0.012179 0.993411 1.000000 0.067245
61 | 2.80 0.999215 1.000000 0.010961 0.993485 1.000000 0.067239
62 | 2.85 0.999280 1.000000 0.010048 0.993523 1.000000 0.067234
63 | 2.90 0.999389 1.000000 0.008525 0.993642 1.000000 0.067200
64 | 2.95 0.999455 1.000000 0.007612 0.993802 1.000000 0.067145
65 | 3.00 0.999520 1.000000 0.006698 0.993957 1.000000 0.066967
66 | [2024-02-04 16:11:31,019::eval::INFO] MAT-R_mean: 0.2994 | MAT-R_median: 0.3033 | MAT-R_std 0.1516
67 | [2024-02-04 16:11:31,019::eval::INFO] MAT-P_mean: 0.6971 | MAT-P_median: 0.5118 | MAT-P_std 2.5074
68 |
--------------------------------------------------------------------------------
/checkpoints/qm9_200steps/samples/samples_all.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/checkpoints/qm9_200steps/samples/samples_all.pkl
--------------------------------------------------------------------------------
/checkpoints/qm9_200steps/samples/samples_all_covmat.csv:
--------------------------------------------------------------------------------
1 | ,COV-R_mean,COV-R_median,COV-R_std,COV-P_mean,COV-P_median,COV-P_std
2 | 0.05,0.00010115094409681809,0.0,0.0011640228873020942,5.0575472048409045e-05,0.0,0.0005820114436510471
3 | 0.1,0.06609495767865171,0.02,0.12174158801889531,0.024568470531056104,0.00896879021879022,0.04781286790066844
4 | 0.15000000000000002,0.25147861556229295,0.17176573426573427,0.24031302256830991,0.09549018442721859,0.06214887640449438,0.10854193223018146
5 | 0.2,0.39488560672391226,0.33892896781354054,0.26950262874024494,0.1611949990040838,0.12867924528301888,0.14414591443505212
6 | 0.25,0.495578308485712,0.4454365079365079,0.27043135026333043,0.2139171955421679,0.17431772237196763,0.16283464245645138
7 | 0.3,0.5751414020096989,0.5368659420289855,0.261139857784902,0.25959096564817863,0.21542104276069018,0.17522366323829522
8 | 0.35000000000000003,0.6422702023841537,0.6144499762920816,0.24169531183888102,0.2992450830655438,0.25183823529411764,0.18317981223346985
9 | 0.4,0.7068397136032962,0.7005577005577006,0.21897640291850254,0.3403816817546149,0.2971547706647044,0.1904941003097485
10 | 0.45,0.7779260397839731,0.8111499611499611,0.19310368616606913,0.3970872122656049,0.3857105538140021,0.20243626841896464
11 | 0.5,0.8552794707902401,0.8898989898989899,0.15425467404097734,0.4776029383577201,0.4588638195004029,0.21321046987652306
12 | 0.55,0.9282851301968171,0.9635254988913525,0.10988135747472798,0.5795268208031766,0.5483333333333333,0.22002722750626835
13 | 0.6000000000000001,0.9700379104922083,1.0,0.08312952507674312,0.6754073596234932,0.7061450381679388,0.21956497708697012
14 | 0.6500000000000001,0.9865537800384889,1.0,0.07463815907443776,0.7550323049100166,0.82,0.20921620553741033
15 | 0.7000000000000001,0.9929853737304797,1.0,0.07155868754577098,0.8158136099120913,0.8881530537159676,0.19651985338361522
16 | 0.7500000000000001,0.9946358770719277,1.0,0.07096688870463422,0.8577904550128883,0.9397568499364907,0.18659572585478243
17 | 0.8,0.994974060757471,1.0,0.06972292755097403,0.8884358294495973,0.96989121989122,0.18030012562785835
18 | 0.8500000000000001,0.9950505843362987,1.0,0.06911482840431585,0.9087371749305106,0.9886043165467626,0.17518222913895623
19 | 0.9000000000000001,0.9951159951159951,1.0,0.06820141657518429,0.9195547405461229,0.9927264071096065,0.17064175753164304
20 | 0.9500000000000001,0.9952250130821559,1.0,0.06667906352663068,0.9257145168578484,0.9941348469212246,0.1673688687041287
21 | 1.0,0.9953776382347811,1.0,0.06454776925865623,0.9306773625960487,0.9957017414403778,0.1634075896624802
22 | 1.05,0.9955084597941742,1.0,0.06272094560039261,0.935446725107578,0.9970410057078825,0.1586063789618342
23 | 1.1,0.9956174777603348,1.0,0.06119859255183924,0.9395204124683095,1.0,0.152419504372303
24 | 1.1500000000000001,0.9957264957264957,1.0,0.05967623950328628,0.9465093686998486,1.0,0.14197445548432253
25 | 1.2000000000000002,0.9957701029129601,1.0,0.059067298283864894,0.9563072992756991,1.0,0.12855642503738074
26 | 1.2500000000000002,0.9958573172858887,1.0,0.05784941584502237,0.9638137301251927,1.0,0.1155583754771393
27 | 1.3,0.9960535496249782,1.0,0.055109180357626616,0.9716174237976887,1.0,0.1009795540939641
28 | 1.35,0.9962061747776034,1.0,0.05297788608965196,0.977215272414187,1.0,0.09092131674120832
29 | 1.4000000000000001,0.996293389150532,1.0,0.051760003650809246,0.9808910969497924,1.0,0.08591510637395192
30 | 1.4500000000000002,0.9963587999302285,1.0,0.05084659182167728,0.983148122013181,1.0,0.08322779138667874
31 | 1.5000000000000002,0.9963806035234607,1.0,0.050542121211966815,0.9849507843993025,1.0,0.0811330566896536
32 | 1.55,0.996555032269318,1.0,0.0481063563342816,0.9862340328117838,1.0,0.07892924040163384
33 | 1.6,0.9967076574219432,1.0,0.04597506206630725,0.9874119019936729,1.0,0.07575912823805947
34 | 1.6500000000000001,0.9968602825745683,1.0,0.04384376779833257,0.9885129265244994,1.0,0.07325930227652488
35 | 1.7000000000000002,0.9969911041339613,1.0,0.04201694414006876,0.9894323691498574,1.0,0.07097455659098532
36 | 1.7500000000000002,0.9970347113204255,1.0,0.04140800292064751,0.9900979348481429,1.0,0.06999894670956665
37 | 1.8,0.9972091400662829,1.0,0.03897223804296238,0.9906648938221057,1.0,0.06928199405144711
38 | 1.85,0.9973617652189081,1.0,0.03684094377498774,0.9909808483778288,1.0,0.06887669335797462
39 | 1.9000000000000001,0.9974489795918368,1.0,0.03562306133614526,0.9911032797610435,1.0,0.06874801464285607
40 | 1.9500000000000002,0.997492586778301,1.0,0.03501412011672394,0.9912535044606053,1.0,0.0685915841352487
41 | 2.0,0.9976234083376941,1.0,0.03318729645846007,0.9913251162039075,1.0,0.06853646887305148
42 | 2.05,0.9978196406767835,1.0,0.030447060971064418,0.9915327256362755,1.0,0.06837629507239903
43 | 2.1,0.9979286586429443,1.0,0.028924707922511186,0.9916700226830903,1.0,0.06831802519418569
44 | 2.15,0.9980376766091051,1.0,0.027402354873957837,0.9917853570153623,1.0,0.06821822152178604
45 | 2.1999999999999997,0.998146694575266,1.0,0.025880001825404626,0.9918878868717073,1.0,0.06813082872340301
46 | 2.25,0.998277516134659,1.0,0.0240531781671408,0.992025080354741,1.0,0.06807688759329783
47 | 2.3,0.9984519448805163,1.0,0.021617413289455645,0.9921326737180102,1.0,0.06800773717756628
48 | 2.35,0.9985173556602128,1.0,0.020704001460323754,0.9922768728993338,1.0,0.0679355595484163
49 | 2.4,0.9986263736263736,1.0,0.019181648411770544,0.9923669445512545,1.0,0.06787658534090224
50 | 2.45,0.9987571951857667,1.0,0.017354824753506623,0.9924427627055161,1.0,0.0678528920931322
51 | 2.5,0.9988226059654631,1.0,0.016441412924374683,0.9925691446767388,1.0,0.06776985632542669
52 | 2.55,0.9989098203383917,1.0,0.015223530485532205,0.9927706967346303,1.0,0.0676585945729409
53 | 2.6,0.998953427524856,1.0,0.0146145892661109,0.992985656238477,1.0,0.067406330877912
54 | 2.65,0.9990188383045526,1.0,0.01370117743697892,0.9931048123061254,1.0,0.06735829846344064
55 | 2.7,0.9990406418977847,1.0,0.013396706827268361,0.9932666500729165,1.0,0.06729041498263391
56 | 2.75,0.9991278562707133,1.0,0.012178824388425704,0.9934106002392857,1.0,0.0672454049817615
57 | 2.8,0.999215070643642,1.0,0.010960941949583141,0.9934850452644369,1.0,0.0672394728913095
58 | 2.85,0.9992804814233386,1.0,0.010047530120451266,0.9935231911849659,1.0,0.06723375971009532
59 | 2.9,0.9993894993894993,1.0,0.008525177071898033,0.9936415499477965,1.0,0.06720012074993871
60 | 2.95,0.999454910169196,1.0,0.007611765242766104,0.9938015207107023,1.0,0.06714470011672279
61 | 3.0,0.9995203209488924,1.0,0.006698353413634183,0.9939568186459425,1.0,0.06696721073132718
62 |
--------------------------------------------------------------------------------
/checkpoints/qm9_200steps/samples/samples_all_covmat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/checkpoints/qm9_200steps/samples/samples_all_covmat.pkl
--------------------------------------------------------------------------------
/checkpoints/qm9_500steps/2000000.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/checkpoints/qm9_500steps/2000000.pt
--------------------------------------------------------------------------------
/checkpoints/qm9_500steps/qm9_500steps.yml:
--------------------------------------------------------------------------------
1 | model:
2 | type: subgraph_diffusion
3 | network: dualenc
4 | hidden_dim: 128
5 | num_convs: 6
6 | num_convs_local: 4
7 | cutoff: 10.0
8 | mlp_act: relu
9 | beta_schedule: sigmoid
10 | beta_start: 1.e-7
11 | beta_end: 2.e-2
12 | num_diffusion_timesteps: 500
13 | edge_order: 3
14 | edge_encoder: mlp
15 | smooth_conv: false
16 | same_mask_steps: 25
17 | mask_pred: MLP
18 |
19 | train:
20 | seed: 2021
21 | batch_size: 64
22 | val_freq: 2000
23 | max_iters: 2000000
24 | max_grad_norm: 10000.00
25 | anneal_power: 2.0
26 | optimizer:
27 | type: adam
28 | lr: 1.e-3
29 | weight_decay: 0.
30 | beta1: 0.95
31 | beta2: 0.999
32 | scheduler:
33 | type: plateau
34 | factor: 0.6
35 | patience: 30
36 |
37 | dataset:
38 | train: ./data/GEOM/QM9/train_data_40k_subgraph.pkl
39 | val: ./data/GEOM/QM9/val_data_5k_subgraph.pkl
40 | test: ./data/GEOM/QM9/test_data_1k.pkl
41 |
--------------------------------------------------------------------------------
/checkpoints/qm9_500steps/samples/log_eval_samples_all.txt:
--------------------------------------------------------------------------------
1 | [2024-02-04 15:46:59,065::eval::INFO] Loading results: checkpoints/qm9_500steps/samples/samples_all.pkl
2 | [2024-02-04 15:46:59,370::eval::INFO] Total: 200
3 | [2024-02-04 15:46:59,421::eval::INFO] Filtered: 196 / 200
4 | [2024-02-04 15:50:38,265::eval::INFO]
5 | COV-R_mean COV-R_median COV-R_std COV-P_mean COV-P_median COV-P_std
6 | 0.05 0.032712 0.009060 0.071551 0.014927 0.004525 0.034089
7 | 0.10 0.271354 0.196429 0.234439 0.119802 0.083333 0.134013
8 | 0.15 0.428807 0.352484 0.277974 0.190273 0.155627 0.165403
9 | 0.20 0.528011 0.481637 0.279454 0.237885 0.209299 0.178689
10 | 0.25 0.602192 0.586570 0.261421 0.279192 0.240906 0.186384
11 | 0.30 0.664374 0.676175 0.240212 0.313114 0.267225 0.189866
12 | 0.35 0.720297 0.726223 0.219854 0.343309 0.288202 0.193520
13 | 0.40 0.774024 0.784423 0.196447 0.376003 0.347277 0.198925
14 | 0.45 0.833441 0.863962 0.167611 0.426181 0.407452 0.209374
15 | 0.50 0.897786 0.941707 0.132740 0.500399 0.483076 0.218857
16 | 0.55 0.953916 0.990521 0.094453 0.607295 0.593957 0.227172
17 | 0.60 0.981479 1.000000 0.072338 0.706846 0.753186 0.218912
18 | 0.65 0.992245 1.000000 0.063863 0.780884 0.854572 0.206989
19 | 0.70 0.995319 1.000000 0.059396 0.834365 0.926402 0.193845
20 | 0.75 0.996082 1.000000 0.053294 0.870825 0.950993 0.184226
21 | 0.80 0.996671 1.000000 0.045077 0.898107 0.982091 0.177346
22 | 0.85 0.997151 1.000000 0.038382 0.913005 0.991560 0.173059
23 | 0.90 0.997950 1.000000 0.028620 0.921745 0.994490 0.168994
24 | 0.95 0.998452 1.000000 0.021617 0.927871 0.996598 0.162295
25 | 1.00 0.998779 1.000000 0.017050 0.933256 1.000000 0.154835
26 | 1.05 0.999062 1.000000 0.013092 0.937535 1.000000 0.148893
27 | 1.10 0.999280 1.000000 0.010048 0.939898 1.000000 0.145510
28 | 1.15 0.999499 1.000000 0.007003 0.948699 1.000000 0.129947
29 | 1.20 0.999564 1.000000 0.006089 0.958744 1.000000 0.113603
30 | 1.25 0.999564 1.000000 0.006089 0.967544 1.000000 0.094783
31 | 1.30 0.999738 1.000000 0.003654 0.975255 1.000000 0.075496
32 | 1.35 0.999782 1.000000 0.003045 0.980040 1.000000 0.064290
33 | 1.40 0.999847 1.000000 0.002131 0.983131 1.000000 0.058868
34 | 1.45 0.999935 1.000000 0.000913 0.984870 1.000000 0.054976
35 | 1.50 0.999935 1.000000 0.000913 0.986283 1.000000 0.052289
36 | 1.55 0.999935 1.000000 0.000913 0.987467 1.000000 0.048353
37 | 1.60 0.999956 1.000000 0.000609 0.988913 1.000000 0.041454
38 | 1.65 0.999956 1.000000 0.000609 0.990281 1.000000 0.036710
39 | 1.70 1.000000 1.000000 0.000000 0.991068 1.000000 0.034609
40 | 1.75 1.000000 1.000000 0.000000 0.991494 1.000000 0.033564
41 | 1.80 1.000000 1.000000 0.000000 0.991963 1.000000 0.032702
42 | 1.85 1.000000 1.000000 0.000000 0.992254 1.000000 0.032477
43 | 1.90 1.000000 1.000000 0.000000 0.992391 1.000000 0.032234
44 | 1.95 1.000000 1.000000 0.000000 0.992674 1.000000 0.031985
45 | 2.00 1.000000 1.000000 0.000000 0.992753 1.000000 0.031887
46 | 2.05 1.000000 1.000000 0.000000 0.992848 1.000000 0.031793
47 | 2.10 1.000000 1.000000 0.000000 0.992954 1.000000 0.031578
48 | 2.15 1.000000 1.000000 0.000000 0.992991 1.000000 0.031521
49 | 2.20 1.000000 1.000000 0.000000 0.993108 1.000000 0.031292
50 | 2.25 1.000000 1.000000 0.000000 0.993329 1.000000 0.031037
51 | 2.30 1.000000 1.000000 0.000000 0.993635 1.000000 0.030488
52 | 2.35 1.000000 1.000000 0.000000 0.993873 1.000000 0.030190
53 | 2.40 1.000000 1.000000 0.000000 0.994005 1.000000 0.030056
54 | 2.45 1.000000 1.000000 0.000000 0.994271 1.000000 0.029820
55 | 2.50 1.000000 1.000000 0.000000 0.994462 1.000000 0.029347
56 | 2.55 1.000000 1.000000 0.000000 0.994680 1.000000 0.028992
57 | 2.60 1.000000 1.000000 0.000000 0.994878 1.000000 0.028735
58 | 2.65 1.000000 1.000000 0.000000 0.995038 1.000000 0.028563
59 | 2.70 1.000000 1.000000 0.000000 0.995204 1.000000 0.028403
60 | 2.75 1.000000 1.000000 0.000000 0.995414 1.000000 0.028231
61 | 2.80 1.000000 1.000000 0.000000 0.995573 1.000000 0.028092
62 | 2.85 1.000000 1.000000 0.000000 0.995729 1.000000 0.027932
63 | 2.90 1.000000 1.000000 0.000000 0.995884 1.000000 0.027833
64 | 2.95 1.000000 1.000000 0.000000 0.996047 1.000000 0.027739
65 | 3.00 1.000000 1.000000 0.000000 0.996186 1.000000 0.027643
66 | [2024-02-04 15:50:38,265::eval::INFO] MAT-R_mean: 0.2417 | MAT-R_median: 0.2449 | MAT-R_std 0.1080
67 | [2024-02-04 15:50:38,265::eval::INFO] MAT-P_mean: 0.5571 | MAT-P_median: 0.4921 | MAT-P_std 0.9973
68 |
--------------------------------------------------------------------------------
/checkpoints/qm9_500steps/samples/samples_all.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/checkpoints/qm9_500steps/samples/samples_all.pkl
--------------------------------------------------------------------------------
/checkpoints/qm9_500steps/samples/samples_all_covmat.csv:
--------------------------------------------------------------------------------
1 | ,COV-R_mean,COV-R_median,COV-R_std,COV-P_mean,COV-P_median,COV-P_std
2 | 0.05,0.032711709588828866,0.009060127231684794,0.07155121746047126,0.01492675832356874,0.004524979524979525,0.034089220735110545
3 | 0.1,0.27135416895301273,0.19642857142857142,0.23443867499848575,0.11980226736434715,0.08333333333333333,0.13401287470087428
4 | 0.15000000000000002,0.4288073021012164,0.35248447204968947,0.27797417934518226,0.190273354171487,0.15562678062678065,0.1654033279682946
5 | 0.2,0.5280113926962319,0.48163746630727766,0.2794543928792841,0.23788510027403467,0.2092993951612903,0.17868862397169458
6 | 0.25,0.602192453468659,0.5865704772475028,0.26142063747855177,0.27919161044555635,0.24090608465608465,0.18638428676495009
7 | 0.3,0.6643740008489631,0.6761750202858747,0.24021151185365944,0.31311439089352006,0.2672250101722501,0.18986576768280886
8 | 0.35000000000000003,0.7202965146420955,0.726222968717195,0.21985373768490118,0.3433087085641054,0.28820181362145475,0.19352008655771624
9 | 0.4,0.7740240266398146,0.7844234079173837,0.19644685529177316,0.3760026786921755,0.34727658186562294,0.198924632306208
10 | 0.45,0.8334414964196138,0.8639622641509435,0.16761055074163497,0.42618096533426403,0.4074519230769231,0.2093742760990014
11 | 0.5,0.8977858561714935,0.9417073688681248,0.13274028496424486,0.5003986958275349,0.48307560137457045,0.21885739334926382
12 | 0.55,0.9539161546591544,0.9905211141060197,0.0944533687904892,0.6072945027695789,0.5939568231957137,0.2271723329068604
13 | 0.6000000000000001,0.9814789414566789,1.0,0.07233786517678083,0.7068456567617836,0.7531863051355742,0.2189121009131429
14 | 0.6500000000000001,0.9922447513430593,1.0,0.06386335173301483,0.7808842672110684,0.8545716727195044,0.20698890270111314
15 | 0.7000000000000001,0.9953187829487754,1.0,0.059396109887135985,0.834365483392989,0.9264015151515151,0.19384478512427444
16 | 0.7500000000000001,0.9960823303680446,1.0,0.053294101359138465,0.870824870090479,0.9509934248141796,0.18422579985422147
17 | 0.8,0.9966710273853132,1.0,0.04507686955175547,0.8981065306541568,0.9820914614646626,0.17734639493916543
18 | 0.8500000000000001,0.9971507064364208,1.0,0.038382447594486065,0.9130047864160271,0.9915598290598291,0.17305904858373555
19 | 0.9000000000000001,0.9979504622361764,1.0,0.02862023731280038,0.9217454042839928,0.9944899817850638,0.1689944096995206
20 | 0.9500000000000001,0.9984519448805163,1.0,0.021617413289455645,0.9278712160469375,0.9965984820436875,0.16229528840360846
21 | 1.0,0.9987789987789989,1.0,0.017050354143795975,0.9332559880655352,1.0,0.1548352915716108
22 | 1.05,0.999062445491017,1.0,0.013092236217557642,0.9375348746068178,1.0,0.14889251457048863
23 | 1.1,0.9992804814233386,1.0,0.010047530120451266,0.939897913659391,1.0,0.14550950610767346
24 | 1.1500000000000001,0.9994985173556603,1.0,0.007002824023344788,0.9486988951016855,1.0,0.12994718273100558
25 | 1.2000000000000002,0.9995639281353567,1.0,0.006089412194212856,0.95874378154494,1.0,0.11360274839243543
26 | 1.2500000000000002,0.9995639281353567,1.0,0.006089412194212856,0.967544428669111,1.0,0.0947830177806354
27 | 1.3,0.999738356881214,1.0,0.003653647316527728,0.9752545458407527,1.0,0.07549569659693457
28 | 1.35,0.9997819640676784,1.0,0.003044706097106426,0.9800399324238473,1.0,0.06429047461466812
29 | 1.4000000000000001,0.9998473748473747,1.0,0.002131294267974505,0.9831307753114109,1.0,0.058868116271158594
30 | 1.4500000000000002,0.9999345892203035,1.0,0.000913411829131929,0.984870393452086,1.0,0.054975545894377496
31 | 1.5000000000000002,0.9999345892203035,1.0,0.000913411829131929,0.9862833473340975,1.0,0.05228913274126952
32 | 1.55,0.9999345892203035,1.0,0.000913411829131929,0.9874672713227176,1.0,0.048352809168404245
33 | 1.6,0.9999563928135358,1.0,0.0006089412194212831,0.9889134634166011,1.0,0.04145438560195448
34 | 1.6500000000000001,0.9999563928135358,1.0,0.0006089412194212831,0.9902806046839789,1.0,0.0367100097724998
35 | 1.7000000000000002,1.0,1.0,0.0,0.9910679746533974,1.0,0.034609100019943016
36 | 1.7500000000000002,1.0,1.0,0.0,0.9914944487118257,1.0,0.03356422106813105
37 | 1.8,1.0,1.0,0.0,0.9919634750400658,1.0,0.03270181787106268
38 | 1.85,1.0,1.0,0.0,0.9922539003393859,1.0,0.03247688094681033
39 | 1.9000000000000001,1.0,1.0,0.0,0.9923908143943294,1.0,0.0322338432847492
40 | 1.9500000000000002,1.0,1.0,0.0,0.9926738263858396,1.0,0.031985447486153375
41 | 2.0,1.0,1.0,0.0,0.9927525363163924,1.0,0.03188675422226196
42 | 2.05,1.0,1.0,0.0,0.9928475103713997,1.0,0.031793055571801544
43 | 2.1,1.0,1.0,0.0,0.9929542137292435,1.0,0.031577512164597095
44 | 2.15,1.0,1.0,0.0,0.9929914860251126,1.0,0.03152091976648795
45 | 2.1999999999999997,1.0,1.0,0.0,0.9931078181515925,1.0,0.03129190632244978
46 | 2.25,1.0,1.0,0.0,0.9933287063523303,1.0,0.03103680541201891
47 | 2.3,1.0,1.0,0.0,0.9936352922310345,1.0,0.030487584066969144
48 | 2.35,1.0,1.0,0.0,0.9938734261288,1.0,0.030189537670333205
49 | 2.4,1.0,1.0,0.0,0.9940045140326391,1.0,0.03005552702336885
50 | 2.45,1.0,1.0,0.0,0.99427139763627,1.0,0.029820178205711857
51 | 2.5,1.0,1.0,0.0,0.9944622124541159,1.0,0.02934684373554284
52 | 2.55,1.0,1.0,0.0,0.9946795555602882,1.0,0.028992053985942273
53 | 2.6,1.0,1.0,0.0,0.9948780135917167,1.0,0.028735354869193028
54 | 2.65,1.0,1.0,0.0,0.99503801395476,1.0,0.028563297446339213
55 | 2.7,1.0,1.0,0.0,0.9952040372067215,1.0,0.028403367668943507
56 | 2.75,1.0,1.0,0.0,0.9954139037773834,1.0,0.028231055564201697
57 | 2.8,1.0,1.0,0.0,0.9955730070156924,1.0,0.02809158835084354
58 | 2.85,1.0,1.0,0.0,0.9957292201822051,1.0,0.02793247439555708
59 | 2.9,1.0,1.0,0.0,0.995883750779165,1.0,0.027832614175859433
60 | 2.95,1.0,1.0,0.0,0.9960471507970968,1.0,0.02773939895707281
61 | 3.0,1.0,1.0,0.0,0.9961860763629112,1.0,0.027642858668279
62 |
--------------------------------------------------------------------------------
/checkpoints/sample_2024_02_05__15_50_46_SubGDiff500/log.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/checkpoints/sample_2024_02_05__15_50_46_SubGDiff500/log.txt
--------------------------------------------------------------------------------
/configs/qm9_200steps.yml:
--------------------------------------------------------------------------------
1 | model:
2 | type: subgraph_diffusion
3 | network: dualenc
4 | hidden_dim: 128
5 | num_convs: 6
6 | num_convs_local: 4
7 | cutoff: 10.0
8 | mlp_act: relu
9 | beta_schedule: sigmoid
10 | beta_start: 1.e-7
11 | beta_end: 5.e-2
12 | num_diffusion_timesteps: 200
13 | edge_order: 3
14 | edge_encoder: mlp
15 | smooth_conv: false
16 | same_mask_steps: 10 # 10-same-subgraph diffusion
17 | mask_pred: MLP #
18 |
19 | train:
20 | seed: 2021
21 | batch_size: 64
22 | val_freq: 2000
23 | max_iters: 2000000
24 | max_grad_norm: 10000.00
25 | anneal_power: 2.0
26 | optimizer:
27 | type: adam
28 | lr: 1.e-3
29 | weight_decay: 0.
30 | beta1: 0.95
31 | beta2: 0.999
32 | scheduler:
33 | type: plateau
34 | factor: 0.6
35 | patience: 30
36 |
37 | dataset:
38 | train: ./data/GEOM/QM9/train_data_40k_subgraph.pkl
39 | val: ./data/GEOM/QM9/val_data_5k_subgraph.pkl
40 | test: ./data/GEOM/QM9/test_data_1k.pkl
41 |
--------------------------------------------------------------------------------
/configs/qm9_500steps.yml:
--------------------------------------------------------------------------------
1 | model:
2 | type: subgraph_diffusion
3 | network: dualenc
4 | hidden_dim: 128
5 | num_convs: 6
6 | num_convs_local: 4
7 | cutoff: 10.0
8 | mlp_act: relu
9 | beta_schedule: sigmoid
10 | beta_start: 1.e-7
11 | beta_end: 2.e-2
12 | num_diffusion_timesteps: 500
13 | edge_order: 3
14 | edge_encoder: mlp
15 | smooth_conv: false
16 | same_mask_steps: 25 # 25-same-subgraph diffusion
17 | mask_pred: MLP
18 |
19 | train:
20 | seed: 2021
21 | batch_size: 64
22 | val_freq: 2000
23 | max_iters: 2000000
24 | max_grad_norm: 10000.00
25 | anneal_power: 2.0
26 | optimizer:
27 | type: adam
28 | lr: 1.e-3
29 | weight_decay: 0.
30 | beta1: 0.95
31 | beta2: 0.999
32 | scheduler:
33 | type: plateau
34 | factor: 0.6
35 | patience: 30
36 |
37 | dataset:
38 | train: ./data/GEOM/QM9/train_data_40k_subgraph.pkl
39 | val: ./data/GEOM/QM9/val_data_5k_subgraph.pkl
40 | test: ./data/GEOM/QM9/test_data_1k.pkl
41 |
--------------------------------------------------------------------------------
/datasets/chem.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from torchvision.transforms.functional import to_tensor
4 |
5 | import rdkit
6 | import rdkit.Chem.Draw
7 | from rdkit import Chem
8 | from rdkit.Chem import rdDepictor as DP
9 | from rdkit.Chem import PeriodicTable as PT
10 | from rdkit.Chem import rdMolAlign as MA
11 | from rdkit.Chem.rdchem import BondType as BT
12 | from rdkit.Chem.rdchem import Mol, GetPeriodicTable
13 | from rdkit.Chem.Draw import rdMolDraw2D as MD2
14 | from rdkit.Chem.rdmolops import RemoveHs
15 | from typing import List, Tuple
16 |
17 |
18 |
19 | BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}
20 | BOND_NAMES = {i: t for i, t in enumerate(BT.names.keys())}
21 |
22 |
23 | def set_conformer_positions(conf, pos):
24 | for i in range(pos.shape[0]):
25 | conf.SetAtomPosition(i, pos[i].tolist())
26 | return conf
27 |
28 |
29 | def draw_mol_image(rdkit_mol, tensor=False):
30 | rdkit_mol.UpdatePropertyCache()
31 | img = rdkit.Chem.Draw.MolToImage(rdkit_mol, kekulize=False)
32 | if tensor:
33 | return to_tensor(img)
34 | else:
35 | return img
36 |
37 |
38 | def update_data_rdmol_positions(data):
39 | for i in range(data.pos.size(0)):
40 | data.rdmol.GetConformer(0).SetAtomPosition(i, data.pos[i].tolist())
41 | return data
42 |
43 |
44 | def update_data_pos_from_rdmol(data):
45 | new_pos = torch.FloatTensor(data.rdmol.GetConformer(0).GetPositions()).to(data.pos)
46 | data.pos = new_pos
47 | return data
48 |
49 |
50 | def set_rdmol_positions(rdkit_mol, pos):
51 | """
52 | Args:
53 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.
54 | pos: (N_atoms, 3)
55 | """
56 | mol = copy.deepcopy(rdkit_mol)
57 | set_rdmol_positions_(mol, pos)
58 | return mol
59 |
60 |
61 | def set_rdmol_positions_(mol, pos):
62 | """
63 | Args:
64 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.
65 | pos: (N_atoms, 3)
66 | """
67 | for i in range(pos.shape[0]):
68 | mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())
69 | return mol
70 |
71 |
72 | def get_atom_symbol(atomic_number):
73 | return PT.GetElementSymbol(GetPeriodicTable(), atomic_number)
74 |
75 |
76 | def mol_to_smiles(mol: Mol) -> str:
77 | return Chem.MolToSmiles(mol, allHsExplicit=True)
78 |
79 |
80 | def mol_to_smiles_without_Hs(mol: Mol) -> str:
81 | return Chem.MolToSmiles(Chem.RemoveHs(mol))
82 |
83 |
84 | def remove_duplicate_mols(molecules: List[Mol]) -> List[Mol]:
85 | unique_tuples: List[Tuple[str, Mol]] = []
86 |
87 | for molecule in molecules:
88 | duplicate = False
89 | smiles = mol_to_smiles(molecule)
90 | for unique_smiles, _ in unique_tuples:
91 | if smiles == unique_smiles:
92 | duplicate = True
93 | break
94 |
95 | if not duplicate:
96 | unique_tuples.append((smiles, molecule))
97 |
98 | return [mol for smiles, mol in unique_tuples]
99 |
100 |
101 | def get_atoms_in_ring(mol):
102 | atoms = set()
103 | for ring in mol.GetRingInfo().AtomRings():
104 | for a in ring:
105 | atoms.add(a)
106 | return atoms
107 |
108 |
109 | def get_2D_mol(mol):
110 | mol = copy.deepcopy(mol)
111 | DP.Compute2DCoords(mol)
112 | return mol
113 |
114 |
115 | def draw_mol_svg(mol,molSize=(450,150),kekulize=False):
116 | mc = Chem.Mol(mol.ToBinary())
117 | if kekulize:
118 | try:
119 | Chem.Kekulize(mc)
120 | except:
121 | mc = Chem.Mol(mol.ToBinary())
122 | if not mc.GetNumConformers():
123 | DP.Compute2DCoords(mc)
124 | drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])
125 | drawer.DrawMolecule(mc)
126 | drawer.FinishDrawing()
127 | svg = drawer.GetDrawingText()
128 | # It seems that the svg renderer used doesn't quite hit the spec.
129 | # Here are some fixes to make it work in the notebook, although I think
130 | # the underlying issue needs to be resolved at the generation step
131 | # return svg.replace('svg:','')
132 | return svg
133 |
134 |
135 | def GetBestRMSD(probe, ref):
136 | probe = RemoveHs(probe)
137 | ref = RemoveHs(ref)
138 | rmsd = MA.GetBestRMS(probe, ref)
139 | return rmsd
140 |
--------------------------------------------------------------------------------
/datasets/rdmol2data.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 |
4 | import torch
5 | from torch_geometric.data import Data
6 |
7 | from torch_scatter import scatter
8 | #from torch.utils.data import Dataset
9 |
10 | from rdkit import Chem
11 | from rdkit.Chem.rdchem import Mol, HybridizationType, BondType
12 | from rdkit import RDLogger
13 | RDLogger.DisableLog('rdApp.*')
14 |
15 | # from confgf import utils
16 | # from datasets.chem import BOND_TYPES, BOND_NAMES
17 | from rdkit.Chem.rdchem import BondType as BT
18 |
19 | BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}
20 |
21 | def rdmol_to_data(mol:Mol, pos=None,y=None, idx=None, smiles=None):
22 | assert mol.GetNumConformers() == 1
23 | N = mol.GetNumAtoms()
24 | if smiles is None:
25 | smiles = Chem.MolToSmiles(mol)
26 | # pos = torch.tensor(mol.GetConformer(0).GetPositions(), dtype=torch.float32)
27 |
28 | atomic_number = []
29 | aromatic = []
30 | sp = []
31 | sp2 = []
32 | sp3 = []
33 | num_hs = []
34 | for atom in mol.GetAtoms():
35 | atomic_number.append(atom.GetAtomicNum())
36 | aromatic.append(1 if atom.GetIsAromatic() else 0)
37 | hybridization = atom.GetHybridization()
38 | sp.append(1 if hybridization == HybridizationType.SP else 0)
39 | sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
40 | sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
41 |
42 | z = torch.tensor(atomic_number, dtype=torch.long)
43 |
44 | row, col, edge_type = [], [], []
45 | for bond in mol.GetBonds():
46 | start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
47 | row += [start, end]
48 | col += [end, start]
49 | edge_type += 2 * [BOND_TYPES[bond.GetBondType()]]
50 |
51 | edge_index = torch.tensor([row, col], dtype=torch.long)
52 | edge_type = torch.tensor(edge_type)
53 |
54 | perm = (edge_index[0] * N + edge_index[1]).argsort()
55 | edge_index = edge_index[:, perm]
56 | edge_type = edge_type[perm]
57 |
58 | row, col = edge_index
59 | hs = (z == 1).to(torch.float32)
60 |
61 | num_hs = scatter(hs[row], col, dim_size=N, reduce='sum').tolist()
62 |
63 | if smiles is None:
64 | smiles = Chem.MolToSmiles(mol)
65 | try:
66 | name = mol.GetProp('_Name')
67 | except:
68 | name=None
69 |
70 | data = Data(atom_type=z, pos=pos, edge_index=edge_index, edge_type=edge_type,
71 | rdmol=copy.deepcopy(mol), smiles=smiles,y=y,idx=idx, name=name)
72 | #data.nx = to_networkx(data, to_undirected=True)
73 |
74 | return data
--------------------------------------------------------------------------------
/env.yaml:
--------------------------------------------------------------------------------
1 | name: SubGDiff
2 | channels:
3 | - conda-forge
4 | - psi4
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - ambit=0.3=h137fa24_1
10 | - atomicwrites=1.4.0=py_0
11 | - backcall=0.2.0=pyh9f0ad1d_0
12 | - blas=1.0=mkl
13 | - brotlipy=0.7.0=py37h27cfd23_1003
14 | - bzip2=1.0.8=h7b6447c_0
15 | - ca-certificates=2023.7.22=hbcca054_0
16 | - cairo=1.16.0=hb05425b_5
17 | - certifi=2023.7.22=pyhd8ed1ab_0
18 | - cffi=1.15.1=py37h5eee18b_3
19 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
20 | - chemps2=1.8.9=h8c3debe_0
21 | - cryptography=39.0.1=py37h9ce1e76_0
22 | - cudatoolkit=11.3.1=h2bc3f7f_2
23 | - decorator=4.4.0=py37_1
24 | - deepdiff=3.3.0=py37_1
25 | - dkh=1.2=h173d85e_2
26 | - entrypoints=0.4=pyhd8ed1ab_0
27 | - expat=2.4.9=h6a678d5_0
28 | - ffmpeg=4.2.2=h20bf706_0
29 | - fontconfig=2.14.1=h4c34cd2_2
30 | - freetype=2.12.1=h4a9f257_0
31 | - gau2grid=1.3.1=h035aef0_0
32 | - gdma=2.2.6=h0e1e685_6
33 | - giflib=5.2.1=h5eee18b_3
34 | - glib=2.69.1=he621ea3_2
35 | - gmp=6.2.1=h295c915_3
36 | - gnutls=3.6.15=he1e5248_0
37 | - hdf5=1.10.2=hba1933b_1
38 | - icu=58.2=he6710b0_3
39 | - idna=3.4=py37h06a4308_0
40 | - importlib_metadata=4.11.3=hd3eb1b0_0
41 | - intel-openmp=2023.1.0=hdb19cb5_46305
42 | - ipython=7.33.0=py37h89c1867_0
43 | - ipython_genutils=0.2.0=py_1
44 | - jedi=0.18.2=pyhd8ed1ab_0
45 | - joblib=1.1.0=pyhd3eb1b0_0
46 | - jpeg=9b=0
47 | - jsonpickle=0.9.6=py37_0
48 | - lame=3.100=h7b6447c_0
49 | - lcms2=2.12=h3be6417_0
50 | - ld_impl_linux-64=2.38=h1181459_1
51 | - libblas=3.9.0=1_h6e990d7_netlib
52 | - libboost=1.73.0=h3ff78a5_11
53 | - libcblas=3.9.0=3_h893e4fe_netlib
54 | - libffi=3.4.4=h6a678d5_0
55 | - libgcc-ng=11.2.0=h1234567_1
56 | - libgfortran-ng=7.5.0=ha8ba4b0_17
57 | - libgfortran4=7.5.0=ha8ba4b0_17
58 | - libgomp=11.2.0=h1234567_1
59 | - libidn2=2.3.4=h5eee18b_0
60 | - libint=1.2.1=hb4a4fd4_6
61 | - liblapack=3.9.0=3_h893e4fe_netlib
62 | - libopus=1.3.1=h7b6447c_0
63 | - libpng=1.6.39=h5eee18b_0
64 | - libsodium=1.0.18=h36c2ea0_1
65 | - libstdcxx-ng=11.2.0=h1234567_1
66 | - libtasn1=4.19.0=h5eee18b_0
67 | - libtiff=4.1.0=h2733197_1
68 | - libunistring=0.9.10=h27cfd23_0
69 | - libuuid=1.41.5=h5eee18b_0
70 | - libuv=1.44.2=h5eee18b_0
71 | - libvpx=1.7.0=h439df22_0
72 | - libwebp=1.2.0=h89dd481_0
73 | - libxc=4.3.4=h7b6447c_0
74 | - libxcb=1.15=h7f8727e_0
75 | - libxml2=2.10.3=hcbfbd50_0
76 | - lz4-c=1.9.4=h6a678d5_0
77 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0
78 | - mkl=2019.4=243
79 | - mkl-service=2.3.0=py37he8ac12f_0
80 | - mkl_fft=1.0.14=py37hd81dba3_0
81 | - mkl_random=1.0.4=py37hd81dba3_0
82 | - more-itertools=8.12.0=pyhd3eb1b0_0
83 | - ncurses=6.4=h6a678d5_0
84 | - nettle=3.7.3=hbbd107a_1
85 | - networkx=2.5.1=pyhd8ed1ab_0
86 | - numpy-base=1.17.0=py37hde5b4d6_0
87 | - nvcc_linux-64=10.1=hcaf9a05_10
88 | - openh264=2.1.1=h4ff587b_0
89 | - openssl=1.1.1v=h7f8727e_0
90 | - pandas=1.2.3=py37hdc94413_0
91 | - parso=0.8.3=pyhd8ed1ab_0
92 | - pcmsolver=1.2.1.1=py37h6d17ec8_2
93 | - pcre=8.45=h295c915_0
94 | - pexpect=4.8.0=pyh1a96a4e_2
95 | - pickleshare=0.7.5=py_1003
96 | - pillow=9.3.0=py37hace64e9_1
97 | - pint=0.10=py_0
98 | - pip=23.2=pyhd8ed1ab_0
99 | - pixman=0.40.0=h7f8727e_1
100 | - pluggy=1.0.0=py37h06a4308_1
101 | - prettytable=3.5.0=py37h06a4308_0
102 | - prompt-toolkit=3.0.39=pyha770c72_0
103 | - psi4=1.3.2+ecbda83=py37h06ff01c_1
104 | - psutil=5.9.0=py37h5eee18b_0
105 | - ptyprocess=0.7.0=pyhd3deb0d_0
106 | - py=1.11.0=pyhd3eb1b0_0
107 | - py-boost=1.73.0=py37ha9443f7_11
108 | - pycparser=2.21=pyhd3eb1b0_0
109 | - pydantic=1.3=py37h516909a_0
110 | - pygments=2.15.1=pyhd8ed1ab_0
111 | - pyopenssl=23.0.0=py37h06a4308_0
112 | - pysocks=1.7.1=py37_1
113 | - pytest=3.10.1=py37_1000
114 | - python=3.7.16=h7a1cb2a_0
115 | - python-dateutil=2.8.2=pyhd3eb1b0_0
116 | - python_abi=3.7=2_cp37m
117 | - pytorch=1.11.0=py3.7_cuda11.3_cudnn8.2.0_0
118 | - pytorch-mutex=1.0=cuda
119 | - pytz=2022.7=py37h06a4308_0
120 | - qcelemental=0.17.0=py_0
121 | - readline=8.2=h5eee18b_0
122 | - requests=2.28.1=py37h06a4308_0
123 | - scikit-learn=0.23.2=py37hddcf8d6_3
124 | - scipy=1.5.3=py37h8911b10_0
125 | - setuptools=59.8.0=py37h89c1867_1
126 | - simint=0.7=h642920c_1
127 | - six=1.16.0=pyhd3eb1b0_1
128 | - sqlite=3.41.2=h5eee18b_0
129 | - threadpoolctl=2.2.0=pyh0d69192_0
130 | - tk=8.6.12=h1ccaba5_0
131 | - torchaudio=0.11.0=py37_cu113
132 | - torchvision=0.12.0=py37_cu113
133 | - traitlets=5.9.0=pyhd8ed1ab_0
134 | - typing_extensions=4.1.1=pyh06a4308_0
135 | - urllib3=1.26.14=py37h06a4308_0
136 | - wcwidth=0.2.5=pyhd3eb1b0_0
137 | - wheel=0.38.4=py37h06a4308_0
138 | - x264=1!157.20191217=h7b6447c_0
139 | - xz=5.4.2=h5eee18b_0
140 | - zeromq=4.3.4=h9c3ff4c_1
141 | - zlib=1.2.13=h5eee18b_0
142 | - zstd=1.4.9=haebb681_0
143 | - pip:
144 | - absl-py==1.4.0
145 | - addict==2.4.0
146 | - aiofiles==22.1.0
147 | - aiosqlite==0.19.0
148 | - anyio==3.7.1
149 | - argon2-cffi==21.3.0
150 | - argon2-cffi-bindings==21.2.0
151 | - argparse==1.4.0
152 | - arrow==1.2.3
153 | - ase==3.22.1
154 | - atom3d==0.2.6
155 | - attrs==23.1.0
156 | - babel==2.12.1
157 | - beautifulsoup4==4.12.2
158 | - biopython==1.81
159 | - bleach==6.0.0
160 | - cached-property==1.5.2
161 | - cachetools==5.3.1
162 | - click==8.1.6
163 | - cloudpickle==2.2.1
164 | - comm==0.1.4
165 | - cycler==0.11.0
166 | - cython==3.0.0
167 | - dataclasses==0.8
168 | - debtcollector==2.5.0
169 | - debugpy==1.6.8
170 | - defusedxml==0.7.1
171 | - descriptastorus==2.5.0.20
172 | - dgl-cu110==0.6.1
173 | - dgllife==0.3.2
174 | - dill==0.3.7
175 | - easy-parallel==0.1.6
176 | - easydict==1.10
177 | - exceptiongroup==1.1.2
178 | - fastjsonschema==2.18.0
179 | - filelock==3.12.2
180 | - fonttools==4.38.0
181 | - fqdn==1.5.1
182 | - freesasa==2.2.0.post3
183 | - fuzzywuzzy==0.18.0
184 | - gdown==4.7.1
185 | - google-auth==2.22.0
186 | - google-auth-oauthlib==0.4.6
187 | - googledrivedownloader==0.4
188 | - grpcio==1.56.2
189 | - h5py==3.8.0
190 | - horovod==0.28.1
191 | - huggingface-hub==0.16.4
192 | - hyperopt==0.2.7
193 | - importlib-metadata==6.7.0
194 | - importlib-resources==5.12.0
195 | - ipykernel==6.16.2
196 | - ipywidgets==8.1.0
197 | - isodate==0.6.1
198 | - isoduration==20.11.0
199 | - jinja2==3.1.2
200 | - json5==0.9.14
201 | - jsonpointer==2.4
202 | - jsonschema==4.17.3
203 | - jupyter==1.0.0
204 | - jupyter-client==7.4.9
205 | - jupyter-console==6.6.3
206 | - jupyter-core==4.12.0
207 | - jupyter-events==0.6.3
208 | - jupyter-server==1.24.0
209 | - jupyter-server-fileid==0.9.0
210 | - jupyter-server-ydoc==0.8.0
211 | - jupyter-ydoc==0.2.5
212 | - jupyterlab==3.6.5
213 | - jupyterlab-pygments==0.2.2
214 | - jupyterlab-server==2.24.0
215 | - jupyterlab-widgets==3.0.8
216 | - kiwisolver==1.4.4
217 | - lmdb==1.4.1
218 | - markdown==3.4.4
219 | - markdown-it-py==2.2.0
220 | - markupsafe==2.1.3
221 | - matplotlib==3.5.3
222 | - mdurl==0.1.2
223 | - mistune==3.0.1
224 | - mmcv==1.5.0
225 | - mmengine==0.8.4
226 | - msgpack==1.0.5
227 | - multipledispatch==1.0.0
228 | - multiprocess==0.70.15
229 | - nbclassic==1.0.0
230 | - nbclient==0.7.4
231 | - nbconvert==7.6.0
232 | - nbformat==5.8.0
233 | - nest-asyncio==1.5.7
234 | - notebook==6.5.5
235 | - notebook-shim==0.2.3
236 | - numexpr==2.8.5
237 | - numpy==1.21.6
238 | - nvidia-nccl-cu11==2.18.3
239 | - oauthlib==3.2.2
240 | - opencv-python==4.8.0.74
241 | - packaging==23.1
242 | - pandas-flavor==0.6.0
243 | - pandocfilters==1.5.0
244 | - pathos==0.3.1
245 | - pkgutil-resolve-name==1.3.10
246 | - platformdirs==3.10.0
247 | - pox==0.3.3
248 | - ppft==1.7.6.7
249 | - prometheus-client==0.17.1
250 | - protobuf==3.20.3
251 | - psikit==0.2.0
252 | - py3dmol==2.0.3
253 | - pyasn1==0.5.0
254 | - pyasn1-modules==0.3.0
255 | - pyparsing==3.1.0
256 | - pyrr==0.10.3
257 | - pyrsistent==0.19.3
258 | - pytdc==0.4.1
259 | - python-dotenv==0.21.1
260 | - python-json-logger==2.0.7
261 | - python-louvain==0.16
262 | - pyyaml==6.0.1
263 | - pyzmq==24.0.1
264 | - qtconsole==5.4.3
265 | - qtpy==2.3.1
266 | - rdflib==6.3.2
267 | - rdkit==2023.3.2
268 | - rdkit-pypi==2023.3.1b1
269 | - regex==2023.6.3
270 | - requests-oauthlib==1.3.1
271 | - rfc3339-validator==0.1.4
272 | - rfc3986-validator==0.1.1
273 | - rich==13.5.2
274 | - rsa==4.9
275 | - safetensors==0.3.1
276 | - seaborn==0.12.2
277 | - send2trash==1.8.2
278 | - sniffio==1.3.0
279 | - soupsieve==2.4.1
280 | - tables==3.7.0
281 | - tensorboard==2.11.2
282 | - tensorboard-data-server==0.6.1
283 | - tensorboard-plugin-wit==1.8.1
284 | - termcolor==2.3.0
285 | - terminado==0.17.1
286 | - tinycss2==1.2.1
287 | - tokenizers==0.13.3
288 | - tomli==2.0.1
289 | - torch-cluster==1.6.0
290 | - torch-geometric==1.7.2
291 | - torch-scatter==2.0.9
292 | - torch-sparse==0.6.13
293 | - tornado==6.2
294 | - tqdm==4.65.0
295 | - transformers==4.31.0
296 | - typing-extensions==4.7.1
297 | - uri-template==1.3.0
298 | - webcolors==1.13
299 | - webencodings==0.5.1
300 | - websocket-client==1.6.1
301 | - werkzeug==2.2.3
302 | - widgetsnbextension==4.0.8
303 | - wilds==2.0.0
304 | - wrapt==1.15.0
305 | - xarray==0.20.2
306 | - y-py==0.6.0
307 | - yapf==0.40.1
308 | - ypy-websocket==0.8.4
309 | - zipp==3.15.0
310 |
--------------------------------------------------------------------------------
/eval_covmat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import pickle
4 | import torch
5 |
6 | from utils.datasets import PackedConformationDataset
7 | from utils.evaluation.covmat import CovMatEvaluator, print_covmat_results
8 | from utils.misc import *
9 |
10 |
11 | if __name__ == '__main__':
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('path', type=str, default="samples_all.pkl")
14 | parser.add_argument('--num_workers', type=int, default=8)
15 | parser.add_argument('--ratio', type=int, default=2)
16 | parser.add_argument('--start_idx', type=int, default=0)
17 | args = parser.parse_args()
18 | assert os.path.isfile(args.path)
19 |
20 | # Logging
21 | tag = args.path.split('/')[-1].split('.')[0]
22 | logger = get_logger('eval', os.path.dirname(args.path), 'log_eval_'+tag+'.txt')
23 |
24 | # Load results
25 | logger.info('Loading results: %s' % args.path)
26 | with open(args.path, 'rb') as f:
27 | packed_dataset = pickle.load(f)
28 | logger.info('Total: %d' % len(packed_dataset))
29 |
30 | # Evaluator
31 | evaluator = CovMatEvaluator(
32 | num_workers = args.num_workers,
33 |
34 | ratio = args.ratio,
35 | print_fn=logger.info,
36 | )
37 | results = evaluator(
38 | packed_data_list = list(packed_dataset),
39 | start_idx = args.start_idx,
40 | )
41 | df = print_covmat_results(results, print_fn=logger.info)
42 |
43 | # Save results
44 | csv_fn = args.path[:-4] + '_covmat.csv'
45 | results_fn = args.path[:-4] + '_covmat.pkl'
46 | df.to_csv(csv_fn)
47 | with open(results_fn, 'wb') as f:
48 | pickle.dump(results, f)
49 |
50 |
--------------------------------------------------------------------------------
/eval_prop.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import argparse
4 | import torch
5 | import numpy as np
6 | from psikit import Psikit
7 | from tqdm.auto import tqdm
8 | from easydict import EasyDict
9 | from torch_geometric.data import Data
10 |
11 | from utils.datasets import PackedConformationDataset
12 | from utils.chem import set_rdmol_positions
13 |
14 |
15 | class PropertyCalculator(object):
16 |
17 | def __init__(self, threads, memory, seed):
18 | super().__init__()
19 | self.pk = Psikit(threads=threads, memory=memory)
20 | self.seed = seed
21 |
22 | def __call__(self, data, num_confs=50):
23 | rdmol = data.rdmol
24 | confs = data.pos_prop
25 |
26 | conf_idx = np.arange(confs.shape[0])
27 | np.random.RandomState(self.seed).shuffle(conf_idx)
28 | conf_idx = conf_idx[:num_confs]
29 |
30 | data.prop_conf_idx = []
31 | data.prop_energy = []
32 | data.prop_homo = []
33 | data.prop_lumo = []
34 | data.prop_dipo = []
35 |
36 | for idx in tqdm(conf_idx):
37 | mol = set_rdmol_positions(rdmol, confs[idx])
38 | self.pk.mol = mol
39 | try:
40 | energy, homo, lumo, dipo = self.pk.energy(), self.pk.HOMO, self.pk.LUMO, self.pk.dipolemoment[-1]
41 | data.prop_conf_idx.append(idx)
42 | data.prop_energy.append(energy)
43 | data.prop_homo.append(homo)
44 | data.prop_lumo.append(lumo)
45 | data.prop_dipo.append(dipo)
46 | except:
47 | pass
48 |
49 | return data
50 |
51 |
52 | def get_prop_matrix(data):
53 | """
54 | Returns:
55 | properties: (4, num_confs) numpy tensor. Energy, HOMO, LUMO, DipoleMoment
56 | """
57 | return np.array([
58 | data.prop_energy,
59 | data.prop_homo,
60 | data.prop_lumo,
61 | data.prop_dipo,
62 | ])
63 |
64 |
65 | def get_ensemble_energy(props):
66 | """
67 | Args:
68 | props: (4, num_confs)
69 | """
70 | avg_ener = np.mean(props[0, :])
71 | low_ener = np.min(props[0, :])
72 | gaps = np.abs(props[1, :] - props[2, :])
73 | avg_gap = np.mean(gaps)
74 | min_gap = np.min(gaps)
75 | max_gap = np.max(gaps)
76 | return np.array([
77 | avg_ener, low_ener, avg_gap, min_gap, max_gap,
78 | ])
79 |
80 | HART_TO_EV = 27.211
81 | HART_TO_KCALPERMOL = 627.5
82 |
83 | if __name__ == '__main__':
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument('--dataset', type=str, default='./data/GEOM/QM9/qm9_property.pkl')
86 | parser.add_argument('--generated', type=str, default=None)
87 | parser.add_argument('--num_confs', type=int, default=50)
88 | parser.add_argument('--threads', type=int, default=8)
89 | parser.add_argument('--memory', type=int, default=16)
90 | parser.add_argument('--seed', type=int, default=2021)
91 | args = parser.parse_args()
92 |
93 | prop_cal = PropertyCalculator(threads=args.threads, memory=args.memory, seed=args.seed)
94 |
95 | cache_ref_fn = os.path.join(
96 | os.path.dirname(args.dataset),
97 | os.path.basename(args.dataset)[:-4] + '_prop.pkl'
98 | )
99 | if not os.path.exists(cache_ref_fn):
100 | dset = PackedConformationDataset(args.dataset)
101 | dset = [data for data in dset]
102 | dset_prop = []
103 | for data in dset:
104 | data.pos_prop = data.pos_ref.reshape(-1, data.num_nodes, 3)
105 | dset_prop.append(prop_cal(data, args.num_confs))
106 | with open(cache_ref_fn, 'wb') as f:
107 | pickle.dump(dset_prop, f)
108 | dset = dset_prop
109 | else:
110 | with open(cache_ref_fn, 'rb') as f:
111 | dset = pickle.load(f)
112 |
113 |
114 | if args.generated is None:
115 | exit()
116 |
117 | print('Start evaluation.')
118 |
119 | cache_gen_fn = os.path.join(
120 | os.path.dirname(args.generated),
121 | os.path.basename(args.generated)[:-4] + '_prop.pkl'
122 | )
123 | if not os.path.exists(cache_gen_fn):
124 | with open(args.generated, 'rb') as f:
125 | gens = pickle.load(f)
126 | gens_prop = []
127 | for data in gens:
128 | if not isinstance(data, Data):
129 | data = EasyDict(data)
130 | data.num_nodes = data.rdmol.GetNumAtoms()
131 | data.pos_prop = data.pos_gen.reshape(-1, data.num_nodes, 3)
132 | gens_prop.append(prop_cal(data, args.num_confs))
133 | with open(cache_gen_fn, 'wb') as f:
134 | pickle.dump(gens_prop, f)
135 | gens = gens_prop
136 | else:
137 | with open(cache_gen_fn, 'rb') as f:
138 | gens = pickle.load(f)
139 |
140 |
141 | dset = {d.smiles:d for d in dset}
142 | gens = {d.smiles:d for d in gens}
143 | all_diff = []
144 | for smiles in dset.keys():
145 | if smiles not in gens:
146 | continue
147 |
148 | prop_gts = get_ensemble_energy(get_prop_matrix(dset[smiles])) * HART_TO_EV
149 | prop_gen = get_ensemble_energy(get_prop_matrix(gens[smiles])) * HART_TO_EV
150 | # prop_gts = np.mean(get_prop_matrix(dset[smiles]), axis=1)
151 | # prop_gen = np.mean(get_prop_matrix(gens[smiles]), axis=1)
152 |
153 | # print(get_prop_matrix(gens[smiles])[0])
154 |
155 | prop_diff = np.abs(prop_gts - prop_gen)
156 |
157 | print('\nProperty: %s' % smiles)
158 | print(' Gts :', prop_gts)
159 | print(' Gen :', prop_gen)
160 | print(' Diff:', prop_diff)
161 |
162 | all_diff.append(prop_diff.reshape(1, -1))
163 | all_diff = np.vstack(all_diff) # (num_mols, 4)
164 | print(all_diff.shape)
165 |
166 | print('[Difference]')
167 | print(' Mean: ', np.mean(all_diff, axis=0))
168 | print(' Median:', np.median(all_diff, axis=0))
169 | print(' Std: ', np.std(all_diff, axis=0))
170 |
--------------------------------------------------------------------------------
/finetune/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from email.policy import default
3 |
4 | parser = argparse.ArgumentParser()
5 |
6 | # about seed and basic info
7 | parser.add_argument("--seed", type=int, default=42)
8 | parser.add_argument("--device", type=int, default=0)
9 |
10 | # parser.add_argument(
11 | # "--model_3d",
12 | # type=str,
13 | # default="schnet",
14 | # choices=[
15 | # "schnet",
16 | # "dimenet",
17 | # "dimenetPP",
18 | # "tfn",
19 | # "se3_transformer",
20 | # "egnn",
21 | # "spherenet",
22 | # "segnn",
23 | # "painn",
24 | # "gemnet",
25 | # "nequip",
26 | # "allegro",
27 | # ],
28 | # )
29 | parser.add_argument("--model_3d", type=str, default="mask_diff")
30 | parser.add_argument(
31 | "--model_2d",
32 | type=str,
33 | default="gin",
34 | choices=[
35 | "gin",
36 | "schnet",
37 | "dimenet",
38 | "dimenetPP",
39 | "tfn",
40 | "se3_transformer",
41 | "egnn",
42 | "spherenet",
43 | "segnn",
44 | "painn",
45 | "gemnet",
46 | "nequip",
47 | "allegro",
48 | ],
49 | )
50 |
51 | # about dataset and dataloader
52 | parser.add_argument("--dataset", type=str, default="qm9")
53 | parser.add_argument("--task", type=str, default="alpha")
54 | parser.add_argument("--num_workers", type=int, default=0)
55 | parser.add_argument("--only_one_atom_type", dest="only_one_atom_type", action="store_true")
56 | parser.set_defaults(only_one_atom_type=False)
57 |
58 | # for MD17
59 | # The default hyper from here: https://github.com/divelab/DIG_storage/tree/main/3dgraph/md17
60 | parser.add_argument("--md17_energy_coeff", type=float, default=0.05)
61 | parser.add_argument("--md17_force_coeff", type=float, default=0.95)
62 |
63 | # for COLL
64 | # The default hyper from here: https://github.com/divelab/DIG_storage/tree/main/3dgraph/md17
65 | parser.add_argument("--coll_energy_coeff", type=float, default=0.05)
66 | parser.add_argument("--coll_force_coeff", type=float, default=0.95)
67 |
68 | # for LBA
69 | # The default hyper from here: https://github.com/drorlab/atom3d/blob/master/examples/lep/enn/utils.py#L37-L43
70 | parser.add_argument("--LBA_year", type=int, default=2020)
71 | parser.add_argument("--LBA_dist", type=float, default=6)
72 | parser.add_argument("--LBA_maxnum", type=int, default=500)
73 | parser.add_argument("--LBA_use_complex", dest="LBA_use_complex", action="store_true")
74 | parser.add_argument("--LBA_no_complex", dest="LBA_use_complex", action="store_false")
75 | parser.set_defaults(LBA_use_complex=False)
76 |
77 | # for LEP
78 | # The default hyper from here: https://github.com/drorlab/atom3d/blob/master/examples/lep/enn/utils.py#L48-L55
79 | parser.add_argument("--LEP_dist", type=float, default=6)
80 | parser.add_argument("--LEP_maxnum", type=float, default=400)
81 | parser.add_argument("--LEP_droph", dest="LEP_droph", action="store_true")
82 | parser.add_argument("--LEP_useh", dest="LEP_droph", action="store_false")
83 | parser.set_defaults(LEP_droph=False)
84 |
85 | # for MoleculeNet
86 | parser.add_argument("--moleculenet_num_conformers", type=int, default=10)
87 |
88 | # about training strategies
89 | parser.add_argument("--split", type=str, default="customized_01",
90 | choices=["customized_01", "customized_02", "random", "atom3d_lba_split30"])
91 | parser.add_argument("--MD17_train_batch_size", type=int, default=1)
92 | parser.add_argument("--batch_size", type=int, default=128)
93 | parser.add_argument("--epochs", type=int, default=100)
94 | parser.add_argument("--lr", type=float, default=1e-4)
95 | parser.add_argument("--lr_scale", type=float, default=1)
96 | parser.add_argument("--decay", type=float, default=0)
97 | parser.add_argument("--print_every_epoch", type=int, default=1)
98 | parser.add_argument("--loss", type=str, default="mae", choices=["mse", "mae"])
99 | parser.add_argument("--lr_scheduler", type=str, default="CosineAnnealingLR")
100 | parser.add_argument("--lr_decay_factor", type=float, default=0.5)
101 | parser.add_argument("--lr_decay_step_size", type=int, default=100)
102 | parser.add_argument("--lr_decay_patience", type=int, default=50)
103 | parser.add_argument("--min_lr", type=float, default=1e-6)
104 | parser.add_argument("--verbose", dest="verbose", action="store_true")
105 | parser.add_argument("--no_verbose", dest="verbose", action="store_false")
106 | parser.set_defaults(verbose=False)
107 | parser.add_argument("--use_rotation_transform", dest="use_rotation_transform", action="store_true")
108 | parser.add_argument("--no_rotation_transform", dest="use_rotation_transform", action="store_false")
109 | parser.set_defaults(use_rotation_transform=False)
110 |
111 | # for SchNet
112 | parser.add_argument("--num_filters", type=int, default=128)
113 | parser.add_argument("--num_interactions", type=int, default=6)
114 | parser.add_argument("--num_gaussians", type=int, default=51)
115 | parser.add_argument("--cutoff", type=float, default=10)
116 | parser.add_argument("--readout", type=str, default="mean", choices=["mean", "add"])
117 |
118 | # for PaiNN
119 | parser.add_argument("--painn_radius_cutoff", type=float, default=5.0)
120 | parser.add_argument("--painn_n_interactions", type=int, default=3)
121 | parser.add_argument("--painn_n_rbf", type=int, default=20)
122 | parser.add_argument("--painn_readout", type=str, default="add", choices=["mean", "add"])
123 |
124 | ######################### for Charge Prediction SSL #########################
125 | parser.add_argument("--charge_masking_ratio", type=float, default=0.3)
126 |
127 | ######################### for Distance Perturbation SSL #########################
128 | parser.add_argument("--distance_sample_ratio", type=float, default=1)
129 |
130 | ######################### for Torsion Angle Perturbation SSL #########################
131 | parser.add_argument("--torsion_angle_sample_ratio", type=float, default=0.001)
132 |
133 | ######################### for Position Perturbation SSL #########################
134 | parser.add_argument("--PP_mu", type=float, default=0)
135 | parser.add_argument("--PP_sigma", type=float, default=0.3)
136 |
137 |
138 | ######################### for GraphMVP SSL #########################
139 | ### for 2D GNN
140 | parser.add_argument("--gnn_type", type=str, default="gin")
141 | parser.add_argument("--num_layer", type=int, default=5)
142 | parser.add_argument("--emb_dim", type=int, default=128)
143 | parser.add_argument("--dropout_ratio", type=float, default=0.5)
144 | parser.add_argument("--graph_pooling", type=str, default="mean")
145 | parser.add_argument("--JK", type=str, default="last")
146 | parser.add_argument("--gnn_2d_lr_scale", type=float, default=1)
147 |
148 | ######################### for GeoSSL #########################
149 | parser.add_argument("--GeoSSL_mu", type=float, default=0)
150 | parser.add_argument("--GeoSSL_sigma", type=float, default=0.3)
151 | parser.add_argument("--GeoSSL_atom_masking_ratio", type=float, default=0.3)
152 | parser.add_argument("--GeoSSL_option", type=str, default="EBM_NCE", choices=["DDM", "EBM_NCE", "InfoNCE", "RR"])
153 | parser.add_argument("--EBM_NCE_SM_coefficient", type=float, default=10.)
154 |
155 | parser.add_argument("--SM_sigma_begin", type=float, default=10)
156 | parser.add_argument("--SM_sigma_end", type=float, default=0.01)
157 | parser.add_argument("--SM_num_noise_level", type=int, default=50)
158 | parser.add_argument("--SM_noise_type", type=str, default="symmetry", choices=["symmetry", "random"])
159 | parser.add_argument("--SM_anneal_power", type=float, default=2)
160 |
161 | ######################### for GraphMVP SSL #########################
162 | ### for 3D GNN
163 | parser.add_argument("--gnn_3d_lr_scale", type=float, default=1)
164 |
165 | ### for masking
166 | parser.add_argument("--SSL_masking_ratio", type=float, default=0.15)
167 |
168 | ### for 2D-3D Contrastive SSL
169 | parser.add_argument("--CL_neg_samples", type=int, default=1)
170 | parser.add_argument("--CL_similarity_metric", type=str, default="InfoNCE_dot_prod",
171 | choices=["InfoNCE_dot_prod", "EBM_dot_prod"])
172 | parser.add_argument("--T", type=float, default=0.1)
173 | parser.add_argument("--normalize", dest="normalize", action="store_true")
174 | parser.add_argument("--no_normalize", dest="normalize", action="store_false")
175 | parser.add_argument("--alpha_1", type=float, default=1)
176 |
177 | ### for 2D-3D Generative SSL
178 | parser.add_argument("--GraphMVP_AE_model", type=str, default="VAE")
179 | parser.add_argument("--detach_target", dest="detach_target", action="store_true")
180 | parser.add_argument("--no_detach_target", dest="detach_target", action="store_false")
181 | parser.set_defaults(detach_target=True)
182 | parser.add_argument("--AE_loss", type=str, default="l2", choices=["l1", "l2", "cosine"])
183 | parser.add_argument("--beta", type=float, default=1)
184 | parser.add_argument("--alpha_2", type=float, default=1)
185 |
186 | ### for 2D SSL
187 | parser.add_argument("--GraphMVP_2D_mode", type=str, default="AM", choices=["AM", "CP"])
188 | parser.add_argument("--alpha_3", type=float, default=1)
189 | ### for AttributeMask
190 | parser.add_argument("--mask_rate", type=float, default=0.15)
191 | parser.add_argument("--mask_edge", type=int, default=0)
192 | ### for ContextPred
193 | parser.add_argument("--csize", type=int, default=3)
194 | parser.add_argument("--contextpred_neg_samples", type=int, default=1)
195 | #######################################################################
196 |
197 |
198 |
199 | ##### about if we would print out eval metric for training data
200 | parser.add_argument("--eval_train", dest="eval_train", action="store_true")
201 | parser.add_argument("--no_eval_train", dest="eval_train", action="store_false")
202 | parser.set_defaults(eval_train=False)
203 | ##### about if we would print out eval metric for training data
204 | ##### this is only for COLL
205 | parser.add_argument("--eval_test", dest="eval_test", action="store_true")
206 | parser.add_argument("--no_eval_test", dest="eval_test", action="store_false")
207 | parser.set_defaults(eval_test=True)
208 |
209 | parser.add_argument("--input_data_dir", type=str, default="")
210 |
211 | # about loading and saving
212 | parser.add_argument("--input_model_file", type=str, default="")
213 | parser.add_argument("--output_model_dir", type=str, default="")
214 |
215 | ## for mask diff (GeoDiff)
216 | parser.add_argument('--ckpt', type=str,default="", help='path for loading the checkpoint')
217 | parser.add_argument('--config', type=str, default=None)
218 | parser.add_argument('--tag', type=str, default=None)
219 | parser.add_argument('--random_init', action="store_true", default=False)
220 |
221 | args = parser.parse_args()
222 | print("arguments\t", args)
223 |
--------------------------------------------------------------------------------
/finetune/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # from finetune.datasets.datasets_Molecule3D import Molecule3D
2 |
3 | from .datasets_QM9 import MoleculeDatasetQM9
4 |
5 | from .rdmol2data import rdmol_to_data
6 | # from finetune.datasets.datasets_MD17 import DatasetMD17
7 | # from finetune.datasets.datasets_MD17Radius import DatasetMD17Radius
8 |
9 | # from finetune.datasets.datasets_LBA import DatasetLBA, TransformLBA
10 | # from finetune.datasets.datasets_LBARadius import DatasetLBARadius
11 |
12 | # from finetune.datasets.datasets_LEP import DatasetLEP, TransformLEP
13 | # from finetune.datasets.datasets_LEPRadius import DatasetLEPRadius
14 |
15 | # from finetune.datasets.datasets_3D import Molecule3DDataset
16 | # from finetune.datasets.datasets_3D_Masking import Molecule3DMaskingDataset
17 | # from finetune.datasets.datasets_3D_Radius import MoleculeDataset3DRadius
18 |
19 | # from finetune.datasets.datasets_utils import graph_data_obj_to_nx_simple, nx_to_graph_data_obj_simple
20 |
--------------------------------------------------------------------------------
/finetune/datasets/datasets_QM9.py:
--------------------------------------------------------------------------------
1 | import os
2 | from itertools import repeat
3 |
4 | import pandas as pd
5 | import torch
6 | from rdkit import Chem
7 | from rdkit.Chem import AllChem
8 | from scipy.constants import physical_constants
9 | from torch_geometric.data import (Data, InMemoryDataset, download_url,
10 | extract_zip)
11 |
12 | # from Geom3D.datasets.datasets_utils import mol_to_graph_data_obj_simple_3D
13 | from datasets.rdmol2data import rdmol_to_data
14 |
15 | class MoleculeDatasetQM9(InMemoryDataset):
16 | raw_url = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip"
17 | raw_url2 = "https://ndownloader.figshare.com/files/3195404"
18 | raw_url3 = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/qm9.csv"
19 | raw_url4 = "https://springernature.figshare.com/ndownloader/files/3195395"
20 |
21 | def __init__(
22 | self,
23 | root,
24 | dataset,
25 | task,
26 | rotation_transform=None,
27 | transform=None,
28 | pre_transform=None,
29 | pre_filter=None,
30 | calculate_thermo=True,
31 | ):
32 | """
33 | The complete columns are
34 | A,B,C,mu,alpha,homo,lumo,gap,r2,zpve,u0,u298,h298,g298,cv,u0_atom,u298_atom,h298_atom,g298_atom
35 | and we take
36 | mu,alpha,homo,lumo,gap,r2,zpve,u0,u298,h298,g298,cv
37 | """
38 | self.root = root
39 | self.rotation_transform = rotation_transform
40 | self.transform = transform
41 | self.pre_transform = pre_transform
42 | self.pre_filter = pre_filter
43 |
44 | self.target_field = [
45 | "mu",
46 | "alpha",
47 | "homo",
48 | "lumo",
49 | "gap",
50 | "r2",
51 | "zpve",
52 | "u0",
53 | "u298",
54 | "h298",
55 | "g298",
56 | "cv",
57 | "gap_02",
58 | ]
59 | self.pd_target_field = [
60 | "mu",
61 | "alpha",
62 | "homo",
63 | "lumo",
64 | "gap",
65 | "r2",
66 | "zpve",
67 | "u0",
68 | "u298",
69 | "h298",
70 | "g298",
71 | "cv",
72 | ]
73 | self.task = task
74 | if self.task == "qm9":
75 | self.task_id = None
76 | else:
77 | self.task_id = self.target_field.index(task)
78 | self.calculate_thermo = calculate_thermo
79 | self.atom_dict = {"H": 1, "C": 6, "N": 7, "O": 8, "F": 9}
80 |
81 | # TODO: need double-check
82 | # https://github.com/atomistic-machine-learning/schnetpack/blob/master/src/schnetpack/datasets/qm9.py
83 | # QM 9dataset unit: property_unit_dict = {
84 | # QM9.A: "GHz",
85 | # QM9.B: "GHz",
86 | # QM9.C: "GHz",
87 | # QM9.mu: "Debye",
88 | # QM9.alpha: "a0 a0 a0",
89 | # QM9.homo: "Ha",
90 | # QM9.lumo: "Ha",
91 | # QM9.gap: "Ha",
92 | # QM9.r2: "a0 a0",
93 | # QM9.zpve: "Ha",
94 | # QM9.U0: "Ha",
95 | # QM9.U: "Ha",
96 | # QM9.H: "Ha",
97 | # QM9.G: "Ha",
98 | # QM9.Cv: "cal/mol/K",
99 | # }
100 | # HAR2EV = 27.211386246
101 | # KCALMOL2EV = 0.04336414
102 |
103 | # conversion = torch.tensor([
104 | # 1., 1., HAR2EV, HAR2EV, HAR2EV, 1., HAR2EV, HAR2EV, HAR2EV, HAR2EV, HAR2EV,
105 | # 1., KCALMOL2EV, KCALMOL2EV, KCALMOL2EV, KCALMOL2EV, 1., 1., 1.
106 | # ])
107 |
108 | # Now we are following these two:
109 | # https://github.com/risilab/cormorant/blob/master/examples/train_qm9.py
110 | # https://github.com/FabianFuchsML/se3-transformer-public/blob/master/experiments/qm9/QM9.py
111 |
112 | self.hartree2eV = physical_constants["hartree-electron volt relationship"][0]
113 |
114 | self.conversion = {
115 | "mu": 1.0,
116 | "alpha": 1.0,
117 | "homo": self.hartree2eV,
118 | "lumo": self.hartree2eV,
119 | "gap": self.hartree2eV,
120 | "gap_02": self.hartree2eV,
121 | "r2": 1.0,
122 | "zpve": self.hartree2eV,
123 | "u0": self.hartree2eV,
124 | "u298": self.hartree2eV,
125 | "h298": self.hartree2eV,
126 | "g298": self.hartree2eV,
127 | "cv": 1.0,
128 | }
129 |
130 | super(MoleculeDatasetQM9, self).__init__(
131 | root, transform, pre_transform, pre_filter
132 | )
133 | self.dataset = dataset
134 | self.data, self.slices = torch.load(self.processed_paths[0])
135 | print("Dataset: {}\nData: {}".format(self.dataset, self.data))
136 |
137 | return
138 |
139 | def mean(self):
140 | y = torch.stack([self.get(i).y for i in range(len(self))], dim=0)
141 | y = y.mean(dim=0)
142 | return y
143 |
144 | def std(self):
145 | y = torch.stack([self.get(i).y for i in range(len(self))], dim=0)
146 | y = y.std(dim=0)
147 | return y
148 |
149 | def get(self, idx):
150 | data = Data()
151 | for key in self.data.keys:
152 | item, slices = self.data[key], self.slices[key]
153 | s = list(repeat(slice(None), item.dim()))
154 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
155 | data[key] = item[s]
156 | if self.rotation_transform is not None:
157 | data.positions = self.rotation_transform(data.positions)
158 | return data
159 |
160 | @property
161 | def raw_file_names(self):
162 | return [
163 | "gdb9.sdf",
164 | "gdb9.sdf.csv",
165 | "uncharacterized.txt",
166 | "qm9.csv",
167 | "atomref.txt",
168 | ]
169 |
170 | @property
171 | def processed_file_names(self):
172 | return "geometric_data_processed.pt"
173 |
174 | def download(self):
175 | file_path = download_url(self.raw_url, self.raw_dir)
176 | extract_zip(file_path, self.raw_dir)
177 | os.unlink(file_path)
178 |
179 | download_url(self.raw_url2, self.raw_dir)
180 | os.rename(
181 | os.path.join(self.raw_dir, "3195404"),
182 | os.path.join(self.raw_dir, "uncharacterized.txt"),
183 | )
184 |
185 | download_url(self.raw_url3, self.raw_dir)
186 |
187 | download_url(self.raw_url4, self.raw_dir)
188 | os.rename(
189 | os.path.join(self.raw_dir, "3195395"),
190 | os.path.join(self.raw_dir, "atomref.txt"),
191 | )
192 | return
193 |
194 | def get_thermo_dict(self):
195 | gdb9_txt_thermo = self.raw_paths[4]
196 | # Loop over file of thermochemical energies
197 | therm_targets = ["zpve", "u0", "u298", "h298", "g298", "cv"]
198 | therm_targets = [6, 7, 8, 9, 10, 11]
199 |
200 | # Dictionary that
201 | id2charge = self.atom_dict
202 |
203 | # Loop over file of thermochemical energies
204 | therm_energy = {target: {} for target in therm_targets}
205 | with open(gdb9_txt_thermo) as f:
206 | for line in f:
207 | # If line starts with an element, convert the rest to a list of energies.
208 | split = line.split()
209 |
210 | # Check charge corresponds to an atom
211 | if len(split) == 0 or split[0] not in id2charge.keys():
212 | continue
213 |
214 | # Loop over learning targets with defined thermochemical energy
215 | for therm_target, split_therm in zip(therm_targets, split[1:]):
216 | therm_energy[therm_target][id2charge[split[0]]] = float(split_therm)
217 |
218 | return therm_energy
219 |
220 | def process(self):
221 | therm_energy = self.get_thermo_dict()
222 | print("therm_energy\t", therm_energy)
223 |
224 | df = pd.read_csv(self.raw_paths[1])
225 | df = df[self.pd_target_field]
226 | df["gap_02"] = df["lumo"] - df["homo"]
227 |
228 | target = df.to_numpy()
229 | target = torch.tensor(target, dtype=torch.float)
230 |
231 | with open(self.raw_paths[2], "r") as f:
232 | # These are the mis-matched molecules, according to `uncharacerized.txt` file.
233 | skip = [int(x.split()[0]) - 1 for x in f.read().split("\n")[9:-2]]
234 |
235 | data_df = pd.read_csv(self.raw_paths[3])
236 | whole_smiles_list = data_df["smiles"].tolist()
237 | print("TODO\t", whole_smiles_list[:100])
238 |
239 | suppl = Chem.SDMolSupplier(self.raw_paths[0], removeHs=False, sanitize=False)
240 |
241 | print("suppl: {}\tsmiles_list: {}".format(len(suppl), len(whole_smiles_list)))
242 |
243 | data_list, data_smiles_list, data_name_list, idx, invalid_count = (
244 | [],
245 | [],
246 | [],
247 | 0,
248 | 0,
249 | )
250 | for i, mol in enumerate(suppl):
251 | if i in skip:
252 | print("Exception with (skip)\t", i)
253 | invalid_count += 1
254 | continue
255 |
256 | # data, atom_count = mol_to_graph_data_obj_simple_3D(mol)
257 | data, atom_count = rdmol_to_data(mol=mol)
258 |
259 | data.id = torch.tensor([idx])
260 | temp_y = target[i]
261 | if self.calculate_thermo:
262 | for atom, count in atom_count.items():
263 | if atom not in self.atom_dict.values():
264 | continue
265 | for target_id, atom_sub_dic in therm_energy.items():
266 | temp_y[target_id] -= atom_sub_dic[atom] * count
267 |
268 | # convert units
269 | for idx, col in enumerate(self.target_field):
270 | temp_y[idx] *= self.conversion[col]
271 | data.y = temp_y
272 |
273 | name = mol.GetProp("_Name")
274 | smiles = whole_smiles_list[i]
275 |
276 | # TODO: need double-check this
277 | temp_mol = AllChem.MolFromSmiles(smiles)
278 | if temp_mol is None:
279 | print("Exception with (invalid mol)\t", i)
280 | invalid_count += 1
281 | continue
282 |
283 | data_smiles_list.append(smiles)
284 | data_name_list.append(name)
285 | data_list.append(data)
286 | idx += 1
287 |
288 | print(
289 | "mol id: [0, {}]\tlen of smiles: {}\tlen of set(smiles): {}".format(
290 | idx - 1, len(data_smiles_list), len(set(data_smiles_list))
291 | )
292 | )
293 | print("{} invalid molecules".format(invalid_count))
294 |
295 | if self.pre_filter is not None:
296 | data_list = [data for data in data_list if self.pre_filter(data)]
297 |
298 | if self.pre_transform is not None:
299 | data_list = [self.pre_transform(data) for data in data_list]
300 |
301 | # TODO: need double-check later, the smiles list are identical here?
302 | data_smiles_series = pd.Series(data_smiles_list)
303 | saver_path = os.path.join(self.processed_dir, "smiles.csv")
304 | print("saving to {}".format(saver_path))
305 | data_smiles_series.to_csv(saver_path, index=False, header=False)
306 |
307 | data_name_series = pd.Series(data_name_list)
308 | saver_path = os.path.join(self.processed_dir, "name.csv")
309 | print("saving to {}".format(saver_path))
310 | data_name_series.to_csv(saver_path, index=False, header=False)
311 |
312 | data, slices = self.collate(data_list)
313 | torch.save((data, slices), self.processed_paths[0])
314 |
315 | return
316 |
--------------------------------------------------------------------------------
/finetune/datasets/rdmol2data.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 |
4 | import torch
5 | from torch_geometric.data import Data
6 |
7 | from torch_scatter import scatter
8 | #from torch.utils.data import Dataset
9 |
10 | from rdkit import Chem
11 | from rdkit.Chem.rdchem import Mol, HybridizationType, BondType
12 | from rdkit import RDLogger
13 | RDLogger.DisableLog('rdApp.*')
14 |
15 | from collections import defaultdict
16 |
17 | # from confgf import utils
18 | # from datasets.chem import BOND_TYPES, BOND_NAMES
19 | from rdkit.Chem.rdchem import BondType as BT
20 |
21 | BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}
22 |
23 | def rdmol_to_data(mol:Mol, pos=None,y=None, idx=None, smiles=None):
24 | assert mol.GetNumConformers() == 1
25 | N = mol.GetNumAtoms()
26 | # if smiles is None:
27 | # smiles = Chem.MolToSmiles(mol)
28 | pos = torch.tensor(mol.GetConformer(0).GetPositions(), dtype=torch.float32)
29 |
30 | atomic_number = []
31 | aromatic = []
32 | sp = []
33 | sp2 = []
34 | sp3 = []
35 | num_hs = []
36 | atom_count = defaultdict(int)
37 | for atom in mol.GetAtoms():
38 | atomic_number.append(atom.GetAtomicNum())
39 | atom_count[atom.GetAtomicNum()] += 1
40 | aromatic.append(1 if atom.GetIsAromatic() else 0)
41 | hybridization = atom.GetHybridization()
42 | sp.append(1 if hybridization == HybridizationType.SP else 0)
43 | sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
44 | sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
45 |
46 | z = torch.tensor(atomic_number, dtype=torch.long)
47 |
48 | row, col, edge_type = [], [], []
49 | for bond in mol.GetBonds():
50 | start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
51 | row += [start, end]
52 | col += [end, start]
53 | edge_type += 2 * [BOND_TYPES[bond.GetBondType()]]
54 |
55 | edge_index = torch.tensor([row, col], dtype=torch.long)
56 | edge_type = torch.tensor(edge_type)
57 |
58 | perm = (edge_index[0] * N + edge_index[1]).argsort()
59 | edge_index = edge_index[:, perm]
60 | edge_type = edge_type[perm]
61 |
62 | row, col = edge_index
63 | hs = (z == 1).to(torch.float32)
64 |
65 | num_hs = scatter(hs[row], col, dim_size=N, reduce='sum').tolist()
66 |
67 | if smiles is None:
68 | smiles = Chem.MolToSmiles(mol)
69 |
70 | try:
71 | name = mol.GetProp('_Name')
72 | except:
73 | name=None
74 |
75 | # data = Data(atom_type=z, pos=pos, edge_index=edge_index, edge_type=edge_type,
76 | # rdmol=copy.deepcopy(mol), smiles=smiles,y=y,id=idx, name=name)
77 | data = Data(atom_type=z, pos=pos, edge_index=edge_index, edge_type=edge_type)
78 | #data.nx = to_networkx(data, to_undirected=True)
79 |
80 | return data, atom_count
--------------------------------------------------------------------------------
/finetune/splitters.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | from collections import defaultdict
5 | from itertools import compress
6 |
7 | import numpy as np
8 | import torch
9 | from rdkit.Chem.Scaffolds import MurckoScaffold
10 | # from sklearn.model_selection import StratifiedKFold
11 |
12 |
13 | def generate_scaffold(smiles, include_chirality=False):
14 | """Obtain Bemis-Murcko scaffold from smiles
15 | :return: smiles of scaffold"""
16 | scaffold = MurckoScaffold.MurckoScaffoldSmiles(
17 | smiles=smiles, includeChirality=include_chirality
18 | )
19 | return scaffold
20 |
21 |
22 | # # test generate_scaffold
23 | # s = 'Cc1cc(Oc2nccc(CCC)c2)ccc1'
24 | # scaffold = generate_scaffold(s)
25 | # assert scaffold == 'c1ccc(Oc2ccccn2)cc1'
26 |
27 |
28 | def scaffold_split(
29 | dataset,
30 | smiles_list,
31 | task_idx=None,
32 | null_value=0,
33 | frac_train=0.8,
34 | frac_valid=0.1,
35 | frac_test=0.1,
36 | return_smiles=False,
37 | ):
38 | """
39 | Adapted from https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py
40 | Split dataset by Bemis-Murcko scaffolds
41 | This function can also ignore examples containing null values for a
42 | selected task when splitting. Deterministic split
43 | :param dataset: pytorch geometric dataset obj
44 | :param smiles_list: list of smiles corresponding to the dataset obj
45 | :param task_idx: column idx of the data.y tensor. Will filter out
46 | examples with null value in specified task column of the data.y tensor
47 | prior to splitting. If None, then no filtering
48 | :param null_value: float that specifies null value in data.y to filter if
49 | task_idx is provided
50 | :param frac_train, frac_valid, frac_test: fractions
51 | :param return_smiles: return SMILES if Ture
52 | :return: train, valid, test slices of the input dataset obj."""
53 | np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0)
54 |
55 | if task_idx is not None:
56 | # filter based on null values in task_idx
57 | # get task array
58 | y_task = np.array([data.y[task_idx].item() for data in dataset])
59 | # boolean array that correspond to non null values
60 | non_null = y_task != null_value
61 | smiles_list = list(compress(enumerate(smiles_list), non_null))
62 | else:
63 | non_null = np.ones(len(dataset)) == 1
64 | smiles_list = list(compress(enumerate(smiles_list), non_null))
65 |
66 | # create dict of the form {scaffold_i: [idx1, idx....]}
67 | all_scaffolds = {}
68 | for i, smiles in smiles_list:
69 | scaffold = generate_scaffold(smiles, include_chirality=True)
70 | if scaffold not in all_scaffolds:
71 | all_scaffolds[scaffold] = [i]
72 | else:
73 | all_scaffolds[scaffold].append(i)
74 |
75 | # sort from largest to smallest sets
76 | all_scaffolds = {key: sorted(value) for key, value in all_scaffolds.items()}
77 | all_scaffold_sets = [
78 | scaffold_set
79 | for (scaffold, scaffold_set) in sorted(
80 | all_scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True
81 | )
82 | ]
83 |
84 | # get train, valid test indices
85 | train_cutoff = frac_train * len(smiles_list)
86 | valid_cutoff = (frac_train + frac_valid) * len(smiles_list)
87 | train_idx, valid_idx, test_idx = [], [], []
88 | for scaffold_set in all_scaffold_sets:
89 | if len(train_idx) + len(scaffold_set) > train_cutoff:
90 | if len(train_idx) + len(valid_idx) + len(scaffold_set) > valid_cutoff:
91 | test_idx.extend(scaffold_set)
92 | else:
93 | valid_idx.extend(scaffold_set)
94 | else:
95 | train_idx.extend(scaffold_set)
96 |
97 | assert len(set(train_idx).intersection(set(valid_idx))) == 0
98 | assert len(set(test_idx).intersection(set(valid_idx))) == 0
99 |
100 | train_dataset = dataset[torch.tensor(train_idx)]
101 | valid_dataset = dataset[torch.tensor(valid_idx)]
102 | test_dataset = dataset[torch.tensor(test_idx)]
103 |
104 | if not return_smiles:
105 | return train_dataset, valid_dataset, test_dataset
106 | else:
107 | train_smiles = [smiles_list[i][1] for i in train_idx]
108 | valid_smiles = [smiles_list[i][1] for i in valid_idx]
109 | test_smiles = [smiles_list[i][1] for i in test_idx]
110 | return (
111 | train_dataset,
112 | valid_dataset,
113 | test_dataset,
114 | (train_smiles, valid_smiles, test_smiles),
115 | )
116 |
117 |
118 | def random_scaffold_split(
119 | dataset,
120 | smiles_list,
121 | task_idx=None,
122 | null_value=0,
123 | frac_train=0.8,
124 | frac_valid=0.1,
125 | frac_test=0.1,
126 | seed=0,
127 | ):
128 | """
129 | Adapted from https://github.com/pfnet-research/chainer-chemistry/blob/master/
130 | chainer_chemistry/dataset/splitters/scaffold_splitter.py
131 | Split dataset by Bemis-Murcko scaffolds
132 | This function can also ignore examples containing null values for a
133 | selected task when splitting. Deterministic split
134 | :param dataset: pytorch geometric dataset obj
135 | :param smiles_list: list of smiles corresponding to the dataset obj
136 | :param task_idx: column idx of the data.y tensor. Will filter out
137 | examples with null value in specified task column of the data.y tensor
138 | prior to splitting. If None, then no filtering
139 | :param null_value: float that specifies null value in data.y to filter if
140 | task_idx is provided
141 | :param frac_train, frac_valid, frac_test: fractions, floats
142 | :param seed: seed
143 | :return: train, valid, test slices of the input dataset obj"""
144 |
145 | np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0)
146 |
147 | if task_idx is not None:
148 | # filter based on null values in task_idx get task array
149 | y_task = np.array([data.y[task_idx].item() for data in dataset])
150 | # boolean array that correspond to non null values
151 | non_null = y_task != null_value
152 | smiles_list = list(compress(enumerate(smiles_list), non_null))
153 | else:
154 | non_null = np.ones(len(dataset)) == 1
155 | smiles_list = list(compress(enumerate(smiles_list), non_null))
156 |
157 | rng = np.random.RandomState(seed)
158 |
159 | scaffolds = defaultdict(list)
160 | for ind, smiles in smiles_list:
161 | scaffold = generate_scaffold(smiles, include_chirality=True)
162 | scaffolds[scaffold].append(ind)
163 |
164 | scaffold_sets = rng.permutation(list(scaffolds.values()))
165 |
166 | n_total_valid = int(np.floor(frac_valid * len(dataset)))
167 | n_total_test = int(np.floor(frac_test * len(dataset)))
168 |
169 | train_idx = []
170 | valid_idx = []
171 | test_idx = []
172 |
173 | for scaffold_set in scaffold_sets:
174 | if len(valid_idx) + len(scaffold_set) <= n_total_valid:
175 | valid_idx.extend(scaffold_set)
176 | elif len(test_idx) + len(scaffold_set) <= n_total_test:
177 | test_idx.extend(scaffold_set)
178 | else:
179 | train_idx.extend(scaffold_set)
180 |
181 | train_dataset = dataset[torch.tensor(train_idx)]
182 | valid_dataset = dataset[torch.tensor(valid_idx)]
183 | test_dataset = dataset[torch.tensor(test_idx)]
184 |
185 | return train_dataset, valid_dataset, test_dataset
186 |
187 |
188 | def random_split(
189 | dataset,
190 | task_idx=None,
191 | null_value=0,
192 | frac_train=0.8,
193 | frac_valid=0.1,
194 | frac_test=0.1,
195 | seed=0,
196 | smiles_list=None,
197 | ):
198 | """
199 | :return: train, valid, test slices of the input dataset obj. If
200 | smiles_list != None, also returns ([train_smiles_list],
201 | [valid_smiles_list], [test_smiles_list])"""
202 |
203 | np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0)
204 |
205 | if task_idx is not None:
206 | # filter based on null values in task_idx
207 | # get task array
208 | y_task = np.array([data.y[task_idx].item() for data in dataset])
209 | non_null = (
210 | y_task != null_value
211 | ) # boolean array that correspond to non null values
212 | idx_array = np.where(non_null)[0]
213 | dataset = dataset[torch.tensor(idx_array)] # examples containing non
214 | # null labels in the specified task_idx
215 | else:
216 | pass
217 |
218 | num_mols = len(dataset)
219 | random.seed(seed)
220 | print("using seed\t", seed)
221 | all_idx = list(range(num_mols))
222 | random.shuffle(all_idx)
223 |
224 | train_idx = all_idx[: int(frac_train * num_mols)]
225 | valid_idx = all_idx[
226 | int(frac_train * num_mols) : int(frac_valid * num_mols)
227 | + int(frac_train * num_mols)
228 | ]
229 | test_idx = all_idx[int(frac_valid * num_mols) + int(frac_train * num_mols) :]
230 |
231 | assert len(set(train_idx).intersection(set(valid_idx))) == 0
232 | assert len(set(valid_idx).intersection(set(test_idx))) == 0
233 | assert len(train_idx) + len(valid_idx) + len(test_idx) == num_mols
234 |
235 | train_dataset = dataset[torch.tensor(train_idx)]
236 | valid_dataset = dataset[torch.tensor(valid_idx)]
237 | test_dataset = dataset[torch.tensor(test_idx)]
238 |
239 | if not smiles_list:
240 | return train_dataset, valid_dataset, test_dataset
241 | else:
242 | train_smiles = [smiles_list[i] for i in train_idx]
243 | valid_smiles = [smiles_list[i] for i in valid_idx]
244 | test_smiles = [smiles_list[i] for i in test_idx]
245 | return (
246 | train_dataset,
247 | valid_dataset,
248 | test_dataset,
249 | (train_smiles, valid_smiles, test_smiles),
250 | )
251 |
252 |
253 | def qm9_random_customized_01(
254 | dataset, task_idx=None, null_value=0, seed=0, smiles_list=None
255 | ):
256 | if task_idx is not None:
257 | # filter based on null values in task_idx
258 | # get task array
259 | y_task = np.array([data.y[task_idx].item() for data in dataset])
260 | non_null = (
261 | y_task != null_value
262 | ) # boolean array that correspond to non null values
263 | idx_array = np.where(non_null)[0]
264 | dataset = dataset[torch.tensor(idx_array)] # examples containing non
265 | # null labels in the specified task_idx
266 | else:
267 | pass
268 |
269 | num_mols = len(dataset)
270 | np.random.seed(seed)
271 | all_idx = np.random.permutation(num_mols)
272 |
273 | Nmols = 133885 - 3054
274 | Ntrain = 110000
275 | Nvalid = 10000
276 | Ntest = Nmols - (Ntrain + Nvalid)
277 |
278 | train_idx = all_idx[:Ntrain]
279 | valid_idx = all_idx[Ntrain : Ntrain + Nvalid]
280 | test_idx = all_idx[Ntrain + Nvalid :]
281 |
282 | print("train_idx: ", train_idx)
283 | print("valid_idx: ", valid_idx)
284 | print("test_idx: ", test_idx)
285 | # np.savez("customized_01", train_idx=train_idx, valid_idx=valid_idx, test_idx=test_idx)
286 |
287 | assert len(set(train_idx).intersection(set(valid_idx))) == 0
288 | assert len(set(valid_idx).intersection(set(test_idx))) == 0
289 | assert len(train_idx) + len(valid_idx) + len(test_idx) == num_mols
290 |
291 | train_dataset = dataset[torch.tensor(train_idx)]
292 | valid_dataset = dataset[torch.tensor(valid_idx)]
293 | test_dataset = dataset[torch.tensor(test_idx)]
294 |
295 | if not smiles_list:
296 | return train_dataset, valid_dataset, test_dataset
297 | else:
298 | train_smiles = [smiles_list[i] for i in train_idx]
299 | valid_smiles = [smiles_list[i] for i in valid_idx]
300 | test_smiles = [smiles_list[i] for i in test_idx]
301 | return (
302 | train_dataset,
303 | valid_dataset,
304 | test_dataset,
305 | (train_smiles, valid_smiles, test_smiles),
306 | )
307 |
308 |
309 | def qm9_random_customized_02(
310 | dataset, task_idx=None, null_value=0, seed=0, smiles_list=None
311 | ):
312 | if task_idx is not None:
313 | # filter based on null values in task_idx
314 | # get task array
315 | y_task = np.array([data.y[task_idx].item() for data in dataset])
316 | non_null = (
317 | y_task != null_value
318 | ) # boolean array that correspond to non null values
319 | idx_array = np.where(non_null)[0]
320 | dataset = dataset[torch.tensor(idx_array)] # examples containing non
321 | # null labels in the specified task_idx
322 | else:
323 | pass
324 |
325 | num_mols = len(dataset)
326 | np.random.seed(seed)
327 | print("using seed\t", seed)
328 | all_idx = np.random.permutation(num_mols)
329 |
330 | Nmols = 133885 - 3054
331 | Ntrain = 100000
332 | Ntest = int(0.1 * Nmols)
333 | Nvalid = Nmols - (Ntrain + Ntest)
334 |
335 | train_idx = all_idx[:Ntrain]
336 | valid_idx = all_idx[Ntrain : Ntrain + Nvalid]
337 | test_idx = all_idx[Ntrain + Nvalid :]
338 |
339 | assert len(set(train_idx).intersection(set(valid_idx))) == 0
340 | assert len(set(valid_idx).intersection(set(test_idx))) == 0
341 | assert len(train_idx) + len(valid_idx) + len(test_idx) == num_mols
342 |
343 | train_dataset = dataset[torch.tensor(train_idx)]
344 | valid_dataset = dataset[torch.tensor(valid_idx)]
345 | test_dataset = dataset[torch.tensor(test_idx)]
346 |
347 | if not smiles_list:
348 | return train_dataset, valid_dataset, test_dataset
349 | else:
350 | train_smiles = [smiles_list[i] for i in train_idx]
351 | valid_smiles = [smiles_list[i] for i in valid_idx]
352 | test_smiles = [smiles_list[i] for i in test_idx]
353 | return (
354 | train_dataset,
355 | valid_dataset,
356 | test_dataset,
357 | (train_smiles, valid_smiles, test_smiles),
358 | )
359 |
360 |
361 | def atom3d_lba_split(dataset, data_root, year):
362 | json_file = os.path.join(data_root, 'processed', 'pdb_id2data_id_{}.json'.format(year))
363 | f = open(json_file, 'r')
364 | pdb_id2data_id = json.load(f)
365 |
366 | def load_pdb_id_list_from_file(file):
367 | pdb_id_list = []
368 | with open(file, 'r') as f:
369 | lines = f.readlines()
370 | for line in lines:
371 | pdb_id_list.append(line.strip())
372 | return pdb_id_list
373 |
374 | def load_data_id(mode):
375 | pdb_id_file = os.path.join(data_root, 'processed', 'targets', '{}.txt'.format(mode))
376 | pdb_id_list = load_pdb_id_list_from_file(pdb_id_file)
377 | data_id_list = [pdb_id2data_id[pdb_id] for pdb_id in pdb_id_list]
378 | return data_id_list
379 |
380 | train_idx = load_data_id('train')
381 | valid_idx = load_data_id('val')
382 | test_idx = load_data_id('test')
383 |
384 | train_dataset = dataset[torch.tensor(train_idx)]
385 | valid_dataset = dataset[torch.tensor(valid_idx)]
386 | test_dataset = dataset[torch.tensor(test_idx)]
387 |
388 | return train_dataset, valid_dataset, test_dataset
389 |
390 |
--------------------------------------------------------------------------------
/models/common.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch_geometric.nn import radius_graph, radius
6 | from torch_scatter import scatter_mean, scatter_add, scatter_max
7 | from torch_sparse import coalesce
8 | from torch_geometric.utils import to_dense_adj, dense_to_sparse
9 |
10 | from utils.chem import BOND_TYPES
11 |
12 |
13 | class MeanReadout(nn.Module):
14 | """Mean readout operator over graphs with variadic sizes."""
15 |
16 | def forward(self, data, input):
17 | """
18 | Perform readout over the graph(s).
19 | Parameters:
20 | data (torch_geometric.data.Data): batched graph
21 | input (Tensor): node representations
22 | Returns:
23 | Tensor: graph representations
24 | """
25 | output = scatter_mean(input, data.batch, dim=0, dim_size=data.num_graphs)
26 | return output
27 |
28 |
29 | class SumReadout(nn.Module):
30 | """Sum readout operator over graphs with variadic sizes."""
31 |
32 | def forward(self, data, input):
33 | """
34 | Perform readout over the graph(s).
35 | Parameters:
36 | data (torch_geometric.data.Data): batched graph
37 | input (Tensor): node representations
38 | Returns:
39 | Tensor: graph representations
40 | """
41 | output = scatter_add(input, data.batch, dim=0, dim_size=data.num_graphs)
42 | return output
43 |
44 |
45 |
46 | class MultiLayerPerceptron(nn.Module):
47 | """
48 | Multi-layer Perceptron.
49 | Note there is no activation or dropout in the last layer.
50 | Parameters:
51 | input_dim (int): input dimension
52 | hidden_dim (list of int): hidden dimensions
53 | activation (str or function, optional): activation function
54 | dropout (float, optional): dropout rate
55 | """
56 |
57 | def __init__(self, input_dim, hidden_dims, activation="relu", dropout=0):
58 | super(MultiLayerPerceptron, self).__init__()
59 |
60 | self.dims = [input_dim] + hidden_dims
61 | if isinstance(activation, str):
62 | self.activation = getattr(F, activation)
63 | else:
64 | self.activation = None
65 | if dropout:
66 | self.dropout = nn.Dropout(dropout)
67 | else:
68 | self.dropout = None
69 |
70 | self.layers = nn.ModuleList()
71 | for i in range(len(self.dims) - 1):
72 | self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))
73 |
74 | def forward(self, input):
75 | """"""
76 | x = input
77 | for i, layer in enumerate(self.layers):
78 | x = layer(x)
79 | if i < len(self.layers) - 1:
80 | if self.activation:
81 | x = self.activation(x)
82 | if self.dropout:
83 | x = self.dropout(x)
84 | return x
85 |
86 |
87 | def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):
88 | h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]
89 | h_pair = torch.cat([h_row*h_col, edge_attr], dim=-1) # (E, 2H)
90 | return h_pair
91 |
92 |
93 | def generate_symmetric_edge_noise(num_nodes_per_graph, edge_index, edge2graph, device):
94 | num_cum_nodes = num_nodes_per_graph.cumsum(0) # (G, )
95 | node_offset = num_cum_nodes - num_nodes_per_graph # (G, )
96 | edge_offset = node_offset[edge2graph] # (E, )
97 |
98 | num_nodes_square = num_nodes_per_graph**2 # (G, )
99 | num_nodes_square_cumsum = num_nodes_square.cumsum(-1) # (G, )
100 | edge_start = num_nodes_square_cumsum - num_nodes_square # (G, )
101 | edge_start = edge_start[edge2graph]
102 |
103 | all_len = num_nodes_square_cumsum[-1]
104 |
105 | node_index = edge_index.t() - edge_offset.unsqueeze(-1)
106 | node_large = node_index.max(dim=-1)[0]
107 | node_small = node_index.min(dim=-1)[0]
108 | undirected_edge_id = node_large * (node_large + 1) + node_small + edge_start
109 |
110 | symm_noise = torch.zeros(size=[all_len.item()], device=device)
111 | symm_noise.normal_()
112 | d_noise = symm_noise[undirected_edge_id].unsqueeze(-1) # (E, 1)
113 | return d_noise
114 |
115 |
116 | def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):
117 | """
118 | Args:
119 | num_nodes: Number of atoms.
120 | edge_index: Bond indices of the original graph.
121 | edge_type: Bond types of the original graph.
122 | order: Extension order.
123 | Returns:
124 | new_edge_index: Extended edge indices.
125 | new_edge_type: Extended edge types.
126 | """
127 |
128 | def binarize(x):
129 | return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))
130 |
131 | def get_higher_order_adj_matrix(adj, order):
132 | """
133 | Args:
134 | adj: (N, N)
135 | type_mat: (N, N)
136 | Returns:
137 | Following attributes will be updated:
138 | - edge_index
139 | - edge_type
140 | Following attributes will be added to the data object:
141 | - bond_edge_index: Original edge_index.
142 | """
143 | adj_mats = [torch.eye(adj.size(0), dtype=torch.long, device=adj.device), \
144 | binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device))]
145 |
146 | for i in range(2, order+1):
147 | adj_mats.append(binarize(adj_mats[i-1] @ adj_mats[1]))
148 | order_mat = torch.zeros_like(adj)
149 |
150 | for i in range(1, order+1):
151 | order_mat = order_mat + (adj_mats[i] - adj_mats[i-1]) * i
152 |
153 | return order_mat
154 |
155 | num_types = len(BOND_TYPES)
156 |
157 | N = num_nodes
158 | adj = to_dense_adj(edge_index).squeeze(0)
159 | adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)
160 |
161 | type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)
162 | type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))
163 | assert (type_mat * type_highorder == 0).all()
164 | type_new = type_mat + type_highorder
165 |
166 | new_edge_index, new_edge_type = dense_to_sparse(type_new)
167 | _, edge_order = dense_to_sparse(adj_order)
168 |
169 | # data.bond_edge_index = data.edge_index # Save original edges
170 | new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data
171 |
172 | # [Note] This is not necessary
173 | # data.is_bond = (data.edge_type < num_types)
174 |
175 | # [Note] In earlier versions, `edge_order` attribute will be added.
176 | # However, it doesn't seem to be necessary anymore so I removed it.
177 | # edge_index_1, data.edge_order = coalesce(new_edge_index, edge_order.long(), N, N) # modify data
178 | # assert (data.edge_index == edge_index_1).all()
179 |
180 | return new_edge_index, new_edge_type
181 |
182 |
183 | def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):
184 |
185 | assert edge_type.dim() == 1
186 | N = pos.size(0)
187 |
188 | bgraph_adj = torch.sparse.LongTensor(
189 | edge_index,
190 | edge_type,
191 | torch.Size([N, N])
192 | )
193 |
194 | if is_sidechain is None:
195 | rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)
196 | else:
197 | # fetch sidechain and its batch index
198 | is_sidechain = is_sidechain.bool()
199 | dummy_index = torch.arange(pos.size(0), device=pos.device)
200 | sidechain_pos = pos[is_sidechain]
201 | sidechain_index = dummy_index[is_sidechain]
202 | sidechain_batch = batch[is_sidechain]
203 |
204 | assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)
205 | r_edge_index_x = assign_index[1]
206 | r_edge_index_y = assign_index[0]
207 | r_edge_index_y = sidechain_index[r_edge_index_y]
208 |
209 | rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)
210 | rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)
211 | rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)
212 | # delete self loop
213 | rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]
214 |
215 | rgraph_adj = torch.sparse.LongTensor(
216 | rgraph_edge_index,
217 | torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,
218 | torch.Size([N, N])
219 | )
220 |
221 | composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)
222 | # edge_index = composed_adj.indices()
223 | # dist = (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)
224 |
225 | new_edge_index = composed_adj.indices()
226 | new_edge_type = composed_adj.values().long()
227 |
228 | return new_edge_index, new_edge_type
229 |
230 |
231 | def extend_graph_order_radius(num_nodes, pos, edge_index, edge_type, batch, order=3, cutoff=10.0,
232 | extend_order=True, extend_radius=True, is_sidechain=None):
233 |
234 | if extend_order:
235 | edge_index, edge_type = _extend_graph_order(
236 | num_nodes=num_nodes,
237 | edge_index=edge_index,
238 | edge_type=edge_type, order=order
239 | )
240 | # edge_index_order = edge_index
241 | # edge_type_order = edge_type
242 |
243 | if extend_radius:
244 | edge_index, edge_type = _extend_to_radius_graph(
245 | pos=pos,
246 | edge_index=edge_index,
247 | edge_type=edge_type,
248 | cutoff=cutoff,
249 | batch=batch,
250 | is_sidechain=is_sidechain
251 |
252 | )
253 |
254 | return edge_index, edge_type
255 |
256 |
257 | def coarse_grain(pos, node_attr, subgraph_index, batch):
258 | cluster_pos = scatter_mean(pos, index=subgraph_index, dim=0) # (num_clusters, 3)
259 | cluster_attr = scatter_add(node_attr, index=subgraph_index, dim=0) # (num_clusters, H)
260 | cluster_batch, _ = scatter_max(batch, index=subgraph_index, dim=0) # (num_clusters, )
261 |
262 | return cluster_pos, cluster_attr, cluster_batch
263 |
264 |
265 | def batch_to_natoms(batch):
266 | return scatter_add(torch.ones_like(batch), index=batch, dim=0)
267 |
268 |
269 | def get_complete_graph(natoms):
270 | """
271 | Args:
272 | natoms: Number of nodes per graph, (B, 1).
273 | Returns:
274 | edge_index: (2, N_1 + N_2 + ... + N_{B-1}), where N_i is the number of nodes of the i-th graph.
275 | num_edges: (B, ), number of edges per graph.
276 | """
277 | natoms_sqr = (natoms ** 2).long()
278 | num_atom_pairs = torch.sum(natoms_sqr)
279 | natoms_expand = torch.repeat_interleave(natoms, natoms_sqr)
280 |
281 | index_offset = torch.cumsum(natoms, dim=0) - natoms
282 | index_offset_expand = torch.repeat_interleave(index_offset, natoms_sqr)
283 |
284 | index_sqr_offset = torch.cumsum(natoms_sqr, dim=0) - natoms_sqr
285 | index_sqr_offset = torch.repeat_interleave(index_sqr_offset, natoms_sqr)
286 |
287 | atom_count_sqr = torch.arange(num_atom_pairs, device=num_atom_pairs.device) - index_sqr_offset
288 |
289 | index1 = (atom_count_sqr // natoms_expand).long() + index_offset_expand
290 | index2 = (atom_count_sqr % natoms_expand).long() + index_offset_expand
291 | edge_index = torch.cat([index1.view(1, -1), index2.view(1, -1)])
292 | mask = torch.logical_not(index1 == index2)
293 | edge_index = edge_index[:, mask]
294 |
295 | num_edges = natoms_sqr - natoms # Number of edges per graph
296 |
297 | return edge_index, num_edges
298 |
--------------------------------------------------------------------------------
/models/encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .schnet import *
2 | from .gin import *
3 | from .edge import *
4 | from .coarse import *
5 |
--------------------------------------------------------------------------------
/models/encoder/coarse.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Module
3 | from torch_scatter import scatter_add, scatter_mean, scatter_max
4 |
5 | from ..common import coarse_grain, batch_to_natoms, get_complete_graph
6 | from .schnet import SchNetEncoder, GaussianSmearing
7 |
8 |
9 | class CoarseGrainingEncoder(Module):
10 |
11 | def __init__(self, hidden_channels, num_filters, num_interactions, edge_channels, cutoff, smooth):
12 | super().__init__()
13 | self.encoder = SchNetEncoder(
14 | hidden_channels=hidden_channels,
15 | num_filters=num_filters,
16 | num_interactions=num_interactions,
17 | edge_channels=edge_channels,
18 | cutoff=cutoff,
19 | smooth=smooth,
20 | )
21 | self.distexp = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels)
22 |
23 |
24 | def forward(self, pos, node_attr, subgraph_index, batch, return_coarse=False):
25 | """
26 | Args:
27 | pos: (N, 3)
28 | node_attr: (N, H)
29 | subgraph_index: (N, )
30 | batch: (N, )
31 | """
32 | cluster_pos, cluster_attr, cluster_batch = coarse_grain(pos, node_attr, subgraph_index, batch)
33 |
34 | edge_index, _ = get_complete_graph(batch_to_natoms(cluster_batch))
35 | row, col = edge_index
36 | edge_length = torch.norm(cluster_pos[row] - cluster_pos[col], dim=1, p=2)
37 | edge_attr = self.distexp(edge_length)
38 |
39 | h = self.encoder(
40 | z = cluster_attr,
41 | edge_index = edge_index,
42 | edge_length = edge_length,
43 | edge_attr = edge_attr,
44 | embed_node = False,
45 | )
46 |
47 | if return_graph:
48 | return h, cluster_pos, cluster_attr, cluster_batch, edge_index
49 | else:
50 | return h
51 |
--------------------------------------------------------------------------------
/models/encoder/edge.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.nn import Module, Sequential, ModuleList, Linear, Embedding
4 | from torch_geometric.nn import MessagePassing, radius_graph
5 | from torch_sparse import coalesce
6 | from torch_geometric.data import Data
7 | from torch_geometric.utils import to_dense_adj, dense_to_sparse
8 | from math import pi as PI
9 |
10 | from utils.chem import BOND_TYPES
11 | from ..common import MeanReadout, SumReadout, MultiLayerPerceptron
12 |
13 |
14 | class GaussianSmearingEdgeEncoder(Module):
15 |
16 | def __init__(self, num_gaussians=64, cutoff=10.0):
17 | super().__init__()
18 | #self.NUM_BOND_TYPES = 22
19 | self.num_gaussians = num_gaussians
20 | self.cutoff = cutoff
21 | self.rbf = GaussianSmearing(start=0.0, stop=cutoff * 2, num_gaussians=num_gaussians) # Larger `stop` to encode more cases
22 | self.bond_emb = Embedding(100, embedding_dim=num_gaussians)
23 |
24 | @property
25 | def out_channels(self):
26 | return self.num_gaussians * 2
27 |
28 | def forward(self, edge_length, edge_type):
29 | """
30 | Input:
31 | edge_length: The length of edges, shape=(E, 1).
32 | edge_type: The type pf edges, shape=(E,)
33 | Returns:
34 | edge_attr: The representation of edges. (E, 2 * num_gaussians)
35 | """
36 | edge_attr = torch.cat([self.rbf(edge_length), self.bond_emb(edge_type)], dim=1)
37 | return edge_attr
38 |
39 |
40 | class MLPEdgeEncoder(Module):
41 |
42 | def __init__(self, hidden_dim=100, activation='relu'):
43 | super().__init__()
44 | self.hidden_dim = hidden_dim
45 | self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)
46 | self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)
47 |
48 | @property
49 | def out_channels(self):
50 | return self.hidden_dim
51 |
52 | def forward(self, edge_length, edge_type):
53 | """
54 | Input:
55 | edge_length: The length of edges, shape=(E, 1).
56 | edge_type: The type pf edges, shape=(E,)
57 | Returns:
58 | edge_attr: The representation of edges. (E, 2 * num_gaussians)
59 | """
60 | d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)
61 | edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)
62 | return d_emb * edge_attr # (num_edge, hidden)
63 |
64 |
65 | def get_edge_encoder(cfg):
66 | if cfg.edge_encoder == 'mlp':
67 | return MLPEdgeEncoder(cfg.hidden_dim, cfg.mlp_act)
68 | elif cfg.edge_encoder == 'gaussian':
69 | return GaussianSmearingEdgeEncoder(config.hidden_dim // 2, cutoff=config.cutoff)
70 | else:
71 | raise NotImplementedError('Unknown edge encoder: %s' % cfg.edge_encoder)
72 |
73 |
--------------------------------------------------------------------------------
/models/encoder/gin.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | from typing import Callable, Union
3 | from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size
4 |
5 | import torch
6 | from torch import Tensor
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch_sparse import SparseTensor, matmul
10 | from torch_geometric.nn.conv import MessagePassing
11 |
12 | from ..common import MeanReadout, SumReadout, MultiLayerPerceptron
13 |
14 |
15 | class GINEConv(MessagePassing):
16 |
17 | def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False,
18 | activation="softplus", **kwargs):
19 | super(GINEConv, self).__init__(aggr='add', **kwargs)
20 | self.nn = nn
21 | self.initial_eps = eps
22 |
23 | if isinstance(activation, str):
24 | self.activation = getattr(F, activation)
25 | else:
26 | self.activation = None
27 |
28 | if train_eps:
29 | self.eps = torch.nn.Parameter(torch.Tensor([eps]))
30 | else:
31 | self.register_buffer('eps', torch.Tensor([eps]))
32 |
33 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
34 | edge_attr: OptTensor = None, size: Size = None) -> Tensor:
35 | """"""
36 | if isinstance(x, Tensor):
37 | x: OptPairTensor = (x, x)
38 |
39 | # Node and edge feature dimensionalites need to match.
40 | if isinstance(edge_index, Tensor):
41 | assert edge_attr is not None
42 | assert x[0].size(-1) == edge_attr.size(-1)
43 | elif isinstance(edge_index, SparseTensor):
44 | assert x[0].size(-1) == edge_index.size(-1)
45 |
46 | # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
47 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
48 |
49 | x_r = x[1]
50 | if x_r is not None:
51 | out = out + (1 + self.eps) * x_r
52 |
53 | return self.nn(out)
54 |
55 | def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
56 | if self.activation:
57 | return self.activation(x_j + edge_attr)
58 | else:
59 | return x_j + edge_attr
60 |
61 | def __repr__(self):
62 | return '{}(nn={})'.format(self.__class__.__name__, self.nn)
63 |
64 |
65 | class GINEncoder(torch.nn.Module):
66 |
67 | def __init__(self, hidden_dim, num_convs=3, activation='relu', short_cut=True, concat_hidden=False):
68 | super().__init__()
69 |
70 | self.hidden_dim = hidden_dim
71 | self.num_convs = num_convs
72 | self.short_cut = short_cut
73 | self.concat_hidden = concat_hidden
74 | self.node_emb = nn.Embedding(100, hidden_dim)
75 |
76 | if isinstance(activation, str):
77 | self.activation = getattr(F, activation)
78 | else:
79 | self.activation = None
80 |
81 | self.convs = nn.ModuleList()
82 | for i in range(self.num_convs):
83 | self.convs.append(GINEConv(MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], \
84 | activation=activation), activation=activation))
85 |
86 |
87 |
88 | def forward(self, z, edge_index, edge_attr):
89 | """
90 | Input:
91 | data: (torch_geometric.data.Data): batched graph
92 | node_attr: node feature tensor with shape (num_node, hidden)
93 | edge_attr: edge feature tensor with shape (num_edge, hidden)
94 | Output:
95 | node_attr
96 | graph feature
97 | """
98 |
99 | node_attr = self.node_emb(z) # (num_node, hidden)
100 |
101 | hiddens = []
102 | conv_input = node_attr # (num_node, hidden)
103 |
104 | for conv_idx, conv in enumerate(self.convs):
105 | hidden = conv(conv_input, edge_index, edge_attr)
106 | if conv_idx < len(self.convs) - 1 and self.activation is not None:
107 | hidden = self.activation(hidden)
108 | assert hidden.shape == conv_input.shape
109 | if self.short_cut and hidden.shape == conv_input.shape:
110 | hidden = hidden + conv_input
111 |
112 | hiddens.append(hidden)
113 | conv_input = hidden
114 |
115 | if self.concat_hidden:
116 | node_feature = torch.cat(hiddens, dim=-1)
117 | else:
118 | node_feature = hiddens[-1]
119 |
120 | return node_feature
121 |
--------------------------------------------------------------------------------
/models/encoder/schnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.nn import Module, Sequential, ModuleList, Linear, Embedding
4 | from torch_geometric.nn import MessagePassing, radius_graph
5 | from torch_sparse import coalesce
6 | from torch_geometric.data import Data
7 | from torch_geometric.utils import to_dense_adj, dense_to_sparse
8 | from math import pi as PI
9 |
10 | from utils.chem import BOND_TYPES
11 | from ..common import MeanReadout, SumReadout, MultiLayerPerceptron
12 |
13 |
14 | class GaussianSmearing(torch.nn.Module):
15 | def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
16 | super(GaussianSmearing, self).__init__()
17 | offset = torch.linspace(start, stop, num_gaussians)
18 | self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
19 | self.register_buffer('offset', offset)
20 |
21 | def forward(self, dist):
22 | dist = dist.view(-1, 1) - self.offset.view(1, -1)
23 | return torch.exp(self.coeff * torch.pow(dist, 2))
24 |
25 |
26 | class AsymmetricSineCosineSmearing(Module):
27 |
28 | def __init__(self, num_basis=50):
29 | super().__init__()
30 | num_basis_k = num_basis // 2
31 | num_basis_l = num_basis - num_basis_k
32 | self.register_buffer('freq_k', torch.arange(1, num_basis_k + 1).float())
33 | self.register_buffer('freq_l', torch.arange(1, num_basis_l + 1).float())
34 |
35 | @property
36 | def num_basis(self):
37 | return self.freq_k.size(0) + self.freq_l.size(0)
38 |
39 | def forward(self, angle):
40 | # If we don't incorporate `cos`, the embedding of 0-deg and 180-deg will be the
41 | # same, which is undesirable.
42 | s = torch.sin(angle.view(-1, 1) * self.freq_k.view(1, -1)) # (num_angles, num_basis_k)
43 | c = torch.cos(angle.view(-1, 1) * self.freq_l.view(1, -1)) # (num_angles, num_basis_l)
44 | return torch.cat([s, c], dim=-1)
45 |
46 |
47 | class SymmetricCosineSmearing(Module):
48 |
49 | def __init__(self, num_basis=50):
50 | super().__init__()
51 | self.register_buffer('freq_k', torch.arange(1, num_basis+1).float())
52 |
53 | @property
54 | def num_basis(self):
55 | return self.freq_k.size(0)
56 |
57 | def forward(self, angle):
58 | return torch.cos(angle.view(-1, 1) * self.freq_k.view(1, -1)) # (num_angles, num_basis)
59 |
60 |
61 | class ShiftedSoftplus(torch.nn.Module):
62 | def __init__(self):
63 | super(ShiftedSoftplus, self).__init__()
64 | self.shift = torch.log(torch.tensor(2.0)).item()
65 |
66 | def forward(self, x):
67 | return F.softplus(x) - self.shift
68 |
69 |
70 | class CFConv(MessagePassing):
71 | def __init__(self, in_channels, out_channels, num_filters, nn, cutoff, smooth):
72 | super(CFConv, self).__init__(aggr='add')
73 | self.lin1 = Linear(in_channels, num_filters, bias=False)
74 | self.lin2 = Linear(num_filters, out_channels)
75 | self.nn = nn
76 | self.cutoff = cutoff
77 | self.smooth = smooth
78 |
79 | self.reset_parameters()
80 |
81 | def reset_parameters(self):
82 | torch.nn.init.xavier_uniform_(self.lin1.weight)
83 | torch.nn.init.xavier_uniform_(self.lin2.weight)
84 | self.lin2.bias.data.fill_(0)
85 |
86 | def forward(self, x, edge_index, edge_length, edge_attr):
87 | if self.smooth:
88 | C = 0.5 * (torch.cos(edge_length * PI / self.cutoff) + 1.0)
89 | C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff
90 | else:
91 | C = (edge_length <= self.cutoff).float()
92 | W = self.nn(edge_attr) * C.view(-1, 1)
93 |
94 | x = self.lin1(x)
95 | x = self.propagate(edge_index, x=x, W=W)
96 | x = self.lin2(x)
97 | return x
98 |
99 | def message(self, x_j, W):
100 | return x_j * W
101 |
102 |
103 | class InteractionBlock(torch.nn.Module):
104 | def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):
105 | super(InteractionBlock, self).__init__()
106 | mlp = Sequential(
107 | Linear(num_gaussians, num_filters),
108 | ShiftedSoftplus(),
109 | Linear(num_filters, num_filters),
110 | )
111 | self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)
112 | self.act = ShiftedSoftplus()
113 | self.lin = Linear(hidden_channels, hidden_channels)
114 |
115 | def forward(self, x, edge_index, edge_length, edge_attr):
116 | x = self.conv(x, edge_index, edge_length, edge_attr)
117 | x = self.act(x)
118 | x = self.lin(x)
119 | return x
120 |
121 |
122 | class SchNetEncoder(Module):
123 |
124 | def __init__(self, hidden_channels=128, num_filters=128,
125 | num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False):
126 | super().__init__()
127 |
128 | self.hidden_channels = hidden_channels
129 | self.num_filters = num_filters
130 | self.num_interactions = num_interactions
131 | self.cutoff = cutoff
132 |
133 | self.embedding = Embedding(100, hidden_channels, max_norm=10.0)
134 |
135 | self.interactions = ModuleList()
136 | for _ in range(num_interactions):
137 | block = InteractionBlock(hidden_channels, edge_channels,
138 | num_filters, cutoff, smooth)
139 | self.interactions.append(block)
140 |
141 | def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):
142 | if embed_node:
143 | assert z.dim() == 1 and z.dtype == torch.long
144 | h = self.embedding(z)
145 | else:
146 | h = z
147 | for interaction in self.interactions:
148 | h = h + interaction(h, edge_index, edge_length, edge_attr)
149 |
150 | return h
151 |
152 |
--------------------------------------------------------------------------------
/models/epsnet/__init__.py:
--------------------------------------------------------------------------------
1 | from .dualenc import DualEncoderEpsNetwork
2 |
3 | def get_model(config):
4 | if config.network == 'dualenc':
5 | return DualEncoderEpsNetwork(config)
6 | else:
7 | raise NotImplementedError('Unknown network: %s' % config.network)
8 |
--------------------------------------------------------------------------------
/models/epsnet/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch_geometric.nn import MessagePassing
5 |
6 |
7 | from torch_geometric.nn import GATConv
8 |
9 | class GAT(torch.nn.Module):
10 | def __init__(self, in_channels, hidden_channels, out_channels, num_heads):
11 | super(GAT, self).__init__()
12 |
13 | self.conv1 = GATConv(in_channels, hidden_channels //2 // num_heads, heads=num_heads)
14 | # self.conv2 = GATConv(hidden_channels, hidden_channels //2 // num_heads, heads=num_heads)
15 | self.conv3 = GATConv(hidden_channels//2, out_channels, heads=1)
16 |
17 | def forward(self, x, edge_index):
18 | x = F.relu(self.conv1(x, edge_index))
19 | ## self.dropout
20 | # x = F.relu(self.conv2(x, edge_index))
21 | x = self.conv3(x, edge_index)
22 | return x
23 |
24 | class Pyg_SelfAttention(MessagePassing):
25 | def __init__(self, in_channels, out_channels, activation):
26 | super(SelfAttention, self).__init__(aggr='add')
27 |
28 | self.activation = activation
29 | self.lin_query = torch.nn.Linear(in_channels, out_channels)
30 | self.lin_key = torch.nn.Linear(in_channels, out_channels)
31 | self.lin_value = torch.nn.Linear(in_channels, out_channels)
32 |
33 | def forward(self, node_presentation, edge_index):
34 | # Calculate query, key and value
35 | query = self.lin_query(node_presentation)
36 | key = self.lin_key(node_presentation)
37 | value = self.lin_value(node_presentation)
38 |
39 | # Calculate attention weights
40 | alpha = self.propagate(edge_index, x=(query, key), size=(node_presentation.size(0), node_presentation.size(0)))
41 | alpha = self.activation(alpha)
42 |
43 | # Apply attention weights to values
44 | attended_values = self.propagate(edge_index, x=(alpha, value), size=(node_presentation.size(0), node_presentation.size(0)))
45 |
46 | return attended_values
47 |
48 | def message(self, x_i, x_j):
49 | # Calculate attention weights
50 | alpha = torch.matmul(x_i, x_j.transpose(0, 1))
51 |
52 | return alpha
53 |
54 | def update(self, aggr_out):
55 | return aggr_out
56 |
57 |
58 |
59 | class SelfAttention(nn.Module):
60 | def __init__(self, in_dim):
61 | super(SelfAttention, self).__init__()
62 |
63 | self.activation = nn.Softmax()
64 |
65 | self.query = nn.Linear(in_dim, in_dim, bias=False)
66 | self.key = nn.Linear(in_dim, in_dim, bias=False)
67 | self.value = nn.Linear(in_dim, in_dim, bias=False)
68 |
69 | def forward(self, node_presentation):
70 | query = self.query(node_presentation)
71 | key = self.key(node_presentation)
72 | value = self.value(node_presentation)
73 |
74 | # Calculate attention weights
75 | attention_weights = torch.matmul(query, key.transpose(-2, -1))
76 | attention_weights = self.activation(attention_weights)
77 |
78 | # Apply attention weights to values
79 | attended_values = torch.matmul(attention_weights, value)
80 | node = attended_values.sum(-1)
81 | return node
82 |
--------------------------------------------------------------------------------
/models/epsnet/diffusion.py:
--------------------------------------------------------------------------------
1 | import os, time
2 | import math
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def get_timestep_embedding(timesteps, embedding_dim):
8 | """
9 | This matches the implementation in Denoising Diffusion Probabilistic Models:
10 | From Fairseq.
11 | Build sinusoidal embeddings.
12 | This matches the implementation in tensor2tensor, but differs slightly
13 | from the description in Section 3.5 of "Attention Is All You Need".
14 | """
15 | assert len(timesteps.shape) == 1
16 |
17 | half_dim = embedding_dim // 2
18 | emb = math.log(10000) / (half_dim - 1)
19 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
20 | emb = emb.to(device=timesteps.device)
21 | emb = timesteps.float()[:, None] * emb[None, :]
22 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
23 | if embedding_dim % 2 == 1: # zero pad
24 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
25 | return emb
26 |
27 |
28 | def nonlinearity(x):
29 | # swish
30 | return x*torch.sigmoid(x)
31 |
32 |
33 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
34 | def sigmoid(x):
35 | return 1 / (np.exp(-x) + 1)
36 |
37 | if beta_schedule == "quad":
38 | betas = (
39 | np.linspace(
40 | beta_start ** 0.5,
41 | beta_end ** 0.5,
42 | num_diffusion_timesteps,
43 | dtype=np.float64,
44 | )
45 | ** 2
46 | )
47 | elif beta_schedule == "linear":
48 | betas = np.linspace(
49 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
50 | )
51 | elif beta_schedule == "const":
52 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
53 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
54 | betas = 1.0 / np.linspace(
55 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
56 | )
57 | elif beta_schedule == "sigmoid":
58 | betas = np.linspace(-6, 6, num_diffusion_timesteps)
59 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
60 | else:
61 | raise NotImplementedError(beta_schedule)
62 | assert betas.shape == (num_diffusion_timesteps,)
63 | return betas
--------------------------------------------------------------------------------
/models/epsnet/dualencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch_scatter import scatter_add, scatter_mean
4 | from torch_scatter import scatter
5 | from torch_geometric.data import Data, Batch
6 | import numpy as np
7 | from numpy import pi as PI
8 | from tqdm.auto import tqdm
9 |
10 | from utils.chem import BOND_TYPES
11 | from ..common import MultiLayerPerceptron, assemble_atom_pair_feature, generate_symmetric_edge_noise, extend_graph_order_radius
12 | from ..encoder import SchNetEncoder, GINEncoder, get_edge_encoder
13 | from ..geometry import get_distance, get_angle, get_dihedral, eq_transform
14 | from models.epsnet.attention import GAT, SelfAttention
15 | from models.epsnet.diffusion import get_timestep_embedding, get_beta_schedule
16 | import pdb
17 |
18 |
19 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
20 | def sigmoid(x):
21 | return 1 / (np.exp(-x) + 1)
22 |
23 | if beta_schedule == "quad":
24 | betas = (
25 | np.linspace(
26 | beta_start ** 0.5,
27 | beta_end ** 0.5,
28 | num_diffusion_timesteps,
29 | dtype=np.float64,
30 | )
31 | ** 2
32 | )
33 | elif beta_schedule == "linear":
34 | betas = np.linspace(
35 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
36 | )
37 | elif beta_schedule == "const":
38 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
39 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
40 | betas = 1.0 / np.linspace(
41 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
42 | )
43 | elif beta_schedule == "sigmoid":
44 | betas = np.linspace(-6, 6, num_diffusion_timesteps)
45 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
46 | else:
47 | raise NotImplementedError(beta_schedule)
48 | assert betas.shape == (num_diffusion_timesteps,)
49 | return betas
50 |
51 |
52 | class DualEncoderEpsNetwork(nn.Module):
53 |
54 | def __init__(self, config):
55 | super().__init__()
56 | self.config = config
57 | self.num_timesteps=config.num_diffusion_timesteps
58 | """
59 | edge_encoder: Takes both edge type and edge length as input and outputs a vector
60 | [Note]: node embedding is done in SchNetEncoder
61 | """
62 | self.model_type = config.type # config.type # 'diffusion'; 'dsm'
63 | self.edge_encoder_global = get_edge_encoder(config)
64 | self.edge_encoder_local = get_edge_encoder(config)
65 | self.is_emb_time = True
66 | if self.is_emb_time:
67 | '''
68 | timestep embedding
69 | '''
70 | self.hidden_dim = config.hidden_dim
71 | self.temb = nn.Module()
72 | self.temb.dense = nn.ModuleList([
73 | torch.nn.Linear(config.hidden_dim,
74 | config.hidden_dim*4),
75 | torch.nn.Linear(config.hidden_dim*4,
76 | config.hidden_dim*4),
77 | ])
78 | self.temb_proj = torch.nn.Linear(config.hidden_dim*4,
79 | config.hidden_dim)
80 | self.nonlinearity = nn.ReLU()
81 | self.is_emb_time = False
82 | """
83 | The graph neural network that extracts node-wise features.
84 | """
85 | self.encoder_global = SchNetEncoder(
86 | hidden_channels=config.hidden_dim,
87 | num_filters=config.hidden_dim,
88 | num_interactions=config.num_convs,
89 | edge_channels=self.edge_encoder_global.out_channels,
90 | cutoff=config.cutoff,
91 | smooth=config.smooth_conv,
92 | )
93 | self.encoder_local = GINEncoder(
94 | hidden_dim=config.hidden_dim,
95 | num_convs=config.num_convs_local,
96 | )
97 | self.grad_global_dist_mlp = MultiLayerPerceptron(
98 | 2 * config.hidden_dim,
99 | [config.hidden_dim, config.hidden_dim // 2, 1],
100 | activation=config.mlp_act
101 | )
102 |
103 | self.grad_local_dist_mlp = MultiLayerPerceptron(
104 | 2 * config.hidden_dim,
105 | [config.hidden_dim, config.hidden_dim // 2, 1],
106 | activation=config.mlp_act
107 | )
108 |
109 | ''' header for masked vector prediciton'''
110 | if self.model_type in ['selected_diffusion', 'subgraph_diffusion']:
111 | self.edge_encoder_mask = get_edge_encoder(config)
112 | self.encoder_mask = SchNetEncoder(
113 | hidden_channels=config.hidden_dim,
114 | num_filters=config.hidden_dim,
115 | num_interactions=config.num_convs,
116 | edge_channels=self.edge_encoder_global.out_channels,
117 | cutoff=config.cutoff,
118 | smooth=config.smooth_conv,
119 | )
120 | self.mask_pred = self.config.get("mask_pred", "MLP").upper()
121 | if self.mask_pred.upper()=="GAT":
122 | self.mask_predictor = GAT(in_channels=config.hidden_dim, hidden_channels=config.hidden_dim//2, out_channels=1, num_heads=2)
123 | elif self.mask_pred=='MLP':
124 | self.mask_predictor = MultiLayerPerceptron(
125 | config.hidden_dim,
126 | [config.hidden_dim, config.hidden_dim // 2, 1],
127 | activation=config.mlp_act
128 | )
129 | elif self.mask_pred=='2BMLP':
130 | self.mask_predictor = MultiLayerPerceptron(
131 | config.hidden_dim*2,
132 | [config.hidden_dim, config.hidden_dim // 2, 1],
133 | activation=config.mlp_act
134 | )
135 | else:
136 | raise
137 | self.CEloss = nn.CrossEntropyLoss()
138 | self.BCEloss = nn.BCEWithLogitsLoss(reduction='none')
139 | else:
140 | self.register_parameter("mask_predictor", None)
141 |
142 | self.model_mask = nn.ModuleList([self.mask_predictor,self.temb,self.temb_proj])
143 |
144 | self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])
145 | '''
146 | Incorporate parameters together
147 | '''
148 |
149 | self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])
150 |
151 |
152 |
153 | if self.model_type in ['selected_diffusion','subgraph_diffusion','diffusion']:
154 | # denoising diffusion
155 | ## betas
156 | betas = get_beta_schedule(
157 | beta_schedule=config.beta_schedule,
158 | beta_start=config.beta_start,
159 | beta_end=config.beta_end,
160 | num_diffusion_timesteps=config.num_diffusion_timesteps,
161 | )
162 | betas = torch.from_numpy(betas).float()
163 | self.betas = nn.Parameter(betas, requires_grad=False)
164 | ## variances
165 | alphas = (1. - betas).cumprod(dim=0)
166 | self.alphas = nn.Parameter(alphas, requires_grad=False)
167 | self.num_timesteps = self.betas.size(0)
168 | elif self.model_type == 'dsm':
169 | # denoising score matching
170 | sigmas = torch.tensor(
171 | np.exp(np.linspace(np.log(config.sigma_begin), np.log(config.sigma_end),
172 | config.num_noise_level)), dtype=torch.float32)
173 | self.sigmas = nn.Parameter(sigmas, requires_grad=False) # (num_noise_level)
174 | self.num_timesteps = self.sigmas.size(0) # betas.shape[0]
175 |
176 |
177 |
178 | def forward(self, atom_type, pos, bond_index, bond_type, batch, time_step,
179 | edge_index=None, edge_type=None, edge_length=None, return_edges=False,
180 | extend_order=True, extend_radius=True, is_sidechain=None):
181 | """
182 | Args:
183 | atom_type: Types of atoms, (N, ).
184 | bond_index: Indices of bonds (not extended, not radius-graph), (2, E).
185 | bond_type: Bond types, (E, ).
186 | batch: Node index to graph index, (N, ).
187 | """
188 | N = atom_type.size(0)
189 | if edge_index is None or edge_type is None or edge_length is None:
190 | edge_index, edge_type = extend_graph_order_radius(
191 | num_nodes=N,
192 | pos=pos,
193 | edge_index=bond_index,
194 | edge_type=bond_type,
195 | batch=batch,
196 | order=self.config.edge_order,
197 | cutoff=self.config.cutoff,
198 | extend_order=extend_order,
199 | extend_radius=extend_radius,
200 | is_sidechain=is_sidechain,
201 | )
202 | edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)
203 | local_edge_mask = is_local_edge(edge_type) # (E, )
204 |
205 | # Emb time_step
206 | if self.model_type in ['selected_diffusion','diffusion',"subgraph_diffusion"]:
207 | # # timestep embedding
208 |
209 | if self.is_emb_time:
210 | temb = get_timestep_embedding(time_step, self.hidden_dim)
211 | temb = self.temb.dense[0](temb)
212 | # temb = self.nonlinearity(temb)
213 | # temb = self.temb.dense[1](temb)
214 | temb = self.temb_proj(self.nonlinearity(temb)) # (G, dim)
215 |
216 | sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)
217 |
218 | # Encoding global
219 | edge_attr_global = self.edge_encoder_global(
220 | edge_length=edge_length,
221 | edge_type=edge_type
222 | )
223 | # Embed edges
224 | # edge_attr += temb_edge
225 |
226 | # Global
227 | node_attr_global = self.encoder_global(
228 | z=atom_type,
229 | edge_index=edge_index,
230 | edge_length=edge_length,
231 | edge_attr=edge_attr_global,
232 | )
233 | if self.is_emb_time: node_attr_global = node_attr_global + 0.1*temb[batch]
234 | ## Assemble pairwise features
235 | # (h_i,h_j,e_ij)
236 | h_pair_global = assemble_atom_pair_feature(
237 | node_attr=node_attr_global,
238 | edge_index=edge_index,
239 | edge_attr=edge_attr_global,
240 | ) # (E_global, 2H)
241 | ## Invariant features of edges (radius graph, global)
242 | edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)
243 |
244 | # Encoding local
245 | edge_attr_local = self.edge_encoder_global(
246 | # edge_attr_local = self.edge_encoder_local(
247 | edge_length=edge_length,
248 | edge_type=edge_type
249 | ) # Embed edges
250 | # edge_attr += temb_edge
251 |
252 | # Local
253 | node_attr_local = self.encoder_local(
254 | z=atom_type,
255 | edge_index=edge_index[:, local_edge_mask],
256 | edge_attr=edge_attr_local[local_edge_mask],
257 | )
258 | if self.is_emb_time: node_attr_local = node_attr_local + 0.1*temb[batch]
259 | ## Assemble pairwise features
260 | h_pair_local = assemble_atom_pair_feature(
261 | node_attr=node_attr_local,
262 | edge_index=edge_index[:, local_edge_mask],
263 | edge_attr=edge_attr_local[local_edge_mask],
264 | ) # (E_local, 2H)
265 | ## Invariant features of edges (bond graph, local)
266 | if isinstance(sigma_edge, torch.Tensor):
267 | edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge[local_edge_mask]) # (E_local, 1)
268 | else:
269 | edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)
270 |
271 |
272 |
273 | self.node_attr_global = node_attr_global
274 | self.node_attr_local= node_attr_local
275 | self.batch = batch
276 |
277 | if return_edges:
278 | if self.model_type in ['selected_diffusion', "subgraph_diffusion"]:
279 | if self.mask_pred=="MLP":
280 | mask_emb=node_attr_global
281 | # mask_emb=torch.cat([node_attr_local,node_attr_global],dim=-1)
282 | node_mask_pred=self.mask_predictor(mask_emb)
283 | elif self.mask_pred=='2BMLP':
284 | mask_emb=torch.cat([node_attr_local,node_attr_global],dim=-1)
285 | node_mask_pred=self.mask_predictor(mask_emb)
286 | elif self.mask_pred.upper()=="GAT":
287 | mask_emb=node_attr_global
288 | node_mask_pred = self.mask_predictor(mask_emb,edge_index)
289 | return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask, node_mask_pred
290 | return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask
291 | else:
292 | return edge_inv_global, edge_inv_local
293 |
294 | def get_node_embeddings(self):
295 | self.node_embs = torch.cat([self.node_attr_local,self.node_attr_global],dim=-1)
296 |
297 | return self.node_embs
298 | def get_global_node_embeddings(self):
299 |
300 |
301 | return self.node_attr_global
302 |
303 | def get_graph_embeddings(self):
304 | # node_embs = torch.cat([self.node_attr_local,self.node_attr_global],dim=-1)
305 |
306 | # self.node_embs
307 | self.node_embs = torch.cat([self.node_attr_local,self.node_attr_global],dim=-1)
308 | self.graph_embs = scatter(self.node_embs, self.batch, dim=0, reduce='sum') # 【"add", "sum", "mean"]
309 | return self.graph_embs
310 | def get_graph_embeddings(self):
311 | self.node_embs = torch.cat([self.node_attr_local,self.node_attr_global],dim=-1)
312 | self.graph_embs = scatter(self.node_embs, self.batch, dim=0, reduce='sum') # 【"add", "sum", "mean"]
313 | return self.graph_embs
314 |
315 | import torch.nn.functional as F
316 | class ShiftedSoftplus(torch.nn.Module):
317 | def __init__(self):
318 | super(ShiftedSoftplus, self).__init__()
319 | self.shift = torch.log(torch.tensor(2.0)).item()
320 |
321 | def forward(self, x):
322 | return F.softplus(x) - self.shift
323 |
324 | class GraphPooling(torch.nn.Module):
325 | def __init__(self, hidden_channels, out_dim, readout="mean"):
326 | super(GraphPooling, self).__init__()
327 | self.lin1 = nn.Linear(hidden_channels, hidden_channels)
328 | self.act = ShiftedSoftplus()
329 | self.lin2 = nn.Linear(hidden_channels, hidden_channels)
330 | self.readout = readout
331 | self.output_layer = nn.Linear(hidden_channels, out_dim)
332 |
333 | def forward(self, x, batch):
334 | h = self.lin1(x)
335 | h = self.act(h)
336 | h = self.lin2(h)
337 | graph_embs = scatter(h, batch, dim=0, reduce=self.readout) # 【"add", "sum", "mean"]
338 | graph_embs = self.output_layer(graph_embs)
339 | return graph_embs
340 |
341 | def is_bond(edge_type):
342 | return torch.logical_and(edge_type < len(BOND_TYPES), edge_type > 0)
343 |
344 |
345 | def is_angle_edge(edge_type):
346 | return edge_type == len(BOND_TYPES) + 1 - 1
347 |
348 |
349 | def is_dihedral_edge(edge_type):
350 | return edge_type == len(BOND_TYPES) + 2 - 1
351 |
352 |
353 | def is_radius_edge(edge_type):
354 | return edge_type == 0
355 |
356 |
357 | def is_local_edge(edge_type):
358 | return edge_type > 0
359 |
360 |
361 | def is_train_edge(edge_index, is_sidechain):
362 | if is_sidechain is None:
363 | return torch.ones(edge_index.size(1), device=edge_index.device).bool()
364 | else:
365 | is_sidechain = is_sidechain.bool()
366 | return torch.logical_or(is_sidechain[edge_index[0]], is_sidechain[edge_index[1]])
367 |
368 |
369 | def regularize_bond_length(edge_type, edge_length, rng=5.0):
370 | mask = is_bond(edge_type).float().reshape(-1, 1)
371 | d = -torch.clamp(edge_length - rng, min=0.0, max=float('inf')) * mask
372 | return d
373 |
374 |
375 | def center_pos(pos, batch):
376 | pos_center = pos - scatter_mean(pos, batch, dim=0)[batch]
377 | return pos_center
378 |
379 |
380 | def clip_norm(vec, limit, p=2):
381 | norm = torch.norm(vec, dim=-1, p=2, keepdim=True)
382 | denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))
383 | return vec * denom
384 |
--------------------------------------------------------------------------------
/models/geometry.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_scatter import scatter_add
3 |
4 |
5 | def get_distance(pos, edge_index):
6 | return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)
7 |
8 |
9 | def eq_transform(score_d, pos, edge_index, edge_length):
10 | N = pos.size(0)
11 | dd_dr = (1. / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3) # (r_i-r_j)/d_{ij}
12 | # https://pytorch-scatter.readthedocs.io/en/1.3.0/functions/add.html
13 | score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) \
14 | + scatter_add(- dd_dr * score_d, edge_index[1], dim=0, dim_size=N) # (N, 3) (-dd_dr)= p_j-p_i; d_ij = -d_ji
15 | return score_pos
16 |
17 |
18 | def convert_cluster_score_d(cluster_score_d, cluster_pos, cluster_edge_index, cluster_edge_length, subgraph_index):
19 | """
20 | Args:
21 | cluster_score_d: (E_c, 1)
22 | subgraph_index: (N, )
23 | """
24 | cluster_score_pos = eq_transform(cluster_score_d, cluster_pos, cluster_edge_index, cluster_edge_length) # (C, 3)
25 | score_pos = cluster_score_pos[subgraph_index]
26 | return score_pos
27 |
28 |
29 | def get_angle(pos, angle_index):
30 | """
31 | Args:
32 | pos: (N, 3)
33 | angle_index: (3, A), left-center-right.
34 | """
35 | n1, ctr, n2 = angle_index # (A, )
36 | v1 = pos[n1] - pos[ctr] # (A, 3)
37 | v2 = pos[n2] - pos[ctr]
38 | inner_prod = torch.sum(v1 * v2, dim=-1, keepdim=True) # (A, 1)
39 | length_prod = torch.norm(v1, dim=-1, keepdim=True) * torch.norm(v2, dim=-1, keepdim=True) # (A, 1)
40 | angle = torch.acos(inner_prod / length_prod) # (A, 1)
41 | return angle
42 |
43 |
44 | def get_dihedral(pos, dihedral_index):
45 | """
46 | Args:
47 | pos: (N, 3)
48 | dihedral: (4, A)
49 | """
50 | n1, ctr1, ctr2, n2 = dihedral_index # (A, )
51 | v_ctr = pos[ctr2] - pos[ctr1] # (A, 3)
52 | v1 = pos[n1] - pos[ctr1]
53 | v2 = pos[n2] - pos[ctr2]
54 | n1 = torch.cross(v_ctr, v1, dim=-1) # Normal vectors of the two planes
55 | n2 = torch.cross(v_ctr, v2, dim=-1)
56 | inner_prod = torch.sum(n1 * n2, dim=1, keepdim=True) # (A, 1)
57 | length_prod = torch.norm(n1, dim=-1, keepdim=True) * torch.norm(n2, dim=-1, keepdim=True)
58 | dihedral = torch.acos(inner_prod / length_prod)
59 | return dihedral
60 |
61 |
62 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | # os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
3 | import argparse
4 | import pickle
5 | import yaml
6 | import torch
7 | from glob import glob
8 | from tqdm.auto import tqdm
9 | from easydict import EasyDict
10 |
11 | from models.epsnet import *
12 | from utils.datasets import *
13 | from utils.transforms import *
14 | from utils.misc import *
15 |
16 |
17 | def num_confs(num:str):
18 | if num.endswith('x'):
19 | return lambda x:x*int(num[:-1])
20 | elif int(num) > 0:
21 | return lambda x:int(num)
22 | else:
23 | raise ValueError()
24 |
25 |
26 | if __name__ == '__main__':
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument('--ckpt', type=str, help='path for loading the checkpoint')
29 | parser.add_argument('--config', type=str, default=None)
30 | parser.add_argument('--save_traj', action='store_true', default=False,
31 | help='whether store the whole trajectory for sampling')
32 | parser.add_argument('--not_use_mask', action='store_true', default=False,
33 | help='whether use mask in sampling')
34 | parser.add_argument('--resume', type=str, default=None)
35 | parser.add_argument('--tag', type=str, default=None)
36 | parser.add_argument('--num_confs', type=num_confs, default=num_confs('2x'))
37 | parser.add_argument('--test_set', type=str, default=None)
38 | parser.add_argument('--start_idx', type=int, default=0)
39 | parser.add_argument('--end_idx', type=int, default=200)
40 | parser.add_argument('--out_dir', type=str, default=None)
41 | parser.add_argument('--device', type=str, default='cuda:0')
42 | parser.add_argument('--clip', type=float, default=1000.0)
43 | parser.add_argument('--clip_local', type=float, default=20)
44 | parser.add_argument('--n_steps', type=int, default=5000,
45 | help='sampling num steps; for DSM framework, this means num steps for each noise scale')
46 | parser.add_argument('--global_start_sigma', type=float, default=0.5,
47 | help='enable global gradients only when noise is low')
48 | parser.add_argument('--w_global', type=float, default=0.3,
49 | help='weight for global gradients')
50 | # Parameters for DDPM
51 | parser.add_argument('--sampling_type', type=str, default='ddpm_noisy',
52 | help='generalized, ddpm_noisy, ld: sampling method for DDIM, DDPM or Langevin Dynamics')
53 | parser.add_argument('--eta', type=float, default=1.0,
54 | help='weight for DDIM and DDPM: 0->DDIM, 1->DDPM')
55 | args = parser.parse_args()
56 |
57 | # Load checkpoint
58 | ckpt = torch.load(args.ckpt, map_location=torch.device('cpu'))
59 | if args.config is None:
60 | config_path = glob(os.path.join(os.path.dirname(os.path.dirname(args.ckpt)), '*.yml'))[0]
61 | else:
62 | config_path =args.config
63 | with open(config_path, 'r') as f:
64 | config = EasyDict(yaml.safe_load(f))
65 | seed_all(config.train.seed)
66 | log_dir = os.path.dirname(os.path.dirname(args.ckpt))
67 |
68 | # Logging
69 | if args.tag is None:
70 | args.tag=os.path.basename(os.path.dirname(args.test_set))
71 | output_dir = get_new_log_dir(log_dir, 'sample', tag=args.tag)
72 | logger = get_logger('test', output_dir)
73 | # Datasets and loaders
74 | logger.info(args)
75 | logger.info('Loading datasets...')
76 | transforms = Compose([
77 | CountNodesPerGraph(),
78 | AddHigherOrderEdges(order=config.model.edge_order), # Offline edge augmentation
79 | ])
80 | if args.test_set is None:
81 | test_set = PackedConformationDataset(config.dataset.test, transform=transforms)
82 | logger.info(f'Loading {config.dataset.test}')
83 | else:
84 | test_set = PackedConformationDataset(args.test_set, transform=transforms)
85 | logger.info(f'Loading {args.test_set}')
86 | # Model
87 | logger.info('Loading model...')
88 | model = get_model(ckpt['config'].model).to(args.device)
89 | model.load_state_dict(ckpt['model'])
90 |
91 | test_set_selected = []
92 | for i, data in enumerate(test_set):
93 | if not (args.start_idx <= i < args.end_idx): continue
94 | if args.tag=="qm92drugs":
95 | atom_set = set(atom.GetSymbol() for atom in data.rdmol.GetAtoms())
96 | qm9_set={'H','C', 'N', 'O', 'F'}
97 | if atom_set <= qm9_set:
98 | test_set_selected.append(data)
99 | else:
100 | test_set_selected.append(data)
101 |
102 |
103 | done_smiles = set()
104 | results = []
105 | if args.resume is not None:
106 | with open(args.resume, 'rb') as f:
107 | results = pickle.load(f)
108 | for data in results:
109 | done_smiles.add(data.smiles)
110 |
111 | for i, data in enumerate(tqdm(test_set_selected)):
112 | if data.smiles in done_smiles:
113 | logger.info('Molecule#%d is already done.' % i)
114 | continue
115 |
116 | num_refs = data.pos_ref.size(0) // data.num_nodes
117 | num_samples = args.num_confs(num_refs)
118 |
119 | data_input = data.clone()
120 | data_input['pos_ref'] = None
121 | batch = repeat_data(data_input, num_samples).to(args.device)
122 |
123 | clip_local = None
124 | for try_n in range(3): # Maximum number of retry
125 | try:
126 | pos_init = torch.randn(batch.num_nodes, 3).to(args.device)
127 | pos_gen, pos_gen_traj = model.langevin_dynamics_sample_diffusion_subgraph(
128 | atom_type=batch.atom_type,
129 | pos_init=pos_init,
130 | bond_index=batch.edge_index,
131 | bond_type=batch.edge_type,
132 | batch=batch.batch,
133 | num_graphs=batch.num_graphs,
134 | extend_order=False, # Done in transforms.
135 | n_steps=args.n_steps,
136 | step_lr=1e-6,
137 | w_global=args.w_global,
138 | global_start_sigma=args.global_start_sigma,
139 | clip=args.clip,
140 | clip_local=clip_local,
141 | sampling_type=args.sampling_type,
142 | eta=args.eta,
143 | use_mask=not args.not_use_mask
144 | )
145 | pos_gen = pos_gen.cpu()
146 | if args.save_traj:
147 | data.pos_gen = torch.stack(pos_gen_traj)
148 | else:
149 | data.pos_gen = pos_gen
150 | results.append(data)
151 | done_smiles.add(data.smiles)
152 |
153 | save_path = os.path.join(output_dir, 'samples_%d.pkl' % i)
154 | logger.info('Saving samples to: %s' % save_path)
155 | with open(save_path, 'wb') as f:
156 | pickle.dump(results, f)
157 |
158 | break # No errors occured, break the retry loop
159 | except FloatingPointError:
160 | # clip_local = 100
161 | if try_n==1:
162 | clip_local = args.clip_local
163 | logger.warning(f'Retrying with local clipping. clip_local={clip_local}')
164 | if try_n==2:
165 | clip_local = args.clip_local//2
166 | logger.warning(f'Retrying with local clipping. clip_local={clip_local}')
167 | seed_all(config.train.seed + try_n)
168 | pass
169 |
170 | save_path = os.path.join(output_dir, 'samples_all.pkl')
171 | logger.info('Saving samples to: %s' % save_path)
172 | logger.info(output_dir)
173 | def get_mol_key(data):
174 | for i, d in enumerate(test_set_selected):
175 | if d.smiles == data.smiles:
176 | return i
177 | return -1
178 | results.sort(key=get_mol_key)
179 |
180 | with open(save_path, 'wb') as f:
181 | pickle.dump(results, f)
182 |
183 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | python test.py --ckpt checkpoints/qm9_500steps/2000000.pt --config checkpoints/qm9_500steps/qm9_500steps.yml --test_set './data/GEOM/QM9/test_data_1k.pkl' --start_idx 800 --end_idx 1000 --sampling_type same_mask_noisy --n_steps 500 --device cuda:1 --w_global 0.1 --clip 1000 --clip_local 20 --global_start_sigma 5 --tag SubGDiff500
2 |
3 | python test.py --ckpt checkpoints/qm9_200steps/2000000.pt --config checkpoints/qm9_200steps/qm9_200steps.yml --test_set './data/GEOM/QM9/test_data_1k.pkl' --start_idx 800 --end_idx 1000 --sampling_type same_mask_noisy --n_steps 200 --device cuda:0 --w_global 0.1 --clip 1000 --clip_local 20 --global_start_sigma 5 --tag SubGDiff200
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import yaml
5 | from easydict import EasyDict
6 | from tqdm.auto import tqdm
7 | from glob import glob
8 | import torch
9 | import torch.utils.tensorboard
10 | from torch.nn.utils import clip_grad_norm_
11 | from torch_geometric.data import DataLoader
12 |
13 | from models.epsnet import get_model
14 | from utils.datasets import ConformationDataset
15 | from utils.transforms import *
16 | from utils.misc import *
17 | from utils.common import get_optimizer, get_scheduler
18 |
19 |
20 | if __name__ == '__main__':
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--config', type=str, default='./configs/qm9_500steps.yml')
23 | parser.add_argument('--device', type=str, default='cuda:4')
24 | parser.add_argument('--resume_iter', type=int, default=None)
25 | parser.add_argument('--logdir', type=str, default='./logs')
26 | # parser.add_argument('--logdir', type=str, default='./logs')
27 | args = parser.parse_args()
28 |
29 |
30 | resume = os.path.isdir(args.config)
31 | if resume:
32 | config_path = glob(os.path.join(args.config, '*.yml'))[0]
33 | resume_from = args.config
34 | else:
35 | config_path = args.config
36 |
37 | with open(config_path, 'r') as f:
38 | config = EasyDict(yaml.safe_load(f))
39 | config_name = os.path.basename(config_path)[:os.path.basename(config_path).rfind('.')]
40 | seed_all(config.train.seed)
41 |
42 |
43 | # Logging
44 | if resume:
45 | log_dir = get_new_log_dir(args.logdir, prefix=config_name, tag='resume')
46 | os.symlink(os.path.realpath(resume_from), os.path.join(log_dir, os.path.basename(resume_from.rstrip("/"))))
47 | else:
48 | log_dir = get_new_log_dir(args.logdir, prefix=config_name)
49 | shutil.copytree('./models', os.path.join(log_dir, 'models'))
50 | ckpt_dir = os.path.join(log_dir, 'checkpoints')
51 | os.makedirs(ckpt_dir, exist_ok=True)
52 | logger = get_logger('train', log_dir)
53 | writer = torch.utils.tensorboard.SummaryWriter(log_dir)
54 | logger.info(args)
55 | logger.info(config)
56 | shutil.copyfile(config_path, os.path.join(log_dir, os.path.basename(config_path)))
57 |
58 | # Datasets and loaders
59 | logger.info('Loading datasets...')
60 | noise_transforms=None
61 | if config.model.type=='subgraph_diffusion':
62 | from utils.transforms import SubgraphNoiseTransform
63 | noise_transforms= SubgraphNoiseTransform(config.model)
64 |
65 | transforms = CountNodesPerGraph()
66 | train_set = ConformationDataset(config.dataset.train, transform=transforms,noise_transform=noise_transforms,config=config.model)
67 | val_set = ConformationDataset(config.dataset.val, transform=transforms, noise_transform=noise_transforms,config=config.model)
68 | train_iterator = inf_iterator(DataLoader(train_set, config.train.batch_size, shuffle=True))
69 | val_loader = DataLoader(val_set, config.train.batch_size, shuffle=False)
70 |
71 | # Model
72 | logger.info('Building model...')
73 | model = get_model(config.model).to(args.device)
74 | # Optimizer
75 | optimizer_global = get_optimizer(config.train.optimizer, model.model_global)
76 | optimizer_local = get_optimizer(config.train.optimizer, model.model_local)
77 | scheduler_global = get_scheduler(config.train.scheduler, optimizer_global)
78 | scheduler_local = get_scheduler(config.train.scheduler, optimizer_local)
79 | start_iter = 1
80 |
81 | # Resume from checkpoint
82 | if resume:
83 | ckpt_path, start_iter = get_checkpoint_path(os.path.join(resume_from, 'checkpoints'), it=args.resume_iter)
84 | logger.info('Resuming from: %s' % ckpt_path)
85 | logger.info('Iteration: %d' % start_iter)
86 | ckpt = torch.load(ckpt_path)
87 | model.load_state_dict(ckpt['model'])
88 | optimizer_global.load_state_dict(ckpt['optimizer_global'])
89 | optimizer_local.load_state_dict(ckpt['optimizer_local'])
90 | scheduler_global.load_state_dict(ckpt['scheduler_global'])
91 | scheduler_local.load_state_dict(ckpt['scheduler_local'])
92 |
93 | def train(it):
94 | model.train()
95 | optimizer_global.zero_grad()
96 | optimizer_local.zero_grad()
97 | batch = next(train_iterator).to(args.device)
98 | loss, loss_global, loss_local = model.get_loss(
99 | data=batch,
100 | atom_type=batch.atom_type,
101 | pos=batch.pos,
102 | bond_index=batch.edge_index,
103 | bond_type=batch.edge_type,
104 | batch=batch.batch,
105 | num_nodes_per_graph=batch.num_nodes_per_graph,
106 | num_graphs=batch.num_graphs,
107 | anneal_power=config.train.anneal_power,
108 | return_unreduced_loss=True
109 | )
110 | loss = loss.mean()
111 | loss.backward()
112 | orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm)
113 | optimizer_global.step()
114 | optimizer_local.step()
115 |
116 | logger.info('[Train] Iter %05d | Loss %.2f | Loss(Global) %.2f | Loss(Local) %.2f | Grad %.2f | LR(Global) %.6f | LR(Local) %.6f |%s' % (
117 | it, loss.item(), loss_global.mean().item(), loss_local.mean().item(), orig_grad_norm, optimizer_global.param_groups[0]['lr'], optimizer_local.param_groups[0]['lr'], log_dir
118 | ))
119 | writer.add_scalar('train/loss', loss, it)
120 | writer.add_scalar('train/loss_global', loss_global.mean(), it)
121 | writer.add_scalar('train/loss_local', loss_local.mean(), it)
122 | writer.add_scalar('train/lr_global', optimizer_global.param_groups[0]['lr'], it)
123 | writer.add_scalar('train/lr_local', optimizer_local.param_groups[0]['lr'], it)
124 | writer.add_scalar('train/grad_norm', orig_grad_norm, it)
125 | writer.flush()
126 |
127 | def validate(it):
128 | sum_loss, sum_n = 0, 0
129 | sum_loss_global, sum_n_global = 0, 0
130 | sum_loss_local, sum_n_local = 0, 0
131 | with torch.no_grad():
132 | model.eval()
133 | for i, batch in enumerate(tqdm(val_loader, desc='Validation')):
134 | batch = batch.to(args.device)
135 | loss, loss_global, loss_local = model.get_loss(
136 | data=batch,
137 | atom_type=batch.atom_type,
138 | pos=batch.pos,
139 | bond_index=batch.edge_index,
140 | bond_type=batch.edge_type,
141 | batch=batch.batch,
142 | num_nodes_per_graph=batch.num_nodes_per_graph,
143 | num_graphs=batch.num_graphs,
144 | anneal_power=config.train.anneal_power,
145 | return_unreduced_loss=True
146 | )
147 | sum_loss += loss.sum().item()
148 | sum_n += loss.size(0)
149 | sum_loss_global += loss_global.sum().item()
150 | sum_n_global += loss_global.size(0)
151 | sum_loss_local += loss_local.sum().item()
152 | sum_n_local += loss_local.size(0)
153 | avg_loss = sum_loss / sum_n
154 | avg_loss_global = sum_loss_global / sum_n_global
155 | avg_loss_local = sum_loss_local / sum_n_local
156 |
157 | if config.train.scheduler.type == 'plateau':
158 | scheduler_global.step(avg_loss_global)
159 | scheduler_local.step(avg_loss_local)
160 | else:
161 | scheduler_global.step()
162 | scheduler_local.step()
163 |
164 | logger.info('[Validate] Iter %05d | Loss %.6f | Loss(Global) %.6f | Loss(Local) %.6f' % (
165 | it, avg_loss, avg_loss_global, avg_loss_local,
166 | ))
167 | writer.add_scalar('val/loss', avg_loss, it)
168 | writer.add_scalar('val/loss_global', avg_loss_global, it)
169 | writer.add_scalar('val/loss_local', avg_loss_local, it)
170 | writer.flush()
171 | return avg_loss
172 |
173 | try:
174 | for it in range(start_iter, config.train.max_iters + 1):
175 | train(it)
176 | # TODO if avg_val_loss < : save
177 | if it % config.train.val_freq == 0 or it == config.train.max_iters:
178 | avg_val_loss = validate(it)
179 | ckpt_path = os.path.join(ckpt_dir, '%d.pt' % it)
180 | torch.save({
181 | 'config': config,
182 | 'model': model.state_dict(),
183 | 'optimizer_global': optimizer_global.state_dict(),
184 | 'scheduler_global': scheduler_global.state_dict(),
185 | 'optimizer_local': optimizer_local.state_dict(),
186 | 'scheduler_local': scheduler_local.state_dict(),
187 | 'iteration': it,
188 | 'avg_val_loss': avg_val_loss,
189 | }, ckpt_path)
190 | except KeyboardInterrupt:
191 | logger.info('Terminating...')
192 |
193 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | python train_dist.py --config ./configs/qm9_200steps.yml --device cuda:0 --logdir ./logs --tag qm9_500steps --n_jobs 12 --print_freq 200
2 | python train_dist.py --config ./configs/qm9_500steps.yml --device cuda:1 --logdir ./logs --tag qm9_200steps --n_jobs 12 --print_freq 200
3 |
--------------------------------------------------------------------------------
/train_dist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import yaml
5 | from easydict import EasyDict
6 | from tqdm.auto import tqdm
7 | from glob import glob
8 | import torch
9 | import torch.utils.tensorboard
10 | from torch.nn.utils import clip_grad_norm_
11 | from torch_geometric.data import DataLoader
12 |
13 |
14 |
15 |
16 | from models.epsnet import get_model
17 | from utils.datasets import ConformationDataset
18 | from utils.transforms import *
19 | from utils.misc import *
20 | from utils.common import get_optimizer, get_scheduler
21 |
22 |
23 | import torch.distributed as dist
24 | def setup_dist(args, port=None, backend="nccl", verbose=False):
25 | # TODO
26 | return rank, local_rank, world_size, device
27 |
28 | def reduce_mean(tensor, nprocs):
29 | rt = tensor.clone()
30 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
31 | rt = rt/ nprocs
32 | return rt
33 |
34 | if __name__ == '__main__':
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument('--config', type=str, default='./configs/qm9_500steps.yml')
37 | parser.add_argument('--device', type=str, default='cuda:4')
38 |
39 | parser.add_argument('--resume_iter', type=int, default=None)
40 | parser.add_argument('--logdir', type=str, default='./logs')
41 | parser.add_argument('--distribution', action='store_true', default=False,
42 | help='enable ddp running')
43 | parser.add_argument('--tag', type=str, default='', help="just for marking the experiments infomation")
44 | parser.add_argument('--n_jobs', type=int, default=2, help="Dataloader cpu ")
45 | parser.add_argument('--print_freq', type=int, default=50, help="")
46 | args = parser.parse_args()
47 |
48 | args.distribution=False # torch.dist
49 |
50 | resume = os.path.isdir(args.config)
51 | if resume:
52 | config_path = glob(os.path.join(args.config, '*.yml'))[0]
53 | resume_from = args.config
54 | else:
55 | config_path = args.config
56 |
57 | with open(config_path, 'r') as f:
58 | config = EasyDict(yaml.safe_load(f))
59 | config_name = os.path.basename(config_path)[:os.path.basename(config_path).rfind('.')]
60 | seed_all(config.train.seed)
61 |
62 | args.local_rank = int(args.device.split(":")[-1])
63 | if 0 and args.distribution:
64 | rank, local_rank, world_size, device = setup_dist(args, verbose=True)
65 | args.device=device
66 | args.local_rank=local_rank
67 | setattr(config, 'local_rank', local_rank)
68 | setattr(config, 'world_size', world_size)
69 | setattr(config, 'tag', args.tag)
70 |
71 |
72 | master_worker = (rank == 0) if args.distribution else True
73 | args.nprocs = torch.cuda.device_count()
74 |
75 |
76 | if master_worker:
77 | # Logging
78 | if resume:
79 | log_dir = get_new_log_dir(args.logdir, prefix=config_name+args.tag, tag='resume')
80 | os.symlink(os.path.realpath(resume_from), os.path.join(log_dir, os.path.basename(resume_from.rstrip("/"))))
81 | else:
82 | log_dir = get_new_log_dir(args.logdir, prefix=config_name+args.tag)
83 | shutil.copytree('./models', os.path.join(log_dir, 'models'))
84 | shutil.copytree('./utils', os.path.join(log_dir, 'utils'))
85 | ckpt_dir = os.path.join(log_dir, 'checkpoints')
86 | os.makedirs(ckpt_dir, exist_ok=True)
87 | logger = get_logger('train', log_dir)
88 | writer = torch.utils.tensorboard.SummaryWriter(log_dir)
89 | logger.info(args)
90 | logger.info(config)
91 | shutil.copyfile(config_path, os.path.join(log_dir, os.path.basename(config_path)))
92 |
93 | # Datasets and loaders
94 | if master_worker: logger.info('Loading datasets...')
95 | noise_transforms=None
96 | if config.model.type=='subgraph_diffusion':
97 | from utils.transforms import SubgraphNoiseTransform
98 | noise_transforms= SubgraphNoiseTransform(config.model, tag=args.tag)
99 | # noise_transform_ddpm= SubgraphNoiseTransform(config.model, tag=args.tag, ddpm=False)
100 |
101 | transforms = CountNodesPerGraph()
102 | train_set = ConformationDataset(config.dataset.train, transform=transforms, noise_transform=noise_transforms,config=config.model)
103 | val_set = ConformationDataset(config.dataset.val, transform=transforms, noise_transform=noise_transforms,config=config.model)
104 | train_iterator = inf_iterator(DataLoader(train_set, config.train.batch_size, num_workers=args.n_jobs, shuffle=True))
105 | val_loader = DataLoader(val_set, config.train.batch_size, num_workers=args.n_jobs, shuffle=False)
106 |
107 |
108 | if args.distribution:
109 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
110 | train_iterator = inf_iterator(DataLoader(train_set, config.train.batch_size, num_workers=args.n_jobs, shuffle=False,sampler=train_sampler))
111 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_set)
112 | val_loader = DataLoader(val_set, config.train.batch_size, shuffle=False, num_workers=args.n_jobs, sampler=val_sampler)
113 |
114 | # Model
115 | if master_worker: logger.info('Building model...')
116 | model = get_model(config.model).to(args.device)
117 | # model = get_model(config.model).cuda()
118 |
119 | # Optimizer
120 | optimizer_global = get_optimizer(config.train.optimizer, model.model_global) # module.
121 | optimizer_local = get_optimizer(config.train.optimizer, model.model_local)
122 | optimizer_mask = get_optimizer(config.train.optimizer, model.model_mask)
123 | scheduler_global = get_scheduler(config.train.scheduler, optimizer_global)
124 | scheduler_local = get_scheduler(config.train.scheduler, optimizer_local)
125 | scheduler_mask = get_scheduler(config.train.scheduler, optimizer_mask)
126 | start_iter = 1
127 |
128 |
129 | if args.distribution:
130 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
131 |
132 | # Resume from checkpoint
133 | if resume:
134 | ckpt_path, start_iter = get_checkpoint_path(os.path.join(resume_from, 'checkpoints'), it=args.resume_iter)
135 | logger.info('Resuming from: %s' % ckpt_path)
136 | logger.info('Iteration: %d' % start_iter)
137 | ckpt = torch.load(ckpt_path)
138 | model.load_state_dict(ckpt['model'])
139 | optimizer_global.load_state_dict(ckpt['optimizer_global'])
140 | optimizer_local.load_state_dict(ckpt['optimizer_local'])
141 | try: optimizer_mask.load_state_dict(ckpt['optimizer_mask'])
142 | except: pass
143 | scheduler_global.load_state_dict(ckpt['scheduler_global'])
144 | scheduler_local.load_state_dict(ckpt['scheduler_local'])
145 | try: scheduler_mask.load_state_dict(ckpt['scheduler_mask'])
146 | except: pass
147 |
148 | def train(it):
149 | model.train()
150 | optimizer_global.zero_grad()
151 | optimizer_local.zero_grad()
152 | optimizer_mask.zero_grad()
153 | ddpm_step = config.train.max_iters//2
154 | batch = next(train_iterator).to(args.device)
155 |
156 | if args.distribution:
157 | loss_func=model.module.get_loss
158 | else:
159 | loss_func=model.get_loss
160 |
161 | loss, loss_global, loss_local, loss_mask = loss_func(
162 | data=batch,
163 | atom_type=batch.atom_type,
164 | pos=batch.pos,
165 | bond_index=batch.edge_index,
166 | bond_type=batch.edge_type,
167 | batch=batch.batch,
168 | num_nodes_per_graph=batch.num_nodes_per_graph,
169 | num_graphs=batch.num_graphs,
170 | anneal_power=config.train.anneal_power,
171 | return_unreduced_loss=True
172 | )
173 | loss_mask = loss_mask.mean()
174 | if hasattr(batch,"last_select"):
175 | sum_selected = batch.last_select.sum()
176 | loss = loss.sum()/sum_selected
177 | loss_global=loss_global.sum()/sum_selected
178 | loss_local = loss_local.sum()/sum_selected
179 | else:
180 | loss = loss.mean()
181 | loss_global=loss_global.mean()
182 | loss_local = loss_local.mean()
183 |
184 | if args.distribution:
185 |
186 | reduced_loss =reduce_mean(loss, args.nprocs)
187 | reduced_loss_global = reduce_mean(loss_global, args.nprocs)
188 | reduced_loss_local = reduce_mean(loss_local, args.nprocs)
189 | reduced_loss_mask = reduce_mean(loss_mask, args.nprocs)
190 |
191 | loss.backward()
192 | orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm)
193 | optimizer_global.step()
194 | optimizer_local.step()
195 | optimizer_mask.step()
196 |
197 | if master_worker and (it-1) % args.print_freq == 0:
198 | if args.distribution:
199 | loss=reduced_loss
200 | loss_global =reduced_loss_global
201 | loss_local = reduced_loss_local
202 | loss_mask = reduced_loss_mask
203 | logger.info('[Train] Epoch %05d | Iter %05d | Loss %.2f | Loss(Global) %.2f | Loss(Local) %.2f | Loss(mask) %.2f | Grad %.2f | LR(Global) %.6f | LR(Local) %.6f| LR(mask) %.6f|%s' % (
204 | (it*config.train.batch_size)//len(train_set), it, loss.item(), loss_global.item(), loss_local.item(), loss_mask.item(), orig_grad_norm, optimizer_global.param_groups[0]['lr'],
205 | optimizer_local.param_groups[0]['lr'],optimizer_mask.param_groups[0]['lr'], log_dir
206 | ))
207 | writer.add_scalar('train/loss', loss, it)
208 | writer.add_scalar('train/loss_global', loss_global.mean(), it)
209 | writer.add_scalar('train/loss_local', loss_local.mean(), it)
210 | writer.add_scalar('train/loss_mask', loss_mask.mean(), it)
211 | writer.add_scalar('train/lr_global', optimizer_global.param_groups[0]['lr'], it)
212 | writer.add_scalar('train/lr_local', optimizer_local.param_groups[0]['lr'], it)
213 | writer.add_scalar('train/lr_mask', optimizer_mask.param_groups[0]['lr'], it)
214 | writer.add_scalar('train/grad_norm', orig_grad_norm, it)
215 | writer.flush()
216 |
217 | def validate(it):
218 | sum_loss, sum_n = torch.tensor(0.0).to(args.local_rank), 0
219 | sum_loss_global, sum_n_global = torch.tensor(0.0).to(args.local_rank), 0
220 | sum_loss_local, sum_n_local = torch.tensor(0.0).to(args.local_rank), 0
221 | sum_loss_mask, sum_n_mask = torch.tensor(0.0).to(args.local_rank), 0
222 | # print("validate....",local_rank,end=' | ')
223 | with torch.no_grad():
224 | model.eval()
225 | for i, batch in enumerate(tqdm(val_loader, desc='Validation',disable=not master_worker)):
226 | batch = batch.to(args.local_rank)
227 | if args.distribution:
228 | loss_func=model.module.get_loss
229 | else:
230 | loss_func=model.get_loss
231 |
232 | loss, loss_global, loss_local, loss_mask = loss_func(
233 | data=batch,
234 | atom_type=batch.atom_type,
235 | pos=batch.pos,
236 | bond_index=batch.edge_index,
237 | bond_type=batch.edge_type,
238 | batch=batch.batch,
239 | num_nodes_per_graph=batch.num_nodes_per_graph,
240 | num_graphs=batch.num_graphs,
241 | anneal_power=config.train.anneal_power,
242 | return_unreduced_loss=True
243 | )
244 | sum_loss += loss.sum().item()
245 | sum_loss_local += loss_local.sum().item()
246 | sum_loss_global += loss_global.sum().item()
247 | sum_loss_mask += loss_mask.sum().item()
248 | sum_n_mask += loss_mask.size(0)
249 | if hasattr(batch, "last_select"):
250 | sum_selected = batch.last_select.sum()
251 | sum_n +=sum_selected
252 | sum_n_local +=sum_selected
253 | sum_n_global +=sum_selected
254 |
255 | else:
256 | sum_n += loss.size(0)
257 | sum_n_local += loss_local.size(0)
258 | sum_n_global += loss_global.size(0)
259 |
260 | avg_loss = sum_loss / sum_n
261 | avg_loss_global = sum_loss_global / sum_n_global
262 | avg_loss_local = sum_loss_local / sum_n_local
263 | avg_loss_mask = sum_loss_mask / sum_n_mask
264 |
265 | if args.distribution:
266 | dist.barrier()
267 | avg_loss =reduce_mean(avg_loss, args.nprocs)
268 | avg_loss_global = reduce_mean(avg_loss_global, args.nprocs)
269 | avg_loss_local = reduce_mean(avg_loss_local, args.nprocs)
270 | avg_loss_mask = reduce_mean(avg_loss_mask, args.nprocs)
271 |
272 |
273 | if config.train.scheduler.type == 'plateau':
274 | scheduler_global.step(avg_loss_global)
275 | scheduler_local.step(avg_loss_local)
276 | # scheduler_mask.step(avg_loss_mask)
277 | else:
278 | scheduler_global.step()
279 | scheduler_local.step()
280 | scheduler_mask.step()
281 |
282 | if master_worker:
283 | logger.info('[Validate] Iter %05d | Loss %.6f | Loss(Global) %.6f | Loss(Local) %.6f | Loss(mask) %.6f' % (
284 | it, avg_loss, avg_loss_global, avg_loss_local, avg_loss_mask
285 | ))
286 | writer.add_scalar('val/loss', avg_loss, it)
287 | writer.add_scalar('val/loss_global', avg_loss_global, it)
288 | writer.add_scalar('val/loss_local', avg_loss_local, it)
289 | writer.add_scalar('val/loss_mask', avg_loss_mask, it)
290 | writer.flush()
291 | return avg_loss
292 |
293 |
294 | if master_worker: print("training....")
295 | try:
296 | for it in range(start_iter, config.train.max_iters + 1):
297 |
298 | # train_sampler.set_epoch(it)
299 | train(it)
300 | # TODO if avg_val_loss < : save
301 | if it % config.train.val_freq == 0 or it == config.train.max_iters:
302 | avg_val_loss = validate(it)
303 | if master_worker and (it % 20000 == 0 or it == config.train.max_iters):
304 | # print("saving checkpoint....")
305 | ckpt_path = os.path.join(ckpt_dir, '%d.pt' % it)
306 | torch.save({
307 | 'config': config,
308 | 'model': model.state_dict(),
309 | 'optimizer_global': optimizer_global.state_dict(),
310 | 'scheduler_global': scheduler_global.state_dict(),
311 | 'optimizer_local': optimizer_local.state_dict(),
312 | 'scheduler_local': scheduler_local.state_dict(),
313 | 'optimizer_mask': optimizer_mask.state_dict(),
314 | 'scheduler_mask': scheduler_mask.state_dict(),
315 | 'iteration': it,
316 | 'avg_val_loss': avg_val_loss,
317 | }, ckpt_path)
318 | except KeyboardInterrupt:
319 | if master_worker: logger.info('Terminating...')
320 |
321 |
322 |
323 |
324 |
--------------------------------------------------------------------------------
/utils/chem.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import torch
3 | from torchvision.transforms.functional import to_tensor
4 | import rdkit
5 | import rdkit.Chem.Draw
6 | from rdkit import Chem
7 | from rdkit.Chem import rdDepictor as DP
8 | from rdkit.Chem import PeriodicTable as PT
9 | from rdkit.Chem import rdMolAlign as MA
10 | from rdkit.Chem.rdchem import BondType as BT
11 | from rdkit.Chem.rdchem import Mol,GetPeriodicTable
12 | from rdkit.Chem.Draw import rdMolDraw2D as MD2
13 | from rdkit.Chem.rdmolops import RemoveHs
14 | from typing import List, Tuple
15 |
16 |
17 | BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}
18 | BOND_NAMES = {i: t for i, t in enumerate(BT.names.keys())}
19 |
20 |
21 | def set_conformer_positions(conf, pos):
22 | for i in range(pos.shape[0]):
23 | conf.SetAtomPosition(i, pos[i].tolist())
24 | return conf
25 |
26 |
27 | def draw_mol_image(rdkit_mol, tensor=False):
28 | rdkit_mol.UpdatePropertyCache()
29 | img = rdkit.Chem.Draw.MolToImage(rdkit_mol, kekulize=False)
30 | if tensor:
31 | return to_tensor(img)
32 | else:
33 | return img
34 |
35 |
36 | def update_data_rdmol_positions(data):
37 | for i in range(data.pos.size(0)):
38 | data.rdmol.GetConformer(0).SetAtomPosition(i, data.pos[i].tolist())
39 | return data
40 |
41 |
42 | def update_data_pos_from_rdmol(data):
43 | new_pos = torch.FloatTensor(data.rdmol.GetConformer(0).GetPositions()).to(data.pos)
44 | data.pos = new_pos
45 | return data
46 |
47 |
48 | def set_rdmol_positions(rdkit_mol, pos):
49 | """
50 | Args:
51 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.
52 | pos: (N_atoms, 3)
53 | """
54 | mol = deepcopy(rdkit_mol)
55 | set_rdmol_positions_(mol, pos)
56 | return mol
57 |
58 |
59 | def set_rdmol_positions_(mol, pos):
60 | """
61 | Args:
62 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.
63 | pos: (N_atoms, 3)
64 | """
65 | for i in range(pos.shape[0]):
66 | mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())
67 | return mol
68 |
69 |
70 | def get_atom_symbol(atomic_number):
71 | return PT.GetElementSymbol(GetPeriodicTable(), atomic_number)
72 |
73 |
74 | def mol_to_smiles(mol: Mol) -> str:
75 | return Chem.MolToSmiles(mol, allHsExplicit=True)
76 |
77 |
78 | def mol_to_smiles_without_Hs(mol: Mol) -> str:
79 | return Chem.MolToSmiles(Chem.RemoveHs(mol))
80 |
81 |
82 | def remove_duplicate_mols(molecules: List[Mol]) -> List[Mol]:
83 | unique_tuples: List[Tuple[str, Mol]] = []
84 |
85 | for molecule in molecules:
86 | duplicate = False
87 | smiles = mol_to_smiles(molecule)
88 | for unique_smiles, _ in unique_tuples:
89 | if smiles == unique_smiles:
90 | duplicate = True
91 | break
92 |
93 | if not duplicate:
94 | unique_tuples.append((smiles, molecule))
95 |
96 | return [mol for smiles, mol in unique_tuples]
97 |
98 |
99 | def get_atoms_in_ring(mol):
100 | atoms = set()
101 | for ring in mol.GetRingInfo().AtomRings():
102 | for a in ring:
103 | atoms.add(a)
104 | return atoms
105 |
106 |
107 | def get_2D_mol(mol):
108 | mol = deepcopy(mol)
109 | DP.Compute2DCoords(mol)
110 | return mol
111 |
112 |
113 | def draw_mol_svg(mol,molSize=(450,150),kekulize=False):
114 | mc = Chem.Mol(mol.ToBinary())
115 | if kekulize:
116 | try:
117 | Chem.Kekulize(mc)
118 | except:
119 | mc = Chem.Mol(mol.ToBinary())
120 | if not mc.GetNumConformers():
121 | DP.Compute2DCoords(mc)
122 | drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])
123 | drawer.DrawMolecule(mc)
124 | drawer.FinishDrawing()
125 | svg = drawer.GetDrawingText()
126 | # It seems that the svg renderer used doesn't quite hit the spec.
127 | # Here are some fixes to make it work in the notebook, although I think
128 | # the underlying issue needs to be resolved at the generation step
129 | # return svg.replace('svg:','')
130 | return svg
131 |
132 |
133 | def get_best_rmsd(probe, ref):
134 | probe = RemoveHs(probe)
135 | ref = RemoveHs(ref)
136 | rmsd = MA.GetBestRMS(probe, ref)
137 | return rmsd
138 |
139 |
--------------------------------------------------------------------------------
/utils/common.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import warnings
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch_geometric.data import Data, Batch
8 |
9 |
10 | #customize exp lr scheduler with min lr
11 | class ExponentialLR_with_minLr(torch.optim.lr_scheduler.ExponentialLR):
12 | def __init__(self, optimizer, gamma, min_lr=1e-4, last_epoch=-1, verbose=False):
13 | self.gamma = gamma
14 | self.min_lr = min_lr
15 | super(ExponentialLR_with_minLr, self).__init__(optimizer, gamma, last_epoch, verbose)
16 |
17 | def get_lr(self):
18 | if not self._get_lr_called_within_step:
19 | warnings.warn("To get the last learning rate computed by the scheduler, "
20 | "please use `get_last_lr()`.", UserWarning)
21 |
22 | if self.last_epoch == 0:
23 | return self.base_lrs
24 | return [max(group['lr'] * self.gamma, self.min_lr)
25 | for group in self.optimizer.param_groups]
26 |
27 | def _get_closed_form_lr(self):
28 | return [max(base_lr * self.gamma ** self.last_epoch, self.min_lr)
29 | for base_lr in self.base_lrs]
30 |
31 |
32 | def repeat_data(data: Data, num_repeat) -> Batch:
33 | datas = [copy.deepcopy(data) for i in range(num_repeat)]
34 | return Batch.from_data_list(datas)
35 |
36 | def repeat_batch(batch: Batch, num_repeat) -> Batch:
37 | datas = batch.to_data_list()
38 | new_data = []
39 | for i in range(num_repeat):
40 | new_data = new_data + copy.deepcopy(datas)
41 | return Batch.from_data_list(new_data)
42 |
43 |
44 | def get_optimizer(cfg, model):
45 | if cfg.type == 'adam':
46 | return torch.optim.Adam(
47 | model.parameters(),
48 | lr=cfg.lr,
49 | weight_decay=cfg.weight_decay,
50 | betas=(cfg.beta1, cfg.beta2, )
51 | )
52 | else:
53 | raise NotImplementedError('Optimizer not supported: %s' % cfg.type)
54 |
55 |
56 | def get_scheduler(cfg, optimizer):
57 | if cfg.type == 'plateau':
58 | return torch.optim.lr_scheduler.ReduceLROnPlateau(
59 | optimizer,
60 | factor=cfg.factor,
61 | patience=cfg.patience,
62 | )
63 | elif cfg.type == 'expmin':
64 | return ExponentialLR_with_minLr(
65 | optimizer,
66 | gamma=cfg.factor,
67 | min_lr=cfg.min_lr,
68 | )
69 | elif cfg.type == 'expmin_milestone':
70 | gamma = np.exp(np.log(cfg.factor) / cfg.milestone)
71 | return ExponentialLR_with_minLr(
72 | optimizer,
73 | gamma=gamma,
74 | min_lr=cfg.min_lr,
75 | )
76 | else:
77 | raise NotImplementedError('Scheduler not supported: %s' % cfg.type)
--------------------------------------------------------------------------------
/utils/evaluation/covmat.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pandas as pd
4 | import multiprocessing as mp
5 | from torch_geometric.data import Data
6 | from functools import partial
7 | from easydict import EasyDict
8 | from tqdm.auto import tqdm
9 | from rdkit import Chem
10 | from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule
11 |
12 | from ..chem import set_rdmol_positions, get_best_rmsd
13 |
14 |
15 | def get_rmsd_confusion_matrix(data: Data, useFF=False):
16 | data['pos_ref'] = data['pos_ref'].reshape(-1, data['rdmol'].GetNumAtoms(), 3)
17 | data['pos_gen'] = data['pos_gen'].reshape(-1, data['rdmol'].GetNumAtoms(), 3)
18 | num_gen = data['pos_gen'].shape[0]
19 | num_ref = data['pos_ref'].shape[0]
20 |
21 | # assert num_gen == data.num_pos_gen.item()
22 | # assert num_ref == data.num_pos_ref.item()
23 |
24 | rmsd_confusion_mat = -1 * np.ones([num_ref, num_gen],dtype=np.float)
25 |
26 | for i in range(num_gen):
27 | gen_mol = set_rdmol_positions(data['rdmol'], data['pos_gen'][i])
28 | if useFF:
29 | #print('Applying FF on generated molecules...')
30 | MMFFOptimizeMolecule(gen_mol)
31 | for j in range(num_ref):
32 | ref_mol = set_rdmol_positions(data['rdmol'], data['pos_ref'][j])
33 |
34 | rmsd_confusion_mat[j,i] = get_best_rmsd(gen_mol, ref_mol)
35 |
36 | return rmsd_confusion_mat
37 |
38 |
39 | def evaluate_conf(data: Data, useFF=False, threshold=0.5):
40 | rmsd_confusion_mat = get_rmsd_confusion_matrix(data, useFF=useFF)
41 | rmsd_ref_min = rmsd_confusion_mat.min(-1)
42 | #print('done one mol')
43 | #print(rmsd_ref_min)
44 | return (rmsd_ref_min<=threshold).mean(), rmsd_ref_min.mean()
45 |
46 |
47 | def print_covmat_results(results, print_fn=print):
48 | df = pd.DataFrame({
49 | 'COV-R_mean': np.mean(results.CoverageR, 0),
50 | 'COV-R_median': np.median(results.CoverageR, 0),
51 | 'COV-R_std': np.std(results.CoverageR, 0),
52 | 'COV-P_mean': np.mean(results.CoverageP, 0),
53 | 'COV-P_median': np.median(results.CoverageP, 0),
54 | 'COV-P_std': np.std(results.CoverageP, 0),
55 | }, index=results.thresholds)
56 | print_fn('\n' + str(df))
57 | print_fn('MAT-R_mean: %.4f | MAT-R_median: %.4f | MAT-R_std %.4f' % (
58 | np.mean(results.MatchingR), np.median(results.MatchingR), np.std(results.MatchingR)
59 | ))
60 | print_fn('MAT-P_mean: %.4f | MAT-P_median: %.4f | MAT-P_std %.4f' % (
61 | np.mean(results.MatchingP), np.median(results.MatchingP), np.std(results.MatchingP)
62 | ))
63 | return df
64 |
65 |
66 |
67 | class CovMatEvaluator(object):
68 |
69 | def __init__(self,
70 | num_workers=8,
71 | use_force_field=False,
72 | thresholds=np.arange(0.05, 3.05, 0.05),
73 | ratio=2,
74 | filter_disconnected=True,
75 | print_fn=print,
76 | ):
77 | super().__init__()
78 | self.num_workers = num_workers
79 | self.use_force_field = use_force_field
80 | self.thresholds = np.array(thresholds).flatten()
81 |
82 | self.ratio = ratio
83 | self.filter_disconnected = filter_disconnected
84 |
85 | self.pool = mp.Pool(num_workers)
86 | self.print_fn = print_fn
87 |
88 | def __call__(self, packed_data_list, start_idx=0):
89 | func = partial(get_rmsd_confusion_matrix, useFF=self.use_force_field)
90 |
91 | filtered_data_list = []
92 | for data in packed_data_list:
93 | if 'pos_gen' not in data or 'pos_ref' not in data: continue
94 | if self.filter_disconnected and ('.' in data['smiles']): continue
95 |
96 | data['pos_ref'] = data['pos_ref'].reshape(-1, data['rdmol'].GetNumAtoms(), 3)
97 | data['pos_gen'] = data['pos_gen'].reshape(-1, data['rdmol'].GetNumAtoms(), 3)
98 |
99 | num_gen = data['pos_ref'].shape[0] * self.ratio
100 | if data['pos_gen'].shape[0] < num_gen: continue
101 | data['pos_gen'] = data['pos_gen'][:num_gen]
102 |
103 | filtered_data_list.append(data)
104 |
105 | filtered_data_list = filtered_data_list[start_idx:]
106 | self.print_fn('Filtered: %d / %d' % (len(filtered_data_list), len(packed_data_list)))
107 |
108 | covr_scores = []
109 | matr_scores = []
110 | covp_scores = []
111 | matp_scores = []
112 | for confusion_mat in tqdm(self.pool.imap(func, filtered_data_list), total=len(filtered_data_list)):
113 | # confusion_mat: (num_ref, num_gen)
114 | rmsd_ref_min = confusion_mat.min(-1) # np (num_ref, )
115 | rmsd_gen_min = confusion_mat.min(0) # np (num_gen, )
116 |
117 | rmsd_cov_thres = rmsd_ref_min.reshape(-1, 1) <= self.thresholds.reshape(1, -1) # np (num_ref, num_thres)
118 | rmsd_jnk_thres = rmsd_gen_min.reshape(-1, 1) <= self.thresholds.reshape(1, -1) # np (num_gen, num_thres)
119 |
120 | matr_scores.append(rmsd_ref_min.mean())
121 | covr_scores.append(rmsd_cov_thres.mean(0, keepdims=True)) # np (1, num_thres)
122 | matp_scores.append(rmsd_gen_min.mean())
123 | covp_scores.append(rmsd_jnk_thres.mean(0, keepdims=True)) # np (1, num_thres)
124 |
125 | covr_scores = np.vstack(covr_scores) # np (num_mols, num_thres)
126 | matr_scores = np.array(matr_scores) # np (num_mols, )
127 | covp_scores = np.vstack(covp_scores) # np (num_mols, num_thres)
128 | matp_scores = np.array(matp_scores)
129 |
130 | results = EasyDict({
131 | 'CoverageR': covr_scores,
132 | 'MatchingR': matr_scores,
133 | 'thresholds': self.thresholds,
134 | 'CoverageP': covp_scores,
135 | 'MatchingP': matp_scores
136 | })
137 | # print_conformation_eval_results(results)
138 | return results
139 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import random
4 | import logging
5 | import torch
6 | import numpy as np
7 | from glob import glob
8 | from logging import Logger
9 | from tqdm.auto import tqdm
10 | from torch_geometric.data import Batch
11 |
12 |
13 | class BlackHole(object):
14 | def __setattr__(self, name, value):
15 | pass
16 | def __call__(self, *args, **kwargs):
17 | return self
18 | def __getattr__(self, name):
19 | return self
20 |
21 |
22 | def get_logger(name, log_dir=None, log_fn='log.txt'):
23 | logger = logging.getLogger(name)
24 | logger.setLevel(logging.DEBUG)
25 | formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s')
26 |
27 | stream_handler = logging.StreamHandler()
28 | stream_handler.setLevel(logging.DEBUG)
29 | stream_handler.setFormatter(formatter)
30 | logger.addHandler(stream_handler)
31 |
32 | if log_dir is not None:
33 | file_handler = logging.FileHandler(os.path.join(log_dir, log_fn))
34 | file_handler.setLevel(logging.DEBUG)
35 | file_handler.setFormatter(formatter)
36 | logger.addHandler(file_handler)
37 |
38 | return logger
39 |
40 |
41 | def get_new_log_dir(root='./logs', prefix='', tag=''):
42 | fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
43 | if prefix != '':
44 | fn = prefix + '_' + fn
45 | if tag != '':
46 | fn = fn + '_' + tag
47 | log_dir = os.path.join(root, fn)
48 | os.makedirs(log_dir)
49 | return log_dir
50 |
51 |
52 | def seed_all(seed):
53 | torch.manual_seed(seed)
54 | np.random.seed(seed)
55 | random.seed(seed)
56 |
57 |
58 | def inf_iterator(iterable):
59 | iterator = iterable.__iter__()
60 | while True:
61 | try:
62 | yield iterator.__next__()
63 | except StopIteration:
64 | iterator = iterable.__iter__()
65 |
66 |
67 | def log_hyperparams(writer, args):
68 | from torch.utils.tensorboard.summary import hparams
69 | vars_args = {k:v if isinstance(v, str) else repr(v) for k, v in vars(args).items()}
70 | exp, ssi, sei = hparams(vars_args, {})
71 | writer.file_writer.add_summary(exp)
72 | writer.file_writer.add_summary(ssi)
73 | writer.file_writer.add_summary(sei)
74 |
75 |
76 | def int_tuple(argstr):
77 | return tuple(map(int, argstr.split(',')))
78 |
79 |
80 | def str_tuple(argstr):
81 | return tuple(argstr.split(','))
82 |
83 |
84 | def repeat_data(data, num_repeat):
85 | datas = [data.clone() for i in range(num_repeat)]
86 | return Batch.from_data_list(datas)
87 |
88 |
89 | def repeat_batch(batch, num_repeat):
90 | datas = batch.to_data_list()
91 | new_data = []
92 | for i in range(num_repeat):
93 | new_data = new_data + datas.clone()
94 | return Batch.from_data_list(new_data)
95 |
96 |
97 | def get_checkpoint_path(folder, it=None):
98 | if it is not None:
99 | return os.path.join(folder, '%d.pt' % it), it
100 | all_iters = list(map(lambda x: int(os.path.basename(x[:-3])), glob(os.path.join(folder, '*.pt'))))
101 | all_iters.sort()
102 | return os.path.join(folder, '%d.pt' % all_iters[-1]), all_iters[-1]
103 |
--------------------------------------------------------------------------------
/utils/transforms.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from torch_geometric.data import Data
4 | from torch_geometric.transforms import Compose
5 | from torch_geometric.utils import to_dense_adj, dense_to_sparse
6 | from torch_sparse import coalesce
7 |
8 | from torch_geometric.utils import to_networkx
9 | import networkx as nx
10 | import numpy as np
11 |
12 |
13 | from .chem import BOND_TYPES, BOND_NAMES, get_atom_symbol
14 |
15 |
16 |
17 |
18 | class AddHigherOrderEdges(object):
19 |
20 | def __init__(self, order, num_types=len(BOND_TYPES)):
21 | super().__init__()
22 | self.order = order
23 | self.num_types = num_types
24 |
25 | def binarize(self, x):
26 | return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))
27 |
28 | def get_higher_order_adj_matrix(self, adj, order):
29 | """
30 | Args:
31 | adj: (N, N)
32 | type_mat: (N, N)
33 | """
34 | adj_mats = [torch.eye(adj.size(0), dtype=torch.long, device=adj.device), \
35 | self.binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device))]
36 |
37 | for i in range(2, order+1):
38 | adj_mats.append(self.binarize(adj_mats[i-1] @ adj_mats[1]))
39 | order_mat = torch.zeros_like(adj)
40 |
41 | for i in range(1, order+1):
42 | order_mat = order_mat + (adj_mats[i] - adj_mats[i-1]) * i
43 |
44 | return order_mat
45 |
46 | def __call__(self, data: Data):
47 | N = data.num_nodes
48 | adj = to_dense_adj(data.edge_index).squeeze(0)
49 | adj_order = self.get_higher_order_adj_matrix(adj, self.order) # (N, N)
50 |
51 | type_mat = to_dense_adj(data.edge_index, edge_attr=data.edge_type).squeeze(0) # (N, N)
52 | type_highorder = torch.where(adj_order > 1, self.num_types + adj_order - 1, torch.zeros_like(adj_order))
53 | assert (type_mat * type_highorder == 0).all()
54 | type_new = type_mat + type_highorder
55 |
56 | new_edge_index, new_edge_type = dense_to_sparse(type_new)
57 | _, edge_order = dense_to_sparse(adj_order)
58 |
59 | data.bond_edge_index = data.edge_index # Save original edges
60 | data.edge_index, data.edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data
61 | edge_index_1, data.edge_order = coalesce(new_edge_index, edge_order.long(), N, N) # modify data
62 | data.is_bond = (data.edge_type < self.num_types)
63 | assert (data.edge_index == edge_index_1).all()
64 |
65 | return data
66 |
67 | class AddEdgeLength(object):
68 |
69 | def __call__(self, data: Data):
70 |
71 | pos = data.pos
72 | row, col = data.edge_index
73 | d = (pos[row] - pos[col]).norm(dim=-1).unsqueeze(-1) # (num_edge, 1)
74 | data.edge_length = d
75 | return data
76 |
77 |
78 | # Add attribute placeholder for data object, so that we can use batch.to_data_list
79 | class AddPlaceHolder(object):
80 | def __call__(self, data: Data):
81 | data.pos_gen = -1. * torch.ones_like(data.pos)
82 | data.d_gen = -1. * torch.ones_like(data.edge_length)
83 | data.d_recover = -1. * torch.ones_like(data.edge_length)
84 | return data
85 |
86 |
87 | class AddEdgeName(object):
88 |
89 | def __init__(self, asymmetric=True):
90 | super().__init__()
91 | self.bonds = copy.deepcopy(BOND_NAMES)
92 | self.bonds[len(BOND_NAMES) + 1] = 'Angle'
93 | self.bonds[len(BOND_NAMES) + 2] = 'Dihedral'
94 | self.asymmetric = asymmetric
95 |
96 | def __call__(self, data:Data):
97 | data.edge_name = []
98 | for i in range(data.edge_index.size(1)):
99 | tail = data.edge_index[0, i]
100 | head = data.edge_index[1, i]
101 | if self.asymmetric and tail >= head:
102 | data.edge_name.append('')
103 | continue
104 | tail_name = get_atom_symbol(data.atom_type[tail].item())
105 | head_name = get_atom_symbol(data.atom_type[head].item())
106 | name = '%s_%s_%s_%d_%d' % (
107 | self.bonds[data.edge_type[i].item()] if data.edge_type[i].item() in self.bonds else 'E'+str(data.edge_type[i].item()),
108 | tail_name,
109 | head_name,
110 | tail,
111 | head,
112 | )
113 | if hasattr(data, 'edge_length'):
114 | name += '_%.3f' % (data.edge_length[i].item())
115 | data.edge_name.append(name)
116 | return data
117 |
118 |
119 | class AddAngleDihedral(object):
120 |
121 | def __init__(self):
122 | super().__init__()
123 |
124 | @staticmethod
125 | def iter_angle_triplet(bond_mat):
126 | n_atoms = bond_mat.size(0)
127 | for j in range(n_atoms):
128 | for k in range(n_atoms):
129 | for l in range(n_atoms):
130 | if bond_mat[j, k].item() == 0 or bond_mat[k, l].item() == 0: continue
131 | if (j == k) or (k == l) or (j >= l): continue
132 | yield(j, k, l)
133 |
134 | @staticmethod
135 | def iter_dihedral_quartet(bond_mat):
136 | n_atoms = bond_mat.size(0)
137 | for i in range(n_atoms):
138 | for j in range(n_atoms):
139 | if i >= j: continue
140 | if bond_mat[i,j].item() == 0:continue
141 | for k in range(n_atoms):
142 | for l in range(n_atoms):
143 | if (k in (i,j)) or (l in (i,j)): continue
144 | if bond_mat[k,i].item() == 0 or bond_mat[l,j].item() == 0: continue
145 | yield(k, i, j, l)
146 |
147 | def __call__(self, data:Data):
148 | N = data.num_nodes
149 | if 'is_bond' in data:
150 | bond_mat = to_dense_adj(data.edge_index, edge_attr=data.is_bond).long().squeeze(0) > 0
151 | else:
152 | bond_mat = to_dense_adj(data.edge_index, edge_attr=data.edge_type).long().squeeze(0) > 0
153 |
154 | # Note: if the name of attribute contains `index`, it will automatically
155 | # increases during batching.
156 | data.angle_index = torch.LongTensor(list(self.iter_angle_triplet(bond_mat))).t()
157 | data.dihedral_index = torch.LongTensor(list(self.iter_dihedral_quartet(bond_mat))).t()
158 |
159 | return data
160 |
161 |
162 | class CountNodesPerGraph(object):
163 |
164 | def __init__(self) -> None:
165 | super().__init__()
166 |
167 | def __call__(self, data):
168 | data.num_nodes_per_graph = torch.LongTensor([data.num_nodes])
169 | return data
170 |
171 | from torch import nn
172 |
173 | class SubgraphNoiseTransform(object):
174 | """ expectation state + k-step same-subgraph diffusion"""
175 | def __init__(self, config, tag='', ddpm=False, boltzmann_weight=False):
176 |
177 | self.config = config
178 | self.ddpm=ddpm # typical DDPM, each atom views as independent point
179 | self.tag= tag
180 |
181 | betas = get_beta_schedule(
182 | beta_schedule=config.beta_schedule,
183 | beta_start=config.beta_start,
184 | beta_end=config.beta_end,
185 | num_diffusion_timesteps=config.num_diffusion_timesteps,
186 | )
187 | betas = torch.from_numpy(betas).float()
188 | self.betas = nn.Parameter(betas, requires_grad=False)
189 | self.one_minu_beta_sqrt = (1-self.betas).sqrt()
190 | ## variances
191 | alphas = (1. - betas).cumprod(dim=0)
192 | self.alphas = nn.Parameter(alphas, requires_grad=False)
193 | check_alphas(self.alphas)
194 | self.num_timesteps = self.betas.size(0)
195 | self.same_mask_steps =self.config.get("same_mask_steps", 250) # same subgraph step k
196 |
197 |
198 |
199 | def __call__(self, data):
200 | # select conformer
201 | # Four elements for DDPM: original_data(pos), gaussian_noise(pos_noise), beta(sigma), time_step
202 | # Sample noise levels
203 | atom_type=data.atom_type
204 | pos=data.pos
205 | bond_index=data.edge_index
206 | bond_type=data.edge_type
207 | num_nodes_per_graph=data.num_nodes_per_graph
208 | mask_subgraph = data.mask_subgraph
209 | time_step = torch.randint(
210 | 1, self.num_timesteps, size=(1,), device=pos.device)
211 | # time_step =0 is the beta_1 in Eq
212 |
213 | data.time_step = time_step
214 | beta = self.betas.index_select(-1, time_step)
215 | data.beta = beta
216 |
217 | if self.ddpm:
218 | alpha= self.alphas.index_select(-1, time_step)
219 | alpha = alpha.expand(pos.shape[0], 1)
220 | data.alpha = alpha
221 | data.last_select=torch.ones_like(alpha)
222 | data.noise_scale = 1- alpha
223 | del data.mask_subgraph
224 | return data
225 |
226 |
227 |
228 | last_select, alpha, noise_scale = self._get_alpha_nodes_expect_same_mask(mask_subgraph, time_step, k =self.same_mask_steps)
229 |
230 | data.alpha = alpha
231 | data.last_select=last_select # # for predict last selected subgraph
232 | data.noise_scale = noise_scale
233 | del data.mask_subgraph
234 | return data
235 |
236 |
237 | def _get_alpha_nodes_expect_same_mask(self, mask_subgraph, time_step, p=0.5, k=250):
238 | """ expectation stata + k-same subgraph(mask) step diffusion """
239 | expect_step = time_step.div(k,rounding_mode='floor') # m := time_step//k
240 | mask_step = time_step % k
241 |
242 | if mask_step==0:
243 | expect_step-=1
244 | mask_step=k
245 | selected_index = torch.randint(low=0, high=mask_subgraph.shape[-1], size=(1,))
246 | selected_node =mask_subgraph.index_select(-1, selected_index) # mask_subgraph[:,[1,2,5]]
247 | if expect_step==0:
248 | selected_nodes = selected_node.repeat(1, mask_step+1)
249 | selected_nodes[:,[i for i in range(0, min(3,time_step+1))]] = True
250 | bern_beta_mask_t = self.betas[0:time_step+1]*selected_nodes
251 |
252 | alpha_t = (1. - bern_beta_mask_t).prod(dim=-1,keepdim=True)
253 |
254 | return selected_node, alpha_t, 1-alpha_t
255 |
256 | ## Phase I: compute t step mean state 0: km
257 | p = mask_subgraph.sum(-1).unsqueeze(-1)/mask_subgraph.size(-1) # node selection probability
258 | if self.tag.startswith('Recover_GeoDiff'):
259 | p = torch.ones_like(p)
260 | selected_node[:][:] = True
261 | if not hasattr(self, "prod_one_minu_beta_sqrt"):
262 | self.prod_one_minu_beta_sqrt =torch.zeros(self.num_timesteps//k)
263 | self.prod_one_minu_beta_sqrt[0] = self.alphas[k-1]
264 | # For every k step, we calcualte a expectation state
265 | for j in range(1, self.num_timesteps//k):
266 | self.prod_one_minu_beta_sqrt[j] = (1-self.betas)[j*k:(j+1)*k].prod(dim=-1) # \prod_{i=(j-1)k+1}^{kj}(1-\beta_i)
267 | self.prod_one_minu_beta_sqrt = self.prod_one_minu_beta_sqrt.sqrt()
268 |
269 | alpha_exp = (p*self.prod_one_minu_beta_sqrt[:expect_step] + 1-p)**2 # \alpha_j= (p\sqrt{\prod_{i=(j-1)k+1}^{kj}(1-\beta_i)} + 1-p)^2
270 | alpha_nodes= alpha_exp.cumprod(dim=-1) # \bar\alpha
271 | noise_scale = (alpha_nodes[:,expect_step-1]/alpha_nodes).mm( (1-self.prod_one_minu_beta_sqrt[:expect_step]**2).unsqueeze(-1) ) * p**2 # [ p\sqrt{\sum_{l=1}^{m} \frac{\bar\alpha_{m}}{\bar\alpha_{l}} (1-\prod_{i=(l-1)k+1}^{kl}(1-\beta_i))} ]^2
272 |
273 |
274 | ## Phase II: time step: km+1 -> t
275 | bern_beta_mask_t = self.betas[time_step-mask_step:time_step+1].repeat(selected_node.shape[0], 1)*selected_node
276 | alpha_t = (1. - bern_beta_mask_t).prod(dim=-1,keepdim=True)
277 | ## combine
278 | alpha_node_t = alpha_t * alpha_nodes.index_select(-1, expect_step-1)
279 | noise_scale_t = alpha_t * noise_scale + 1-alpha_t
280 | if (alpha_node_t==0).sum():
281 | print(alpha_node_t,time_step)
282 |
283 | return selected_node, alpha_node_t, noise_scale_t
284 |
285 |
286 | def __repr__(self) -> str:
287 | return (f'{self.__class__.__name__}(sigma_min={self.sigma_min}, '
288 | f'sigma_max={self.sigma_max})')
289 |
290 |
291 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
292 | def sigmoid(x):
293 | return 1 / (np.exp(-x) + 1)
294 |
295 | if beta_schedule == "quad":
296 | betas = (
297 | np.linspace(
298 | beta_start ** 0.5,
299 | beta_end ** 0.5,
300 | num_diffusion_timesteps,
301 | dtype=np.float64,
302 | )
303 | ** 2
304 | )
305 | elif beta_schedule == "linear":
306 | betas = np.linspace(
307 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
308 | )
309 | elif beta_schedule == "const":
310 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
311 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
312 | betas = 1.0 / np.linspace(
313 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
314 | )
315 | elif beta_schedule == "sigmoid":
316 | betas = np.linspace(-6, 6, num_diffusion_timesteps)
317 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
318 | else:
319 | raise NotImplementedError(beta_schedule)
320 | assert betas.shape == (num_diffusion_timesteps,)
321 | return betas
322 |
323 | def modify_conformer(pos, edge_index, mask_rotate, torsion_updates, as_numpy=False):
324 | if type(pos) != np.ndarray: pos = pos.cpu().numpy()
325 | for idx_edge, e in enumerate(edge_index.cpu().numpy()):
326 | if torsion_updates[idx_edge] == 0:
327 | continue
328 | u, v = e[0], e[1]
329 |
330 | # check if need to reverse the edge, v should be connected to the part that gets rotated
331 | assert not mask_rotate[idx_edge, u]
332 | assert mask_rotate[idx_edge, v]
333 |
334 | rot_vec = pos[u] - pos[v] # convention: positive rotation if pointing inwards. NOTE: DIFFERENT FROM THE PAPER!
335 | rot_vec = rot_vec * torsion_updates[idx_edge] / np.linalg.norm(rot_vec) # idx_edge!
336 | rot_mat = R.from_rotvec(rot_vec).as_matrix()
337 |
338 | pos[mask_rotate[idx_edge]] = (pos[mask_rotate[idx_edge]] - pos[v]) @ rot_mat.T + pos[v]
339 |
340 | if not as_numpy: pos = torch.from_numpy(pos.astype(np.float32))
341 | return pos
342 |
343 | def get_transformation_mask(pyg_data):
344 | ''' get subgraph distribution'''
345 | G = to_networkx(pyg_data, to_undirected=False)
346 | to_rotate = []
347 | edge_set=[]
348 | edges = pyg_data.edge_index.T.numpy()
349 | for i in range(0, edges.shape[0]):
350 | # assert edges[i, 0] == edges[i+1, 1]
351 | edge = list(edges[i])
352 | if edge[::-1] in edge_set: # remove replicated undircted edge
353 | continue
354 | edge_set.append(edge)
355 | G2 = G.to_undirected()
356 | G2.remove_edge(*edges[i])
357 | if not nx.is_connected(G2):
358 | l = list(sorted(nx.connected_components(G2), key=len))
359 | if len(l[0]) > 1:
360 | to_rotate.extend(l)
361 |
362 | if len(to_rotate)==0:
363 | to_rotate.append(list(G.nodes()))
364 |
365 | mask_rotate = np.zeros((len(G.nodes()), len(to_rotate)), dtype=bool)
366 | for i in range(len(to_rotate)):
367 | mask_rotate.T[i][list(to_rotate[i])] = True
368 | pyg_data.mask_subgraph = torch.tensor(mask_rotate)
369 | return pyg_data
370 |
371 | def get_alpah_nodes_schedule(pyg_data, config):
372 | mask_subgraph=pyg_data.mask_subgraph
373 | selected_index= torch.randint(low=0,high=mask_subgraph.shape[-1], size=(config.num_diffusion_timesteps,))
374 | selected_nodes=mask_subgraph.index_select(-1, selected_index) # mask_subgraph[:,[1,2,5]]
375 | # gurantee the noises have been added into every node
376 | selected_nodes[:,[i for i in range(0, min(3,config.num_diffusion_timesteps+1))]] = True
377 |
378 | betas = get_beta_schedule(
379 | beta_schedule=config.beta_schedule,
380 | beta_start=config.beta_start,
381 | beta_end=config.beta_end,
382 | num_diffusion_timesteps=config.num_diffusion_timesteps,
383 | )
384 | betas = torch.from_numpy(betas).float()
385 |
386 | betas_nodes = betas * selected_nodes
387 | alpha_nodes = (1. - betas_nodes).cumprod(dim=-1) # alpha_t for each node
388 | pyg_data.selected_nodes, pyg_data.alpha_nodes = selected_nodes, alpha_nodes
389 | return pyg_data
390 |
391 | def check_alphas(alphas):
392 | for n,a in enumerate(alphas):
393 | if a==0:
394 | print(f"Warning bar alpha become zero at {n}-th time_step");
395 | break
396 | print("The smallest alpha is ", a.item())
397 |
398 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import py3Dmol
2 | from rdkit import Chem
3 |
4 |
5 | def visualize_mol(mol, size=(300, 300), surface=False, opacity=0.5):
6 | """Draw molecule in 3D
7 |
8 | Args:
9 | ----
10 | mol: rdMol, molecule to show
11 | size: tuple(int, int), canvas size
12 | style: str, type of drawing molecule
13 | style can be 'line', 'stick', 'sphere', 'carton'
14 | surface, bool, display SAS
15 | opacity, float, opacity of surface, range 0.0-1.0
16 | Return:
17 | ----
18 | viewer: py3Dmol.view, a class for constructing embedded 3Dmol.js views in ipython notebooks.
19 | """
20 | # assert style in ('line', 'stick', 'sphere', 'carton')
21 | mblock = Chem.MolToMolBlock(mol)
22 | viewer = py3Dmol.view(width=size[0], height=size[1])
23 | viewer.addModel(mblock, 'mol')
24 | viewer.setStyle({'stick':{}, 'sphere':{'radius':0.35}})
25 | if surface:
26 | viewer.addSurface(py3Dmol.SAS, {'opacity': opacity})
27 | viewer.zoomTo()
28 | return viewer
29 |
--------------------------------------------------------------------------------