├── .gitignore ├── LICENSE ├── README.md ├── assets ├── CCCCC[C@@H](C)CO_traj.gif ├── Geodiff500-CCCCC[C@@H](C)CO_traj.gif ├── bbbp_sider.png ├── case_study.png ├── clintox_bace.png ├── mask_diff_framework.png ├── subgdiff_framework.jpg ├── subgdiff_framework.png └── toxcast.png ├── checkpoints ├── qm9_200steps │ ├── 2000000.pt │ ├── qm9_200steps.yml │ └── samples │ │ ├── log_eval_samples_all.txt │ │ ├── samples_all.pkl │ │ ├── samples_all_covmat.csv │ │ └── samples_all_covmat.pkl ├── qm9_500steps │ ├── 2000000.pt │ ├── qm9_500steps.yml │ └── samples │ │ ├── log_eval_samples_all.txt │ │ ├── samples_all.pkl │ │ └── samples_all_covmat.csv └── sample_2024_02_05__15_50_46_SubGDiff500 │ └── log.txt ├── configs ├── qm9_200steps.yml └── qm9_500steps.yml ├── datasets ├── chem.py ├── qm9.py └── rdmol2data.py ├── env.yaml ├── eval_covmat.py ├── eval_prop.py ├── finetune ├── config.py ├── datasets │ ├── __init__.py │ ├── datasets_QM9.py │ └── rdmol2data.py └── splitters.py ├── finetune_qm9.py ├── models ├── common.py ├── encoder │ ├── __init__.py │ ├── coarse.py │ ├── edge.py │ ├── gin.py │ └── schnet.py ├── epsnet │ ├── __init__.py │ ├── attention.py │ ├── diffusion.py │ ├── dualenc.py │ └── dualencoder.py └── geometry.py ├── test.py ├── test.sh ├── train.py ├── train.sh ├── train_dist.py ├── utils ├── chem.py ├── common.py ├── datasets.py ├── evaluation │ └── covmat.py ├── misc.py ├── transforms.py └── visualize.py └── visualization.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | logs 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 SubGDiff authors. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SUBGDIFF: A Subgraph Diffusion Model to Improve Molecular Representation Learning 2 | 3 | The official implementation of the NeurIPS 2024 paper [SubgDiff: A Subgraph Diffusion Model to Improve Molecular Representation Learning](https://arxiv.org/abs/2405.05665). 4 | 5 |
6 | framework 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | ## Environments 16 | 17 | 18 | ### Dependency 19 | ``` 20 | python=3.7 21 | 22 | pytorch 1.11.0 py3.7_cuda11.3_cudnn8.2.0_0 23 | torch-cluster 1.6.0 pypi_0 24 | torch-geometric 1.7.2 pypi_0 25 | torch-scatter 2.0.9 pypi_0 26 | torch-sparse 0.6.13 pypi_0 27 | 28 | rdkit 2023.3.2 pypi_0 29 | ``` 30 | other useful packages in my enviroment: see `env.txt` / `env.yaml` 31 | 32 | 33 | ## Dataset 34 | 35 | The dataset can directly download from [https://zenodo.org/records/10616999](https://zenodo.org/records/10616999) 36 | 37 | ### Offical Dataset 38 | 39 | 40 | The offical raw GEOM dataset is avaiable [[here]](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/JNGTDF). 41 | 42 | 43 | ### Prepare your own GEOM dataset from scratch (optional) 44 | 45 | You can also download origianl GEOM full dataset and prepare your own data split. A guide is available at previous work GeoDiff's [[github page]](https://github.com/MinkaiXu/GeoDiff). 46 | 47 | ## Training 48 | 49 | All hyper-parameters and training details are provided in config files (`./configs/*.yml`), and free feel to tune these parameters. 50 | 51 | You can train the model with the following commands: 52 | 53 | - QM9 dataset 54 | ```bash 55 | # Default settings 56 | 57 | sh train.sh 58 | 59 | ``` 60 | 61 | ## Generation 62 | 63 | 64 | 65 | 66 | 67 | You can generate conformations for entire or part of test sets by: 68 | 69 | `One can use the commands to generate new samples: 70 | ```bash 71 | sh test.sh 72 | ``` 73 | 74 | Here `start_idx` and `end_idx` indicate the range of the test set that we want to use. All hyper-parameters related to sampling can be set in `test.py` files. Following GeoDiff, the `start_idx=800` and `end_idx=100` in our experiments. 75 | 76 | Conformations of some drug-like molecules generated by SubGDiff can see `visulization.ipynb` 77 | 78 | ## Evaluation 79 | 80 | After generating conformations following the obove commands, the results of all benchmark tasks can be calculated based on the generated data. 81 | 82 | ### Conformation Generation 83 | 84 | The `COV` and `MAT` scores on the GEOM datasets can be calculated using the following commands: 85 | 86 | Evaluation on generated samples: 87 | ```bash 88 | python eval_covmat.py checkpoints/qm9_500steps/samples/samples_all.pkl 89 | python eval_covmat.py checkpoints/qm9_200steps/samples/samples_all.pkl 90 | ``` 91 | 92 | 93 | 94 | 95 | 96 | 97 | ## Visualizing molecules with PyMol 98 | 99 | - molecules visulization: 100 | 101 | Run `visualization.ipynb` 102 | 103 | - molecules trajectory generation (--save_traj) 104 | 105 | ```bash 106 | python test.py --ckpt checkpoints/qm9_500steps/2000000.pt --config checkpoints/qm9_500steps/qm9_500steps.yml --test_set './data/GEOM/QM9/test_data_1k.pkl' --start_idx 800 --end_idx 1000 --sampling_type same_mask_noisy --n_steps 500 --device cuda:1 --w_global 0.1 --clip 1000 --clip_local 20 --global_start_sigma 5 --tag SubGDiff500 --save_traj 107 | ``` 108 | 109 | Run `visualization.ipynb`. 110 | 111 | 112 | 113 | 114 | 115 | ## Subgraph diffusion process 116 | 117 | Please see the `Class SubgraphNoiseTransform` in `utils/transforms.py` 118 | 119 | 120 | ## Visualization 121 | ### Sampling trajectory of SubGDiff (ours) 122 | The video suggests that the denoising network will only denoise the atomic coordinates from a subgraph at each timestep during the sampling process. 123 | 124 | 500steps 125 | 126 | ### Sampling trajectory of GeoDiff 127 | 128 | The video indicates that the denoising network will denoise all atomic coordinates at each timestep during the sampling process. 129 | 130 | 500steps 131 | 132 | ### Visualization of final sampling results 133 | The following figures come from the final step of the sampling trajectory above (SubGDiff and GeoDiff). From the figures, we can see that our SubGDiff can generate a more similar conformation to the ground truth. 134 | 135 | 136 | 137 | ## Citation 138 | ``` 139 | @inproceedings{zhang2024subgdiff, 140 | title={SubgDiff: A Subgraph Diffusion Model to Improve Molecular Representation Learning}, 141 | author={Zhang, Jiying and Liu, Zijing and Wang, Yu, Feng, Bin and Li, Yu}, 142 | booktitle={Advances in Neural Information Processing Systems}, 143 | year={2024} 144 | } 145 | ``` 146 | 147 | 148 | ## Acknowledgement 149 | 150 | This repo is built upon the previous work Geodiff's [[codebase]](https://github.com/MinkaiXu/GeoDiff). 151 | -------------------------------------------------------------------------------- /assets/CCCCC[C@@H](C)CO_traj.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/CCCCC[C@@H](C)CO_traj.gif -------------------------------------------------------------------------------- /assets/Geodiff500-CCCCC[C@@H](C)CO_traj.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/Geodiff500-CCCCC[C@@H](C)CO_traj.gif -------------------------------------------------------------------------------- /assets/bbbp_sider.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/bbbp_sider.png -------------------------------------------------------------------------------- /assets/case_study.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/case_study.png -------------------------------------------------------------------------------- /assets/clintox_bace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/clintox_bace.png -------------------------------------------------------------------------------- /assets/mask_diff_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/mask_diff_framework.png -------------------------------------------------------------------------------- /assets/subgdiff_framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/subgdiff_framework.jpg -------------------------------------------------------------------------------- /assets/subgdiff_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/subgdiff_framework.png -------------------------------------------------------------------------------- /assets/toxcast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/assets/toxcast.png -------------------------------------------------------------------------------- /checkpoints/qm9_200steps/2000000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/checkpoints/qm9_200steps/2000000.pt -------------------------------------------------------------------------------- /checkpoints/qm9_200steps/qm9_200steps.yml: -------------------------------------------------------------------------------- 1 | model: 2 | type: subgraph_diffusion 3 | network: dualenc 4 | hidden_dim: 128 5 | num_convs: 6 6 | num_convs_local: 4 7 | cutoff: 10.0 8 | mlp_act: relu 9 | beta_schedule: sigmoid 10 | beta_start: 1.e-7 11 | beta_end: 5.e-2 12 | num_diffusion_timesteps: 200 13 | edge_order: 3 14 | edge_encoder: mlp 15 | smooth_conv: false 16 | same_mask_steps: 10 17 | mask_pred: MLP # 18 | 19 | train: 20 | seed: 2021 21 | batch_size: 64 22 | val_freq: 2000 23 | max_iters: 2000000 24 | max_grad_norm: 10000.00 25 | anneal_power: 2.0 26 | optimizer: 27 | type: adam 28 | lr: 1.e-3 29 | weight_decay: 0. 30 | beta1: 0.95 31 | beta2: 0.999 32 | scheduler: 33 | type: plateau 34 | factor: 0.6 35 | patience: 30 36 | 37 | dataset: 38 | train: ./data/GEOM/QM9/train_data_40k_subgraph.pkl 39 | val: ./data/GEOM/QM9/val_data_5k_subgraph.pkl 40 | test: ./data/GEOM/QM9/test_data_1k.pkl 41 | -------------------------------------------------------------------------------- /checkpoints/qm9_200steps/samples/log_eval_samples_all.txt: -------------------------------------------------------------------------------- 1 | [2024-02-04 16:07:52,500::eval::INFO] Loading results: checkpoints/qm9_200steps/samples/samples_all.pkl 2 | [2024-02-04 16:07:52,671::eval::INFO] Total: 200 3 | [2024-02-04 16:07:52,718::eval::INFO] Filtered: 196 / 200 4 | [2024-02-04 16:11:31,018::eval::INFO] 5 | COV-R_mean COV-R_median COV-R_std COV-P_mean COV-P_median COV-P_std 6 | 0.05 0.000101 0.000000 0.001164 0.000051 0.000000 0.000582 7 | 0.10 0.066095 0.020000 0.121742 0.024568 0.008969 0.047813 8 | 0.15 0.251479 0.171766 0.240313 0.095490 0.062149 0.108542 9 | 0.20 0.394886 0.338929 0.269503 0.161195 0.128679 0.144146 10 | 0.25 0.495578 0.445437 0.270431 0.213917 0.174318 0.162835 11 | 0.30 0.575141 0.536866 0.261140 0.259591 0.215421 0.175224 12 | 0.35 0.642270 0.614450 0.241695 0.299245 0.251838 0.183180 13 | 0.40 0.706840 0.700558 0.218976 0.340382 0.297155 0.190494 14 | 0.45 0.777926 0.811150 0.193104 0.397087 0.385711 0.202436 15 | 0.50 0.855279 0.889899 0.154255 0.477603 0.458864 0.213210 16 | 0.55 0.928285 0.963525 0.109881 0.579527 0.548333 0.220027 17 | 0.60 0.970038 1.000000 0.083130 0.675407 0.706145 0.219565 18 | 0.65 0.986554 1.000000 0.074638 0.755032 0.820000 0.209216 19 | 0.70 0.992985 1.000000 0.071559 0.815814 0.888153 0.196520 20 | 0.75 0.994636 1.000000 0.070967 0.857790 0.939757 0.186596 21 | 0.80 0.994974 1.000000 0.069723 0.888436 0.969891 0.180300 22 | 0.85 0.995051 1.000000 0.069115 0.908737 0.988604 0.175182 23 | 0.90 0.995116 1.000000 0.068201 0.919555 0.992726 0.170642 24 | 0.95 0.995225 1.000000 0.066679 0.925715 0.994135 0.167369 25 | 1.00 0.995378 1.000000 0.064548 0.930677 0.995702 0.163408 26 | 1.05 0.995508 1.000000 0.062721 0.935447 0.997041 0.158606 27 | 1.10 0.995617 1.000000 0.061199 0.939520 1.000000 0.152420 28 | 1.15 0.995726 1.000000 0.059676 0.946509 1.000000 0.141974 29 | 1.20 0.995770 1.000000 0.059067 0.956307 1.000000 0.128556 30 | 1.25 0.995857 1.000000 0.057849 0.963814 1.000000 0.115558 31 | 1.30 0.996054 1.000000 0.055109 0.971617 1.000000 0.100980 32 | 1.35 0.996206 1.000000 0.052978 0.977215 1.000000 0.090921 33 | 1.40 0.996293 1.000000 0.051760 0.980891 1.000000 0.085915 34 | 1.45 0.996359 1.000000 0.050847 0.983148 1.000000 0.083228 35 | 1.50 0.996381 1.000000 0.050542 0.984951 1.000000 0.081133 36 | 1.55 0.996555 1.000000 0.048106 0.986234 1.000000 0.078929 37 | 1.60 0.996708 1.000000 0.045975 0.987412 1.000000 0.075759 38 | 1.65 0.996860 1.000000 0.043844 0.988513 1.000000 0.073259 39 | 1.70 0.996991 1.000000 0.042017 0.989432 1.000000 0.070975 40 | 1.75 0.997035 1.000000 0.041408 0.990098 1.000000 0.069999 41 | 1.80 0.997209 1.000000 0.038972 0.990665 1.000000 0.069282 42 | 1.85 0.997362 1.000000 0.036841 0.990981 1.000000 0.068877 43 | 1.90 0.997449 1.000000 0.035623 0.991103 1.000000 0.068748 44 | 1.95 0.997493 1.000000 0.035014 0.991254 1.000000 0.068592 45 | 2.00 0.997623 1.000000 0.033187 0.991325 1.000000 0.068536 46 | 2.05 0.997820 1.000000 0.030447 0.991533 1.000000 0.068376 47 | 2.10 0.997929 1.000000 0.028925 0.991670 1.000000 0.068318 48 | 2.15 0.998038 1.000000 0.027402 0.991785 1.000000 0.068218 49 | 2.20 0.998147 1.000000 0.025880 0.991888 1.000000 0.068131 50 | 2.25 0.998278 1.000000 0.024053 0.992025 1.000000 0.068077 51 | 2.30 0.998452 1.000000 0.021617 0.992133 1.000000 0.068008 52 | 2.35 0.998517 1.000000 0.020704 0.992277 1.000000 0.067936 53 | 2.40 0.998626 1.000000 0.019182 0.992367 1.000000 0.067877 54 | 2.45 0.998757 1.000000 0.017355 0.992443 1.000000 0.067853 55 | 2.50 0.998823 1.000000 0.016441 0.992569 1.000000 0.067770 56 | 2.55 0.998910 1.000000 0.015224 0.992771 1.000000 0.067659 57 | 2.60 0.998953 1.000000 0.014615 0.992986 1.000000 0.067406 58 | 2.65 0.999019 1.000000 0.013701 0.993105 1.000000 0.067358 59 | 2.70 0.999041 1.000000 0.013397 0.993267 1.000000 0.067290 60 | 2.75 0.999128 1.000000 0.012179 0.993411 1.000000 0.067245 61 | 2.80 0.999215 1.000000 0.010961 0.993485 1.000000 0.067239 62 | 2.85 0.999280 1.000000 0.010048 0.993523 1.000000 0.067234 63 | 2.90 0.999389 1.000000 0.008525 0.993642 1.000000 0.067200 64 | 2.95 0.999455 1.000000 0.007612 0.993802 1.000000 0.067145 65 | 3.00 0.999520 1.000000 0.006698 0.993957 1.000000 0.066967 66 | [2024-02-04 16:11:31,019::eval::INFO] MAT-R_mean: 0.2994 | MAT-R_median: 0.3033 | MAT-R_std 0.1516 67 | [2024-02-04 16:11:31,019::eval::INFO] MAT-P_mean: 0.6971 | MAT-P_median: 0.5118 | MAT-P_std 2.5074 68 | -------------------------------------------------------------------------------- /checkpoints/qm9_200steps/samples/samples_all.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/checkpoints/qm9_200steps/samples/samples_all.pkl -------------------------------------------------------------------------------- /checkpoints/qm9_200steps/samples/samples_all_covmat.csv: -------------------------------------------------------------------------------- 1 | ,COV-R_mean,COV-R_median,COV-R_std,COV-P_mean,COV-P_median,COV-P_std 2 | 0.05,0.00010115094409681809,0.0,0.0011640228873020942,5.0575472048409045e-05,0.0,0.0005820114436510471 3 | 0.1,0.06609495767865171,0.02,0.12174158801889531,0.024568470531056104,0.00896879021879022,0.04781286790066844 4 | 0.15000000000000002,0.25147861556229295,0.17176573426573427,0.24031302256830991,0.09549018442721859,0.06214887640449438,0.10854193223018146 5 | 0.2,0.39488560672391226,0.33892896781354054,0.26950262874024494,0.1611949990040838,0.12867924528301888,0.14414591443505212 6 | 0.25,0.495578308485712,0.4454365079365079,0.27043135026333043,0.2139171955421679,0.17431772237196763,0.16283464245645138 7 | 0.3,0.5751414020096989,0.5368659420289855,0.261139857784902,0.25959096564817863,0.21542104276069018,0.17522366323829522 8 | 0.35000000000000003,0.6422702023841537,0.6144499762920816,0.24169531183888102,0.2992450830655438,0.25183823529411764,0.18317981223346985 9 | 0.4,0.7068397136032962,0.7005577005577006,0.21897640291850254,0.3403816817546149,0.2971547706647044,0.1904941003097485 10 | 0.45,0.7779260397839731,0.8111499611499611,0.19310368616606913,0.3970872122656049,0.3857105538140021,0.20243626841896464 11 | 0.5,0.8552794707902401,0.8898989898989899,0.15425467404097734,0.4776029383577201,0.4588638195004029,0.21321046987652306 12 | 0.55,0.9282851301968171,0.9635254988913525,0.10988135747472798,0.5795268208031766,0.5483333333333333,0.22002722750626835 13 | 0.6000000000000001,0.9700379104922083,1.0,0.08312952507674312,0.6754073596234932,0.7061450381679388,0.21956497708697012 14 | 0.6500000000000001,0.9865537800384889,1.0,0.07463815907443776,0.7550323049100166,0.82,0.20921620553741033 15 | 0.7000000000000001,0.9929853737304797,1.0,0.07155868754577098,0.8158136099120913,0.8881530537159676,0.19651985338361522 16 | 0.7500000000000001,0.9946358770719277,1.0,0.07096688870463422,0.8577904550128883,0.9397568499364907,0.18659572585478243 17 | 0.8,0.994974060757471,1.0,0.06972292755097403,0.8884358294495973,0.96989121989122,0.18030012562785835 18 | 0.8500000000000001,0.9950505843362987,1.0,0.06911482840431585,0.9087371749305106,0.9886043165467626,0.17518222913895623 19 | 0.9000000000000001,0.9951159951159951,1.0,0.06820141657518429,0.9195547405461229,0.9927264071096065,0.17064175753164304 20 | 0.9500000000000001,0.9952250130821559,1.0,0.06667906352663068,0.9257145168578484,0.9941348469212246,0.1673688687041287 21 | 1.0,0.9953776382347811,1.0,0.06454776925865623,0.9306773625960487,0.9957017414403778,0.1634075896624802 22 | 1.05,0.9955084597941742,1.0,0.06272094560039261,0.935446725107578,0.9970410057078825,0.1586063789618342 23 | 1.1,0.9956174777603348,1.0,0.06119859255183924,0.9395204124683095,1.0,0.152419504372303 24 | 1.1500000000000001,0.9957264957264957,1.0,0.05967623950328628,0.9465093686998486,1.0,0.14197445548432253 25 | 1.2000000000000002,0.9957701029129601,1.0,0.059067298283864894,0.9563072992756991,1.0,0.12855642503738074 26 | 1.2500000000000002,0.9958573172858887,1.0,0.05784941584502237,0.9638137301251927,1.0,0.1155583754771393 27 | 1.3,0.9960535496249782,1.0,0.055109180357626616,0.9716174237976887,1.0,0.1009795540939641 28 | 1.35,0.9962061747776034,1.0,0.05297788608965196,0.977215272414187,1.0,0.09092131674120832 29 | 1.4000000000000001,0.996293389150532,1.0,0.051760003650809246,0.9808910969497924,1.0,0.08591510637395192 30 | 1.4500000000000002,0.9963587999302285,1.0,0.05084659182167728,0.983148122013181,1.0,0.08322779138667874 31 | 1.5000000000000002,0.9963806035234607,1.0,0.050542121211966815,0.9849507843993025,1.0,0.0811330566896536 32 | 1.55,0.996555032269318,1.0,0.0481063563342816,0.9862340328117838,1.0,0.07892924040163384 33 | 1.6,0.9967076574219432,1.0,0.04597506206630725,0.9874119019936729,1.0,0.07575912823805947 34 | 1.6500000000000001,0.9968602825745683,1.0,0.04384376779833257,0.9885129265244994,1.0,0.07325930227652488 35 | 1.7000000000000002,0.9969911041339613,1.0,0.04201694414006876,0.9894323691498574,1.0,0.07097455659098532 36 | 1.7500000000000002,0.9970347113204255,1.0,0.04140800292064751,0.9900979348481429,1.0,0.06999894670956665 37 | 1.8,0.9972091400662829,1.0,0.03897223804296238,0.9906648938221057,1.0,0.06928199405144711 38 | 1.85,0.9973617652189081,1.0,0.03684094377498774,0.9909808483778288,1.0,0.06887669335797462 39 | 1.9000000000000001,0.9974489795918368,1.0,0.03562306133614526,0.9911032797610435,1.0,0.06874801464285607 40 | 1.9500000000000002,0.997492586778301,1.0,0.03501412011672394,0.9912535044606053,1.0,0.0685915841352487 41 | 2.0,0.9976234083376941,1.0,0.03318729645846007,0.9913251162039075,1.0,0.06853646887305148 42 | 2.05,0.9978196406767835,1.0,0.030447060971064418,0.9915327256362755,1.0,0.06837629507239903 43 | 2.1,0.9979286586429443,1.0,0.028924707922511186,0.9916700226830903,1.0,0.06831802519418569 44 | 2.15,0.9980376766091051,1.0,0.027402354873957837,0.9917853570153623,1.0,0.06821822152178604 45 | 2.1999999999999997,0.998146694575266,1.0,0.025880001825404626,0.9918878868717073,1.0,0.06813082872340301 46 | 2.25,0.998277516134659,1.0,0.0240531781671408,0.992025080354741,1.0,0.06807688759329783 47 | 2.3,0.9984519448805163,1.0,0.021617413289455645,0.9921326737180102,1.0,0.06800773717756628 48 | 2.35,0.9985173556602128,1.0,0.020704001460323754,0.9922768728993338,1.0,0.0679355595484163 49 | 2.4,0.9986263736263736,1.0,0.019181648411770544,0.9923669445512545,1.0,0.06787658534090224 50 | 2.45,0.9987571951857667,1.0,0.017354824753506623,0.9924427627055161,1.0,0.0678528920931322 51 | 2.5,0.9988226059654631,1.0,0.016441412924374683,0.9925691446767388,1.0,0.06776985632542669 52 | 2.55,0.9989098203383917,1.0,0.015223530485532205,0.9927706967346303,1.0,0.0676585945729409 53 | 2.6,0.998953427524856,1.0,0.0146145892661109,0.992985656238477,1.0,0.067406330877912 54 | 2.65,0.9990188383045526,1.0,0.01370117743697892,0.9931048123061254,1.0,0.06735829846344064 55 | 2.7,0.9990406418977847,1.0,0.013396706827268361,0.9932666500729165,1.0,0.06729041498263391 56 | 2.75,0.9991278562707133,1.0,0.012178824388425704,0.9934106002392857,1.0,0.0672454049817615 57 | 2.8,0.999215070643642,1.0,0.010960941949583141,0.9934850452644369,1.0,0.0672394728913095 58 | 2.85,0.9992804814233386,1.0,0.010047530120451266,0.9935231911849659,1.0,0.06723375971009532 59 | 2.9,0.9993894993894993,1.0,0.008525177071898033,0.9936415499477965,1.0,0.06720012074993871 60 | 2.95,0.999454910169196,1.0,0.007611765242766104,0.9938015207107023,1.0,0.06714470011672279 61 | 3.0,0.9995203209488924,1.0,0.006698353413634183,0.9939568186459425,1.0,0.06696721073132718 62 | -------------------------------------------------------------------------------- /checkpoints/qm9_200steps/samples/samples_all_covmat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/checkpoints/qm9_200steps/samples/samples_all_covmat.pkl -------------------------------------------------------------------------------- /checkpoints/qm9_500steps/2000000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/checkpoints/qm9_500steps/2000000.pt -------------------------------------------------------------------------------- /checkpoints/qm9_500steps/qm9_500steps.yml: -------------------------------------------------------------------------------- 1 | model: 2 | type: subgraph_diffusion 3 | network: dualenc 4 | hidden_dim: 128 5 | num_convs: 6 6 | num_convs_local: 4 7 | cutoff: 10.0 8 | mlp_act: relu 9 | beta_schedule: sigmoid 10 | beta_start: 1.e-7 11 | beta_end: 2.e-2 12 | num_diffusion_timesteps: 500 13 | edge_order: 3 14 | edge_encoder: mlp 15 | smooth_conv: false 16 | same_mask_steps: 25 17 | mask_pred: MLP 18 | 19 | train: 20 | seed: 2021 21 | batch_size: 64 22 | val_freq: 2000 23 | max_iters: 2000000 24 | max_grad_norm: 10000.00 25 | anneal_power: 2.0 26 | optimizer: 27 | type: adam 28 | lr: 1.e-3 29 | weight_decay: 0. 30 | beta1: 0.95 31 | beta2: 0.999 32 | scheduler: 33 | type: plateau 34 | factor: 0.6 35 | patience: 30 36 | 37 | dataset: 38 | train: ./data/GEOM/QM9/train_data_40k_subgraph.pkl 39 | val: ./data/GEOM/QM9/val_data_5k_subgraph.pkl 40 | test: ./data/GEOM/QM9/test_data_1k.pkl 41 | -------------------------------------------------------------------------------- /checkpoints/qm9_500steps/samples/log_eval_samples_all.txt: -------------------------------------------------------------------------------- 1 | [2024-02-04 15:46:59,065::eval::INFO] Loading results: checkpoints/qm9_500steps/samples/samples_all.pkl 2 | [2024-02-04 15:46:59,370::eval::INFO] Total: 200 3 | [2024-02-04 15:46:59,421::eval::INFO] Filtered: 196 / 200 4 | [2024-02-04 15:50:38,265::eval::INFO] 5 | COV-R_mean COV-R_median COV-R_std COV-P_mean COV-P_median COV-P_std 6 | 0.05 0.032712 0.009060 0.071551 0.014927 0.004525 0.034089 7 | 0.10 0.271354 0.196429 0.234439 0.119802 0.083333 0.134013 8 | 0.15 0.428807 0.352484 0.277974 0.190273 0.155627 0.165403 9 | 0.20 0.528011 0.481637 0.279454 0.237885 0.209299 0.178689 10 | 0.25 0.602192 0.586570 0.261421 0.279192 0.240906 0.186384 11 | 0.30 0.664374 0.676175 0.240212 0.313114 0.267225 0.189866 12 | 0.35 0.720297 0.726223 0.219854 0.343309 0.288202 0.193520 13 | 0.40 0.774024 0.784423 0.196447 0.376003 0.347277 0.198925 14 | 0.45 0.833441 0.863962 0.167611 0.426181 0.407452 0.209374 15 | 0.50 0.897786 0.941707 0.132740 0.500399 0.483076 0.218857 16 | 0.55 0.953916 0.990521 0.094453 0.607295 0.593957 0.227172 17 | 0.60 0.981479 1.000000 0.072338 0.706846 0.753186 0.218912 18 | 0.65 0.992245 1.000000 0.063863 0.780884 0.854572 0.206989 19 | 0.70 0.995319 1.000000 0.059396 0.834365 0.926402 0.193845 20 | 0.75 0.996082 1.000000 0.053294 0.870825 0.950993 0.184226 21 | 0.80 0.996671 1.000000 0.045077 0.898107 0.982091 0.177346 22 | 0.85 0.997151 1.000000 0.038382 0.913005 0.991560 0.173059 23 | 0.90 0.997950 1.000000 0.028620 0.921745 0.994490 0.168994 24 | 0.95 0.998452 1.000000 0.021617 0.927871 0.996598 0.162295 25 | 1.00 0.998779 1.000000 0.017050 0.933256 1.000000 0.154835 26 | 1.05 0.999062 1.000000 0.013092 0.937535 1.000000 0.148893 27 | 1.10 0.999280 1.000000 0.010048 0.939898 1.000000 0.145510 28 | 1.15 0.999499 1.000000 0.007003 0.948699 1.000000 0.129947 29 | 1.20 0.999564 1.000000 0.006089 0.958744 1.000000 0.113603 30 | 1.25 0.999564 1.000000 0.006089 0.967544 1.000000 0.094783 31 | 1.30 0.999738 1.000000 0.003654 0.975255 1.000000 0.075496 32 | 1.35 0.999782 1.000000 0.003045 0.980040 1.000000 0.064290 33 | 1.40 0.999847 1.000000 0.002131 0.983131 1.000000 0.058868 34 | 1.45 0.999935 1.000000 0.000913 0.984870 1.000000 0.054976 35 | 1.50 0.999935 1.000000 0.000913 0.986283 1.000000 0.052289 36 | 1.55 0.999935 1.000000 0.000913 0.987467 1.000000 0.048353 37 | 1.60 0.999956 1.000000 0.000609 0.988913 1.000000 0.041454 38 | 1.65 0.999956 1.000000 0.000609 0.990281 1.000000 0.036710 39 | 1.70 1.000000 1.000000 0.000000 0.991068 1.000000 0.034609 40 | 1.75 1.000000 1.000000 0.000000 0.991494 1.000000 0.033564 41 | 1.80 1.000000 1.000000 0.000000 0.991963 1.000000 0.032702 42 | 1.85 1.000000 1.000000 0.000000 0.992254 1.000000 0.032477 43 | 1.90 1.000000 1.000000 0.000000 0.992391 1.000000 0.032234 44 | 1.95 1.000000 1.000000 0.000000 0.992674 1.000000 0.031985 45 | 2.00 1.000000 1.000000 0.000000 0.992753 1.000000 0.031887 46 | 2.05 1.000000 1.000000 0.000000 0.992848 1.000000 0.031793 47 | 2.10 1.000000 1.000000 0.000000 0.992954 1.000000 0.031578 48 | 2.15 1.000000 1.000000 0.000000 0.992991 1.000000 0.031521 49 | 2.20 1.000000 1.000000 0.000000 0.993108 1.000000 0.031292 50 | 2.25 1.000000 1.000000 0.000000 0.993329 1.000000 0.031037 51 | 2.30 1.000000 1.000000 0.000000 0.993635 1.000000 0.030488 52 | 2.35 1.000000 1.000000 0.000000 0.993873 1.000000 0.030190 53 | 2.40 1.000000 1.000000 0.000000 0.994005 1.000000 0.030056 54 | 2.45 1.000000 1.000000 0.000000 0.994271 1.000000 0.029820 55 | 2.50 1.000000 1.000000 0.000000 0.994462 1.000000 0.029347 56 | 2.55 1.000000 1.000000 0.000000 0.994680 1.000000 0.028992 57 | 2.60 1.000000 1.000000 0.000000 0.994878 1.000000 0.028735 58 | 2.65 1.000000 1.000000 0.000000 0.995038 1.000000 0.028563 59 | 2.70 1.000000 1.000000 0.000000 0.995204 1.000000 0.028403 60 | 2.75 1.000000 1.000000 0.000000 0.995414 1.000000 0.028231 61 | 2.80 1.000000 1.000000 0.000000 0.995573 1.000000 0.028092 62 | 2.85 1.000000 1.000000 0.000000 0.995729 1.000000 0.027932 63 | 2.90 1.000000 1.000000 0.000000 0.995884 1.000000 0.027833 64 | 2.95 1.000000 1.000000 0.000000 0.996047 1.000000 0.027739 65 | 3.00 1.000000 1.000000 0.000000 0.996186 1.000000 0.027643 66 | [2024-02-04 15:50:38,265::eval::INFO] MAT-R_mean: 0.2417 | MAT-R_median: 0.2449 | MAT-R_std 0.1080 67 | [2024-02-04 15:50:38,265::eval::INFO] MAT-P_mean: 0.5571 | MAT-P_median: 0.4921 | MAT-P_std 0.9973 68 | -------------------------------------------------------------------------------- /checkpoints/qm9_500steps/samples/samples_all.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/checkpoints/qm9_500steps/samples/samples_all.pkl -------------------------------------------------------------------------------- /checkpoints/qm9_500steps/samples/samples_all_covmat.csv: -------------------------------------------------------------------------------- 1 | ,COV-R_mean,COV-R_median,COV-R_std,COV-P_mean,COV-P_median,COV-P_std 2 | 0.05,0.032711709588828866,0.009060127231684794,0.07155121746047126,0.01492675832356874,0.004524979524979525,0.034089220735110545 3 | 0.1,0.27135416895301273,0.19642857142857142,0.23443867499848575,0.11980226736434715,0.08333333333333333,0.13401287470087428 4 | 0.15000000000000002,0.4288073021012164,0.35248447204968947,0.27797417934518226,0.190273354171487,0.15562678062678065,0.1654033279682946 5 | 0.2,0.5280113926962319,0.48163746630727766,0.2794543928792841,0.23788510027403467,0.2092993951612903,0.17868862397169458 6 | 0.25,0.602192453468659,0.5865704772475028,0.26142063747855177,0.27919161044555635,0.24090608465608465,0.18638428676495009 7 | 0.3,0.6643740008489631,0.6761750202858747,0.24021151185365944,0.31311439089352006,0.2672250101722501,0.18986576768280886 8 | 0.35000000000000003,0.7202965146420955,0.726222968717195,0.21985373768490118,0.3433087085641054,0.28820181362145475,0.19352008655771624 9 | 0.4,0.7740240266398146,0.7844234079173837,0.19644685529177316,0.3760026786921755,0.34727658186562294,0.198924632306208 10 | 0.45,0.8334414964196138,0.8639622641509435,0.16761055074163497,0.42618096533426403,0.4074519230769231,0.2093742760990014 11 | 0.5,0.8977858561714935,0.9417073688681248,0.13274028496424486,0.5003986958275349,0.48307560137457045,0.21885739334926382 12 | 0.55,0.9539161546591544,0.9905211141060197,0.0944533687904892,0.6072945027695789,0.5939568231957137,0.2271723329068604 13 | 0.6000000000000001,0.9814789414566789,1.0,0.07233786517678083,0.7068456567617836,0.7531863051355742,0.2189121009131429 14 | 0.6500000000000001,0.9922447513430593,1.0,0.06386335173301483,0.7808842672110684,0.8545716727195044,0.20698890270111314 15 | 0.7000000000000001,0.9953187829487754,1.0,0.059396109887135985,0.834365483392989,0.9264015151515151,0.19384478512427444 16 | 0.7500000000000001,0.9960823303680446,1.0,0.053294101359138465,0.870824870090479,0.9509934248141796,0.18422579985422147 17 | 0.8,0.9966710273853132,1.0,0.04507686955175547,0.8981065306541568,0.9820914614646626,0.17734639493916543 18 | 0.8500000000000001,0.9971507064364208,1.0,0.038382447594486065,0.9130047864160271,0.9915598290598291,0.17305904858373555 19 | 0.9000000000000001,0.9979504622361764,1.0,0.02862023731280038,0.9217454042839928,0.9944899817850638,0.1689944096995206 20 | 0.9500000000000001,0.9984519448805163,1.0,0.021617413289455645,0.9278712160469375,0.9965984820436875,0.16229528840360846 21 | 1.0,0.9987789987789989,1.0,0.017050354143795975,0.9332559880655352,1.0,0.1548352915716108 22 | 1.05,0.999062445491017,1.0,0.013092236217557642,0.9375348746068178,1.0,0.14889251457048863 23 | 1.1,0.9992804814233386,1.0,0.010047530120451266,0.939897913659391,1.0,0.14550950610767346 24 | 1.1500000000000001,0.9994985173556603,1.0,0.007002824023344788,0.9486988951016855,1.0,0.12994718273100558 25 | 1.2000000000000002,0.9995639281353567,1.0,0.006089412194212856,0.95874378154494,1.0,0.11360274839243543 26 | 1.2500000000000002,0.9995639281353567,1.0,0.006089412194212856,0.967544428669111,1.0,0.0947830177806354 27 | 1.3,0.999738356881214,1.0,0.003653647316527728,0.9752545458407527,1.0,0.07549569659693457 28 | 1.35,0.9997819640676784,1.0,0.003044706097106426,0.9800399324238473,1.0,0.06429047461466812 29 | 1.4000000000000001,0.9998473748473747,1.0,0.002131294267974505,0.9831307753114109,1.0,0.058868116271158594 30 | 1.4500000000000002,0.9999345892203035,1.0,0.000913411829131929,0.984870393452086,1.0,0.054975545894377496 31 | 1.5000000000000002,0.9999345892203035,1.0,0.000913411829131929,0.9862833473340975,1.0,0.05228913274126952 32 | 1.55,0.9999345892203035,1.0,0.000913411829131929,0.9874672713227176,1.0,0.048352809168404245 33 | 1.6,0.9999563928135358,1.0,0.0006089412194212831,0.9889134634166011,1.0,0.04145438560195448 34 | 1.6500000000000001,0.9999563928135358,1.0,0.0006089412194212831,0.9902806046839789,1.0,0.0367100097724998 35 | 1.7000000000000002,1.0,1.0,0.0,0.9910679746533974,1.0,0.034609100019943016 36 | 1.7500000000000002,1.0,1.0,0.0,0.9914944487118257,1.0,0.03356422106813105 37 | 1.8,1.0,1.0,0.0,0.9919634750400658,1.0,0.03270181787106268 38 | 1.85,1.0,1.0,0.0,0.9922539003393859,1.0,0.03247688094681033 39 | 1.9000000000000001,1.0,1.0,0.0,0.9923908143943294,1.0,0.0322338432847492 40 | 1.9500000000000002,1.0,1.0,0.0,0.9926738263858396,1.0,0.031985447486153375 41 | 2.0,1.0,1.0,0.0,0.9927525363163924,1.0,0.03188675422226196 42 | 2.05,1.0,1.0,0.0,0.9928475103713997,1.0,0.031793055571801544 43 | 2.1,1.0,1.0,0.0,0.9929542137292435,1.0,0.031577512164597095 44 | 2.15,1.0,1.0,0.0,0.9929914860251126,1.0,0.03152091976648795 45 | 2.1999999999999997,1.0,1.0,0.0,0.9931078181515925,1.0,0.03129190632244978 46 | 2.25,1.0,1.0,0.0,0.9933287063523303,1.0,0.03103680541201891 47 | 2.3,1.0,1.0,0.0,0.9936352922310345,1.0,0.030487584066969144 48 | 2.35,1.0,1.0,0.0,0.9938734261288,1.0,0.030189537670333205 49 | 2.4,1.0,1.0,0.0,0.9940045140326391,1.0,0.03005552702336885 50 | 2.45,1.0,1.0,0.0,0.99427139763627,1.0,0.029820178205711857 51 | 2.5,1.0,1.0,0.0,0.9944622124541159,1.0,0.02934684373554284 52 | 2.55,1.0,1.0,0.0,0.9946795555602882,1.0,0.028992053985942273 53 | 2.6,1.0,1.0,0.0,0.9948780135917167,1.0,0.028735354869193028 54 | 2.65,1.0,1.0,0.0,0.99503801395476,1.0,0.028563297446339213 55 | 2.7,1.0,1.0,0.0,0.9952040372067215,1.0,0.028403367668943507 56 | 2.75,1.0,1.0,0.0,0.9954139037773834,1.0,0.028231055564201697 57 | 2.8,1.0,1.0,0.0,0.9955730070156924,1.0,0.02809158835084354 58 | 2.85,1.0,1.0,0.0,0.9957292201822051,1.0,0.02793247439555708 59 | 2.9,1.0,1.0,0.0,0.995883750779165,1.0,0.027832614175859433 60 | 2.95,1.0,1.0,0.0,0.9960471507970968,1.0,0.02773939895707281 61 | 3.0,1.0,1.0,0.0,0.9961860763629112,1.0,0.027642858668279 62 | -------------------------------------------------------------------------------- /checkpoints/sample_2024_02_05__15_50_46_SubGDiff500/log.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-XL/SubgDiff/f9cf4db06f010a31523aec5b9a8063834c7bb9e5/checkpoints/sample_2024_02_05__15_50_46_SubGDiff500/log.txt -------------------------------------------------------------------------------- /configs/qm9_200steps.yml: -------------------------------------------------------------------------------- 1 | model: 2 | type: subgraph_diffusion 3 | network: dualenc 4 | hidden_dim: 128 5 | num_convs: 6 6 | num_convs_local: 4 7 | cutoff: 10.0 8 | mlp_act: relu 9 | beta_schedule: sigmoid 10 | beta_start: 1.e-7 11 | beta_end: 5.e-2 12 | num_diffusion_timesteps: 200 13 | edge_order: 3 14 | edge_encoder: mlp 15 | smooth_conv: false 16 | same_mask_steps: 10 # 10-same-subgraph diffusion 17 | mask_pred: MLP # 18 | 19 | train: 20 | seed: 2021 21 | batch_size: 64 22 | val_freq: 2000 23 | max_iters: 2000000 24 | max_grad_norm: 10000.00 25 | anneal_power: 2.0 26 | optimizer: 27 | type: adam 28 | lr: 1.e-3 29 | weight_decay: 0. 30 | beta1: 0.95 31 | beta2: 0.999 32 | scheduler: 33 | type: plateau 34 | factor: 0.6 35 | patience: 30 36 | 37 | dataset: 38 | train: ./data/GEOM/QM9/train_data_40k_subgraph.pkl 39 | val: ./data/GEOM/QM9/val_data_5k_subgraph.pkl 40 | test: ./data/GEOM/QM9/test_data_1k.pkl 41 | -------------------------------------------------------------------------------- /configs/qm9_500steps.yml: -------------------------------------------------------------------------------- 1 | model: 2 | type: subgraph_diffusion 3 | network: dualenc 4 | hidden_dim: 128 5 | num_convs: 6 6 | num_convs_local: 4 7 | cutoff: 10.0 8 | mlp_act: relu 9 | beta_schedule: sigmoid 10 | beta_start: 1.e-7 11 | beta_end: 2.e-2 12 | num_diffusion_timesteps: 500 13 | edge_order: 3 14 | edge_encoder: mlp 15 | smooth_conv: false 16 | same_mask_steps: 25 # 25-same-subgraph diffusion 17 | mask_pred: MLP 18 | 19 | train: 20 | seed: 2021 21 | batch_size: 64 22 | val_freq: 2000 23 | max_iters: 2000000 24 | max_grad_norm: 10000.00 25 | anneal_power: 2.0 26 | optimizer: 27 | type: adam 28 | lr: 1.e-3 29 | weight_decay: 0. 30 | beta1: 0.95 31 | beta2: 0.999 32 | scheduler: 33 | type: plateau 34 | factor: 0.6 35 | patience: 30 36 | 37 | dataset: 38 | train: ./data/GEOM/QM9/train_data_40k_subgraph.pkl 39 | val: ./data/GEOM/QM9/val_data_5k_subgraph.pkl 40 | test: ./data/GEOM/QM9/test_data_1k.pkl 41 | -------------------------------------------------------------------------------- /datasets/chem.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torchvision.transforms.functional import to_tensor 4 | 5 | import rdkit 6 | import rdkit.Chem.Draw 7 | from rdkit import Chem 8 | from rdkit.Chem import rdDepictor as DP 9 | from rdkit.Chem import PeriodicTable as PT 10 | from rdkit.Chem import rdMolAlign as MA 11 | from rdkit.Chem.rdchem import BondType as BT 12 | from rdkit.Chem.rdchem import Mol, GetPeriodicTable 13 | from rdkit.Chem.Draw import rdMolDraw2D as MD2 14 | from rdkit.Chem.rdmolops import RemoveHs 15 | from typing import List, Tuple 16 | 17 | 18 | 19 | BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} 20 | BOND_NAMES = {i: t for i, t in enumerate(BT.names.keys())} 21 | 22 | 23 | def set_conformer_positions(conf, pos): 24 | for i in range(pos.shape[0]): 25 | conf.SetAtomPosition(i, pos[i].tolist()) 26 | return conf 27 | 28 | 29 | def draw_mol_image(rdkit_mol, tensor=False): 30 | rdkit_mol.UpdatePropertyCache() 31 | img = rdkit.Chem.Draw.MolToImage(rdkit_mol, kekulize=False) 32 | if tensor: 33 | return to_tensor(img) 34 | else: 35 | return img 36 | 37 | 38 | def update_data_rdmol_positions(data): 39 | for i in range(data.pos.size(0)): 40 | data.rdmol.GetConformer(0).SetAtomPosition(i, data.pos[i].tolist()) 41 | return data 42 | 43 | 44 | def update_data_pos_from_rdmol(data): 45 | new_pos = torch.FloatTensor(data.rdmol.GetConformer(0).GetPositions()).to(data.pos) 46 | data.pos = new_pos 47 | return data 48 | 49 | 50 | def set_rdmol_positions(rdkit_mol, pos): 51 | """ 52 | Args: 53 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object. 54 | pos: (N_atoms, 3) 55 | """ 56 | mol = copy.deepcopy(rdkit_mol) 57 | set_rdmol_positions_(mol, pos) 58 | return mol 59 | 60 | 61 | def set_rdmol_positions_(mol, pos): 62 | """ 63 | Args: 64 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object. 65 | pos: (N_atoms, 3) 66 | """ 67 | for i in range(pos.shape[0]): 68 | mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist()) 69 | return mol 70 | 71 | 72 | def get_atom_symbol(atomic_number): 73 | return PT.GetElementSymbol(GetPeriodicTable(), atomic_number) 74 | 75 | 76 | def mol_to_smiles(mol: Mol) -> str: 77 | return Chem.MolToSmiles(mol, allHsExplicit=True) 78 | 79 | 80 | def mol_to_smiles_without_Hs(mol: Mol) -> str: 81 | return Chem.MolToSmiles(Chem.RemoveHs(mol)) 82 | 83 | 84 | def remove_duplicate_mols(molecules: List[Mol]) -> List[Mol]: 85 | unique_tuples: List[Tuple[str, Mol]] = [] 86 | 87 | for molecule in molecules: 88 | duplicate = False 89 | smiles = mol_to_smiles(molecule) 90 | for unique_smiles, _ in unique_tuples: 91 | if smiles == unique_smiles: 92 | duplicate = True 93 | break 94 | 95 | if not duplicate: 96 | unique_tuples.append((smiles, molecule)) 97 | 98 | return [mol for smiles, mol in unique_tuples] 99 | 100 | 101 | def get_atoms_in_ring(mol): 102 | atoms = set() 103 | for ring in mol.GetRingInfo().AtomRings(): 104 | for a in ring: 105 | atoms.add(a) 106 | return atoms 107 | 108 | 109 | def get_2D_mol(mol): 110 | mol = copy.deepcopy(mol) 111 | DP.Compute2DCoords(mol) 112 | return mol 113 | 114 | 115 | def draw_mol_svg(mol,molSize=(450,150),kekulize=False): 116 | mc = Chem.Mol(mol.ToBinary()) 117 | if kekulize: 118 | try: 119 | Chem.Kekulize(mc) 120 | except: 121 | mc = Chem.Mol(mol.ToBinary()) 122 | if not mc.GetNumConformers(): 123 | DP.Compute2DCoords(mc) 124 | drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1]) 125 | drawer.DrawMolecule(mc) 126 | drawer.FinishDrawing() 127 | svg = drawer.GetDrawingText() 128 | # It seems that the svg renderer used doesn't quite hit the spec. 129 | # Here are some fixes to make it work in the notebook, although I think 130 | # the underlying issue needs to be resolved at the generation step 131 | # return svg.replace('svg:','') 132 | return svg 133 | 134 | 135 | def GetBestRMSD(probe, ref): 136 | probe = RemoveHs(probe) 137 | ref = RemoveHs(ref) 138 | rmsd = MA.GetBestRMS(probe, ref) 139 | return rmsd 140 | -------------------------------------------------------------------------------- /datasets/rdmol2data.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | import torch 5 | from torch_geometric.data import Data 6 | 7 | from torch_scatter import scatter 8 | #from torch.utils.data import Dataset 9 | 10 | from rdkit import Chem 11 | from rdkit.Chem.rdchem import Mol, HybridizationType, BondType 12 | from rdkit import RDLogger 13 | RDLogger.DisableLog('rdApp.*') 14 | 15 | # from confgf import utils 16 | # from datasets.chem import BOND_TYPES, BOND_NAMES 17 | from rdkit.Chem.rdchem import BondType as BT 18 | 19 | BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} 20 | 21 | def rdmol_to_data(mol:Mol, pos=None,y=None, idx=None, smiles=None): 22 | assert mol.GetNumConformers() == 1 23 | N = mol.GetNumAtoms() 24 | if smiles is None: 25 | smiles = Chem.MolToSmiles(mol) 26 | # pos = torch.tensor(mol.GetConformer(0).GetPositions(), dtype=torch.float32) 27 | 28 | atomic_number = [] 29 | aromatic = [] 30 | sp = [] 31 | sp2 = [] 32 | sp3 = [] 33 | num_hs = [] 34 | for atom in mol.GetAtoms(): 35 | atomic_number.append(atom.GetAtomicNum()) 36 | aromatic.append(1 if atom.GetIsAromatic() else 0) 37 | hybridization = atom.GetHybridization() 38 | sp.append(1 if hybridization == HybridizationType.SP else 0) 39 | sp2.append(1 if hybridization == HybridizationType.SP2 else 0) 40 | sp3.append(1 if hybridization == HybridizationType.SP3 else 0) 41 | 42 | z = torch.tensor(atomic_number, dtype=torch.long) 43 | 44 | row, col, edge_type = [], [], [] 45 | for bond in mol.GetBonds(): 46 | start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() 47 | row += [start, end] 48 | col += [end, start] 49 | edge_type += 2 * [BOND_TYPES[bond.GetBondType()]] 50 | 51 | edge_index = torch.tensor([row, col], dtype=torch.long) 52 | edge_type = torch.tensor(edge_type) 53 | 54 | perm = (edge_index[0] * N + edge_index[1]).argsort() 55 | edge_index = edge_index[:, perm] 56 | edge_type = edge_type[perm] 57 | 58 | row, col = edge_index 59 | hs = (z == 1).to(torch.float32) 60 | 61 | num_hs = scatter(hs[row], col, dim_size=N, reduce='sum').tolist() 62 | 63 | if smiles is None: 64 | smiles = Chem.MolToSmiles(mol) 65 | try: 66 | name = mol.GetProp('_Name') 67 | except: 68 | name=None 69 | 70 | data = Data(atom_type=z, pos=pos, edge_index=edge_index, edge_type=edge_type, 71 | rdmol=copy.deepcopy(mol), smiles=smiles,y=y,idx=idx, name=name) 72 | #data.nx = to_networkx(data, to_undirected=True) 73 | 74 | return data -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: SubGDiff 2 | channels: 3 | - conda-forge 4 | - psi4 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - ambit=0.3=h137fa24_1 10 | - atomicwrites=1.4.0=py_0 11 | - backcall=0.2.0=pyh9f0ad1d_0 12 | - blas=1.0=mkl 13 | - brotlipy=0.7.0=py37h27cfd23_1003 14 | - bzip2=1.0.8=h7b6447c_0 15 | - ca-certificates=2023.7.22=hbcca054_0 16 | - cairo=1.16.0=hb05425b_5 17 | - certifi=2023.7.22=pyhd8ed1ab_0 18 | - cffi=1.15.1=py37h5eee18b_3 19 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 20 | - chemps2=1.8.9=h8c3debe_0 21 | - cryptography=39.0.1=py37h9ce1e76_0 22 | - cudatoolkit=11.3.1=h2bc3f7f_2 23 | - decorator=4.4.0=py37_1 24 | - deepdiff=3.3.0=py37_1 25 | - dkh=1.2=h173d85e_2 26 | - entrypoints=0.4=pyhd8ed1ab_0 27 | - expat=2.4.9=h6a678d5_0 28 | - ffmpeg=4.2.2=h20bf706_0 29 | - fontconfig=2.14.1=h4c34cd2_2 30 | - freetype=2.12.1=h4a9f257_0 31 | - gau2grid=1.3.1=h035aef0_0 32 | - gdma=2.2.6=h0e1e685_6 33 | - giflib=5.2.1=h5eee18b_3 34 | - glib=2.69.1=he621ea3_2 35 | - gmp=6.2.1=h295c915_3 36 | - gnutls=3.6.15=he1e5248_0 37 | - hdf5=1.10.2=hba1933b_1 38 | - icu=58.2=he6710b0_3 39 | - idna=3.4=py37h06a4308_0 40 | - importlib_metadata=4.11.3=hd3eb1b0_0 41 | - intel-openmp=2023.1.0=hdb19cb5_46305 42 | - ipython=7.33.0=py37h89c1867_0 43 | - ipython_genutils=0.2.0=py_1 44 | - jedi=0.18.2=pyhd8ed1ab_0 45 | - joblib=1.1.0=pyhd3eb1b0_0 46 | - jpeg=9b=0 47 | - jsonpickle=0.9.6=py37_0 48 | - lame=3.100=h7b6447c_0 49 | - lcms2=2.12=h3be6417_0 50 | - ld_impl_linux-64=2.38=h1181459_1 51 | - libblas=3.9.0=1_h6e990d7_netlib 52 | - libboost=1.73.0=h3ff78a5_11 53 | - libcblas=3.9.0=3_h893e4fe_netlib 54 | - libffi=3.4.4=h6a678d5_0 55 | - libgcc-ng=11.2.0=h1234567_1 56 | - libgfortran-ng=7.5.0=ha8ba4b0_17 57 | - libgfortran4=7.5.0=ha8ba4b0_17 58 | - libgomp=11.2.0=h1234567_1 59 | - libidn2=2.3.4=h5eee18b_0 60 | - libint=1.2.1=hb4a4fd4_6 61 | - liblapack=3.9.0=3_h893e4fe_netlib 62 | - libopus=1.3.1=h7b6447c_0 63 | - libpng=1.6.39=h5eee18b_0 64 | - libsodium=1.0.18=h36c2ea0_1 65 | - libstdcxx-ng=11.2.0=h1234567_1 66 | - libtasn1=4.19.0=h5eee18b_0 67 | - libtiff=4.1.0=h2733197_1 68 | - libunistring=0.9.10=h27cfd23_0 69 | - libuuid=1.41.5=h5eee18b_0 70 | - libuv=1.44.2=h5eee18b_0 71 | - libvpx=1.7.0=h439df22_0 72 | - libwebp=1.2.0=h89dd481_0 73 | - libxc=4.3.4=h7b6447c_0 74 | - libxcb=1.15=h7f8727e_0 75 | - libxml2=2.10.3=hcbfbd50_0 76 | - lz4-c=1.9.4=h6a678d5_0 77 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 78 | - mkl=2019.4=243 79 | - mkl-service=2.3.0=py37he8ac12f_0 80 | - mkl_fft=1.0.14=py37hd81dba3_0 81 | - mkl_random=1.0.4=py37hd81dba3_0 82 | - more-itertools=8.12.0=pyhd3eb1b0_0 83 | - ncurses=6.4=h6a678d5_0 84 | - nettle=3.7.3=hbbd107a_1 85 | - networkx=2.5.1=pyhd8ed1ab_0 86 | - numpy-base=1.17.0=py37hde5b4d6_0 87 | - nvcc_linux-64=10.1=hcaf9a05_10 88 | - openh264=2.1.1=h4ff587b_0 89 | - openssl=1.1.1v=h7f8727e_0 90 | - pandas=1.2.3=py37hdc94413_0 91 | - parso=0.8.3=pyhd8ed1ab_0 92 | - pcmsolver=1.2.1.1=py37h6d17ec8_2 93 | - pcre=8.45=h295c915_0 94 | - pexpect=4.8.0=pyh1a96a4e_2 95 | - pickleshare=0.7.5=py_1003 96 | - pillow=9.3.0=py37hace64e9_1 97 | - pint=0.10=py_0 98 | - pip=23.2=pyhd8ed1ab_0 99 | - pixman=0.40.0=h7f8727e_1 100 | - pluggy=1.0.0=py37h06a4308_1 101 | - prettytable=3.5.0=py37h06a4308_0 102 | - prompt-toolkit=3.0.39=pyha770c72_0 103 | - psi4=1.3.2+ecbda83=py37h06ff01c_1 104 | - psutil=5.9.0=py37h5eee18b_0 105 | - ptyprocess=0.7.0=pyhd3deb0d_0 106 | - py=1.11.0=pyhd3eb1b0_0 107 | - py-boost=1.73.0=py37ha9443f7_11 108 | - pycparser=2.21=pyhd3eb1b0_0 109 | - pydantic=1.3=py37h516909a_0 110 | - pygments=2.15.1=pyhd8ed1ab_0 111 | - pyopenssl=23.0.0=py37h06a4308_0 112 | - pysocks=1.7.1=py37_1 113 | - pytest=3.10.1=py37_1000 114 | - python=3.7.16=h7a1cb2a_0 115 | - python-dateutil=2.8.2=pyhd3eb1b0_0 116 | - python_abi=3.7=2_cp37m 117 | - pytorch=1.11.0=py3.7_cuda11.3_cudnn8.2.0_0 118 | - pytorch-mutex=1.0=cuda 119 | - pytz=2022.7=py37h06a4308_0 120 | - qcelemental=0.17.0=py_0 121 | - readline=8.2=h5eee18b_0 122 | - requests=2.28.1=py37h06a4308_0 123 | - scikit-learn=0.23.2=py37hddcf8d6_3 124 | - scipy=1.5.3=py37h8911b10_0 125 | - setuptools=59.8.0=py37h89c1867_1 126 | - simint=0.7=h642920c_1 127 | - six=1.16.0=pyhd3eb1b0_1 128 | - sqlite=3.41.2=h5eee18b_0 129 | - threadpoolctl=2.2.0=pyh0d69192_0 130 | - tk=8.6.12=h1ccaba5_0 131 | - torchaudio=0.11.0=py37_cu113 132 | - torchvision=0.12.0=py37_cu113 133 | - traitlets=5.9.0=pyhd8ed1ab_0 134 | - typing_extensions=4.1.1=pyh06a4308_0 135 | - urllib3=1.26.14=py37h06a4308_0 136 | - wcwidth=0.2.5=pyhd3eb1b0_0 137 | - wheel=0.38.4=py37h06a4308_0 138 | - x264=1!157.20191217=h7b6447c_0 139 | - xz=5.4.2=h5eee18b_0 140 | - zeromq=4.3.4=h9c3ff4c_1 141 | - zlib=1.2.13=h5eee18b_0 142 | - zstd=1.4.9=haebb681_0 143 | - pip: 144 | - absl-py==1.4.0 145 | - addict==2.4.0 146 | - aiofiles==22.1.0 147 | - aiosqlite==0.19.0 148 | - anyio==3.7.1 149 | - argon2-cffi==21.3.0 150 | - argon2-cffi-bindings==21.2.0 151 | - argparse==1.4.0 152 | - arrow==1.2.3 153 | - ase==3.22.1 154 | - atom3d==0.2.6 155 | - attrs==23.1.0 156 | - babel==2.12.1 157 | - beautifulsoup4==4.12.2 158 | - biopython==1.81 159 | - bleach==6.0.0 160 | - cached-property==1.5.2 161 | - cachetools==5.3.1 162 | - click==8.1.6 163 | - cloudpickle==2.2.1 164 | - comm==0.1.4 165 | - cycler==0.11.0 166 | - cython==3.0.0 167 | - dataclasses==0.8 168 | - debtcollector==2.5.0 169 | - debugpy==1.6.8 170 | - defusedxml==0.7.1 171 | - descriptastorus==2.5.0.20 172 | - dgl-cu110==0.6.1 173 | - dgllife==0.3.2 174 | - dill==0.3.7 175 | - easy-parallel==0.1.6 176 | - easydict==1.10 177 | - exceptiongroup==1.1.2 178 | - fastjsonschema==2.18.0 179 | - filelock==3.12.2 180 | - fonttools==4.38.0 181 | - fqdn==1.5.1 182 | - freesasa==2.2.0.post3 183 | - fuzzywuzzy==0.18.0 184 | - gdown==4.7.1 185 | - google-auth==2.22.0 186 | - google-auth-oauthlib==0.4.6 187 | - googledrivedownloader==0.4 188 | - grpcio==1.56.2 189 | - h5py==3.8.0 190 | - horovod==0.28.1 191 | - huggingface-hub==0.16.4 192 | - hyperopt==0.2.7 193 | - importlib-metadata==6.7.0 194 | - importlib-resources==5.12.0 195 | - ipykernel==6.16.2 196 | - ipywidgets==8.1.0 197 | - isodate==0.6.1 198 | - isoduration==20.11.0 199 | - jinja2==3.1.2 200 | - json5==0.9.14 201 | - jsonpointer==2.4 202 | - jsonschema==4.17.3 203 | - jupyter==1.0.0 204 | - jupyter-client==7.4.9 205 | - jupyter-console==6.6.3 206 | - jupyter-core==4.12.0 207 | - jupyter-events==0.6.3 208 | - jupyter-server==1.24.0 209 | - jupyter-server-fileid==0.9.0 210 | - jupyter-server-ydoc==0.8.0 211 | - jupyter-ydoc==0.2.5 212 | - jupyterlab==3.6.5 213 | - jupyterlab-pygments==0.2.2 214 | - jupyterlab-server==2.24.0 215 | - jupyterlab-widgets==3.0.8 216 | - kiwisolver==1.4.4 217 | - lmdb==1.4.1 218 | - markdown==3.4.4 219 | - markdown-it-py==2.2.0 220 | - markupsafe==2.1.3 221 | - matplotlib==3.5.3 222 | - mdurl==0.1.2 223 | - mistune==3.0.1 224 | - mmcv==1.5.0 225 | - mmengine==0.8.4 226 | - msgpack==1.0.5 227 | - multipledispatch==1.0.0 228 | - multiprocess==0.70.15 229 | - nbclassic==1.0.0 230 | - nbclient==0.7.4 231 | - nbconvert==7.6.0 232 | - nbformat==5.8.0 233 | - nest-asyncio==1.5.7 234 | - notebook==6.5.5 235 | - notebook-shim==0.2.3 236 | - numexpr==2.8.5 237 | - numpy==1.21.6 238 | - nvidia-nccl-cu11==2.18.3 239 | - oauthlib==3.2.2 240 | - opencv-python==4.8.0.74 241 | - packaging==23.1 242 | - pandas-flavor==0.6.0 243 | - pandocfilters==1.5.0 244 | - pathos==0.3.1 245 | - pkgutil-resolve-name==1.3.10 246 | - platformdirs==3.10.0 247 | - pox==0.3.3 248 | - ppft==1.7.6.7 249 | - prometheus-client==0.17.1 250 | - protobuf==3.20.3 251 | - psikit==0.2.0 252 | - py3dmol==2.0.3 253 | - pyasn1==0.5.0 254 | - pyasn1-modules==0.3.0 255 | - pyparsing==3.1.0 256 | - pyrr==0.10.3 257 | - pyrsistent==0.19.3 258 | - pytdc==0.4.1 259 | - python-dotenv==0.21.1 260 | - python-json-logger==2.0.7 261 | - python-louvain==0.16 262 | - pyyaml==6.0.1 263 | - pyzmq==24.0.1 264 | - qtconsole==5.4.3 265 | - qtpy==2.3.1 266 | - rdflib==6.3.2 267 | - rdkit==2023.3.2 268 | - rdkit-pypi==2023.3.1b1 269 | - regex==2023.6.3 270 | - requests-oauthlib==1.3.1 271 | - rfc3339-validator==0.1.4 272 | - rfc3986-validator==0.1.1 273 | - rich==13.5.2 274 | - rsa==4.9 275 | - safetensors==0.3.1 276 | - seaborn==0.12.2 277 | - send2trash==1.8.2 278 | - sniffio==1.3.0 279 | - soupsieve==2.4.1 280 | - tables==3.7.0 281 | - tensorboard==2.11.2 282 | - tensorboard-data-server==0.6.1 283 | - tensorboard-plugin-wit==1.8.1 284 | - termcolor==2.3.0 285 | - terminado==0.17.1 286 | - tinycss2==1.2.1 287 | - tokenizers==0.13.3 288 | - tomli==2.0.1 289 | - torch-cluster==1.6.0 290 | - torch-geometric==1.7.2 291 | - torch-scatter==2.0.9 292 | - torch-sparse==0.6.13 293 | - tornado==6.2 294 | - tqdm==4.65.0 295 | - transformers==4.31.0 296 | - typing-extensions==4.7.1 297 | - uri-template==1.3.0 298 | - webcolors==1.13 299 | - webencodings==0.5.1 300 | - websocket-client==1.6.1 301 | - werkzeug==2.2.3 302 | - widgetsnbextension==4.0.8 303 | - wilds==2.0.0 304 | - wrapt==1.15.0 305 | - xarray==0.20.2 306 | - y-py==0.6.0 307 | - yapf==0.40.1 308 | - ypy-websocket==0.8.4 309 | - zipp==3.15.0 310 | -------------------------------------------------------------------------------- /eval_covmat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pickle 4 | import torch 5 | 6 | from utils.datasets import PackedConformationDataset 7 | from utils.evaluation.covmat import CovMatEvaluator, print_covmat_results 8 | from utils.misc import * 9 | 10 | 11 | if __name__ == '__main__': 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('path', type=str, default="samples_all.pkl") 14 | parser.add_argument('--num_workers', type=int, default=8) 15 | parser.add_argument('--ratio', type=int, default=2) 16 | parser.add_argument('--start_idx', type=int, default=0) 17 | args = parser.parse_args() 18 | assert os.path.isfile(args.path) 19 | 20 | # Logging 21 | tag = args.path.split('/')[-1].split('.')[0] 22 | logger = get_logger('eval', os.path.dirname(args.path), 'log_eval_'+tag+'.txt') 23 | 24 | # Load results 25 | logger.info('Loading results: %s' % args.path) 26 | with open(args.path, 'rb') as f: 27 | packed_dataset = pickle.load(f) 28 | logger.info('Total: %d' % len(packed_dataset)) 29 | 30 | # Evaluator 31 | evaluator = CovMatEvaluator( 32 | num_workers = args.num_workers, 33 | 34 | ratio = args.ratio, 35 | print_fn=logger.info, 36 | ) 37 | results = evaluator( 38 | packed_data_list = list(packed_dataset), 39 | start_idx = args.start_idx, 40 | ) 41 | df = print_covmat_results(results, print_fn=logger.info) 42 | 43 | # Save results 44 | csv_fn = args.path[:-4] + '_covmat.csv' 45 | results_fn = args.path[:-4] + '_covmat.pkl' 46 | df.to_csv(csv_fn) 47 | with open(results_fn, 'wb') as f: 48 | pickle.dump(results, f) 49 | 50 | -------------------------------------------------------------------------------- /eval_prop.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import argparse 4 | import torch 5 | import numpy as np 6 | from psikit import Psikit 7 | from tqdm.auto import tqdm 8 | from easydict import EasyDict 9 | from torch_geometric.data import Data 10 | 11 | from utils.datasets import PackedConformationDataset 12 | from utils.chem import set_rdmol_positions 13 | 14 | 15 | class PropertyCalculator(object): 16 | 17 | def __init__(self, threads, memory, seed): 18 | super().__init__() 19 | self.pk = Psikit(threads=threads, memory=memory) 20 | self.seed = seed 21 | 22 | def __call__(self, data, num_confs=50): 23 | rdmol = data.rdmol 24 | confs = data.pos_prop 25 | 26 | conf_idx = np.arange(confs.shape[0]) 27 | np.random.RandomState(self.seed).shuffle(conf_idx) 28 | conf_idx = conf_idx[:num_confs] 29 | 30 | data.prop_conf_idx = [] 31 | data.prop_energy = [] 32 | data.prop_homo = [] 33 | data.prop_lumo = [] 34 | data.prop_dipo = [] 35 | 36 | for idx in tqdm(conf_idx): 37 | mol = set_rdmol_positions(rdmol, confs[idx]) 38 | self.pk.mol = mol 39 | try: 40 | energy, homo, lumo, dipo = self.pk.energy(), self.pk.HOMO, self.pk.LUMO, self.pk.dipolemoment[-1] 41 | data.prop_conf_idx.append(idx) 42 | data.prop_energy.append(energy) 43 | data.prop_homo.append(homo) 44 | data.prop_lumo.append(lumo) 45 | data.prop_dipo.append(dipo) 46 | except: 47 | pass 48 | 49 | return data 50 | 51 | 52 | def get_prop_matrix(data): 53 | """ 54 | Returns: 55 | properties: (4, num_confs) numpy tensor. Energy, HOMO, LUMO, DipoleMoment 56 | """ 57 | return np.array([ 58 | data.prop_energy, 59 | data.prop_homo, 60 | data.prop_lumo, 61 | data.prop_dipo, 62 | ]) 63 | 64 | 65 | def get_ensemble_energy(props): 66 | """ 67 | Args: 68 | props: (4, num_confs) 69 | """ 70 | avg_ener = np.mean(props[0, :]) 71 | low_ener = np.min(props[0, :]) 72 | gaps = np.abs(props[1, :] - props[2, :]) 73 | avg_gap = np.mean(gaps) 74 | min_gap = np.min(gaps) 75 | max_gap = np.max(gaps) 76 | return np.array([ 77 | avg_ener, low_ener, avg_gap, min_gap, max_gap, 78 | ]) 79 | 80 | HART_TO_EV = 27.211 81 | HART_TO_KCALPERMOL = 627.5 82 | 83 | if __name__ == '__main__': 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--dataset', type=str, default='./data/GEOM/QM9/qm9_property.pkl') 86 | parser.add_argument('--generated', type=str, default=None) 87 | parser.add_argument('--num_confs', type=int, default=50) 88 | parser.add_argument('--threads', type=int, default=8) 89 | parser.add_argument('--memory', type=int, default=16) 90 | parser.add_argument('--seed', type=int, default=2021) 91 | args = parser.parse_args() 92 | 93 | prop_cal = PropertyCalculator(threads=args.threads, memory=args.memory, seed=args.seed) 94 | 95 | cache_ref_fn = os.path.join( 96 | os.path.dirname(args.dataset), 97 | os.path.basename(args.dataset)[:-4] + '_prop.pkl' 98 | ) 99 | if not os.path.exists(cache_ref_fn): 100 | dset = PackedConformationDataset(args.dataset) 101 | dset = [data for data in dset] 102 | dset_prop = [] 103 | for data in dset: 104 | data.pos_prop = data.pos_ref.reshape(-1, data.num_nodes, 3) 105 | dset_prop.append(prop_cal(data, args.num_confs)) 106 | with open(cache_ref_fn, 'wb') as f: 107 | pickle.dump(dset_prop, f) 108 | dset = dset_prop 109 | else: 110 | with open(cache_ref_fn, 'rb') as f: 111 | dset = pickle.load(f) 112 | 113 | 114 | if args.generated is None: 115 | exit() 116 | 117 | print('Start evaluation.') 118 | 119 | cache_gen_fn = os.path.join( 120 | os.path.dirname(args.generated), 121 | os.path.basename(args.generated)[:-4] + '_prop.pkl' 122 | ) 123 | if not os.path.exists(cache_gen_fn): 124 | with open(args.generated, 'rb') as f: 125 | gens = pickle.load(f) 126 | gens_prop = [] 127 | for data in gens: 128 | if not isinstance(data, Data): 129 | data = EasyDict(data) 130 | data.num_nodes = data.rdmol.GetNumAtoms() 131 | data.pos_prop = data.pos_gen.reshape(-1, data.num_nodes, 3) 132 | gens_prop.append(prop_cal(data, args.num_confs)) 133 | with open(cache_gen_fn, 'wb') as f: 134 | pickle.dump(gens_prop, f) 135 | gens = gens_prop 136 | else: 137 | with open(cache_gen_fn, 'rb') as f: 138 | gens = pickle.load(f) 139 | 140 | 141 | dset = {d.smiles:d for d in dset} 142 | gens = {d.smiles:d for d in gens} 143 | all_diff = [] 144 | for smiles in dset.keys(): 145 | if smiles not in gens: 146 | continue 147 | 148 | prop_gts = get_ensemble_energy(get_prop_matrix(dset[smiles])) * HART_TO_EV 149 | prop_gen = get_ensemble_energy(get_prop_matrix(gens[smiles])) * HART_TO_EV 150 | # prop_gts = np.mean(get_prop_matrix(dset[smiles]), axis=1) 151 | # prop_gen = np.mean(get_prop_matrix(gens[smiles]), axis=1) 152 | 153 | # print(get_prop_matrix(gens[smiles])[0]) 154 | 155 | prop_diff = np.abs(prop_gts - prop_gen) 156 | 157 | print('\nProperty: %s' % smiles) 158 | print(' Gts :', prop_gts) 159 | print(' Gen :', prop_gen) 160 | print(' Diff:', prop_diff) 161 | 162 | all_diff.append(prop_diff.reshape(1, -1)) 163 | all_diff = np.vstack(all_diff) # (num_mols, 4) 164 | print(all_diff.shape) 165 | 166 | print('[Difference]') 167 | print(' Mean: ', np.mean(all_diff, axis=0)) 168 | print(' Median:', np.median(all_diff, axis=0)) 169 | print(' Std: ', np.std(all_diff, axis=0)) 170 | -------------------------------------------------------------------------------- /finetune/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from email.policy import default 3 | 4 | parser = argparse.ArgumentParser() 5 | 6 | # about seed and basic info 7 | parser.add_argument("--seed", type=int, default=42) 8 | parser.add_argument("--device", type=int, default=0) 9 | 10 | # parser.add_argument( 11 | # "--model_3d", 12 | # type=str, 13 | # default="schnet", 14 | # choices=[ 15 | # "schnet", 16 | # "dimenet", 17 | # "dimenetPP", 18 | # "tfn", 19 | # "se3_transformer", 20 | # "egnn", 21 | # "spherenet", 22 | # "segnn", 23 | # "painn", 24 | # "gemnet", 25 | # "nequip", 26 | # "allegro", 27 | # ], 28 | # ) 29 | parser.add_argument("--model_3d", type=str, default="mask_diff") 30 | parser.add_argument( 31 | "--model_2d", 32 | type=str, 33 | default="gin", 34 | choices=[ 35 | "gin", 36 | "schnet", 37 | "dimenet", 38 | "dimenetPP", 39 | "tfn", 40 | "se3_transformer", 41 | "egnn", 42 | "spherenet", 43 | "segnn", 44 | "painn", 45 | "gemnet", 46 | "nequip", 47 | "allegro", 48 | ], 49 | ) 50 | 51 | # about dataset and dataloader 52 | parser.add_argument("--dataset", type=str, default="qm9") 53 | parser.add_argument("--task", type=str, default="alpha") 54 | parser.add_argument("--num_workers", type=int, default=0) 55 | parser.add_argument("--only_one_atom_type", dest="only_one_atom_type", action="store_true") 56 | parser.set_defaults(only_one_atom_type=False) 57 | 58 | # for MD17 59 | # The default hyper from here: https://github.com/divelab/DIG_storage/tree/main/3dgraph/md17 60 | parser.add_argument("--md17_energy_coeff", type=float, default=0.05) 61 | parser.add_argument("--md17_force_coeff", type=float, default=0.95) 62 | 63 | # for COLL 64 | # The default hyper from here: https://github.com/divelab/DIG_storage/tree/main/3dgraph/md17 65 | parser.add_argument("--coll_energy_coeff", type=float, default=0.05) 66 | parser.add_argument("--coll_force_coeff", type=float, default=0.95) 67 | 68 | # for LBA 69 | # The default hyper from here: https://github.com/drorlab/atom3d/blob/master/examples/lep/enn/utils.py#L37-L43 70 | parser.add_argument("--LBA_year", type=int, default=2020) 71 | parser.add_argument("--LBA_dist", type=float, default=6) 72 | parser.add_argument("--LBA_maxnum", type=int, default=500) 73 | parser.add_argument("--LBA_use_complex", dest="LBA_use_complex", action="store_true") 74 | parser.add_argument("--LBA_no_complex", dest="LBA_use_complex", action="store_false") 75 | parser.set_defaults(LBA_use_complex=False) 76 | 77 | # for LEP 78 | # The default hyper from here: https://github.com/drorlab/atom3d/blob/master/examples/lep/enn/utils.py#L48-L55 79 | parser.add_argument("--LEP_dist", type=float, default=6) 80 | parser.add_argument("--LEP_maxnum", type=float, default=400) 81 | parser.add_argument("--LEP_droph", dest="LEP_droph", action="store_true") 82 | parser.add_argument("--LEP_useh", dest="LEP_droph", action="store_false") 83 | parser.set_defaults(LEP_droph=False) 84 | 85 | # for MoleculeNet 86 | parser.add_argument("--moleculenet_num_conformers", type=int, default=10) 87 | 88 | # about training strategies 89 | parser.add_argument("--split", type=str, default="customized_01", 90 | choices=["customized_01", "customized_02", "random", "atom3d_lba_split30"]) 91 | parser.add_argument("--MD17_train_batch_size", type=int, default=1) 92 | parser.add_argument("--batch_size", type=int, default=128) 93 | parser.add_argument("--epochs", type=int, default=100) 94 | parser.add_argument("--lr", type=float, default=1e-4) 95 | parser.add_argument("--lr_scale", type=float, default=1) 96 | parser.add_argument("--decay", type=float, default=0) 97 | parser.add_argument("--print_every_epoch", type=int, default=1) 98 | parser.add_argument("--loss", type=str, default="mae", choices=["mse", "mae"]) 99 | parser.add_argument("--lr_scheduler", type=str, default="CosineAnnealingLR") 100 | parser.add_argument("--lr_decay_factor", type=float, default=0.5) 101 | parser.add_argument("--lr_decay_step_size", type=int, default=100) 102 | parser.add_argument("--lr_decay_patience", type=int, default=50) 103 | parser.add_argument("--min_lr", type=float, default=1e-6) 104 | parser.add_argument("--verbose", dest="verbose", action="store_true") 105 | parser.add_argument("--no_verbose", dest="verbose", action="store_false") 106 | parser.set_defaults(verbose=False) 107 | parser.add_argument("--use_rotation_transform", dest="use_rotation_transform", action="store_true") 108 | parser.add_argument("--no_rotation_transform", dest="use_rotation_transform", action="store_false") 109 | parser.set_defaults(use_rotation_transform=False) 110 | 111 | # for SchNet 112 | parser.add_argument("--num_filters", type=int, default=128) 113 | parser.add_argument("--num_interactions", type=int, default=6) 114 | parser.add_argument("--num_gaussians", type=int, default=51) 115 | parser.add_argument("--cutoff", type=float, default=10) 116 | parser.add_argument("--readout", type=str, default="mean", choices=["mean", "add"]) 117 | 118 | # for PaiNN 119 | parser.add_argument("--painn_radius_cutoff", type=float, default=5.0) 120 | parser.add_argument("--painn_n_interactions", type=int, default=3) 121 | parser.add_argument("--painn_n_rbf", type=int, default=20) 122 | parser.add_argument("--painn_readout", type=str, default="add", choices=["mean", "add"]) 123 | 124 | ######################### for Charge Prediction SSL ######################### 125 | parser.add_argument("--charge_masking_ratio", type=float, default=0.3) 126 | 127 | ######################### for Distance Perturbation SSL ######################### 128 | parser.add_argument("--distance_sample_ratio", type=float, default=1) 129 | 130 | ######################### for Torsion Angle Perturbation SSL ######################### 131 | parser.add_argument("--torsion_angle_sample_ratio", type=float, default=0.001) 132 | 133 | ######################### for Position Perturbation SSL ######################### 134 | parser.add_argument("--PP_mu", type=float, default=0) 135 | parser.add_argument("--PP_sigma", type=float, default=0.3) 136 | 137 | 138 | ######################### for GraphMVP SSL ######################### 139 | ### for 2D GNN 140 | parser.add_argument("--gnn_type", type=str, default="gin") 141 | parser.add_argument("--num_layer", type=int, default=5) 142 | parser.add_argument("--emb_dim", type=int, default=128) 143 | parser.add_argument("--dropout_ratio", type=float, default=0.5) 144 | parser.add_argument("--graph_pooling", type=str, default="mean") 145 | parser.add_argument("--JK", type=str, default="last") 146 | parser.add_argument("--gnn_2d_lr_scale", type=float, default=1) 147 | 148 | ######################### for GeoSSL ######################### 149 | parser.add_argument("--GeoSSL_mu", type=float, default=0) 150 | parser.add_argument("--GeoSSL_sigma", type=float, default=0.3) 151 | parser.add_argument("--GeoSSL_atom_masking_ratio", type=float, default=0.3) 152 | parser.add_argument("--GeoSSL_option", type=str, default="EBM_NCE", choices=["DDM", "EBM_NCE", "InfoNCE", "RR"]) 153 | parser.add_argument("--EBM_NCE_SM_coefficient", type=float, default=10.) 154 | 155 | parser.add_argument("--SM_sigma_begin", type=float, default=10) 156 | parser.add_argument("--SM_sigma_end", type=float, default=0.01) 157 | parser.add_argument("--SM_num_noise_level", type=int, default=50) 158 | parser.add_argument("--SM_noise_type", type=str, default="symmetry", choices=["symmetry", "random"]) 159 | parser.add_argument("--SM_anneal_power", type=float, default=2) 160 | 161 | ######################### for GraphMVP SSL ######################### 162 | ### for 3D GNN 163 | parser.add_argument("--gnn_3d_lr_scale", type=float, default=1) 164 | 165 | ### for masking 166 | parser.add_argument("--SSL_masking_ratio", type=float, default=0.15) 167 | 168 | ### for 2D-3D Contrastive SSL 169 | parser.add_argument("--CL_neg_samples", type=int, default=1) 170 | parser.add_argument("--CL_similarity_metric", type=str, default="InfoNCE_dot_prod", 171 | choices=["InfoNCE_dot_prod", "EBM_dot_prod"]) 172 | parser.add_argument("--T", type=float, default=0.1) 173 | parser.add_argument("--normalize", dest="normalize", action="store_true") 174 | parser.add_argument("--no_normalize", dest="normalize", action="store_false") 175 | parser.add_argument("--alpha_1", type=float, default=1) 176 | 177 | ### for 2D-3D Generative SSL 178 | parser.add_argument("--GraphMVP_AE_model", type=str, default="VAE") 179 | parser.add_argument("--detach_target", dest="detach_target", action="store_true") 180 | parser.add_argument("--no_detach_target", dest="detach_target", action="store_false") 181 | parser.set_defaults(detach_target=True) 182 | parser.add_argument("--AE_loss", type=str, default="l2", choices=["l1", "l2", "cosine"]) 183 | parser.add_argument("--beta", type=float, default=1) 184 | parser.add_argument("--alpha_2", type=float, default=1) 185 | 186 | ### for 2D SSL 187 | parser.add_argument("--GraphMVP_2D_mode", type=str, default="AM", choices=["AM", "CP"]) 188 | parser.add_argument("--alpha_3", type=float, default=1) 189 | ### for AttributeMask 190 | parser.add_argument("--mask_rate", type=float, default=0.15) 191 | parser.add_argument("--mask_edge", type=int, default=0) 192 | ### for ContextPred 193 | parser.add_argument("--csize", type=int, default=3) 194 | parser.add_argument("--contextpred_neg_samples", type=int, default=1) 195 | ####################################################################### 196 | 197 | 198 | 199 | ##### about if we would print out eval metric for training data 200 | parser.add_argument("--eval_train", dest="eval_train", action="store_true") 201 | parser.add_argument("--no_eval_train", dest="eval_train", action="store_false") 202 | parser.set_defaults(eval_train=False) 203 | ##### about if we would print out eval metric for training data 204 | ##### this is only for COLL 205 | parser.add_argument("--eval_test", dest="eval_test", action="store_true") 206 | parser.add_argument("--no_eval_test", dest="eval_test", action="store_false") 207 | parser.set_defaults(eval_test=True) 208 | 209 | parser.add_argument("--input_data_dir", type=str, default="") 210 | 211 | # about loading and saving 212 | parser.add_argument("--input_model_file", type=str, default="") 213 | parser.add_argument("--output_model_dir", type=str, default="") 214 | 215 | ## for mask diff (GeoDiff) 216 | parser.add_argument('--ckpt', type=str,default="", help='path for loading the checkpoint') 217 | parser.add_argument('--config', type=str, default=None) 218 | parser.add_argument('--tag', type=str, default=None) 219 | parser.add_argument('--random_init', action="store_true", default=False) 220 | 221 | args = parser.parse_args() 222 | print("arguments\t", args) 223 | -------------------------------------------------------------------------------- /finetune/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # from finetune.datasets.datasets_Molecule3D import Molecule3D 2 | 3 | from .datasets_QM9 import MoleculeDatasetQM9 4 | 5 | from .rdmol2data import rdmol_to_data 6 | # from finetune.datasets.datasets_MD17 import DatasetMD17 7 | # from finetune.datasets.datasets_MD17Radius import DatasetMD17Radius 8 | 9 | # from finetune.datasets.datasets_LBA import DatasetLBA, TransformLBA 10 | # from finetune.datasets.datasets_LBARadius import DatasetLBARadius 11 | 12 | # from finetune.datasets.datasets_LEP import DatasetLEP, TransformLEP 13 | # from finetune.datasets.datasets_LEPRadius import DatasetLEPRadius 14 | 15 | # from finetune.datasets.datasets_3D import Molecule3DDataset 16 | # from finetune.datasets.datasets_3D_Masking import Molecule3DMaskingDataset 17 | # from finetune.datasets.datasets_3D_Radius import MoleculeDataset3DRadius 18 | 19 | # from finetune.datasets.datasets_utils import graph_data_obj_to_nx_simple, nx_to_graph_data_obj_simple 20 | -------------------------------------------------------------------------------- /finetune/datasets/datasets_QM9.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import repeat 3 | 4 | import pandas as pd 5 | import torch 6 | from rdkit import Chem 7 | from rdkit.Chem import AllChem 8 | from scipy.constants import physical_constants 9 | from torch_geometric.data import (Data, InMemoryDataset, download_url, 10 | extract_zip) 11 | 12 | # from Geom3D.datasets.datasets_utils import mol_to_graph_data_obj_simple_3D 13 | from datasets.rdmol2data import rdmol_to_data 14 | 15 | class MoleculeDatasetQM9(InMemoryDataset): 16 | raw_url = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip" 17 | raw_url2 = "https://ndownloader.figshare.com/files/3195404" 18 | raw_url3 = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/qm9.csv" 19 | raw_url4 = "https://springernature.figshare.com/ndownloader/files/3195395" 20 | 21 | def __init__( 22 | self, 23 | root, 24 | dataset, 25 | task, 26 | rotation_transform=None, 27 | transform=None, 28 | pre_transform=None, 29 | pre_filter=None, 30 | calculate_thermo=True, 31 | ): 32 | """ 33 | The complete columns are 34 | A,B,C,mu,alpha,homo,lumo,gap,r2,zpve,u0,u298,h298,g298,cv,u0_atom,u298_atom,h298_atom,g298_atom 35 | and we take 36 | mu,alpha,homo,lumo,gap,r2,zpve,u0,u298,h298,g298,cv 37 | """ 38 | self.root = root 39 | self.rotation_transform = rotation_transform 40 | self.transform = transform 41 | self.pre_transform = pre_transform 42 | self.pre_filter = pre_filter 43 | 44 | self.target_field = [ 45 | "mu", 46 | "alpha", 47 | "homo", 48 | "lumo", 49 | "gap", 50 | "r2", 51 | "zpve", 52 | "u0", 53 | "u298", 54 | "h298", 55 | "g298", 56 | "cv", 57 | "gap_02", 58 | ] 59 | self.pd_target_field = [ 60 | "mu", 61 | "alpha", 62 | "homo", 63 | "lumo", 64 | "gap", 65 | "r2", 66 | "zpve", 67 | "u0", 68 | "u298", 69 | "h298", 70 | "g298", 71 | "cv", 72 | ] 73 | self.task = task 74 | if self.task == "qm9": 75 | self.task_id = None 76 | else: 77 | self.task_id = self.target_field.index(task) 78 | self.calculate_thermo = calculate_thermo 79 | self.atom_dict = {"H": 1, "C": 6, "N": 7, "O": 8, "F": 9} 80 | 81 | # TODO: need double-check 82 | # https://github.com/atomistic-machine-learning/schnetpack/blob/master/src/schnetpack/datasets/qm9.py 83 | # QM 9dataset unit: property_unit_dict = { 84 | # QM9.A: "GHz", 85 | # QM9.B: "GHz", 86 | # QM9.C: "GHz", 87 | # QM9.mu: "Debye", 88 | # QM9.alpha: "a0 a0 a0", 89 | # QM9.homo: "Ha", 90 | # QM9.lumo: "Ha", 91 | # QM9.gap: "Ha", 92 | # QM9.r2: "a0 a0", 93 | # QM9.zpve: "Ha", 94 | # QM9.U0: "Ha", 95 | # QM9.U: "Ha", 96 | # QM9.H: "Ha", 97 | # QM9.G: "Ha", 98 | # QM9.Cv: "cal/mol/K", 99 | # } 100 | # HAR2EV = 27.211386246 101 | # KCALMOL2EV = 0.04336414 102 | 103 | # conversion = torch.tensor([ 104 | # 1., 1., HAR2EV, HAR2EV, HAR2EV, 1., HAR2EV, HAR2EV, HAR2EV, HAR2EV, HAR2EV, 105 | # 1., KCALMOL2EV, KCALMOL2EV, KCALMOL2EV, KCALMOL2EV, 1., 1., 1. 106 | # ]) 107 | 108 | # Now we are following these two: 109 | # https://github.com/risilab/cormorant/blob/master/examples/train_qm9.py 110 | # https://github.com/FabianFuchsML/se3-transformer-public/blob/master/experiments/qm9/QM9.py 111 | 112 | self.hartree2eV = physical_constants["hartree-electron volt relationship"][0] 113 | 114 | self.conversion = { 115 | "mu": 1.0, 116 | "alpha": 1.0, 117 | "homo": self.hartree2eV, 118 | "lumo": self.hartree2eV, 119 | "gap": self.hartree2eV, 120 | "gap_02": self.hartree2eV, 121 | "r2": 1.0, 122 | "zpve": self.hartree2eV, 123 | "u0": self.hartree2eV, 124 | "u298": self.hartree2eV, 125 | "h298": self.hartree2eV, 126 | "g298": self.hartree2eV, 127 | "cv": 1.0, 128 | } 129 | 130 | super(MoleculeDatasetQM9, self).__init__( 131 | root, transform, pre_transform, pre_filter 132 | ) 133 | self.dataset = dataset 134 | self.data, self.slices = torch.load(self.processed_paths[0]) 135 | print("Dataset: {}\nData: {}".format(self.dataset, self.data)) 136 | 137 | return 138 | 139 | def mean(self): 140 | y = torch.stack([self.get(i).y for i in range(len(self))], dim=0) 141 | y = y.mean(dim=0) 142 | return y 143 | 144 | def std(self): 145 | y = torch.stack([self.get(i).y for i in range(len(self))], dim=0) 146 | y = y.std(dim=0) 147 | return y 148 | 149 | def get(self, idx): 150 | data = Data() 151 | for key in self.data.keys: 152 | item, slices = self.data[key], self.slices[key] 153 | s = list(repeat(slice(None), item.dim())) 154 | s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) 155 | data[key] = item[s] 156 | if self.rotation_transform is not None: 157 | data.positions = self.rotation_transform(data.positions) 158 | return data 159 | 160 | @property 161 | def raw_file_names(self): 162 | return [ 163 | "gdb9.sdf", 164 | "gdb9.sdf.csv", 165 | "uncharacterized.txt", 166 | "qm9.csv", 167 | "atomref.txt", 168 | ] 169 | 170 | @property 171 | def processed_file_names(self): 172 | return "geometric_data_processed.pt" 173 | 174 | def download(self): 175 | file_path = download_url(self.raw_url, self.raw_dir) 176 | extract_zip(file_path, self.raw_dir) 177 | os.unlink(file_path) 178 | 179 | download_url(self.raw_url2, self.raw_dir) 180 | os.rename( 181 | os.path.join(self.raw_dir, "3195404"), 182 | os.path.join(self.raw_dir, "uncharacterized.txt"), 183 | ) 184 | 185 | download_url(self.raw_url3, self.raw_dir) 186 | 187 | download_url(self.raw_url4, self.raw_dir) 188 | os.rename( 189 | os.path.join(self.raw_dir, "3195395"), 190 | os.path.join(self.raw_dir, "atomref.txt"), 191 | ) 192 | return 193 | 194 | def get_thermo_dict(self): 195 | gdb9_txt_thermo = self.raw_paths[4] 196 | # Loop over file of thermochemical energies 197 | therm_targets = ["zpve", "u0", "u298", "h298", "g298", "cv"] 198 | therm_targets = [6, 7, 8, 9, 10, 11] 199 | 200 | # Dictionary that 201 | id2charge = self.atom_dict 202 | 203 | # Loop over file of thermochemical energies 204 | therm_energy = {target: {} for target in therm_targets} 205 | with open(gdb9_txt_thermo) as f: 206 | for line in f: 207 | # If line starts with an element, convert the rest to a list of energies. 208 | split = line.split() 209 | 210 | # Check charge corresponds to an atom 211 | if len(split) == 0 or split[0] not in id2charge.keys(): 212 | continue 213 | 214 | # Loop over learning targets with defined thermochemical energy 215 | for therm_target, split_therm in zip(therm_targets, split[1:]): 216 | therm_energy[therm_target][id2charge[split[0]]] = float(split_therm) 217 | 218 | return therm_energy 219 | 220 | def process(self): 221 | therm_energy = self.get_thermo_dict() 222 | print("therm_energy\t", therm_energy) 223 | 224 | df = pd.read_csv(self.raw_paths[1]) 225 | df = df[self.pd_target_field] 226 | df["gap_02"] = df["lumo"] - df["homo"] 227 | 228 | target = df.to_numpy() 229 | target = torch.tensor(target, dtype=torch.float) 230 | 231 | with open(self.raw_paths[2], "r") as f: 232 | # These are the mis-matched molecules, according to `uncharacerized.txt` file. 233 | skip = [int(x.split()[0]) - 1 for x in f.read().split("\n")[9:-2]] 234 | 235 | data_df = pd.read_csv(self.raw_paths[3]) 236 | whole_smiles_list = data_df["smiles"].tolist() 237 | print("TODO\t", whole_smiles_list[:100]) 238 | 239 | suppl = Chem.SDMolSupplier(self.raw_paths[0], removeHs=False, sanitize=False) 240 | 241 | print("suppl: {}\tsmiles_list: {}".format(len(suppl), len(whole_smiles_list))) 242 | 243 | data_list, data_smiles_list, data_name_list, idx, invalid_count = ( 244 | [], 245 | [], 246 | [], 247 | 0, 248 | 0, 249 | ) 250 | for i, mol in enumerate(suppl): 251 | if i in skip: 252 | print("Exception with (skip)\t", i) 253 | invalid_count += 1 254 | continue 255 | 256 | # data, atom_count = mol_to_graph_data_obj_simple_3D(mol) 257 | data, atom_count = rdmol_to_data(mol=mol) 258 | 259 | data.id = torch.tensor([idx]) 260 | temp_y = target[i] 261 | if self.calculate_thermo: 262 | for atom, count in atom_count.items(): 263 | if atom not in self.atom_dict.values(): 264 | continue 265 | for target_id, atom_sub_dic in therm_energy.items(): 266 | temp_y[target_id] -= atom_sub_dic[atom] * count 267 | 268 | # convert units 269 | for idx, col in enumerate(self.target_field): 270 | temp_y[idx] *= self.conversion[col] 271 | data.y = temp_y 272 | 273 | name = mol.GetProp("_Name") 274 | smiles = whole_smiles_list[i] 275 | 276 | # TODO: need double-check this 277 | temp_mol = AllChem.MolFromSmiles(smiles) 278 | if temp_mol is None: 279 | print("Exception with (invalid mol)\t", i) 280 | invalid_count += 1 281 | continue 282 | 283 | data_smiles_list.append(smiles) 284 | data_name_list.append(name) 285 | data_list.append(data) 286 | idx += 1 287 | 288 | print( 289 | "mol id: [0, {}]\tlen of smiles: {}\tlen of set(smiles): {}".format( 290 | idx - 1, len(data_smiles_list), len(set(data_smiles_list)) 291 | ) 292 | ) 293 | print("{} invalid molecules".format(invalid_count)) 294 | 295 | if self.pre_filter is not None: 296 | data_list = [data for data in data_list if self.pre_filter(data)] 297 | 298 | if self.pre_transform is not None: 299 | data_list = [self.pre_transform(data) for data in data_list] 300 | 301 | # TODO: need double-check later, the smiles list are identical here? 302 | data_smiles_series = pd.Series(data_smiles_list) 303 | saver_path = os.path.join(self.processed_dir, "smiles.csv") 304 | print("saving to {}".format(saver_path)) 305 | data_smiles_series.to_csv(saver_path, index=False, header=False) 306 | 307 | data_name_series = pd.Series(data_name_list) 308 | saver_path = os.path.join(self.processed_dir, "name.csv") 309 | print("saving to {}".format(saver_path)) 310 | data_name_series.to_csv(saver_path, index=False, header=False) 311 | 312 | data, slices = self.collate(data_list) 313 | torch.save((data, slices), self.processed_paths[0]) 314 | 315 | return 316 | -------------------------------------------------------------------------------- /finetune/datasets/rdmol2data.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | import torch 5 | from torch_geometric.data import Data 6 | 7 | from torch_scatter import scatter 8 | #from torch.utils.data import Dataset 9 | 10 | from rdkit import Chem 11 | from rdkit.Chem.rdchem import Mol, HybridizationType, BondType 12 | from rdkit import RDLogger 13 | RDLogger.DisableLog('rdApp.*') 14 | 15 | from collections import defaultdict 16 | 17 | # from confgf import utils 18 | # from datasets.chem import BOND_TYPES, BOND_NAMES 19 | from rdkit.Chem.rdchem import BondType as BT 20 | 21 | BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} 22 | 23 | def rdmol_to_data(mol:Mol, pos=None,y=None, idx=None, smiles=None): 24 | assert mol.GetNumConformers() == 1 25 | N = mol.GetNumAtoms() 26 | # if smiles is None: 27 | # smiles = Chem.MolToSmiles(mol) 28 | pos = torch.tensor(mol.GetConformer(0).GetPositions(), dtype=torch.float32) 29 | 30 | atomic_number = [] 31 | aromatic = [] 32 | sp = [] 33 | sp2 = [] 34 | sp3 = [] 35 | num_hs = [] 36 | atom_count = defaultdict(int) 37 | for atom in mol.GetAtoms(): 38 | atomic_number.append(atom.GetAtomicNum()) 39 | atom_count[atom.GetAtomicNum()] += 1 40 | aromatic.append(1 if atom.GetIsAromatic() else 0) 41 | hybridization = atom.GetHybridization() 42 | sp.append(1 if hybridization == HybridizationType.SP else 0) 43 | sp2.append(1 if hybridization == HybridizationType.SP2 else 0) 44 | sp3.append(1 if hybridization == HybridizationType.SP3 else 0) 45 | 46 | z = torch.tensor(atomic_number, dtype=torch.long) 47 | 48 | row, col, edge_type = [], [], [] 49 | for bond in mol.GetBonds(): 50 | start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() 51 | row += [start, end] 52 | col += [end, start] 53 | edge_type += 2 * [BOND_TYPES[bond.GetBondType()]] 54 | 55 | edge_index = torch.tensor([row, col], dtype=torch.long) 56 | edge_type = torch.tensor(edge_type) 57 | 58 | perm = (edge_index[0] * N + edge_index[1]).argsort() 59 | edge_index = edge_index[:, perm] 60 | edge_type = edge_type[perm] 61 | 62 | row, col = edge_index 63 | hs = (z == 1).to(torch.float32) 64 | 65 | num_hs = scatter(hs[row], col, dim_size=N, reduce='sum').tolist() 66 | 67 | if smiles is None: 68 | smiles = Chem.MolToSmiles(mol) 69 | 70 | try: 71 | name = mol.GetProp('_Name') 72 | except: 73 | name=None 74 | 75 | # data = Data(atom_type=z, pos=pos, edge_index=edge_index, edge_type=edge_type, 76 | # rdmol=copy.deepcopy(mol), smiles=smiles,y=y,id=idx, name=name) 77 | data = Data(atom_type=z, pos=pos, edge_index=edge_index, edge_type=edge_type) 78 | #data.nx = to_networkx(data, to_undirected=True) 79 | 80 | return data, atom_count -------------------------------------------------------------------------------- /finetune/splitters.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from collections import defaultdict 5 | from itertools import compress 6 | 7 | import numpy as np 8 | import torch 9 | from rdkit.Chem.Scaffolds import MurckoScaffold 10 | # from sklearn.model_selection import StratifiedKFold 11 | 12 | 13 | def generate_scaffold(smiles, include_chirality=False): 14 | """Obtain Bemis-Murcko scaffold from smiles 15 | :return: smiles of scaffold""" 16 | scaffold = MurckoScaffold.MurckoScaffoldSmiles( 17 | smiles=smiles, includeChirality=include_chirality 18 | ) 19 | return scaffold 20 | 21 | 22 | # # test generate_scaffold 23 | # s = 'Cc1cc(Oc2nccc(CCC)c2)ccc1' 24 | # scaffold = generate_scaffold(s) 25 | # assert scaffold == 'c1ccc(Oc2ccccn2)cc1' 26 | 27 | 28 | def scaffold_split( 29 | dataset, 30 | smiles_list, 31 | task_idx=None, 32 | null_value=0, 33 | frac_train=0.8, 34 | frac_valid=0.1, 35 | frac_test=0.1, 36 | return_smiles=False, 37 | ): 38 | """ 39 | Adapted from https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py 40 | Split dataset by Bemis-Murcko scaffolds 41 | This function can also ignore examples containing null values for a 42 | selected task when splitting. Deterministic split 43 | :param dataset: pytorch geometric dataset obj 44 | :param smiles_list: list of smiles corresponding to the dataset obj 45 | :param task_idx: column idx of the data.y tensor. Will filter out 46 | examples with null value in specified task column of the data.y tensor 47 | prior to splitting. If None, then no filtering 48 | :param null_value: float that specifies null value in data.y to filter if 49 | task_idx is provided 50 | :param frac_train, frac_valid, frac_test: fractions 51 | :param return_smiles: return SMILES if Ture 52 | :return: train, valid, test slices of the input dataset obj.""" 53 | np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0) 54 | 55 | if task_idx is not None: 56 | # filter based on null values in task_idx 57 | # get task array 58 | y_task = np.array([data.y[task_idx].item() for data in dataset]) 59 | # boolean array that correspond to non null values 60 | non_null = y_task != null_value 61 | smiles_list = list(compress(enumerate(smiles_list), non_null)) 62 | else: 63 | non_null = np.ones(len(dataset)) == 1 64 | smiles_list = list(compress(enumerate(smiles_list), non_null)) 65 | 66 | # create dict of the form {scaffold_i: [idx1, idx....]} 67 | all_scaffolds = {} 68 | for i, smiles in smiles_list: 69 | scaffold = generate_scaffold(smiles, include_chirality=True) 70 | if scaffold not in all_scaffolds: 71 | all_scaffolds[scaffold] = [i] 72 | else: 73 | all_scaffolds[scaffold].append(i) 74 | 75 | # sort from largest to smallest sets 76 | all_scaffolds = {key: sorted(value) for key, value in all_scaffolds.items()} 77 | all_scaffold_sets = [ 78 | scaffold_set 79 | for (scaffold, scaffold_set) in sorted( 80 | all_scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True 81 | ) 82 | ] 83 | 84 | # get train, valid test indices 85 | train_cutoff = frac_train * len(smiles_list) 86 | valid_cutoff = (frac_train + frac_valid) * len(smiles_list) 87 | train_idx, valid_idx, test_idx = [], [], [] 88 | for scaffold_set in all_scaffold_sets: 89 | if len(train_idx) + len(scaffold_set) > train_cutoff: 90 | if len(train_idx) + len(valid_idx) + len(scaffold_set) > valid_cutoff: 91 | test_idx.extend(scaffold_set) 92 | else: 93 | valid_idx.extend(scaffold_set) 94 | else: 95 | train_idx.extend(scaffold_set) 96 | 97 | assert len(set(train_idx).intersection(set(valid_idx))) == 0 98 | assert len(set(test_idx).intersection(set(valid_idx))) == 0 99 | 100 | train_dataset = dataset[torch.tensor(train_idx)] 101 | valid_dataset = dataset[torch.tensor(valid_idx)] 102 | test_dataset = dataset[torch.tensor(test_idx)] 103 | 104 | if not return_smiles: 105 | return train_dataset, valid_dataset, test_dataset 106 | else: 107 | train_smiles = [smiles_list[i][1] for i in train_idx] 108 | valid_smiles = [smiles_list[i][1] for i in valid_idx] 109 | test_smiles = [smiles_list[i][1] for i in test_idx] 110 | return ( 111 | train_dataset, 112 | valid_dataset, 113 | test_dataset, 114 | (train_smiles, valid_smiles, test_smiles), 115 | ) 116 | 117 | 118 | def random_scaffold_split( 119 | dataset, 120 | smiles_list, 121 | task_idx=None, 122 | null_value=0, 123 | frac_train=0.8, 124 | frac_valid=0.1, 125 | frac_test=0.1, 126 | seed=0, 127 | ): 128 | """ 129 | Adapted from https://github.com/pfnet-research/chainer-chemistry/blob/master/ 130 | chainer_chemistry/dataset/splitters/scaffold_splitter.py 131 | Split dataset by Bemis-Murcko scaffolds 132 | This function can also ignore examples containing null values for a 133 | selected task when splitting. Deterministic split 134 | :param dataset: pytorch geometric dataset obj 135 | :param smiles_list: list of smiles corresponding to the dataset obj 136 | :param task_idx: column idx of the data.y tensor. Will filter out 137 | examples with null value in specified task column of the data.y tensor 138 | prior to splitting. If None, then no filtering 139 | :param null_value: float that specifies null value in data.y to filter if 140 | task_idx is provided 141 | :param frac_train, frac_valid, frac_test: fractions, floats 142 | :param seed: seed 143 | :return: train, valid, test slices of the input dataset obj""" 144 | 145 | np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0) 146 | 147 | if task_idx is not None: 148 | # filter based on null values in task_idx get task array 149 | y_task = np.array([data.y[task_idx].item() for data in dataset]) 150 | # boolean array that correspond to non null values 151 | non_null = y_task != null_value 152 | smiles_list = list(compress(enumerate(smiles_list), non_null)) 153 | else: 154 | non_null = np.ones(len(dataset)) == 1 155 | smiles_list = list(compress(enumerate(smiles_list), non_null)) 156 | 157 | rng = np.random.RandomState(seed) 158 | 159 | scaffolds = defaultdict(list) 160 | for ind, smiles in smiles_list: 161 | scaffold = generate_scaffold(smiles, include_chirality=True) 162 | scaffolds[scaffold].append(ind) 163 | 164 | scaffold_sets = rng.permutation(list(scaffolds.values())) 165 | 166 | n_total_valid = int(np.floor(frac_valid * len(dataset))) 167 | n_total_test = int(np.floor(frac_test * len(dataset))) 168 | 169 | train_idx = [] 170 | valid_idx = [] 171 | test_idx = [] 172 | 173 | for scaffold_set in scaffold_sets: 174 | if len(valid_idx) + len(scaffold_set) <= n_total_valid: 175 | valid_idx.extend(scaffold_set) 176 | elif len(test_idx) + len(scaffold_set) <= n_total_test: 177 | test_idx.extend(scaffold_set) 178 | else: 179 | train_idx.extend(scaffold_set) 180 | 181 | train_dataset = dataset[torch.tensor(train_idx)] 182 | valid_dataset = dataset[torch.tensor(valid_idx)] 183 | test_dataset = dataset[torch.tensor(test_idx)] 184 | 185 | return train_dataset, valid_dataset, test_dataset 186 | 187 | 188 | def random_split( 189 | dataset, 190 | task_idx=None, 191 | null_value=0, 192 | frac_train=0.8, 193 | frac_valid=0.1, 194 | frac_test=0.1, 195 | seed=0, 196 | smiles_list=None, 197 | ): 198 | """ 199 | :return: train, valid, test slices of the input dataset obj. If 200 | smiles_list != None, also returns ([train_smiles_list], 201 | [valid_smiles_list], [test_smiles_list])""" 202 | 203 | np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0) 204 | 205 | if task_idx is not None: 206 | # filter based on null values in task_idx 207 | # get task array 208 | y_task = np.array([data.y[task_idx].item() for data in dataset]) 209 | non_null = ( 210 | y_task != null_value 211 | ) # boolean array that correspond to non null values 212 | idx_array = np.where(non_null)[0] 213 | dataset = dataset[torch.tensor(idx_array)] # examples containing non 214 | # null labels in the specified task_idx 215 | else: 216 | pass 217 | 218 | num_mols = len(dataset) 219 | random.seed(seed) 220 | print("using seed\t", seed) 221 | all_idx = list(range(num_mols)) 222 | random.shuffle(all_idx) 223 | 224 | train_idx = all_idx[: int(frac_train * num_mols)] 225 | valid_idx = all_idx[ 226 | int(frac_train * num_mols) : int(frac_valid * num_mols) 227 | + int(frac_train * num_mols) 228 | ] 229 | test_idx = all_idx[int(frac_valid * num_mols) + int(frac_train * num_mols) :] 230 | 231 | assert len(set(train_idx).intersection(set(valid_idx))) == 0 232 | assert len(set(valid_idx).intersection(set(test_idx))) == 0 233 | assert len(train_idx) + len(valid_idx) + len(test_idx) == num_mols 234 | 235 | train_dataset = dataset[torch.tensor(train_idx)] 236 | valid_dataset = dataset[torch.tensor(valid_idx)] 237 | test_dataset = dataset[torch.tensor(test_idx)] 238 | 239 | if not smiles_list: 240 | return train_dataset, valid_dataset, test_dataset 241 | else: 242 | train_smiles = [smiles_list[i] for i in train_idx] 243 | valid_smiles = [smiles_list[i] for i in valid_idx] 244 | test_smiles = [smiles_list[i] for i in test_idx] 245 | return ( 246 | train_dataset, 247 | valid_dataset, 248 | test_dataset, 249 | (train_smiles, valid_smiles, test_smiles), 250 | ) 251 | 252 | 253 | def qm9_random_customized_01( 254 | dataset, task_idx=None, null_value=0, seed=0, smiles_list=None 255 | ): 256 | if task_idx is not None: 257 | # filter based on null values in task_idx 258 | # get task array 259 | y_task = np.array([data.y[task_idx].item() for data in dataset]) 260 | non_null = ( 261 | y_task != null_value 262 | ) # boolean array that correspond to non null values 263 | idx_array = np.where(non_null)[0] 264 | dataset = dataset[torch.tensor(idx_array)] # examples containing non 265 | # null labels in the specified task_idx 266 | else: 267 | pass 268 | 269 | num_mols = len(dataset) 270 | np.random.seed(seed) 271 | all_idx = np.random.permutation(num_mols) 272 | 273 | Nmols = 133885 - 3054 274 | Ntrain = 110000 275 | Nvalid = 10000 276 | Ntest = Nmols - (Ntrain + Nvalid) 277 | 278 | train_idx = all_idx[:Ntrain] 279 | valid_idx = all_idx[Ntrain : Ntrain + Nvalid] 280 | test_idx = all_idx[Ntrain + Nvalid :] 281 | 282 | print("train_idx: ", train_idx) 283 | print("valid_idx: ", valid_idx) 284 | print("test_idx: ", test_idx) 285 | # np.savez("customized_01", train_idx=train_idx, valid_idx=valid_idx, test_idx=test_idx) 286 | 287 | assert len(set(train_idx).intersection(set(valid_idx))) == 0 288 | assert len(set(valid_idx).intersection(set(test_idx))) == 0 289 | assert len(train_idx) + len(valid_idx) + len(test_idx) == num_mols 290 | 291 | train_dataset = dataset[torch.tensor(train_idx)] 292 | valid_dataset = dataset[torch.tensor(valid_idx)] 293 | test_dataset = dataset[torch.tensor(test_idx)] 294 | 295 | if not smiles_list: 296 | return train_dataset, valid_dataset, test_dataset 297 | else: 298 | train_smiles = [smiles_list[i] for i in train_idx] 299 | valid_smiles = [smiles_list[i] for i in valid_idx] 300 | test_smiles = [smiles_list[i] for i in test_idx] 301 | return ( 302 | train_dataset, 303 | valid_dataset, 304 | test_dataset, 305 | (train_smiles, valid_smiles, test_smiles), 306 | ) 307 | 308 | 309 | def qm9_random_customized_02( 310 | dataset, task_idx=None, null_value=0, seed=0, smiles_list=None 311 | ): 312 | if task_idx is not None: 313 | # filter based on null values in task_idx 314 | # get task array 315 | y_task = np.array([data.y[task_idx].item() for data in dataset]) 316 | non_null = ( 317 | y_task != null_value 318 | ) # boolean array that correspond to non null values 319 | idx_array = np.where(non_null)[0] 320 | dataset = dataset[torch.tensor(idx_array)] # examples containing non 321 | # null labels in the specified task_idx 322 | else: 323 | pass 324 | 325 | num_mols = len(dataset) 326 | np.random.seed(seed) 327 | print("using seed\t", seed) 328 | all_idx = np.random.permutation(num_mols) 329 | 330 | Nmols = 133885 - 3054 331 | Ntrain = 100000 332 | Ntest = int(0.1 * Nmols) 333 | Nvalid = Nmols - (Ntrain + Ntest) 334 | 335 | train_idx = all_idx[:Ntrain] 336 | valid_idx = all_idx[Ntrain : Ntrain + Nvalid] 337 | test_idx = all_idx[Ntrain + Nvalid :] 338 | 339 | assert len(set(train_idx).intersection(set(valid_idx))) == 0 340 | assert len(set(valid_idx).intersection(set(test_idx))) == 0 341 | assert len(train_idx) + len(valid_idx) + len(test_idx) == num_mols 342 | 343 | train_dataset = dataset[torch.tensor(train_idx)] 344 | valid_dataset = dataset[torch.tensor(valid_idx)] 345 | test_dataset = dataset[torch.tensor(test_idx)] 346 | 347 | if not smiles_list: 348 | return train_dataset, valid_dataset, test_dataset 349 | else: 350 | train_smiles = [smiles_list[i] for i in train_idx] 351 | valid_smiles = [smiles_list[i] for i in valid_idx] 352 | test_smiles = [smiles_list[i] for i in test_idx] 353 | return ( 354 | train_dataset, 355 | valid_dataset, 356 | test_dataset, 357 | (train_smiles, valid_smiles, test_smiles), 358 | ) 359 | 360 | 361 | def atom3d_lba_split(dataset, data_root, year): 362 | json_file = os.path.join(data_root, 'processed', 'pdb_id2data_id_{}.json'.format(year)) 363 | f = open(json_file, 'r') 364 | pdb_id2data_id = json.load(f) 365 | 366 | def load_pdb_id_list_from_file(file): 367 | pdb_id_list = [] 368 | with open(file, 'r') as f: 369 | lines = f.readlines() 370 | for line in lines: 371 | pdb_id_list.append(line.strip()) 372 | return pdb_id_list 373 | 374 | def load_data_id(mode): 375 | pdb_id_file = os.path.join(data_root, 'processed', 'targets', '{}.txt'.format(mode)) 376 | pdb_id_list = load_pdb_id_list_from_file(pdb_id_file) 377 | data_id_list = [pdb_id2data_id[pdb_id] for pdb_id in pdb_id_list] 378 | return data_id_list 379 | 380 | train_idx = load_data_id('train') 381 | valid_idx = load_data_id('val') 382 | test_idx = load_data_id('test') 383 | 384 | train_dataset = dataset[torch.tensor(train_idx)] 385 | valid_dataset = dataset[torch.tensor(valid_idx)] 386 | test_dataset = dataset[torch.tensor(test_idx)] 387 | 388 | return train_dataset, valid_dataset, test_dataset 389 | 390 | -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import radius_graph, radius 6 | from torch_scatter import scatter_mean, scatter_add, scatter_max 7 | from torch_sparse import coalesce 8 | from torch_geometric.utils import to_dense_adj, dense_to_sparse 9 | 10 | from utils.chem import BOND_TYPES 11 | 12 | 13 | class MeanReadout(nn.Module): 14 | """Mean readout operator over graphs with variadic sizes.""" 15 | 16 | def forward(self, data, input): 17 | """ 18 | Perform readout over the graph(s). 19 | Parameters: 20 | data (torch_geometric.data.Data): batched graph 21 | input (Tensor): node representations 22 | Returns: 23 | Tensor: graph representations 24 | """ 25 | output = scatter_mean(input, data.batch, dim=0, dim_size=data.num_graphs) 26 | return output 27 | 28 | 29 | class SumReadout(nn.Module): 30 | """Sum readout operator over graphs with variadic sizes.""" 31 | 32 | def forward(self, data, input): 33 | """ 34 | Perform readout over the graph(s). 35 | Parameters: 36 | data (torch_geometric.data.Data): batched graph 37 | input (Tensor): node representations 38 | Returns: 39 | Tensor: graph representations 40 | """ 41 | output = scatter_add(input, data.batch, dim=0, dim_size=data.num_graphs) 42 | return output 43 | 44 | 45 | 46 | class MultiLayerPerceptron(nn.Module): 47 | """ 48 | Multi-layer Perceptron. 49 | Note there is no activation or dropout in the last layer. 50 | Parameters: 51 | input_dim (int): input dimension 52 | hidden_dim (list of int): hidden dimensions 53 | activation (str or function, optional): activation function 54 | dropout (float, optional): dropout rate 55 | """ 56 | 57 | def __init__(self, input_dim, hidden_dims, activation="relu", dropout=0): 58 | super(MultiLayerPerceptron, self).__init__() 59 | 60 | self.dims = [input_dim] + hidden_dims 61 | if isinstance(activation, str): 62 | self.activation = getattr(F, activation) 63 | else: 64 | self.activation = None 65 | if dropout: 66 | self.dropout = nn.Dropout(dropout) 67 | else: 68 | self.dropout = None 69 | 70 | self.layers = nn.ModuleList() 71 | for i in range(len(self.dims) - 1): 72 | self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1])) 73 | 74 | def forward(self, input): 75 | """""" 76 | x = input 77 | for i, layer in enumerate(self.layers): 78 | x = layer(x) 79 | if i < len(self.layers) - 1: 80 | if self.activation: 81 | x = self.activation(x) 82 | if self.dropout: 83 | x = self.dropout(x) 84 | return x 85 | 86 | 87 | def assemble_atom_pair_feature(node_attr, edge_index, edge_attr): 88 | h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]] 89 | h_pair = torch.cat([h_row*h_col, edge_attr], dim=-1) # (E, 2H) 90 | return h_pair 91 | 92 | 93 | def generate_symmetric_edge_noise(num_nodes_per_graph, edge_index, edge2graph, device): 94 | num_cum_nodes = num_nodes_per_graph.cumsum(0) # (G, ) 95 | node_offset = num_cum_nodes - num_nodes_per_graph # (G, ) 96 | edge_offset = node_offset[edge2graph] # (E, ) 97 | 98 | num_nodes_square = num_nodes_per_graph**2 # (G, ) 99 | num_nodes_square_cumsum = num_nodes_square.cumsum(-1) # (G, ) 100 | edge_start = num_nodes_square_cumsum - num_nodes_square # (G, ) 101 | edge_start = edge_start[edge2graph] 102 | 103 | all_len = num_nodes_square_cumsum[-1] 104 | 105 | node_index = edge_index.t() - edge_offset.unsqueeze(-1) 106 | node_large = node_index.max(dim=-1)[0] 107 | node_small = node_index.min(dim=-1)[0] 108 | undirected_edge_id = node_large * (node_large + 1) + node_small + edge_start 109 | 110 | symm_noise = torch.zeros(size=[all_len.item()], device=device) 111 | symm_noise.normal_() 112 | d_noise = symm_noise[undirected_edge_id].unsqueeze(-1) # (E, 1) 113 | return d_noise 114 | 115 | 116 | def _extend_graph_order(num_nodes, edge_index, edge_type, order=3): 117 | """ 118 | Args: 119 | num_nodes: Number of atoms. 120 | edge_index: Bond indices of the original graph. 121 | edge_type: Bond types of the original graph. 122 | order: Extension order. 123 | Returns: 124 | new_edge_index: Extended edge indices. 125 | new_edge_type: Extended edge types. 126 | """ 127 | 128 | def binarize(x): 129 | return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)) 130 | 131 | def get_higher_order_adj_matrix(adj, order): 132 | """ 133 | Args: 134 | adj: (N, N) 135 | type_mat: (N, N) 136 | Returns: 137 | Following attributes will be updated: 138 | - edge_index 139 | - edge_type 140 | Following attributes will be added to the data object: 141 | - bond_edge_index: Original edge_index. 142 | """ 143 | adj_mats = [torch.eye(adj.size(0), dtype=torch.long, device=adj.device), \ 144 | binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device))] 145 | 146 | for i in range(2, order+1): 147 | adj_mats.append(binarize(adj_mats[i-1] @ adj_mats[1])) 148 | order_mat = torch.zeros_like(adj) 149 | 150 | for i in range(1, order+1): 151 | order_mat = order_mat + (adj_mats[i] - adj_mats[i-1]) * i 152 | 153 | return order_mat 154 | 155 | num_types = len(BOND_TYPES) 156 | 157 | N = num_nodes 158 | adj = to_dense_adj(edge_index).squeeze(0) 159 | adj_order = get_higher_order_adj_matrix(adj, order) # (N, N) 160 | 161 | type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N) 162 | type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order)) 163 | assert (type_mat * type_highorder == 0).all() 164 | type_new = type_mat + type_highorder 165 | 166 | new_edge_index, new_edge_type = dense_to_sparse(type_new) 167 | _, edge_order = dense_to_sparse(adj_order) 168 | 169 | # data.bond_edge_index = data.edge_index # Save original edges 170 | new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data 171 | 172 | # [Note] This is not necessary 173 | # data.is_bond = (data.edge_type < num_types) 174 | 175 | # [Note] In earlier versions, `edge_order` attribute will be added. 176 | # However, it doesn't seem to be necessary anymore so I removed it. 177 | # edge_index_1, data.edge_order = coalesce(new_edge_index, edge_order.long(), N, N) # modify data 178 | # assert (data.edge_index == edge_index_1).all() 179 | 180 | return new_edge_index, new_edge_type 181 | 182 | 183 | def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None): 184 | 185 | assert edge_type.dim() == 1 186 | N = pos.size(0) 187 | 188 | bgraph_adj = torch.sparse.LongTensor( 189 | edge_index, 190 | edge_type, 191 | torch.Size([N, N]) 192 | ) 193 | 194 | if is_sidechain is None: 195 | rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r) 196 | else: 197 | # fetch sidechain and its batch index 198 | is_sidechain = is_sidechain.bool() 199 | dummy_index = torch.arange(pos.size(0), device=pos.device) 200 | sidechain_pos = pos[is_sidechain] 201 | sidechain_index = dummy_index[is_sidechain] 202 | sidechain_batch = batch[is_sidechain] 203 | 204 | assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch) 205 | r_edge_index_x = assign_index[1] 206 | r_edge_index_y = assign_index[0] 207 | r_edge_index_y = sidechain_index[r_edge_index_y] 208 | 209 | rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E) 210 | rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E) 211 | rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E) 212 | # delete self loop 213 | rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])] 214 | 215 | rgraph_adj = torch.sparse.LongTensor( 216 | rgraph_edge_index, 217 | torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number, 218 | torch.Size([N, N]) 219 | ) 220 | 221 | composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T) 222 | # edge_index = composed_adj.indices() 223 | # dist = (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1) 224 | 225 | new_edge_index = composed_adj.indices() 226 | new_edge_type = composed_adj.values().long() 227 | 228 | return new_edge_index, new_edge_type 229 | 230 | 231 | def extend_graph_order_radius(num_nodes, pos, edge_index, edge_type, batch, order=3, cutoff=10.0, 232 | extend_order=True, extend_radius=True, is_sidechain=None): 233 | 234 | if extend_order: 235 | edge_index, edge_type = _extend_graph_order( 236 | num_nodes=num_nodes, 237 | edge_index=edge_index, 238 | edge_type=edge_type, order=order 239 | ) 240 | # edge_index_order = edge_index 241 | # edge_type_order = edge_type 242 | 243 | if extend_radius: 244 | edge_index, edge_type = _extend_to_radius_graph( 245 | pos=pos, 246 | edge_index=edge_index, 247 | edge_type=edge_type, 248 | cutoff=cutoff, 249 | batch=batch, 250 | is_sidechain=is_sidechain 251 | 252 | ) 253 | 254 | return edge_index, edge_type 255 | 256 | 257 | def coarse_grain(pos, node_attr, subgraph_index, batch): 258 | cluster_pos = scatter_mean(pos, index=subgraph_index, dim=0) # (num_clusters, 3) 259 | cluster_attr = scatter_add(node_attr, index=subgraph_index, dim=0) # (num_clusters, H) 260 | cluster_batch, _ = scatter_max(batch, index=subgraph_index, dim=0) # (num_clusters, ) 261 | 262 | return cluster_pos, cluster_attr, cluster_batch 263 | 264 | 265 | def batch_to_natoms(batch): 266 | return scatter_add(torch.ones_like(batch), index=batch, dim=0) 267 | 268 | 269 | def get_complete_graph(natoms): 270 | """ 271 | Args: 272 | natoms: Number of nodes per graph, (B, 1). 273 | Returns: 274 | edge_index: (2, N_1 + N_2 + ... + N_{B-1}), where N_i is the number of nodes of the i-th graph. 275 | num_edges: (B, ), number of edges per graph. 276 | """ 277 | natoms_sqr = (natoms ** 2).long() 278 | num_atom_pairs = torch.sum(natoms_sqr) 279 | natoms_expand = torch.repeat_interleave(natoms, natoms_sqr) 280 | 281 | index_offset = torch.cumsum(natoms, dim=0) - natoms 282 | index_offset_expand = torch.repeat_interleave(index_offset, natoms_sqr) 283 | 284 | index_sqr_offset = torch.cumsum(natoms_sqr, dim=0) - natoms_sqr 285 | index_sqr_offset = torch.repeat_interleave(index_sqr_offset, natoms_sqr) 286 | 287 | atom_count_sqr = torch.arange(num_atom_pairs, device=num_atom_pairs.device) - index_sqr_offset 288 | 289 | index1 = (atom_count_sqr // natoms_expand).long() + index_offset_expand 290 | index2 = (atom_count_sqr % natoms_expand).long() + index_offset_expand 291 | edge_index = torch.cat([index1.view(1, -1), index2.view(1, -1)]) 292 | mask = torch.logical_not(index1 == index2) 293 | edge_index = edge_index[:, mask] 294 | 295 | num_edges = natoms_sqr - natoms # Number of edges per graph 296 | 297 | return edge_index, num_edges 298 | -------------------------------------------------------------------------------- /models/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .schnet import * 2 | from .gin import * 3 | from .edge import * 4 | from .coarse import * 5 | -------------------------------------------------------------------------------- /models/encoder/coarse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | from torch_scatter import scatter_add, scatter_mean, scatter_max 4 | 5 | from ..common import coarse_grain, batch_to_natoms, get_complete_graph 6 | from .schnet import SchNetEncoder, GaussianSmearing 7 | 8 | 9 | class CoarseGrainingEncoder(Module): 10 | 11 | def __init__(self, hidden_channels, num_filters, num_interactions, edge_channels, cutoff, smooth): 12 | super().__init__() 13 | self.encoder = SchNetEncoder( 14 | hidden_channels=hidden_channels, 15 | num_filters=num_filters, 16 | num_interactions=num_interactions, 17 | edge_channels=edge_channels, 18 | cutoff=cutoff, 19 | smooth=smooth, 20 | ) 21 | self.distexp = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels) 22 | 23 | 24 | def forward(self, pos, node_attr, subgraph_index, batch, return_coarse=False): 25 | """ 26 | Args: 27 | pos: (N, 3) 28 | node_attr: (N, H) 29 | subgraph_index: (N, ) 30 | batch: (N, ) 31 | """ 32 | cluster_pos, cluster_attr, cluster_batch = coarse_grain(pos, node_attr, subgraph_index, batch) 33 | 34 | edge_index, _ = get_complete_graph(batch_to_natoms(cluster_batch)) 35 | row, col = edge_index 36 | edge_length = torch.norm(cluster_pos[row] - cluster_pos[col], dim=1, p=2) 37 | edge_attr = self.distexp(edge_length) 38 | 39 | h = self.encoder( 40 | z = cluster_attr, 41 | edge_index = edge_index, 42 | edge_length = edge_length, 43 | edge_attr = edge_attr, 44 | embed_node = False, 45 | ) 46 | 47 | if return_graph: 48 | return h, cluster_pos, cluster_attr, cluster_batch, edge_index 49 | else: 50 | return h 51 | -------------------------------------------------------------------------------- /models/encoder/edge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Module, Sequential, ModuleList, Linear, Embedding 4 | from torch_geometric.nn import MessagePassing, radius_graph 5 | from torch_sparse import coalesce 6 | from torch_geometric.data import Data 7 | from torch_geometric.utils import to_dense_adj, dense_to_sparse 8 | from math import pi as PI 9 | 10 | from utils.chem import BOND_TYPES 11 | from ..common import MeanReadout, SumReadout, MultiLayerPerceptron 12 | 13 | 14 | class GaussianSmearingEdgeEncoder(Module): 15 | 16 | def __init__(self, num_gaussians=64, cutoff=10.0): 17 | super().__init__() 18 | #self.NUM_BOND_TYPES = 22 19 | self.num_gaussians = num_gaussians 20 | self.cutoff = cutoff 21 | self.rbf = GaussianSmearing(start=0.0, stop=cutoff * 2, num_gaussians=num_gaussians) # Larger `stop` to encode more cases 22 | self.bond_emb = Embedding(100, embedding_dim=num_gaussians) 23 | 24 | @property 25 | def out_channels(self): 26 | return self.num_gaussians * 2 27 | 28 | def forward(self, edge_length, edge_type): 29 | """ 30 | Input: 31 | edge_length: The length of edges, shape=(E, 1). 32 | edge_type: The type pf edges, shape=(E,) 33 | Returns: 34 | edge_attr: The representation of edges. (E, 2 * num_gaussians) 35 | """ 36 | edge_attr = torch.cat([self.rbf(edge_length), self.bond_emb(edge_type)], dim=1) 37 | return edge_attr 38 | 39 | 40 | class MLPEdgeEncoder(Module): 41 | 42 | def __init__(self, hidden_dim=100, activation='relu'): 43 | super().__init__() 44 | self.hidden_dim = hidden_dim 45 | self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim) 46 | self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation) 47 | 48 | @property 49 | def out_channels(self): 50 | return self.hidden_dim 51 | 52 | def forward(self, edge_length, edge_type): 53 | """ 54 | Input: 55 | edge_length: The length of edges, shape=(E, 1). 56 | edge_type: The type pf edges, shape=(E,) 57 | Returns: 58 | edge_attr: The representation of edges. (E, 2 * num_gaussians) 59 | """ 60 | d_emb = self.mlp(edge_length) # (num_edge, hidden_dim) 61 | edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim) 62 | return d_emb * edge_attr # (num_edge, hidden) 63 | 64 | 65 | def get_edge_encoder(cfg): 66 | if cfg.edge_encoder == 'mlp': 67 | return MLPEdgeEncoder(cfg.hidden_dim, cfg.mlp_act) 68 | elif cfg.edge_encoder == 'gaussian': 69 | return GaussianSmearingEdgeEncoder(config.hidden_dim // 2, cutoff=config.cutoff) 70 | else: 71 | raise NotImplementedError('Unknown edge encoder: %s' % cfg.edge_encoder) 72 | 73 | -------------------------------------------------------------------------------- /models/encoder/gin.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from typing import Callable, Union 3 | from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size 4 | 5 | import torch 6 | from torch import Tensor 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch_sparse import SparseTensor, matmul 10 | from torch_geometric.nn.conv import MessagePassing 11 | 12 | from ..common import MeanReadout, SumReadout, MultiLayerPerceptron 13 | 14 | 15 | class GINEConv(MessagePassing): 16 | 17 | def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False, 18 | activation="softplus", **kwargs): 19 | super(GINEConv, self).__init__(aggr='add', **kwargs) 20 | self.nn = nn 21 | self.initial_eps = eps 22 | 23 | if isinstance(activation, str): 24 | self.activation = getattr(F, activation) 25 | else: 26 | self.activation = None 27 | 28 | if train_eps: 29 | self.eps = torch.nn.Parameter(torch.Tensor([eps])) 30 | else: 31 | self.register_buffer('eps', torch.Tensor([eps])) 32 | 33 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, 34 | edge_attr: OptTensor = None, size: Size = None) -> Tensor: 35 | """""" 36 | if isinstance(x, Tensor): 37 | x: OptPairTensor = (x, x) 38 | 39 | # Node and edge feature dimensionalites need to match. 40 | if isinstance(edge_index, Tensor): 41 | assert edge_attr is not None 42 | assert x[0].size(-1) == edge_attr.size(-1) 43 | elif isinstance(edge_index, SparseTensor): 44 | assert x[0].size(-1) == edge_index.size(-1) 45 | 46 | # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) 47 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) 48 | 49 | x_r = x[1] 50 | if x_r is not None: 51 | out = out + (1 + self.eps) * x_r 52 | 53 | return self.nn(out) 54 | 55 | def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor: 56 | if self.activation: 57 | return self.activation(x_j + edge_attr) 58 | else: 59 | return x_j + edge_attr 60 | 61 | def __repr__(self): 62 | return '{}(nn={})'.format(self.__class__.__name__, self.nn) 63 | 64 | 65 | class GINEncoder(torch.nn.Module): 66 | 67 | def __init__(self, hidden_dim, num_convs=3, activation='relu', short_cut=True, concat_hidden=False): 68 | super().__init__() 69 | 70 | self.hidden_dim = hidden_dim 71 | self.num_convs = num_convs 72 | self.short_cut = short_cut 73 | self.concat_hidden = concat_hidden 74 | self.node_emb = nn.Embedding(100, hidden_dim) 75 | 76 | if isinstance(activation, str): 77 | self.activation = getattr(F, activation) 78 | else: 79 | self.activation = None 80 | 81 | self.convs = nn.ModuleList() 82 | for i in range(self.num_convs): 83 | self.convs.append(GINEConv(MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], \ 84 | activation=activation), activation=activation)) 85 | 86 | 87 | 88 | def forward(self, z, edge_index, edge_attr): 89 | """ 90 | Input: 91 | data: (torch_geometric.data.Data): batched graph 92 | node_attr: node feature tensor with shape (num_node, hidden) 93 | edge_attr: edge feature tensor with shape (num_edge, hidden) 94 | Output: 95 | node_attr 96 | graph feature 97 | """ 98 | 99 | node_attr = self.node_emb(z) # (num_node, hidden) 100 | 101 | hiddens = [] 102 | conv_input = node_attr # (num_node, hidden) 103 | 104 | for conv_idx, conv in enumerate(self.convs): 105 | hidden = conv(conv_input, edge_index, edge_attr) 106 | if conv_idx < len(self.convs) - 1 and self.activation is not None: 107 | hidden = self.activation(hidden) 108 | assert hidden.shape == conv_input.shape 109 | if self.short_cut and hidden.shape == conv_input.shape: 110 | hidden = hidden + conv_input 111 | 112 | hiddens.append(hidden) 113 | conv_input = hidden 114 | 115 | if self.concat_hidden: 116 | node_feature = torch.cat(hiddens, dim=-1) 117 | else: 118 | node_feature = hiddens[-1] 119 | 120 | return node_feature 121 | -------------------------------------------------------------------------------- /models/encoder/schnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Module, Sequential, ModuleList, Linear, Embedding 4 | from torch_geometric.nn import MessagePassing, radius_graph 5 | from torch_sparse import coalesce 6 | from torch_geometric.data import Data 7 | from torch_geometric.utils import to_dense_adj, dense_to_sparse 8 | from math import pi as PI 9 | 10 | from utils.chem import BOND_TYPES 11 | from ..common import MeanReadout, SumReadout, MultiLayerPerceptron 12 | 13 | 14 | class GaussianSmearing(torch.nn.Module): 15 | def __init__(self, start=0.0, stop=5.0, num_gaussians=50): 16 | super(GaussianSmearing, self).__init__() 17 | offset = torch.linspace(start, stop, num_gaussians) 18 | self.coeff = -0.5 / (offset[1] - offset[0]).item()**2 19 | self.register_buffer('offset', offset) 20 | 21 | def forward(self, dist): 22 | dist = dist.view(-1, 1) - self.offset.view(1, -1) 23 | return torch.exp(self.coeff * torch.pow(dist, 2)) 24 | 25 | 26 | class AsymmetricSineCosineSmearing(Module): 27 | 28 | def __init__(self, num_basis=50): 29 | super().__init__() 30 | num_basis_k = num_basis // 2 31 | num_basis_l = num_basis - num_basis_k 32 | self.register_buffer('freq_k', torch.arange(1, num_basis_k + 1).float()) 33 | self.register_buffer('freq_l', torch.arange(1, num_basis_l + 1).float()) 34 | 35 | @property 36 | def num_basis(self): 37 | return self.freq_k.size(0) + self.freq_l.size(0) 38 | 39 | def forward(self, angle): 40 | # If we don't incorporate `cos`, the embedding of 0-deg and 180-deg will be the 41 | # same, which is undesirable. 42 | s = torch.sin(angle.view(-1, 1) * self.freq_k.view(1, -1)) # (num_angles, num_basis_k) 43 | c = torch.cos(angle.view(-1, 1) * self.freq_l.view(1, -1)) # (num_angles, num_basis_l) 44 | return torch.cat([s, c], dim=-1) 45 | 46 | 47 | class SymmetricCosineSmearing(Module): 48 | 49 | def __init__(self, num_basis=50): 50 | super().__init__() 51 | self.register_buffer('freq_k', torch.arange(1, num_basis+1).float()) 52 | 53 | @property 54 | def num_basis(self): 55 | return self.freq_k.size(0) 56 | 57 | def forward(self, angle): 58 | return torch.cos(angle.view(-1, 1) * self.freq_k.view(1, -1)) # (num_angles, num_basis) 59 | 60 | 61 | class ShiftedSoftplus(torch.nn.Module): 62 | def __init__(self): 63 | super(ShiftedSoftplus, self).__init__() 64 | self.shift = torch.log(torch.tensor(2.0)).item() 65 | 66 | def forward(self, x): 67 | return F.softplus(x) - self.shift 68 | 69 | 70 | class CFConv(MessagePassing): 71 | def __init__(self, in_channels, out_channels, num_filters, nn, cutoff, smooth): 72 | super(CFConv, self).__init__(aggr='add') 73 | self.lin1 = Linear(in_channels, num_filters, bias=False) 74 | self.lin2 = Linear(num_filters, out_channels) 75 | self.nn = nn 76 | self.cutoff = cutoff 77 | self.smooth = smooth 78 | 79 | self.reset_parameters() 80 | 81 | def reset_parameters(self): 82 | torch.nn.init.xavier_uniform_(self.lin1.weight) 83 | torch.nn.init.xavier_uniform_(self.lin2.weight) 84 | self.lin2.bias.data.fill_(0) 85 | 86 | def forward(self, x, edge_index, edge_length, edge_attr): 87 | if self.smooth: 88 | C = 0.5 * (torch.cos(edge_length * PI / self.cutoff) + 1.0) 89 | C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff 90 | else: 91 | C = (edge_length <= self.cutoff).float() 92 | W = self.nn(edge_attr) * C.view(-1, 1) 93 | 94 | x = self.lin1(x) 95 | x = self.propagate(edge_index, x=x, W=W) 96 | x = self.lin2(x) 97 | return x 98 | 99 | def message(self, x_j, W): 100 | return x_j * W 101 | 102 | 103 | class InteractionBlock(torch.nn.Module): 104 | def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth): 105 | super(InteractionBlock, self).__init__() 106 | mlp = Sequential( 107 | Linear(num_gaussians, num_filters), 108 | ShiftedSoftplus(), 109 | Linear(num_filters, num_filters), 110 | ) 111 | self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth) 112 | self.act = ShiftedSoftplus() 113 | self.lin = Linear(hidden_channels, hidden_channels) 114 | 115 | def forward(self, x, edge_index, edge_length, edge_attr): 116 | x = self.conv(x, edge_index, edge_length, edge_attr) 117 | x = self.act(x) 118 | x = self.lin(x) 119 | return x 120 | 121 | 122 | class SchNetEncoder(Module): 123 | 124 | def __init__(self, hidden_channels=128, num_filters=128, 125 | num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False): 126 | super().__init__() 127 | 128 | self.hidden_channels = hidden_channels 129 | self.num_filters = num_filters 130 | self.num_interactions = num_interactions 131 | self.cutoff = cutoff 132 | 133 | self.embedding = Embedding(100, hidden_channels, max_norm=10.0) 134 | 135 | self.interactions = ModuleList() 136 | for _ in range(num_interactions): 137 | block = InteractionBlock(hidden_channels, edge_channels, 138 | num_filters, cutoff, smooth) 139 | self.interactions.append(block) 140 | 141 | def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True): 142 | if embed_node: 143 | assert z.dim() == 1 and z.dtype == torch.long 144 | h = self.embedding(z) 145 | else: 146 | h = z 147 | for interaction in self.interactions: 148 | h = h + interaction(h, edge_index, edge_length, edge_attr) 149 | 150 | return h 151 | 152 | -------------------------------------------------------------------------------- /models/epsnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .dualenc import DualEncoderEpsNetwork 2 | 3 | def get_model(config): 4 | if config.network == 'dualenc': 5 | return DualEncoderEpsNetwork(config) 6 | else: 7 | raise NotImplementedError('Unknown network: %s' % config.network) 8 | -------------------------------------------------------------------------------- /models/epsnet/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import MessagePassing 5 | 6 | 7 | from torch_geometric.nn import GATConv 8 | 9 | class GAT(torch.nn.Module): 10 | def __init__(self, in_channels, hidden_channels, out_channels, num_heads): 11 | super(GAT, self).__init__() 12 | 13 | self.conv1 = GATConv(in_channels, hidden_channels //2 // num_heads, heads=num_heads) 14 | # self.conv2 = GATConv(hidden_channels, hidden_channels //2 // num_heads, heads=num_heads) 15 | self.conv3 = GATConv(hidden_channels//2, out_channels, heads=1) 16 | 17 | def forward(self, x, edge_index): 18 | x = F.relu(self.conv1(x, edge_index)) 19 | ## self.dropout 20 | # x = F.relu(self.conv2(x, edge_index)) 21 | x = self.conv3(x, edge_index) 22 | return x 23 | 24 | class Pyg_SelfAttention(MessagePassing): 25 | def __init__(self, in_channels, out_channels, activation): 26 | super(SelfAttention, self).__init__(aggr='add') 27 | 28 | self.activation = activation 29 | self.lin_query = torch.nn.Linear(in_channels, out_channels) 30 | self.lin_key = torch.nn.Linear(in_channels, out_channels) 31 | self.lin_value = torch.nn.Linear(in_channels, out_channels) 32 | 33 | def forward(self, node_presentation, edge_index): 34 | # Calculate query, key and value 35 | query = self.lin_query(node_presentation) 36 | key = self.lin_key(node_presentation) 37 | value = self.lin_value(node_presentation) 38 | 39 | # Calculate attention weights 40 | alpha = self.propagate(edge_index, x=(query, key), size=(node_presentation.size(0), node_presentation.size(0))) 41 | alpha = self.activation(alpha) 42 | 43 | # Apply attention weights to values 44 | attended_values = self.propagate(edge_index, x=(alpha, value), size=(node_presentation.size(0), node_presentation.size(0))) 45 | 46 | return attended_values 47 | 48 | def message(self, x_i, x_j): 49 | # Calculate attention weights 50 | alpha = torch.matmul(x_i, x_j.transpose(0, 1)) 51 | 52 | return alpha 53 | 54 | def update(self, aggr_out): 55 | return aggr_out 56 | 57 | 58 | 59 | class SelfAttention(nn.Module): 60 | def __init__(self, in_dim): 61 | super(SelfAttention, self).__init__() 62 | 63 | self.activation = nn.Softmax() 64 | 65 | self.query = nn.Linear(in_dim, in_dim, bias=False) 66 | self.key = nn.Linear(in_dim, in_dim, bias=False) 67 | self.value = nn.Linear(in_dim, in_dim, bias=False) 68 | 69 | def forward(self, node_presentation): 70 | query = self.query(node_presentation) 71 | key = self.key(node_presentation) 72 | value = self.value(node_presentation) 73 | 74 | # Calculate attention weights 75 | attention_weights = torch.matmul(query, key.transpose(-2, -1)) 76 | attention_weights = self.activation(attention_weights) 77 | 78 | # Apply attention weights to values 79 | attended_values = torch.matmul(attention_weights, value) 80 | node = attended_values.sum(-1) 81 | return node 82 | -------------------------------------------------------------------------------- /models/epsnet/diffusion.py: -------------------------------------------------------------------------------- 1 | import os, time 2 | import math 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def get_timestep_embedding(timesteps, embedding_dim): 8 | """ 9 | This matches the implementation in Denoising Diffusion Probabilistic Models: 10 | From Fairseq. 11 | Build sinusoidal embeddings. 12 | This matches the implementation in tensor2tensor, but differs slightly 13 | from the description in Section 3.5 of "Attention Is All You Need". 14 | """ 15 | assert len(timesteps.shape) == 1 16 | 17 | half_dim = embedding_dim // 2 18 | emb = math.log(10000) / (half_dim - 1) 19 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 20 | emb = emb.to(device=timesteps.device) 21 | emb = timesteps.float()[:, None] * emb[None, :] 22 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 23 | if embedding_dim % 2 == 1: # zero pad 24 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 25 | return emb 26 | 27 | 28 | def nonlinearity(x): 29 | # swish 30 | return x*torch.sigmoid(x) 31 | 32 | 33 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 34 | def sigmoid(x): 35 | return 1 / (np.exp(-x) + 1) 36 | 37 | if beta_schedule == "quad": 38 | betas = ( 39 | np.linspace( 40 | beta_start ** 0.5, 41 | beta_end ** 0.5, 42 | num_diffusion_timesteps, 43 | dtype=np.float64, 44 | ) 45 | ** 2 46 | ) 47 | elif beta_schedule == "linear": 48 | betas = np.linspace( 49 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 50 | ) 51 | elif beta_schedule == "const": 52 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 53 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 54 | betas = 1.0 / np.linspace( 55 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 56 | ) 57 | elif beta_schedule == "sigmoid": 58 | betas = np.linspace(-6, 6, num_diffusion_timesteps) 59 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start 60 | else: 61 | raise NotImplementedError(beta_schedule) 62 | assert betas.shape == (num_diffusion_timesteps,) 63 | return betas -------------------------------------------------------------------------------- /models/epsnet/dualencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_scatter import scatter_add, scatter_mean 4 | from torch_scatter import scatter 5 | from torch_geometric.data import Data, Batch 6 | import numpy as np 7 | from numpy import pi as PI 8 | from tqdm.auto import tqdm 9 | 10 | from utils.chem import BOND_TYPES 11 | from ..common import MultiLayerPerceptron, assemble_atom_pair_feature, generate_symmetric_edge_noise, extend_graph_order_radius 12 | from ..encoder import SchNetEncoder, GINEncoder, get_edge_encoder 13 | from ..geometry import get_distance, get_angle, get_dihedral, eq_transform 14 | from models.epsnet.attention import GAT, SelfAttention 15 | from models.epsnet.diffusion import get_timestep_embedding, get_beta_schedule 16 | import pdb 17 | 18 | 19 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 20 | def sigmoid(x): 21 | return 1 / (np.exp(-x) + 1) 22 | 23 | if beta_schedule == "quad": 24 | betas = ( 25 | np.linspace( 26 | beta_start ** 0.5, 27 | beta_end ** 0.5, 28 | num_diffusion_timesteps, 29 | dtype=np.float64, 30 | ) 31 | ** 2 32 | ) 33 | elif beta_schedule == "linear": 34 | betas = np.linspace( 35 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 36 | ) 37 | elif beta_schedule == "const": 38 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 39 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 40 | betas = 1.0 / np.linspace( 41 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 42 | ) 43 | elif beta_schedule == "sigmoid": 44 | betas = np.linspace(-6, 6, num_diffusion_timesteps) 45 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start 46 | else: 47 | raise NotImplementedError(beta_schedule) 48 | assert betas.shape == (num_diffusion_timesteps,) 49 | return betas 50 | 51 | 52 | class DualEncoderEpsNetwork(nn.Module): 53 | 54 | def __init__(self, config): 55 | super().__init__() 56 | self.config = config 57 | self.num_timesteps=config.num_diffusion_timesteps 58 | """ 59 | edge_encoder: Takes both edge type and edge length as input and outputs a vector 60 | [Note]: node embedding is done in SchNetEncoder 61 | """ 62 | self.model_type = config.type # config.type # 'diffusion'; 'dsm' 63 | self.edge_encoder_global = get_edge_encoder(config) 64 | self.edge_encoder_local = get_edge_encoder(config) 65 | self.is_emb_time = True 66 | if self.is_emb_time: 67 | ''' 68 | timestep embedding 69 | ''' 70 | self.hidden_dim = config.hidden_dim 71 | self.temb = nn.Module() 72 | self.temb.dense = nn.ModuleList([ 73 | torch.nn.Linear(config.hidden_dim, 74 | config.hidden_dim*4), 75 | torch.nn.Linear(config.hidden_dim*4, 76 | config.hidden_dim*4), 77 | ]) 78 | self.temb_proj = torch.nn.Linear(config.hidden_dim*4, 79 | config.hidden_dim) 80 | self.nonlinearity = nn.ReLU() 81 | self.is_emb_time = False 82 | """ 83 | The graph neural network that extracts node-wise features. 84 | """ 85 | self.encoder_global = SchNetEncoder( 86 | hidden_channels=config.hidden_dim, 87 | num_filters=config.hidden_dim, 88 | num_interactions=config.num_convs, 89 | edge_channels=self.edge_encoder_global.out_channels, 90 | cutoff=config.cutoff, 91 | smooth=config.smooth_conv, 92 | ) 93 | self.encoder_local = GINEncoder( 94 | hidden_dim=config.hidden_dim, 95 | num_convs=config.num_convs_local, 96 | ) 97 | self.grad_global_dist_mlp = MultiLayerPerceptron( 98 | 2 * config.hidden_dim, 99 | [config.hidden_dim, config.hidden_dim // 2, 1], 100 | activation=config.mlp_act 101 | ) 102 | 103 | self.grad_local_dist_mlp = MultiLayerPerceptron( 104 | 2 * config.hidden_dim, 105 | [config.hidden_dim, config.hidden_dim // 2, 1], 106 | activation=config.mlp_act 107 | ) 108 | 109 | ''' header for masked vector prediciton''' 110 | if self.model_type in ['selected_diffusion', 'subgraph_diffusion']: 111 | self.edge_encoder_mask = get_edge_encoder(config) 112 | self.encoder_mask = SchNetEncoder( 113 | hidden_channels=config.hidden_dim, 114 | num_filters=config.hidden_dim, 115 | num_interactions=config.num_convs, 116 | edge_channels=self.edge_encoder_global.out_channels, 117 | cutoff=config.cutoff, 118 | smooth=config.smooth_conv, 119 | ) 120 | self.mask_pred = self.config.get("mask_pred", "MLP").upper() 121 | if self.mask_pred.upper()=="GAT": 122 | self.mask_predictor = GAT(in_channels=config.hidden_dim, hidden_channels=config.hidden_dim//2, out_channels=1, num_heads=2) 123 | elif self.mask_pred=='MLP': 124 | self.mask_predictor = MultiLayerPerceptron( 125 | config.hidden_dim, 126 | [config.hidden_dim, config.hidden_dim // 2, 1], 127 | activation=config.mlp_act 128 | ) 129 | elif self.mask_pred=='2BMLP': 130 | self.mask_predictor = MultiLayerPerceptron( 131 | config.hidden_dim*2, 132 | [config.hidden_dim, config.hidden_dim // 2, 1], 133 | activation=config.mlp_act 134 | ) 135 | else: 136 | raise 137 | self.CEloss = nn.CrossEntropyLoss() 138 | self.BCEloss = nn.BCEWithLogitsLoss(reduction='none') 139 | else: 140 | self.register_parameter("mask_predictor", None) 141 | 142 | self.model_mask = nn.ModuleList([self.mask_predictor,self.temb,self.temb_proj]) 143 | 144 | self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp]) 145 | ''' 146 | Incorporate parameters together 147 | ''' 148 | 149 | self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp]) 150 | 151 | 152 | 153 | if self.model_type in ['selected_diffusion','subgraph_diffusion','diffusion']: 154 | # denoising diffusion 155 | ## betas 156 | betas = get_beta_schedule( 157 | beta_schedule=config.beta_schedule, 158 | beta_start=config.beta_start, 159 | beta_end=config.beta_end, 160 | num_diffusion_timesteps=config.num_diffusion_timesteps, 161 | ) 162 | betas = torch.from_numpy(betas).float() 163 | self.betas = nn.Parameter(betas, requires_grad=False) 164 | ## variances 165 | alphas = (1. - betas).cumprod(dim=0) 166 | self.alphas = nn.Parameter(alphas, requires_grad=False) 167 | self.num_timesteps = self.betas.size(0) 168 | elif self.model_type == 'dsm': 169 | # denoising score matching 170 | sigmas = torch.tensor( 171 | np.exp(np.linspace(np.log(config.sigma_begin), np.log(config.sigma_end), 172 | config.num_noise_level)), dtype=torch.float32) 173 | self.sigmas = nn.Parameter(sigmas, requires_grad=False) # (num_noise_level) 174 | self.num_timesteps = self.sigmas.size(0) # betas.shape[0] 175 | 176 | 177 | 178 | def forward(self, atom_type, pos, bond_index, bond_type, batch, time_step, 179 | edge_index=None, edge_type=None, edge_length=None, return_edges=False, 180 | extend_order=True, extend_radius=True, is_sidechain=None): 181 | """ 182 | Args: 183 | atom_type: Types of atoms, (N, ). 184 | bond_index: Indices of bonds (not extended, not radius-graph), (2, E). 185 | bond_type: Bond types, (E, ). 186 | batch: Node index to graph index, (N, ). 187 | """ 188 | N = atom_type.size(0) 189 | if edge_index is None or edge_type is None or edge_length is None: 190 | edge_index, edge_type = extend_graph_order_radius( 191 | num_nodes=N, 192 | pos=pos, 193 | edge_index=bond_index, 194 | edge_type=bond_type, 195 | batch=batch, 196 | order=self.config.edge_order, 197 | cutoff=self.config.cutoff, 198 | extend_order=extend_order, 199 | extend_radius=extend_radius, 200 | is_sidechain=is_sidechain, 201 | ) 202 | edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1) 203 | local_edge_mask = is_local_edge(edge_type) # (E, ) 204 | 205 | # Emb time_step 206 | if self.model_type in ['selected_diffusion','diffusion',"subgraph_diffusion"]: 207 | # # timestep embedding 208 | 209 | if self.is_emb_time: 210 | temb = get_timestep_embedding(time_step, self.hidden_dim) 211 | temb = self.temb.dense[0](temb) 212 | # temb = self.nonlinearity(temb) 213 | # temb = self.temb.dense[1](temb) 214 | temb = self.temb_proj(self.nonlinearity(temb)) # (G, dim) 215 | 216 | sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1) 217 | 218 | # Encoding global 219 | edge_attr_global = self.edge_encoder_global( 220 | edge_length=edge_length, 221 | edge_type=edge_type 222 | ) 223 | # Embed edges 224 | # edge_attr += temb_edge 225 | 226 | # Global 227 | node_attr_global = self.encoder_global( 228 | z=atom_type, 229 | edge_index=edge_index, 230 | edge_length=edge_length, 231 | edge_attr=edge_attr_global, 232 | ) 233 | if self.is_emb_time: node_attr_global = node_attr_global + 0.1*temb[batch] 234 | ## Assemble pairwise features 235 | # (h_i,h_j,e_ij) 236 | h_pair_global = assemble_atom_pair_feature( 237 | node_attr=node_attr_global, 238 | edge_index=edge_index, 239 | edge_attr=edge_attr_global, 240 | ) # (E_global, 2H) 241 | ## Invariant features of edges (radius graph, global) 242 | edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1) 243 | 244 | # Encoding local 245 | edge_attr_local = self.edge_encoder_global( 246 | # edge_attr_local = self.edge_encoder_local( 247 | edge_length=edge_length, 248 | edge_type=edge_type 249 | ) # Embed edges 250 | # edge_attr += temb_edge 251 | 252 | # Local 253 | node_attr_local = self.encoder_local( 254 | z=atom_type, 255 | edge_index=edge_index[:, local_edge_mask], 256 | edge_attr=edge_attr_local[local_edge_mask], 257 | ) 258 | if self.is_emb_time: node_attr_local = node_attr_local + 0.1*temb[batch] 259 | ## Assemble pairwise features 260 | h_pair_local = assemble_atom_pair_feature( 261 | node_attr=node_attr_local, 262 | edge_index=edge_index[:, local_edge_mask], 263 | edge_attr=edge_attr_local[local_edge_mask], 264 | ) # (E_local, 2H) 265 | ## Invariant features of edges (bond graph, local) 266 | if isinstance(sigma_edge, torch.Tensor): 267 | edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge[local_edge_mask]) # (E_local, 1) 268 | else: 269 | edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1) 270 | 271 | 272 | 273 | self.node_attr_global = node_attr_global 274 | self.node_attr_local= node_attr_local 275 | self.batch = batch 276 | 277 | if return_edges: 278 | if self.model_type in ['selected_diffusion', "subgraph_diffusion"]: 279 | if self.mask_pred=="MLP": 280 | mask_emb=node_attr_global 281 | # mask_emb=torch.cat([node_attr_local,node_attr_global],dim=-1) 282 | node_mask_pred=self.mask_predictor(mask_emb) 283 | elif self.mask_pred=='2BMLP': 284 | mask_emb=torch.cat([node_attr_local,node_attr_global],dim=-1) 285 | node_mask_pred=self.mask_predictor(mask_emb) 286 | elif self.mask_pred.upper()=="GAT": 287 | mask_emb=node_attr_global 288 | node_mask_pred = self.mask_predictor(mask_emb,edge_index) 289 | return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask, node_mask_pred 290 | return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask 291 | else: 292 | return edge_inv_global, edge_inv_local 293 | 294 | def get_node_embeddings(self): 295 | self.node_embs = torch.cat([self.node_attr_local,self.node_attr_global],dim=-1) 296 | 297 | return self.node_embs 298 | def get_global_node_embeddings(self): 299 | 300 | 301 | return self.node_attr_global 302 | 303 | def get_graph_embeddings(self): 304 | # node_embs = torch.cat([self.node_attr_local,self.node_attr_global],dim=-1) 305 | 306 | # self.node_embs 307 | self.node_embs = torch.cat([self.node_attr_local,self.node_attr_global],dim=-1) 308 | self.graph_embs = scatter(self.node_embs, self.batch, dim=0, reduce='sum') # 【"add", "sum", "mean"] 309 | return self.graph_embs 310 | def get_graph_embeddings(self): 311 | self.node_embs = torch.cat([self.node_attr_local,self.node_attr_global],dim=-1) 312 | self.graph_embs = scatter(self.node_embs, self.batch, dim=0, reduce='sum') # 【"add", "sum", "mean"] 313 | return self.graph_embs 314 | 315 | import torch.nn.functional as F 316 | class ShiftedSoftplus(torch.nn.Module): 317 | def __init__(self): 318 | super(ShiftedSoftplus, self).__init__() 319 | self.shift = torch.log(torch.tensor(2.0)).item() 320 | 321 | def forward(self, x): 322 | return F.softplus(x) - self.shift 323 | 324 | class GraphPooling(torch.nn.Module): 325 | def __init__(self, hidden_channels, out_dim, readout="mean"): 326 | super(GraphPooling, self).__init__() 327 | self.lin1 = nn.Linear(hidden_channels, hidden_channels) 328 | self.act = ShiftedSoftplus() 329 | self.lin2 = nn.Linear(hidden_channels, hidden_channels) 330 | self.readout = readout 331 | self.output_layer = nn.Linear(hidden_channels, out_dim) 332 | 333 | def forward(self, x, batch): 334 | h = self.lin1(x) 335 | h = self.act(h) 336 | h = self.lin2(h) 337 | graph_embs = scatter(h, batch, dim=0, reduce=self.readout) # 【"add", "sum", "mean"] 338 | graph_embs = self.output_layer(graph_embs) 339 | return graph_embs 340 | 341 | def is_bond(edge_type): 342 | return torch.logical_and(edge_type < len(BOND_TYPES), edge_type > 0) 343 | 344 | 345 | def is_angle_edge(edge_type): 346 | return edge_type == len(BOND_TYPES) + 1 - 1 347 | 348 | 349 | def is_dihedral_edge(edge_type): 350 | return edge_type == len(BOND_TYPES) + 2 - 1 351 | 352 | 353 | def is_radius_edge(edge_type): 354 | return edge_type == 0 355 | 356 | 357 | def is_local_edge(edge_type): 358 | return edge_type > 0 359 | 360 | 361 | def is_train_edge(edge_index, is_sidechain): 362 | if is_sidechain is None: 363 | return torch.ones(edge_index.size(1), device=edge_index.device).bool() 364 | else: 365 | is_sidechain = is_sidechain.bool() 366 | return torch.logical_or(is_sidechain[edge_index[0]], is_sidechain[edge_index[1]]) 367 | 368 | 369 | def regularize_bond_length(edge_type, edge_length, rng=5.0): 370 | mask = is_bond(edge_type).float().reshape(-1, 1) 371 | d = -torch.clamp(edge_length - rng, min=0.0, max=float('inf')) * mask 372 | return d 373 | 374 | 375 | def center_pos(pos, batch): 376 | pos_center = pos - scatter_mean(pos, batch, dim=0)[batch] 377 | return pos_center 378 | 379 | 380 | def clip_norm(vec, limit, p=2): 381 | norm = torch.norm(vec, dim=-1, p=2, keepdim=True) 382 | denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm)) 383 | return vec * denom 384 | -------------------------------------------------------------------------------- /models/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter_add 3 | 4 | 5 | def get_distance(pos, edge_index): 6 | return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1) 7 | 8 | 9 | def eq_transform(score_d, pos, edge_index, edge_length): 10 | N = pos.size(0) 11 | dd_dr = (1. / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3) # (r_i-r_j)/d_{ij} 12 | # https://pytorch-scatter.readthedocs.io/en/1.3.0/functions/add.html 13 | score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) \ 14 | + scatter_add(- dd_dr * score_d, edge_index[1], dim=0, dim_size=N) # (N, 3) (-dd_dr)= p_j-p_i; d_ij = -d_ji 15 | return score_pos 16 | 17 | 18 | def convert_cluster_score_d(cluster_score_d, cluster_pos, cluster_edge_index, cluster_edge_length, subgraph_index): 19 | """ 20 | Args: 21 | cluster_score_d: (E_c, 1) 22 | subgraph_index: (N, ) 23 | """ 24 | cluster_score_pos = eq_transform(cluster_score_d, cluster_pos, cluster_edge_index, cluster_edge_length) # (C, 3) 25 | score_pos = cluster_score_pos[subgraph_index] 26 | return score_pos 27 | 28 | 29 | def get_angle(pos, angle_index): 30 | """ 31 | Args: 32 | pos: (N, 3) 33 | angle_index: (3, A), left-center-right. 34 | """ 35 | n1, ctr, n2 = angle_index # (A, ) 36 | v1 = pos[n1] - pos[ctr] # (A, 3) 37 | v2 = pos[n2] - pos[ctr] 38 | inner_prod = torch.sum(v1 * v2, dim=-1, keepdim=True) # (A, 1) 39 | length_prod = torch.norm(v1, dim=-1, keepdim=True) * torch.norm(v2, dim=-1, keepdim=True) # (A, 1) 40 | angle = torch.acos(inner_prod / length_prod) # (A, 1) 41 | return angle 42 | 43 | 44 | def get_dihedral(pos, dihedral_index): 45 | """ 46 | Args: 47 | pos: (N, 3) 48 | dihedral: (4, A) 49 | """ 50 | n1, ctr1, ctr2, n2 = dihedral_index # (A, ) 51 | v_ctr = pos[ctr2] - pos[ctr1] # (A, 3) 52 | v1 = pos[n1] - pos[ctr1] 53 | v2 = pos[n2] - pos[ctr2] 54 | n1 = torch.cross(v_ctr, v1, dim=-1) # Normal vectors of the two planes 55 | n2 = torch.cross(v_ctr, v2, dim=-1) 56 | inner_prod = torch.sum(n1 * n2, dim=1, keepdim=True) # (A, 1) 57 | length_prod = torch.norm(n1, dim=-1, keepdim=True) * torch.norm(n2, dim=-1, keepdim=True) 58 | dihedral = torch.acos(inner_prod / length_prod) 59 | return dihedral 60 | 61 | 62 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 3 | import argparse 4 | import pickle 5 | import yaml 6 | import torch 7 | from glob import glob 8 | from tqdm.auto import tqdm 9 | from easydict import EasyDict 10 | 11 | from models.epsnet import * 12 | from utils.datasets import * 13 | from utils.transforms import * 14 | from utils.misc import * 15 | 16 | 17 | def num_confs(num:str): 18 | if num.endswith('x'): 19 | return lambda x:x*int(num[:-1]) 20 | elif int(num) > 0: 21 | return lambda x:int(num) 22 | else: 23 | raise ValueError() 24 | 25 | 26 | if __name__ == '__main__': 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--ckpt', type=str, help='path for loading the checkpoint') 29 | parser.add_argument('--config', type=str, default=None) 30 | parser.add_argument('--save_traj', action='store_true', default=False, 31 | help='whether store the whole trajectory for sampling') 32 | parser.add_argument('--not_use_mask', action='store_true', default=False, 33 | help='whether use mask in sampling') 34 | parser.add_argument('--resume', type=str, default=None) 35 | parser.add_argument('--tag', type=str, default=None) 36 | parser.add_argument('--num_confs', type=num_confs, default=num_confs('2x')) 37 | parser.add_argument('--test_set', type=str, default=None) 38 | parser.add_argument('--start_idx', type=int, default=0) 39 | parser.add_argument('--end_idx', type=int, default=200) 40 | parser.add_argument('--out_dir', type=str, default=None) 41 | parser.add_argument('--device', type=str, default='cuda:0') 42 | parser.add_argument('--clip', type=float, default=1000.0) 43 | parser.add_argument('--clip_local', type=float, default=20) 44 | parser.add_argument('--n_steps', type=int, default=5000, 45 | help='sampling num steps; for DSM framework, this means num steps for each noise scale') 46 | parser.add_argument('--global_start_sigma', type=float, default=0.5, 47 | help='enable global gradients only when noise is low') 48 | parser.add_argument('--w_global', type=float, default=0.3, 49 | help='weight for global gradients') 50 | # Parameters for DDPM 51 | parser.add_argument('--sampling_type', type=str, default='ddpm_noisy', 52 | help='generalized, ddpm_noisy, ld: sampling method for DDIM, DDPM or Langevin Dynamics') 53 | parser.add_argument('--eta', type=float, default=1.0, 54 | help='weight for DDIM and DDPM: 0->DDIM, 1->DDPM') 55 | args = parser.parse_args() 56 | 57 | # Load checkpoint 58 | ckpt = torch.load(args.ckpt, map_location=torch.device('cpu')) 59 | if args.config is None: 60 | config_path = glob(os.path.join(os.path.dirname(os.path.dirname(args.ckpt)), '*.yml'))[0] 61 | else: 62 | config_path =args.config 63 | with open(config_path, 'r') as f: 64 | config = EasyDict(yaml.safe_load(f)) 65 | seed_all(config.train.seed) 66 | log_dir = os.path.dirname(os.path.dirname(args.ckpt)) 67 | 68 | # Logging 69 | if args.tag is None: 70 | args.tag=os.path.basename(os.path.dirname(args.test_set)) 71 | output_dir = get_new_log_dir(log_dir, 'sample', tag=args.tag) 72 | logger = get_logger('test', output_dir) 73 | # Datasets and loaders 74 | logger.info(args) 75 | logger.info('Loading datasets...') 76 | transforms = Compose([ 77 | CountNodesPerGraph(), 78 | AddHigherOrderEdges(order=config.model.edge_order), # Offline edge augmentation 79 | ]) 80 | if args.test_set is None: 81 | test_set = PackedConformationDataset(config.dataset.test, transform=transforms) 82 | logger.info(f'Loading {config.dataset.test}') 83 | else: 84 | test_set = PackedConformationDataset(args.test_set, transform=transforms) 85 | logger.info(f'Loading {args.test_set}') 86 | # Model 87 | logger.info('Loading model...') 88 | model = get_model(ckpt['config'].model).to(args.device) 89 | model.load_state_dict(ckpt['model']) 90 | 91 | test_set_selected = [] 92 | for i, data in enumerate(test_set): 93 | if not (args.start_idx <= i < args.end_idx): continue 94 | if args.tag=="qm92drugs": 95 | atom_set = set(atom.GetSymbol() for atom in data.rdmol.GetAtoms()) 96 | qm9_set={'H','C', 'N', 'O', 'F'} 97 | if atom_set <= qm9_set: 98 | test_set_selected.append(data) 99 | else: 100 | test_set_selected.append(data) 101 | 102 | 103 | done_smiles = set() 104 | results = [] 105 | if args.resume is not None: 106 | with open(args.resume, 'rb') as f: 107 | results = pickle.load(f) 108 | for data in results: 109 | done_smiles.add(data.smiles) 110 | 111 | for i, data in enumerate(tqdm(test_set_selected)): 112 | if data.smiles in done_smiles: 113 | logger.info('Molecule#%d is already done.' % i) 114 | continue 115 | 116 | num_refs = data.pos_ref.size(0) // data.num_nodes 117 | num_samples = args.num_confs(num_refs) 118 | 119 | data_input = data.clone() 120 | data_input['pos_ref'] = None 121 | batch = repeat_data(data_input, num_samples).to(args.device) 122 | 123 | clip_local = None 124 | for try_n in range(3): # Maximum number of retry 125 | try: 126 | pos_init = torch.randn(batch.num_nodes, 3).to(args.device) 127 | pos_gen, pos_gen_traj = model.langevin_dynamics_sample_diffusion_subgraph( 128 | atom_type=batch.atom_type, 129 | pos_init=pos_init, 130 | bond_index=batch.edge_index, 131 | bond_type=batch.edge_type, 132 | batch=batch.batch, 133 | num_graphs=batch.num_graphs, 134 | extend_order=False, # Done in transforms. 135 | n_steps=args.n_steps, 136 | step_lr=1e-6, 137 | w_global=args.w_global, 138 | global_start_sigma=args.global_start_sigma, 139 | clip=args.clip, 140 | clip_local=clip_local, 141 | sampling_type=args.sampling_type, 142 | eta=args.eta, 143 | use_mask=not args.not_use_mask 144 | ) 145 | pos_gen = pos_gen.cpu() 146 | if args.save_traj: 147 | data.pos_gen = torch.stack(pos_gen_traj) 148 | else: 149 | data.pos_gen = pos_gen 150 | results.append(data) 151 | done_smiles.add(data.smiles) 152 | 153 | save_path = os.path.join(output_dir, 'samples_%d.pkl' % i) 154 | logger.info('Saving samples to: %s' % save_path) 155 | with open(save_path, 'wb') as f: 156 | pickle.dump(results, f) 157 | 158 | break # No errors occured, break the retry loop 159 | except FloatingPointError: 160 | # clip_local = 100 161 | if try_n==1: 162 | clip_local = args.clip_local 163 | logger.warning(f'Retrying with local clipping. clip_local={clip_local}') 164 | if try_n==2: 165 | clip_local = args.clip_local//2 166 | logger.warning(f'Retrying with local clipping. clip_local={clip_local}') 167 | seed_all(config.train.seed + try_n) 168 | pass 169 | 170 | save_path = os.path.join(output_dir, 'samples_all.pkl') 171 | logger.info('Saving samples to: %s' % save_path) 172 | logger.info(output_dir) 173 | def get_mol_key(data): 174 | for i, d in enumerate(test_set_selected): 175 | if d.smiles == data.smiles: 176 | return i 177 | return -1 178 | results.sort(key=get_mol_key) 179 | 180 | with open(save_path, 'wb') as f: 181 | pickle.dump(results, f) 182 | 183 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python test.py --ckpt checkpoints/qm9_500steps/2000000.pt --config checkpoints/qm9_500steps/qm9_500steps.yml --test_set './data/GEOM/QM9/test_data_1k.pkl' --start_idx 800 --end_idx 1000 --sampling_type same_mask_noisy --n_steps 500 --device cuda:1 --w_global 0.1 --clip 1000 --clip_local 20 --global_start_sigma 5 --tag SubGDiff500 2 | 3 | python test.py --ckpt checkpoints/qm9_200steps/2000000.pt --config checkpoints/qm9_200steps/qm9_200steps.yml --test_set './data/GEOM/QM9/test_data_1k.pkl' --start_idx 800 --end_idx 1000 --sampling_type same_mask_noisy --n_steps 200 --device cuda:0 --w_global 0.1 --clip 1000 --clip_local 20 --global_start_sigma 5 --tag SubGDiff200 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import yaml 5 | from easydict import EasyDict 6 | from tqdm.auto import tqdm 7 | from glob import glob 8 | import torch 9 | import torch.utils.tensorboard 10 | from torch.nn.utils import clip_grad_norm_ 11 | from torch_geometric.data import DataLoader 12 | 13 | from models.epsnet import get_model 14 | from utils.datasets import ConformationDataset 15 | from utils.transforms import * 16 | from utils.misc import * 17 | from utils.common import get_optimizer, get_scheduler 18 | 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--config', type=str, default='./configs/qm9_500steps.yml') 23 | parser.add_argument('--device', type=str, default='cuda:4') 24 | parser.add_argument('--resume_iter', type=int, default=None) 25 | parser.add_argument('--logdir', type=str, default='./logs') 26 | # parser.add_argument('--logdir', type=str, default='./logs') 27 | args = parser.parse_args() 28 | 29 | 30 | resume = os.path.isdir(args.config) 31 | if resume: 32 | config_path = glob(os.path.join(args.config, '*.yml'))[0] 33 | resume_from = args.config 34 | else: 35 | config_path = args.config 36 | 37 | with open(config_path, 'r') as f: 38 | config = EasyDict(yaml.safe_load(f)) 39 | config_name = os.path.basename(config_path)[:os.path.basename(config_path).rfind('.')] 40 | seed_all(config.train.seed) 41 | 42 | 43 | # Logging 44 | if resume: 45 | log_dir = get_new_log_dir(args.logdir, prefix=config_name, tag='resume') 46 | os.symlink(os.path.realpath(resume_from), os.path.join(log_dir, os.path.basename(resume_from.rstrip("/")))) 47 | else: 48 | log_dir = get_new_log_dir(args.logdir, prefix=config_name) 49 | shutil.copytree('./models', os.path.join(log_dir, 'models')) 50 | ckpt_dir = os.path.join(log_dir, 'checkpoints') 51 | os.makedirs(ckpt_dir, exist_ok=True) 52 | logger = get_logger('train', log_dir) 53 | writer = torch.utils.tensorboard.SummaryWriter(log_dir) 54 | logger.info(args) 55 | logger.info(config) 56 | shutil.copyfile(config_path, os.path.join(log_dir, os.path.basename(config_path))) 57 | 58 | # Datasets and loaders 59 | logger.info('Loading datasets...') 60 | noise_transforms=None 61 | if config.model.type=='subgraph_diffusion': 62 | from utils.transforms import SubgraphNoiseTransform 63 | noise_transforms= SubgraphNoiseTransform(config.model) 64 | 65 | transforms = CountNodesPerGraph() 66 | train_set = ConformationDataset(config.dataset.train, transform=transforms,noise_transform=noise_transforms,config=config.model) 67 | val_set = ConformationDataset(config.dataset.val, transform=transforms, noise_transform=noise_transforms,config=config.model) 68 | train_iterator = inf_iterator(DataLoader(train_set, config.train.batch_size, shuffle=True)) 69 | val_loader = DataLoader(val_set, config.train.batch_size, shuffle=False) 70 | 71 | # Model 72 | logger.info('Building model...') 73 | model = get_model(config.model).to(args.device) 74 | # Optimizer 75 | optimizer_global = get_optimizer(config.train.optimizer, model.model_global) 76 | optimizer_local = get_optimizer(config.train.optimizer, model.model_local) 77 | scheduler_global = get_scheduler(config.train.scheduler, optimizer_global) 78 | scheduler_local = get_scheduler(config.train.scheduler, optimizer_local) 79 | start_iter = 1 80 | 81 | # Resume from checkpoint 82 | if resume: 83 | ckpt_path, start_iter = get_checkpoint_path(os.path.join(resume_from, 'checkpoints'), it=args.resume_iter) 84 | logger.info('Resuming from: %s' % ckpt_path) 85 | logger.info('Iteration: %d' % start_iter) 86 | ckpt = torch.load(ckpt_path) 87 | model.load_state_dict(ckpt['model']) 88 | optimizer_global.load_state_dict(ckpt['optimizer_global']) 89 | optimizer_local.load_state_dict(ckpt['optimizer_local']) 90 | scheduler_global.load_state_dict(ckpt['scheduler_global']) 91 | scheduler_local.load_state_dict(ckpt['scheduler_local']) 92 | 93 | def train(it): 94 | model.train() 95 | optimizer_global.zero_grad() 96 | optimizer_local.zero_grad() 97 | batch = next(train_iterator).to(args.device) 98 | loss, loss_global, loss_local = model.get_loss( 99 | data=batch, 100 | atom_type=batch.atom_type, 101 | pos=batch.pos, 102 | bond_index=batch.edge_index, 103 | bond_type=batch.edge_type, 104 | batch=batch.batch, 105 | num_nodes_per_graph=batch.num_nodes_per_graph, 106 | num_graphs=batch.num_graphs, 107 | anneal_power=config.train.anneal_power, 108 | return_unreduced_loss=True 109 | ) 110 | loss = loss.mean() 111 | loss.backward() 112 | orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm) 113 | optimizer_global.step() 114 | optimizer_local.step() 115 | 116 | logger.info('[Train] Iter %05d | Loss %.2f | Loss(Global) %.2f | Loss(Local) %.2f | Grad %.2f | LR(Global) %.6f | LR(Local) %.6f |%s' % ( 117 | it, loss.item(), loss_global.mean().item(), loss_local.mean().item(), orig_grad_norm, optimizer_global.param_groups[0]['lr'], optimizer_local.param_groups[0]['lr'], log_dir 118 | )) 119 | writer.add_scalar('train/loss', loss, it) 120 | writer.add_scalar('train/loss_global', loss_global.mean(), it) 121 | writer.add_scalar('train/loss_local', loss_local.mean(), it) 122 | writer.add_scalar('train/lr_global', optimizer_global.param_groups[0]['lr'], it) 123 | writer.add_scalar('train/lr_local', optimizer_local.param_groups[0]['lr'], it) 124 | writer.add_scalar('train/grad_norm', orig_grad_norm, it) 125 | writer.flush() 126 | 127 | def validate(it): 128 | sum_loss, sum_n = 0, 0 129 | sum_loss_global, sum_n_global = 0, 0 130 | sum_loss_local, sum_n_local = 0, 0 131 | with torch.no_grad(): 132 | model.eval() 133 | for i, batch in enumerate(tqdm(val_loader, desc='Validation')): 134 | batch = batch.to(args.device) 135 | loss, loss_global, loss_local = model.get_loss( 136 | data=batch, 137 | atom_type=batch.atom_type, 138 | pos=batch.pos, 139 | bond_index=batch.edge_index, 140 | bond_type=batch.edge_type, 141 | batch=batch.batch, 142 | num_nodes_per_graph=batch.num_nodes_per_graph, 143 | num_graphs=batch.num_graphs, 144 | anneal_power=config.train.anneal_power, 145 | return_unreduced_loss=True 146 | ) 147 | sum_loss += loss.sum().item() 148 | sum_n += loss.size(0) 149 | sum_loss_global += loss_global.sum().item() 150 | sum_n_global += loss_global.size(0) 151 | sum_loss_local += loss_local.sum().item() 152 | sum_n_local += loss_local.size(0) 153 | avg_loss = sum_loss / sum_n 154 | avg_loss_global = sum_loss_global / sum_n_global 155 | avg_loss_local = sum_loss_local / sum_n_local 156 | 157 | if config.train.scheduler.type == 'plateau': 158 | scheduler_global.step(avg_loss_global) 159 | scheduler_local.step(avg_loss_local) 160 | else: 161 | scheduler_global.step() 162 | scheduler_local.step() 163 | 164 | logger.info('[Validate] Iter %05d | Loss %.6f | Loss(Global) %.6f | Loss(Local) %.6f' % ( 165 | it, avg_loss, avg_loss_global, avg_loss_local, 166 | )) 167 | writer.add_scalar('val/loss', avg_loss, it) 168 | writer.add_scalar('val/loss_global', avg_loss_global, it) 169 | writer.add_scalar('val/loss_local', avg_loss_local, it) 170 | writer.flush() 171 | return avg_loss 172 | 173 | try: 174 | for it in range(start_iter, config.train.max_iters + 1): 175 | train(it) 176 | # TODO if avg_val_loss < : save 177 | if it % config.train.val_freq == 0 or it == config.train.max_iters: 178 | avg_val_loss = validate(it) 179 | ckpt_path = os.path.join(ckpt_dir, '%d.pt' % it) 180 | torch.save({ 181 | 'config': config, 182 | 'model': model.state_dict(), 183 | 'optimizer_global': optimizer_global.state_dict(), 184 | 'scheduler_global': scheduler_global.state_dict(), 185 | 'optimizer_local': optimizer_local.state_dict(), 186 | 'scheduler_local': scheduler_local.state_dict(), 187 | 'iteration': it, 188 | 'avg_val_loss': avg_val_loss, 189 | }, ckpt_path) 190 | except KeyboardInterrupt: 191 | logger.info('Terminating...') 192 | 193 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python train_dist.py --config ./configs/qm9_200steps.yml --device cuda:0 --logdir ./logs --tag qm9_500steps --n_jobs 12 --print_freq 200 2 | python train_dist.py --config ./configs/qm9_500steps.yml --device cuda:1 --logdir ./logs --tag qm9_200steps --n_jobs 12 --print_freq 200 3 | -------------------------------------------------------------------------------- /train_dist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import yaml 5 | from easydict import EasyDict 6 | from tqdm.auto import tqdm 7 | from glob import glob 8 | import torch 9 | import torch.utils.tensorboard 10 | from torch.nn.utils import clip_grad_norm_ 11 | from torch_geometric.data import DataLoader 12 | 13 | 14 | 15 | 16 | from models.epsnet import get_model 17 | from utils.datasets import ConformationDataset 18 | from utils.transforms import * 19 | from utils.misc import * 20 | from utils.common import get_optimizer, get_scheduler 21 | 22 | 23 | import torch.distributed as dist 24 | def setup_dist(args, port=None, backend="nccl", verbose=False): 25 | # TODO 26 | return rank, local_rank, world_size, device 27 | 28 | def reduce_mean(tensor, nprocs): 29 | rt = tensor.clone() 30 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 31 | rt = rt/ nprocs 32 | return rt 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--config', type=str, default='./configs/qm9_500steps.yml') 37 | parser.add_argument('--device', type=str, default='cuda:4') 38 | 39 | parser.add_argument('--resume_iter', type=int, default=None) 40 | parser.add_argument('--logdir', type=str, default='./logs') 41 | parser.add_argument('--distribution', action='store_true', default=False, 42 | help='enable ddp running') 43 | parser.add_argument('--tag', type=str, default='', help="just for marking the experiments infomation") 44 | parser.add_argument('--n_jobs', type=int, default=2, help="Dataloader cpu ") 45 | parser.add_argument('--print_freq', type=int, default=50, help="") 46 | args = parser.parse_args() 47 | 48 | args.distribution=False # torch.dist 49 | 50 | resume = os.path.isdir(args.config) 51 | if resume: 52 | config_path = glob(os.path.join(args.config, '*.yml'))[0] 53 | resume_from = args.config 54 | else: 55 | config_path = args.config 56 | 57 | with open(config_path, 'r') as f: 58 | config = EasyDict(yaml.safe_load(f)) 59 | config_name = os.path.basename(config_path)[:os.path.basename(config_path).rfind('.')] 60 | seed_all(config.train.seed) 61 | 62 | args.local_rank = int(args.device.split(":")[-1]) 63 | if 0 and args.distribution: 64 | rank, local_rank, world_size, device = setup_dist(args, verbose=True) 65 | args.device=device 66 | args.local_rank=local_rank 67 | setattr(config, 'local_rank', local_rank) 68 | setattr(config, 'world_size', world_size) 69 | setattr(config, 'tag', args.tag) 70 | 71 | 72 | master_worker = (rank == 0) if args.distribution else True 73 | args.nprocs = torch.cuda.device_count() 74 | 75 | 76 | if master_worker: 77 | # Logging 78 | if resume: 79 | log_dir = get_new_log_dir(args.logdir, prefix=config_name+args.tag, tag='resume') 80 | os.symlink(os.path.realpath(resume_from), os.path.join(log_dir, os.path.basename(resume_from.rstrip("/")))) 81 | else: 82 | log_dir = get_new_log_dir(args.logdir, prefix=config_name+args.tag) 83 | shutil.copytree('./models', os.path.join(log_dir, 'models')) 84 | shutil.copytree('./utils', os.path.join(log_dir, 'utils')) 85 | ckpt_dir = os.path.join(log_dir, 'checkpoints') 86 | os.makedirs(ckpt_dir, exist_ok=True) 87 | logger = get_logger('train', log_dir) 88 | writer = torch.utils.tensorboard.SummaryWriter(log_dir) 89 | logger.info(args) 90 | logger.info(config) 91 | shutil.copyfile(config_path, os.path.join(log_dir, os.path.basename(config_path))) 92 | 93 | # Datasets and loaders 94 | if master_worker: logger.info('Loading datasets...') 95 | noise_transforms=None 96 | if config.model.type=='subgraph_diffusion': 97 | from utils.transforms import SubgraphNoiseTransform 98 | noise_transforms= SubgraphNoiseTransform(config.model, tag=args.tag) 99 | # noise_transform_ddpm= SubgraphNoiseTransform(config.model, tag=args.tag, ddpm=False) 100 | 101 | transforms = CountNodesPerGraph() 102 | train_set = ConformationDataset(config.dataset.train, transform=transforms, noise_transform=noise_transforms,config=config.model) 103 | val_set = ConformationDataset(config.dataset.val, transform=transforms, noise_transform=noise_transforms,config=config.model) 104 | train_iterator = inf_iterator(DataLoader(train_set, config.train.batch_size, num_workers=args.n_jobs, shuffle=True)) 105 | val_loader = DataLoader(val_set, config.train.batch_size, num_workers=args.n_jobs, shuffle=False) 106 | 107 | 108 | if args.distribution: 109 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) 110 | train_iterator = inf_iterator(DataLoader(train_set, config.train.batch_size, num_workers=args.n_jobs, shuffle=False,sampler=train_sampler)) 111 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_set) 112 | val_loader = DataLoader(val_set, config.train.batch_size, shuffle=False, num_workers=args.n_jobs, sampler=val_sampler) 113 | 114 | # Model 115 | if master_worker: logger.info('Building model...') 116 | model = get_model(config.model).to(args.device) 117 | # model = get_model(config.model).cuda() 118 | 119 | # Optimizer 120 | optimizer_global = get_optimizer(config.train.optimizer, model.model_global) # module. 121 | optimizer_local = get_optimizer(config.train.optimizer, model.model_local) 122 | optimizer_mask = get_optimizer(config.train.optimizer, model.model_mask) 123 | scheduler_global = get_scheduler(config.train.scheduler, optimizer_global) 124 | scheduler_local = get_scheduler(config.train.scheduler, optimizer_local) 125 | scheduler_mask = get_scheduler(config.train.scheduler, optimizer_mask) 126 | start_iter = 1 127 | 128 | 129 | if args.distribution: 130 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) 131 | 132 | # Resume from checkpoint 133 | if resume: 134 | ckpt_path, start_iter = get_checkpoint_path(os.path.join(resume_from, 'checkpoints'), it=args.resume_iter) 135 | logger.info('Resuming from: %s' % ckpt_path) 136 | logger.info('Iteration: %d' % start_iter) 137 | ckpt = torch.load(ckpt_path) 138 | model.load_state_dict(ckpt['model']) 139 | optimizer_global.load_state_dict(ckpt['optimizer_global']) 140 | optimizer_local.load_state_dict(ckpt['optimizer_local']) 141 | try: optimizer_mask.load_state_dict(ckpt['optimizer_mask']) 142 | except: pass 143 | scheduler_global.load_state_dict(ckpt['scheduler_global']) 144 | scheduler_local.load_state_dict(ckpt['scheduler_local']) 145 | try: scheduler_mask.load_state_dict(ckpt['scheduler_mask']) 146 | except: pass 147 | 148 | def train(it): 149 | model.train() 150 | optimizer_global.zero_grad() 151 | optimizer_local.zero_grad() 152 | optimizer_mask.zero_grad() 153 | ddpm_step = config.train.max_iters//2 154 | batch = next(train_iterator).to(args.device) 155 | 156 | if args.distribution: 157 | loss_func=model.module.get_loss 158 | else: 159 | loss_func=model.get_loss 160 | 161 | loss, loss_global, loss_local, loss_mask = loss_func( 162 | data=batch, 163 | atom_type=batch.atom_type, 164 | pos=batch.pos, 165 | bond_index=batch.edge_index, 166 | bond_type=batch.edge_type, 167 | batch=batch.batch, 168 | num_nodes_per_graph=batch.num_nodes_per_graph, 169 | num_graphs=batch.num_graphs, 170 | anneal_power=config.train.anneal_power, 171 | return_unreduced_loss=True 172 | ) 173 | loss_mask = loss_mask.mean() 174 | if hasattr(batch,"last_select"): 175 | sum_selected = batch.last_select.sum() 176 | loss = loss.sum()/sum_selected 177 | loss_global=loss_global.sum()/sum_selected 178 | loss_local = loss_local.sum()/sum_selected 179 | else: 180 | loss = loss.mean() 181 | loss_global=loss_global.mean() 182 | loss_local = loss_local.mean() 183 | 184 | if args.distribution: 185 | 186 | reduced_loss =reduce_mean(loss, args.nprocs) 187 | reduced_loss_global = reduce_mean(loss_global, args.nprocs) 188 | reduced_loss_local = reduce_mean(loss_local, args.nprocs) 189 | reduced_loss_mask = reduce_mean(loss_mask, args.nprocs) 190 | 191 | loss.backward() 192 | orig_grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm) 193 | optimizer_global.step() 194 | optimizer_local.step() 195 | optimizer_mask.step() 196 | 197 | if master_worker and (it-1) % args.print_freq == 0: 198 | if args.distribution: 199 | loss=reduced_loss 200 | loss_global =reduced_loss_global 201 | loss_local = reduced_loss_local 202 | loss_mask = reduced_loss_mask 203 | logger.info('[Train] Epoch %05d | Iter %05d | Loss %.2f | Loss(Global) %.2f | Loss(Local) %.2f | Loss(mask) %.2f | Grad %.2f | LR(Global) %.6f | LR(Local) %.6f| LR(mask) %.6f|%s' % ( 204 | (it*config.train.batch_size)//len(train_set), it, loss.item(), loss_global.item(), loss_local.item(), loss_mask.item(), orig_grad_norm, optimizer_global.param_groups[0]['lr'], 205 | optimizer_local.param_groups[0]['lr'],optimizer_mask.param_groups[0]['lr'], log_dir 206 | )) 207 | writer.add_scalar('train/loss', loss, it) 208 | writer.add_scalar('train/loss_global', loss_global.mean(), it) 209 | writer.add_scalar('train/loss_local', loss_local.mean(), it) 210 | writer.add_scalar('train/loss_mask', loss_mask.mean(), it) 211 | writer.add_scalar('train/lr_global', optimizer_global.param_groups[0]['lr'], it) 212 | writer.add_scalar('train/lr_local', optimizer_local.param_groups[0]['lr'], it) 213 | writer.add_scalar('train/lr_mask', optimizer_mask.param_groups[0]['lr'], it) 214 | writer.add_scalar('train/grad_norm', orig_grad_norm, it) 215 | writer.flush() 216 | 217 | def validate(it): 218 | sum_loss, sum_n = torch.tensor(0.0).to(args.local_rank), 0 219 | sum_loss_global, sum_n_global = torch.tensor(0.0).to(args.local_rank), 0 220 | sum_loss_local, sum_n_local = torch.tensor(0.0).to(args.local_rank), 0 221 | sum_loss_mask, sum_n_mask = torch.tensor(0.0).to(args.local_rank), 0 222 | # print("validate....",local_rank,end=' | ') 223 | with torch.no_grad(): 224 | model.eval() 225 | for i, batch in enumerate(tqdm(val_loader, desc='Validation',disable=not master_worker)): 226 | batch = batch.to(args.local_rank) 227 | if args.distribution: 228 | loss_func=model.module.get_loss 229 | else: 230 | loss_func=model.get_loss 231 | 232 | loss, loss_global, loss_local, loss_mask = loss_func( 233 | data=batch, 234 | atom_type=batch.atom_type, 235 | pos=batch.pos, 236 | bond_index=batch.edge_index, 237 | bond_type=batch.edge_type, 238 | batch=batch.batch, 239 | num_nodes_per_graph=batch.num_nodes_per_graph, 240 | num_graphs=batch.num_graphs, 241 | anneal_power=config.train.anneal_power, 242 | return_unreduced_loss=True 243 | ) 244 | sum_loss += loss.sum().item() 245 | sum_loss_local += loss_local.sum().item() 246 | sum_loss_global += loss_global.sum().item() 247 | sum_loss_mask += loss_mask.sum().item() 248 | sum_n_mask += loss_mask.size(0) 249 | if hasattr(batch, "last_select"): 250 | sum_selected = batch.last_select.sum() 251 | sum_n +=sum_selected 252 | sum_n_local +=sum_selected 253 | sum_n_global +=sum_selected 254 | 255 | else: 256 | sum_n += loss.size(0) 257 | sum_n_local += loss_local.size(0) 258 | sum_n_global += loss_global.size(0) 259 | 260 | avg_loss = sum_loss / sum_n 261 | avg_loss_global = sum_loss_global / sum_n_global 262 | avg_loss_local = sum_loss_local / sum_n_local 263 | avg_loss_mask = sum_loss_mask / sum_n_mask 264 | 265 | if args.distribution: 266 | dist.barrier() 267 | avg_loss =reduce_mean(avg_loss, args.nprocs) 268 | avg_loss_global = reduce_mean(avg_loss_global, args.nprocs) 269 | avg_loss_local = reduce_mean(avg_loss_local, args.nprocs) 270 | avg_loss_mask = reduce_mean(avg_loss_mask, args.nprocs) 271 | 272 | 273 | if config.train.scheduler.type == 'plateau': 274 | scheduler_global.step(avg_loss_global) 275 | scheduler_local.step(avg_loss_local) 276 | # scheduler_mask.step(avg_loss_mask) 277 | else: 278 | scheduler_global.step() 279 | scheduler_local.step() 280 | scheduler_mask.step() 281 | 282 | if master_worker: 283 | logger.info('[Validate] Iter %05d | Loss %.6f | Loss(Global) %.6f | Loss(Local) %.6f | Loss(mask) %.6f' % ( 284 | it, avg_loss, avg_loss_global, avg_loss_local, avg_loss_mask 285 | )) 286 | writer.add_scalar('val/loss', avg_loss, it) 287 | writer.add_scalar('val/loss_global', avg_loss_global, it) 288 | writer.add_scalar('val/loss_local', avg_loss_local, it) 289 | writer.add_scalar('val/loss_mask', avg_loss_mask, it) 290 | writer.flush() 291 | return avg_loss 292 | 293 | 294 | if master_worker: print("training....") 295 | try: 296 | for it in range(start_iter, config.train.max_iters + 1): 297 | 298 | # train_sampler.set_epoch(it) 299 | train(it) 300 | # TODO if avg_val_loss < : save 301 | if it % config.train.val_freq == 0 or it == config.train.max_iters: 302 | avg_val_loss = validate(it) 303 | if master_worker and (it % 20000 == 0 or it == config.train.max_iters): 304 | # print("saving checkpoint....") 305 | ckpt_path = os.path.join(ckpt_dir, '%d.pt' % it) 306 | torch.save({ 307 | 'config': config, 308 | 'model': model.state_dict(), 309 | 'optimizer_global': optimizer_global.state_dict(), 310 | 'scheduler_global': scheduler_global.state_dict(), 311 | 'optimizer_local': optimizer_local.state_dict(), 312 | 'scheduler_local': scheduler_local.state_dict(), 313 | 'optimizer_mask': optimizer_mask.state_dict(), 314 | 'scheduler_mask': scheduler_mask.state_dict(), 315 | 'iteration': it, 316 | 'avg_val_loss': avg_val_loss, 317 | }, ckpt_path) 318 | except KeyboardInterrupt: 319 | if master_worker: logger.info('Terminating...') 320 | 321 | 322 | 323 | 324 | -------------------------------------------------------------------------------- /utils/chem.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch 3 | from torchvision.transforms.functional import to_tensor 4 | import rdkit 5 | import rdkit.Chem.Draw 6 | from rdkit import Chem 7 | from rdkit.Chem import rdDepictor as DP 8 | from rdkit.Chem import PeriodicTable as PT 9 | from rdkit.Chem import rdMolAlign as MA 10 | from rdkit.Chem.rdchem import BondType as BT 11 | from rdkit.Chem.rdchem import Mol,GetPeriodicTable 12 | from rdkit.Chem.Draw import rdMolDraw2D as MD2 13 | from rdkit.Chem.rdmolops import RemoveHs 14 | from typing import List, Tuple 15 | 16 | 17 | BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} 18 | BOND_NAMES = {i: t for i, t in enumerate(BT.names.keys())} 19 | 20 | 21 | def set_conformer_positions(conf, pos): 22 | for i in range(pos.shape[0]): 23 | conf.SetAtomPosition(i, pos[i].tolist()) 24 | return conf 25 | 26 | 27 | def draw_mol_image(rdkit_mol, tensor=False): 28 | rdkit_mol.UpdatePropertyCache() 29 | img = rdkit.Chem.Draw.MolToImage(rdkit_mol, kekulize=False) 30 | if tensor: 31 | return to_tensor(img) 32 | else: 33 | return img 34 | 35 | 36 | def update_data_rdmol_positions(data): 37 | for i in range(data.pos.size(0)): 38 | data.rdmol.GetConformer(0).SetAtomPosition(i, data.pos[i].tolist()) 39 | return data 40 | 41 | 42 | def update_data_pos_from_rdmol(data): 43 | new_pos = torch.FloatTensor(data.rdmol.GetConformer(0).GetPositions()).to(data.pos) 44 | data.pos = new_pos 45 | return data 46 | 47 | 48 | def set_rdmol_positions(rdkit_mol, pos): 49 | """ 50 | Args: 51 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object. 52 | pos: (N_atoms, 3) 53 | """ 54 | mol = deepcopy(rdkit_mol) 55 | set_rdmol_positions_(mol, pos) 56 | return mol 57 | 58 | 59 | def set_rdmol_positions_(mol, pos): 60 | """ 61 | Args: 62 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object. 63 | pos: (N_atoms, 3) 64 | """ 65 | for i in range(pos.shape[0]): 66 | mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist()) 67 | return mol 68 | 69 | 70 | def get_atom_symbol(atomic_number): 71 | return PT.GetElementSymbol(GetPeriodicTable(), atomic_number) 72 | 73 | 74 | def mol_to_smiles(mol: Mol) -> str: 75 | return Chem.MolToSmiles(mol, allHsExplicit=True) 76 | 77 | 78 | def mol_to_smiles_without_Hs(mol: Mol) -> str: 79 | return Chem.MolToSmiles(Chem.RemoveHs(mol)) 80 | 81 | 82 | def remove_duplicate_mols(molecules: List[Mol]) -> List[Mol]: 83 | unique_tuples: List[Tuple[str, Mol]] = [] 84 | 85 | for molecule in molecules: 86 | duplicate = False 87 | smiles = mol_to_smiles(molecule) 88 | for unique_smiles, _ in unique_tuples: 89 | if smiles == unique_smiles: 90 | duplicate = True 91 | break 92 | 93 | if not duplicate: 94 | unique_tuples.append((smiles, molecule)) 95 | 96 | return [mol for smiles, mol in unique_tuples] 97 | 98 | 99 | def get_atoms_in_ring(mol): 100 | atoms = set() 101 | for ring in mol.GetRingInfo().AtomRings(): 102 | for a in ring: 103 | atoms.add(a) 104 | return atoms 105 | 106 | 107 | def get_2D_mol(mol): 108 | mol = deepcopy(mol) 109 | DP.Compute2DCoords(mol) 110 | return mol 111 | 112 | 113 | def draw_mol_svg(mol,molSize=(450,150),kekulize=False): 114 | mc = Chem.Mol(mol.ToBinary()) 115 | if kekulize: 116 | try: 117 | Chem.Kekulize(mc) 118 | except: 119 | mc = Chem.Mol(mol.ToBinary()) 120 | if not mc.GetNumConformers(): 121 | DP.Compute2DCoords(mc) 122 | drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1]) 123 | drawer.DrawMolecule(mc) 124 | drawer.FinishDrawing() 125 | svg = drawer.GetDrawingText() 126 | # It seems that the svg renderer used doesn't quite hit the spec. 127 | # Here are some fixes to make it work in the notebook, although I think 128 | # the underlying issue needs to be resolved at the generation step 129 | # return svg.replace('svg:','') 130 | return svg 131 | 132 | 133 | def get_best_rmsd(probe, ref): 134 | probe = RemoveHs(probe) 135 | ref = RemoveHs(ref) 136 | rmsd = MA.GetBestRMS(probe, ref) 137 | return rmsd 138 | 139 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import warnings 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch_geometric.data import Data, Batch 8 | 9 | 10 | #customize exp lr scheduler with min lr 11 | class ExponentialLR_with_minLr(torch.optim.lr_scheduler.ExponentialLR): 12 | def __init__(self, optimizer, gamma, min_lr=1e-4, last_epoch=-1, verbose=False): 13 | self.gamma = gamma 14 | self.min_lr = min_lr 15 | super(ExponentialLR_with_minLr, self).__init__(optimizer, gamma, last_epoch, verbose) 16 | 17 | def get_lr(self): 18 | if not self._get_lr_called_within_step: 19 | warnings.warn("To get the last learning rate computed by the scheduler, " 20 | "please use `get_last_lr()`.", UserWarning) 21 | 22 | if self.last_epoch == 0: 23 | return self.base_lrs 24 | return [max(group['lr'] * self.gamma, self.min_lr) 25 | for group in self.optimizer.param_groups] 26 | 27 | def _get_closed_form_lr(self): 28 | return [max(base_lr * self.gamma ** self.last_epoch, self.min_lr) 29 | for base_lr in self.base_lrs] 30 | 31 | 32 | def repeat_data(data: Data, num_repeat) -> Batch: 33 | datas = [copy.deepcopy(data) for i in range(num_repeat)] 34 | return Batch.from_data_list(datas) 35 | 36 | def repeat_batch(batch: Batch, num_repeat) -> Batch: 37 | datas = batch.to_data_list() 38 | new_data = [] 39 | for i in range(num_repeat): 40 | new_data = new_data + copy.deepcopy(datas) 41 | return Batch.from_data_list(new_data) 42 | 43 | 44 | def get_optimizer(cfg, model): 45 | if cfg.type == 'adam': 46 | return torch.optim.Adam( 47 | model.parameters(), 48 | lr=cfg.lr, 49 | weight_decay=cfg.weight_decay, 50 | betas=(cfg.beta1, cfg.beta2, ) 51 | ) 52 | else: 53 | raise NotImplementedError('Optimizer not supported: %s' % cfg.type) 54 | 55 | 56 | def get_scheduler(cfg, optimizer): 57 | if cfg.type == 'plateau': 58 | return torch.optim.lr_scheduler.ReduceLROnPlateau( 59 | optimizer, 60 | factor=cfg.factor, 61 | patience=cfg.patience, 62 | ) 63 | elif cfg.type == 'expmin': 64 | return ExponentialLR_with_minLr( 65 | optimizer, 66 | gamma=cfg.factor, 67 | min_lr=cfg.min_lr, 68 | ) 69 | elif cfg.type == 'expmin_milestone': 70 | gamma = np.exp(np.log(cfg.factor) / cfg.milestone) 71 | return ExponentialLR_with_minLr( 72 | optimizer, 73 | gamma=gamma, 74 | min_lr=cfg.min_lr, 75 | ) 76 | else: 77 | raise NotImplementedError('Scheduler not supported: %s' % cfg.type) -------------------------------------------------------------------------------- /utils/evaluation/covmat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | import multiprocessing as mp 5 | from torch_geometric.data import Data 6 | from functools import partial 7 | from easydict import EasyDict 8 | from tqdm.auto import tqdm 9 | from rdkit import Chem 10 | from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule 11 | 12 | from ..chem import set_rdmol_positions, get_best_rmsd 13 | 14 | 15 | def get_rmsd_confusion_matrix(data: Data, useFF=False): 16 | data['pos_ref'] = data['pos_ref'].reshape(-1, data['rdmol'].GetNumAtoms(), 3) 17 | data['pos_gen'] = data['pos_gen'].reshape(-1, data['rdmol'].GetNumAtoms(), 3) 18 | num_gen = data['pos_gen'].shape[0] 19 | num_ref = data['pos_ref'].shape[0] 20 | 21 | # assert num_gen == data.num_pos_gen.item() 22 | # assert num_ref == data.num_pos_ref.item() 23 | 24 | rmsd_confusion_mat = -1 * np.ones([num_ref, num_gen],dtype=np.float) 25 | 26 | for i in range(num_gen): 27 | gen_mol = set_rdmol_positions(data['rdmol'], data['pos_gen'][i]) 28 | if useFF: 29 | #print('Applying FF on generated molecules...') 30 | MMFFOptimizeMolecule(gen_mol) 31 | for j in range(num_ref): 32 | ref_mol = set_rdmol_positions(data['rdmol'], data['pos_ref'][j]) 33 | 34 | rmsd_confusion_mat[j,i] = get_best_rmsd(gen_mol, ref_mol) 35 | 36 | return rmsd_confusion_mat 37 | 38 | 39 | def evaluate_conf(data: Data, useFF=False, threshold=0.5): 40 | rmsd_confusion_mat = get_rmsd_confusion_matrix(data, useFF=useFF) 41 | rmsd_ref_min = rmsd_confusion_mat.min(-1) 42 | #print('done one mol') 43 | #print(rmsd_ref_min) 44 | return (rmsd_ref_min<=threshold).mean(), rmsd_ref_min.mean() 45 | 46 | 47 | def print_covmat_results(results, print_fn=print): 48 | df = pd.DataFrame({ 49 | 'COV-R_mean': np.mean(results.CoverageR, 0), 50 | 'COV-R_median': np.median(results.CoverageR, 0), 51 | 'COV-R_std': np.std(results.CoverageR, 0), 52 | 'COV-P_mean': np.mean(results.CoverageP, 0), 53 | 'COV-P_median': np.median(results.CoverageP, 0), 54 | 'COV-P_std': np.std(results.CoverageP, 0), 55 | }, index=results.thresholds) 56 | print_fn('\n' + str(df)) 57 | print_fn('MAT-R_mean: %.4f | MAT-R_median: %.4f | MAT-R_std %.4f' % ( 58 | np.mean(results.MatchingR), np.median(results.MatchingR), np.std(results.MatchingR) 59 | )) 60 | print_fn('MAT-P_mean: %.4f | MAT-P_median: %.4f | MAT-P_std %.4f' % ( 61 | np.mean(results.MatchingP), np.median(results.MatchingP), np.std(results.MatchingP) 62 | )) 63 | return df 64 | 65 | 66 | 67 | class CovMatEvaluator(object): 68 | 69 | def __init__(self, 70 | num_workers=8, 71 | use_force_field=False, 72 | thresholds=np.arange(0.05, 3.05, 0.05), 73 | ratio=2, 74 | filter_disconnected=True, 75 | print_fn=print, 76 | ): 77 | super().__init__() 78 | self.num_workers = num_workers 79 | self.use_force_field = use_force_field 80 | self.thresholds = np.array(thresholds).flatten() 81 | 82 | self.ratio = ratio 83 | self.filter_disconnected = filter_disconnected 84 | 85 | self.pool = mp.Pool(num_workers) 86 | self.print_fn = print_fn 87 | 88 | def __call__(self, packed_data_list, start_idx=0): 89 | func = partial(get_rmsd_confusion_matrix, useFF=self.use_force_field) 90 | 91 | filtered_data_list = [] 92 | for data in packed_data_list: 93 | if 'pos_gen' not in data or 'pos_ref' not in data: continue 94 | if self.filter_disconnected and ('.' in data['smiles']): continue 95 | 96 | data['pos_ref'] = data['pos_ref'].reshape(-1, data['rdmol'].GetNumAtoms(), 3) 97 | data['pos_gen'] = data['pos_gen'].reshape(-1, data['rdmol'].GetNumAtoms(), 3) 98 | 99 | num_gen = data['pos_ref'].shape[0] * self.ratio 100 | if data['pos_gen'].shape[0] < num_gen: continue 101 | data['pos_gen'] = data['pos_gen'][:num_gen] 102 | 103 | filtered_data_list.append(data) 104 | 105 | filtered_data_list = filtered_data_list[start_idx:] 106 | self.print_fn('Filtered: %d / %d' % (len(filtered_data_list), len(packed_data_list))) 107 | 108 | covr_scores = [] 109 | matr_scores = [] 110 | covp_scores = [] 111 | matp_scores = [] 112 | for confusion_mat in tqdm(self.pool.imap(func, filtered_data_list), total=len(filtered_data_list)): 113 | # confusion_mat: (num_ref, num_gen) 114 | rmsd_ref_min = confusion_mat.min(-1) # np (num_ref, ) 115 | rmsd_gen_min = confusion_mat.min(0) # np (num_gen, ) 116 | 117 | rmsd_cov_thres = rmsd_ref_min.reshape(-1, 1) <= self.thresholds.reshape(1, -1) # np (num_ref, num_thres) 118 | rmsd_jnk_thres = rmsd_gen_min.reshape(-1, 1) <= self.thresholds.reshape(1, -1) # np (num_gen, num_thres) 119 | 120 | matr_scores.append(rmsd_ref_min.mean()) 121 | covr_scores.append(rmsd_cov_thres.mean(0, keepdims=True)) # np (1, num_thres) 122 | matp_scores.append(rmsd_gen_min.mean()) 123 | covp_scores.append(rmsd_jnk_thres.mean(0, keepdims=True)) # np (1, num_thres) 124 | 125 | covr_scores = np.vstack(covr_scores) # np (num_mols, num_thres) 126 | matr_scores = np.array(matr_scores) # np (num_mols, ) 127 | covp_scores = np.vstack(covp_scores) # np (num_mols, num_thres) 128 | matp_scores = np.array(matp_scores) 129 | 130 | results = EasyDict({ 131 | 'CoverageR': covr_scores, 132 | 'MatchingR': matr_scores, 133 | 'thresholds': self.thresholds, 134 | 'CoverageP': covp_scores, 135 | 'MatchingP': matp_scores 136 | }) 137 | # print_conformation_eval_results(results) 138 | return results 139 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import random 4 | import logging 5 | import torch 6 | import numpy as np 7 | from glob import glob 8 | from logging import Logger 9 | from tqdm.auto import tqdm 10 | from torch_geometric.data import Batch 11 | 12 | 13 | class BlackHole(object): 14 | def __setattr__(self, name, value): 15 | pass 16 | def __call__(self, *args, **kwargs): 17 | return self 18 | def __getattr__(self, name): 19 | return self 20 | 21 | 22 | def get_logger(name, log_dir=None, log_fn='log.txt'): 23 | logger = logging.getLogger(name) 24 | logger.setLevel(logging.DEBUG) 25 | formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s') 26 | 27 | stream_handler = logging.StreamHandler() 28 | stream_handler.setLevel(logging.DEBUG) 29 | stream_handler.setFormatter(formatter) 30 | logger.addHandler(stream_handler) 31 | 32 | if log_dir is not None: 33 | file_handler = logging.FileHandler(os.path.join(log_dir, log_fn)) 34 | file_handler.setLevel(logging.DEBUG) 35 | file_handler.setFormatter(formatter) 36 | logger.addHandler(file_handler) 37 | 38 | return logger 39 | 40 | 41 | def get_new_log_dir(root='./logs', prefix='', tag=''): 42 | fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()) 43 | if prefix != '': 44 | fn = prefix + '_' + fn 45 | if tag != '': 46 | fn = fn + '_' + tag 47 | log_dir = os.path.join(root, fn) 48 | os.makedirs(log_dir) 49 | return log_dir 50 | 51 | 52 | def seed_all(seed): 53 | torch.manual_seed(seed) 54 | np.random.seed(seed) 55 | random.seed(seed) 56 | 57 | 58 | def inf_iterator(iterable): 59 | iterator = iterable.__iter__() 60 | while True: 61 | try: 62 | yield iterator.__next__() 63 | except StopIteration: 64 | iterator = iterable.__iter__() 65 | 66 | 67 | def log_hyperparams(writer, args): 68 | from torch.utils.tensorboard.summary import hparams 69 | vars_args = {k:v if isinstance(v, str) else repr(v) for k, v in vars(args).items()} 70 | exp, ssi, sei = hparams(vars_args, {}) 71 | writer.file_writer.add_summary(exp) 72 | writer.file_writer.add_summary(ssi) 73 | writer.file_writer.add_summary(sei) 74 | 75 | 76 | def int_tuple(argstr): 77 | return tuple(map(int, argstr.split(','))) 78 | 79 | 80 | def str_tuple(argstr): 81 | return tuple(argstr.split(',')) 82 | 83 | 84 | def repeat_data(data, num_repeat): 85 | datas = [data.clone() for i in range(num_repeat)] 86 | return Batch.from_data_list(datas) 87 | 88 | 89 | def repeat_batch(batch, num_repeat): 90 | datas = batch.to_data_list() 91 | new_data = [] 92 | for i in range(num_repeat): 93 | new_data = new_data + datas.clone() 94 | return Batch.from_data_list(new_data) 95 | 96 | 97 | def get_checkpoint_path(folder, it=None): 98 | if it is not None: 99 | return os.path.join(folder, '%d.pt' % it), it 100 | all_iters = list(map(lambda x: int(os.path.basename(x[:-3])), glob(os.path.join(folder, '*.pt')))) 101 | all_iters.sort() 102 | return os.path.join(folder, '%d.pt' % all_iters[-1]), all_iters[-1] 103 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch_geometric.data import Data 4 | from torch_geometric.transforms import Compose 5 | from torch_geometric.utils import to_dense_adj, dense_to_sparse 6 | from torch_sparse import coalesce 7 | 8 | from torch_geometric.utils import to_networkx 9 | import networkx as nx 10 | import numpy as np 11 | 12 | 13 | from .chem import BOND_TYPES, BOND_NAMES, get_atom_symbol 14 | 15 | 16 | 17 | 18 | class AddHigherOrderEdges(object): 19 | 20 | def __init__(self, order, num_types=len(BOND_TYPES)): 21 | super().__init__() 22 | self.order = order 23 | self.num_types = num_types 24 | 25 | def binarize(self, x): 26 | return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)) 27 | 28 | def get_higher_order_adj_matrix(self, adj, order): 29 | """ 30 | Args: 31 | adj: (N, N) 32 | type_mat: (N, N) 33 | """ 34 | adj_mats = [torch.eye(adj.size(0), dtype=torch.long, device=adj.device), \ 35 | self.binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device))] 36 | 37 | for i in range(2, order+1): 38 | adj_mats.append(self.binarize(adj_mats[i-1] @ adj_mats[1])) 39 | order_mat = torch.zeros_like(adj) 40 | 41 | for i in range(1, order+1): 42 | order_mat = order_mat + (adj_mats[i] - adj_mats[i-1]) * i 43 | 44 | return order_mat 45 | 46 | def __call__(self, data: Data): 47 | N = data.num_nodes 48 | adj = to_dense_adj(data.edge_index).squeeze(0) 49 | adj_order = self.get_higher_order_adj_matrix(adj, self.order) # (N, N) 50 | 51 | type_mat = to_dense_adj(data.edge_index, edge_attr=data.edge_type).squeeze(0) # (N, N) 52 | type_highorder = torch.where(adj_order > 1, self.num_types + adj_order - 1, torch.zeros_like(adj_order)) 53 | assert (type_mat * type_highorder == 0).all() 54 | type_new = type_mat + type_highorder 55 | 56 | new_edge_index, new_edge_type = dense_to_sparse(type_new) 57 | _, edge_order = dense_to_sparse(adj_order) 58 | 59 | data.bond_edge_index = data.edge_index # Save original edges 60 | data.edge_index, data.edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data 61 | edge_index_1, data.edge_order = coalesce(new_edge_index, edge_order.long(), N, N) # modify data 62 | data.is_bond = (data.edge_type < self.num_types) 63 | assert (data.edge_index == edge_index_1).all() 64 | 65 | return data 66 | 67 | class AddEdgeLength(object): 68 | 69 | def __call__(self, data: Data): 70 | 71 | pos = data.pos 72 | row, col = data.edge_index 73 | d = (pos[row] - pos[col]).norm(dim=-1).unsqueeze(-1) # (num_edge, 1) 74 | data.edge_length = d 75 | return data 76 | 77 | 78 | # Add attribute placeholder for data object, so that we can use batch.to_data_list 79 | class AddPlaceHolder(object): 80 | def __call__(self, data: Data): 81 | data.pos_gen = -1. * torch.ones_like(data.pos) 82 | data.d_gen = -1. * torch.ones_like(data.edge_length) 83 | data.d_recover = -1. * torch.ones_like(data.edge_length) 84 | return data 85 | 86 | 87 | class AddEdgeName(object): 88 | 89 | def __init__(self, asymmetric=True): 90 | super().__init__() 91 | self.bonds = copy.deepcopy(BOND_NAMES) 92 | self.bonds[len(BOND_NAMES) + 1] = 'Angle' 93 | self.bonds[len(BOND_NAMES) + 2] = 'Dihedral' 94 | self.asymmetric = asymmetric 95 | 96 | def __call__(self, data:Data): 97 | data.edge_name = [] 98 | for i in range(data.edge_index.size(1)): 99 | tail = data.edge_index[0, i] 100 | head = data.edge_index[1, i] 101 | if self.asymmetric and tail >= head: 102 | data.edge_name.append('') 103 | continue 104 | tail_name = get_atom_symbol(data.atom_type[tail].item()) 105 | head_name = get_atom_symbol(data.atom_type[head].item()) 106 | name = '%s_%s_%s_%d_%d' % ( 107 | self.bonds[data.edge_type[i].item()] if data.edge_type[i].item() in self.bonds else 'E'+str(data.edge_type[i].item()), 108 | tail_name, 109 | head_name, 110 | tail, 111 | head, 112 | ) 113 | if hasattr(data, 'edge_length'): 114 | name += '_%.3f' % (data.edge_length[i].item()) 115 | data.edge_name.append(name) 116 | return data 117 | 118 | 119 | class AddAngleDihedral(object): 120 | 121 | def __init__(self): 122 | super().__init__() 123 | 124 | @staticmethod 125 | def iter_angle_triplet(bond_mat): 126 | n_atoms = bond_mat.size(0) 127 | for j in range(n_atoms): 128 | for k in range(n_atoms): 129 | for l in range(n_atoms): 130 | if bond_mat[j, k].item() == 0 or bond_mat[k, l].item() == 0: continue 131 | if (j == k) or (k == l) or (j >= l): continue 132 | yield(j, k, l) 133 | 134 | @staticmethod 135 | def iter_dihedral_quartet(bond_mat): 136 | n_atoms = bond_mat.size(0) 137 | for i in range(n_atoms): 138 | for j in range(n_atoms): 139 | if i >= j: continue 140 | if bond_mat[i,j].item() == 0:continue 141 | for k in range(n_atoms): 142 | for l in range(n_atoms): 143 | if (k in (i,j)) or (l in (i,j)): continue 144 | if bond_mat[k,i].item() == 0 or bond_mat[l,j].item() == 0: continue 145 | yield(k, i, j, l) 146 | 147 | def __call__(self, data:Data): 148 | N = data.num_nodes 149 | if 'is_bond' in data: 150 | bond_mat = to_dense_adj(data.edge_index, edge_attr=data.is_bond).long().squeeze(0) > 0 151 | else: 152 | bond_mat = to_dense_adj(data.edge_index, edge_attr=data.edge_type).long().squeeze(0) > 0 153 | 154 | # Note: if the name of attribute contains `index`, it will automatically 155 | # increases during batching. 156 | data.angle_index = torch.LongTensor(list(self.iter_angle_triplet(bond_mat))).t() 157 | data.dihedral_index = torch.LongTensor(list(self.iter_dihedral_quartet(bond_mat))).t() 158 | 159 | return data 160 | 161 | 162 | class CountNodesPerGraph(object): 163 | 164 | def __init__(self) -> None: 165 | super().__init__() 166 | 167 | def __call__(self, data): 168 | data.num_nodes_per_graph = torch.LongTensor([data.num_nodes]) 169 | return data 170 | 171 | from torch import nn 172 | 173 | class SubgraphNoiseTransform(object): 174 | """ expectation state + k-step same-subgraph diffusion""" 175 | def __init__(self, config, tag='', ddpm=False, boltzmann_weight=False): 176 | 177 | self.config = config 178 | self.ddpm=ddpm # typical DDPM, each atom views as independent point 179 | self.tag= tag 180 | 181 | betas = get_beta_schedule( 182 | beta_schedule=config.beta_schedule, 183 | beta_start=config.beta_start, 184 | beta_end=config.beta_end, 185 | num_diffusion_timesteps=config.num_diffusion_timesteps, 186 | ) 187 | betas = torch.from_numpy(betas).float() 188 | self.betas = nn.Parameter(betas, requires_grad=False) 189 | self.one_minu_beta_sqrt = (1-self.betas).sqrt() 190 | ## variances 191 | alphas = (1. - betas).cumprod(dim=0) 192 | self.alphas = nn.Parameter(alphas, requires_grad=False) 193 | check_alphas(self.alphas) 194 | self.num_timesteps = self.betas.size(0) 195 | self.same_mask_steps =self.config.get("same_mask_steps", 250) # same subgraph step k 196 | 197 | 198 | 199 | def __call__(self, data): 200 | # select conformer 201 | # Four elements for DDPM: original_data(pos), gaussian_noise(pos_noise), beta(sigma), time_step 202 | # Sample noise levels 203 | atom_type=data.atom_type 204 | pos=data.pos 205 | bond_index=data.edge_index 206 | bond_type=data.edge_type 207 | num_nodes_per_graph=data.num_nodes_per_graph 208 | mask_subgraph = data.mask_subgraph 209 | time_step = torch.randint( 210 | 1, self.num_timesteps, size=(1,), device=pos.device) 211 | # time_step =0 is the beta_1 in Eq 212 | 213 | data.time_step = time_step 214 | beta = self.betas.index_select(-1, time_step) 215 | data.beta = beta 216 | 217 | if self.ddpm: 218 | alpha= self.alphas.index_select(-1, time_step) 219 | alpha = alpha.expand(pos.shape[0], 1) 220 | data.alpha = alpha 221 | data.last_select=torch.ones_like(alpha) 222 | data.noise_scale = 1- alpha 223 | del data.mask_subgraph 224 | return data 225 | 226 | 227 | 228 | last_select, alpha, noise_scale = self._get_alpha_nodes_expect_same_mask(mask_subgraph, time_step, k =self.same_mask_steps) 229 | 230 | data.alpha = alpha 231 | data.last_select=last_select # # for predict last selected subgraph 232 | data.noise_scale = noise_scale 233 | del data.mask_subgraph 234 | return data 235 | 236 | 237 | def _get_alpha_nodes_expect_same_mask(self, mask_subgraph, time_step, p=0.5, k=250): 238 | """ expectation stata + k-same subgraph(mask) step diffusion """ 239 | expect_step = time_step.div(k,rounding_mode='floor') # m := time_step//k 240 | mask_step = time_step % k 241 | 242 | if mask_step==0: 243 | expect_step-=1 244 | mask_step=k 245 | selected_index = torch.randint(low=0, high=mask_subgraph.shape[-1], size=(1,)) 246 | selected_node =mask_subgraph.index_select(-1, selected_index) # mask_subgraph[:,[1,2,5]] 247 | if expect_step==0: 248 | selected_nodes = selected_node.repeat(1, mask_step+1) 249 | selected_nodes[:,[i for i in range(0, min(3,time_step+1))]] = True 250 | bern_beta_mask_t = self.betas[0:time_step+1]*selected_nodes 251 | 252 | alpha_t = (1. - bern_beta_mask_t).prod(dim=-1,keepdim=True) 253 | 254 | return selected_node, alpha_t, 1-alpha_t 255 | 256 | ## Phase I: compute t step mean state 0: km 257 | p = mask_subgraph.sum(-1).unsqueeze(-1)/mask_subgraph.size(-1) # node selection probability 258 | if self.tag.startswith('Recover_GeoDiff'): 259 | p = torch.ones_like(p) 260 | selected_node[:][:] = True 261 | if not hasattr(self, "prod_one_minu_beta_sqrt"): 262 | self.prod_one_minu_beta_sqrt =torch.zeros(self.num_timesteps//k) 263 | self.prod_one_minu_beta_sqrt[0] = self.alphas[k-1] 264 | # For every k step, we calcualte a expectation state 265 | for j in range(1, self.num_timesteps//k): 266 | self.prod_one_minu_beta_sqrt[j] = (1-self.betas)[j*k:(j+1)*k].prod(dim=-1) # \prod_{i=(j-1)k+1}^{kj}(1-\beta_i) 267 | self.prod_one_minu_beta_sqrt = self.prod_one_minu_beta_sqrt.sqrt() 268 | 269 | alpha_exp = (p*self.prod_one_minu_beta_sqrt[:expect_step] + 1-p)**2 # \alpha_j= (p\sqrt{\prod_{i=(j-1)k+1}^{kj}(1-\beta_i)} + 1-p)^2 270 | alpha_nodes= alpha_exp.cumprod(dim=-1) # \bar\alpha 271 | noise_scale = (alpha_nodes[:,expect_step-1]/alpha_nodes).mm( (1-self.prod_one_minu_beta_sqrt[:expect_step]**2).unsqueeze(-1) ) * p**2 # [ p\sqrt{\sum_{l=1}^{m} \frac{\bar\alpha_{m}}{\bar\alpha_{l}} (1-\prod_{i=(l-1)k+1}^{kl}(1-\beta_i))} ]^2 272 | 273 | 274 | ## Phase II: time step: km+1 -> t 275 | bern_beta_mask_t = self.betas[time_step-mask_step:time_step+1].repeat(selected_node.shape[0], 1)*selected_node 276 | alpha_t = (1. - bern_beta_mask_t).prod(dim=-1,keepdim=True) 277 | ## combine 278 | alpha_node_t = alpha_t * alpha_nodes.index_select(-1, expect_step-1) 279 | noise_scale_t = alpha_t * noise_scale + 1-alpha_t 280 | if (alpha_node_t==0).sum(): 281 | print(alpha_node_t,time_step) 282 | 283 | return selected_node, alpha_node_t, noise_scale_t 284 | 285 | 286 | def __repr__(self) -> str: 287 | return (f'{self.__class__.__name__}(sigma_min={self.sigma_min}, ' 288 | f'sigma_max={self.sigma_max})') 289 | 290 | 291 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 292 | def sigmoid(x): 293 | return 1 / (np.exp(-x) + 1) 294 | 295 | if beta_schedule == "quad": 296 | betas = ( 297 | np.linspace( 298 | beta_start ** 0.5, 299 | beta_end ** 0.5, 300 | num_diffusion_timesteps, 301 | dtype=np.float64, 302 | ) 303 | ** 2 304 | ) 305 | elif beta_schedule == "linear": 306 | betas = np.linspace( 307 | beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 308 | ) 309 | elif beta_schedule == "const": 310 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 311 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 312 | betas = 1.0 / np.linspace( 313 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 314 | ) 315 | elif beta_schedule == "sigmoid": 316 | betas = np.linspace(-6, 6, num_diffusion_timesteps) 317 | betas = sigmoid(betas) * (beta_end - beta_start) + beta_start 318 | else: 319 | raise NotImplementedError(beta_schedule) 320 | assert betas.shape == (num_diffusion_timesteps,) 321 | return betas 322 | 323 | def modify_conformer(pos, edge_index, mask_rotate, torsion_updates, as_numpy=False): 324 | if type(pos) != np.ndarray: pos = pos.cpu().numpy() 325 | for idx_edge, e in enumerate(edge_index.cpu().numpy()): 326 | if torsion_updates[idx_edge] == 0: 327 | continue 328 | u, v = e[0], e[1] 329 | 330 | # check if need to reverse the edge, v should be connected to the part that gets rotated 331 | assert not mask_rotate[idx_edge, u] 332 | assert mask_rotate[idx_edge, v] 333 | 334 | rot_vec = pos[u] - pos[v] # convention: positive rotation if pointing inwards. NOTE: DIFFERENT FROM THE PAPER! 335 | rot_vec = rot_vec * torsion_updates[idx_edge] / np.linalg.norm(rot_vec) # idx_edge! 336 | rot_mat = R.from_rotvec(rot_vec).as_matrix() 337 | 338 | pos[mask_rotate[idx_edge]] = (pos[mask_rotate[idx_edge]] - pos[v]) @ rot_mat.T + pos[v] 339 | 340 | if not as_numpy: pos = torch.from_numpy(pos.astype(np.float32)) 341 | return pos 342 | 343 | def get_transformation_mask(pyg_data): 344 | ''' get subgraph distribution''' 345 | G = to_networkx(pyg_data, to_undirected=False) 346 | to_rotate = [] 347 | edge_set=[] 348 | edges = pyg_data.edge_index.T.numpy() 349 | for i in range(0, edges.shape[0]): 350 | # assert edges[i, 0] == edges[i+1, 1] 351 | edge = list(edges[i]) 352 | if edge[::-1] in edge_set: # remove replicated undircted edge 353 | continue 354 | edge_set.append(edge) 355 | G2 = G.to_undirected() 356 | G2.remove_edge(*edges[i]) 357 | if not nx.is_connected(G2): 358 | l = list(sorted(nx.connected_components(G2), key=len)) 359 | if len(l[0]) > 1: 360 | to_rotate.extend(l) 361 | 362 | if len(to_rotate)==0: 363 | to_rotate.append(list(G.nodes())) 364 | 365 | mask_rotate = np.zeros((len(G.nodes()), len(to_rotate)), dtype=bool) 366 | for i in range(len(to_rotate)): 367 | mask_rotate.T[i][list(to_rotate[i])] = True 368 | pyg_data.mask_subgraph = torch.tensor(mask_rotate) 369 | return pyg_data 370 | 371 | def get_alpah_nodes_schedule(pyg_data, config): 372 | mask_subgraph=pyg_data.mask_subgraph 373 | selected_index= torch.randint(low=0,high=mask_subgraph.shape[-1], size=(config.num_diffusion_timesteps,)) 374 | selected_nodes=mask_subgraph.index_select(-1, selected_index) # mask_subgraph[:,[1,2,5]] 375 | # gurantee the noises have been added into every node 376 | selected_nodes[:,[i for i in range(0, min(3,config.num_diffusion_timesteps+1))]] = True 377 | 378 | betas = get_beta_schedule( 379 | beta_schedule=config.beta_schedule, 380 | beta_start=config.beta_start, 381 | beta_end=config.beta_end, 382 | num_diffusion_timesteps=config.num_diffusion_timesteps, 383 | ) 384 | betas = torch.from_numpy(betas).float() 385 | 386 | betas_nodes = betas * selected_nodes 387 | alpha_nodes = (1. - betas_nodes).cumprod(dim=-1) # alpha_t for each node 388 | pyg_data.selected_nodes, pyg_data.alpha_nodes = selected_nodes, alpha_nodes 389 | return pyg_data 390 | 391 | def check_alphas(alphas): 392 | for n,a in enumerate(alphas): 393 | if a==0: 394 | print(f"Warning bar alpha become zero at {n}-th time_step"); 395 | break 396 | print("The smallest alpha is ", a.item()) 397 | 398 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import py3Dmol 2 | from rdkit import Chem 3 | 4 | 5 | def visualize_mol(mol, size=(300, 300), surface=False, opacity=0.5): 6 | """Draw molecule in 3D 7 | 8 | Args: 9 | ---- 10 | mol: rdMol, molecule to show 11 | size: tuple(int, int), canvas size 12 | style: str, type of drawing molecule 13 | style can be 'line', 'stick', 'sphere', 'carton' 14 | surface, bool, display SAS 15 | opacity, float, opacity of surface, range 0.0-1.0 16 | Return: 17 | ---- 18 | viewer: py3Dmol.view, a class for constructing embedded 3Dmol.js views in ipython notebooks. 19 | """ 20 | # assert style in ('line', 'stick', 'sphere', 'carton') 21 | mblock = Chem.MolToMolBlock(mol) 22 | viewer = py3Dmol.view(width=size[0], height=size[1]) 23 | viewer.addModel(mblock, 'mol') 24 | viewer.setStyle({'stick':{}, 'sphere':{'radius':0.35}}) 25 | if surface: 26 | viewer.addSurface(py3Dmol.SAS, {'opacity': opacity}) 27 | viewer.zoomTo() 28 | return viewer 29 | --------------------------------------------------------------------------------