├── .gitignore ├── Readme.md ├── clof_environment.yml ├── confgen ├── confgf │ ├── __init__.py │ ├── dataset │ │ ├── __init__.py │ │ └── dataset.py │ ├── layers │ │ ├── __init__.py │ │ ├── clofnet.py │ │ ├── common.py │ │ ├── gat.py │ │ └── gin.py │ ├── models │ │ ├── __init__.py │ │ └── scorenet.py │ ├── runner │ │ ├── __init__.py │ │ └── clofnet_runner.py │ └── utils │ │ ├── __init__.py │ │ ├── chem.py │ │ ├── distgeom.py │ │ ├── evaluation.py │ │ ├── torch.py │ │ └── transforms.py ├── config │ ├── drugs_clofnet.yml │ └── qm9_clofnet.yml └── script │ ├── gen.py │ ├── get_rdkit_results.py │ ├── get_task1_results.py │ ├── get_task2_results.py │ ├── process_GEOM_dataset.py │ ├── process_iso17_dataset.py │ └── train.py ├── main_newtonian.py ├── models └── gcl.py └── newtonian ├── __init__.py ├── clof.py ├── dataloader.py ├── dataset ├── generate_dataset.py ├── script.sh └── synthetic_sim.py ├── dataset4newton.py ├── egnn.py ├── gnn.py └── layers.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | .idea* 141 | *.png 142 | data/* 143 | *.json 144 | *.DS_Store 145 | outputs_vae/* 146 | cpu*.s 147 | gpu*.shh 148 | push.sh 149 | *.npy 150 | *temp* 151 | qm9/lie_conv/data/* 152 | *cache/trans_Q/mutex* 153 | 154 | .vscode 155 | legacy 156 | run*.sh 157 | t*.sh 158 | upload_azure.sh 159 | sweep_statistic.py 160 | statistic.py 161 | eval_nbody.py 162 | eval.py 163 | graph.py 164 | main_nbody_v.py 165 | losess.py 166 | # LICENSE 167 | n_body_system/dataset/upload_azure.sh 168 | cp_dir.py 169 | 170 | test.py 171 | newtonian/legacy 172 | 173 | 174 | utils.py 175 | get_best_results.py 176 | 177 | scripts 178 | 179 | *.yaml 180 | *submit_train.py 181 | 182 | results 183 | dataset 184 | saved 185 | .amltconfig 186 | .amltignore 187 | amlt 188 | query_logs.py 189 | 190 | !confgen/confgf/dataset 191 | *_debug.yml -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # SE(3) Equivariant Graph Neural Networks with Complete Local Frames 2 | 3 | Reference implementation in PyTorch of the equivariant graph neural network (**ClofNet**). You can find the paper [here](https://arxiv.org/abs/2110.14811). 4 | 5 | ## Run the code 6 | 7 | ### Build environment 8 | for newtonian system experiments 9 | ``` 10 | conda create -n clof python=3.7 -y 11 | conda activate clof 12 | conda install -y -c pytorch pytorch=1.7.0 torchvision torchaudio cudatoolkit=10.2 -y 13 | ``` 14 | for conformation generation task 15 | ``` 16 | conda install -y -c rdkit rdkit==2020.03.2.0 17 | conda install -y scikit-learn pandas decorator ipython networkx tqdm matplotlib 18 | conda install -y -c conda-forge easydict 19 | pip install pyyaml wandb 20 | pip install torch-scatter==2.0.6 -f https://pytorch-geometric.com/whl/torch-1.7.0+cu102.html 21 | pip install torch-sparse==0.6.8 -f https://pytorch-geometric.com/whl/torch-1.7.0+cu102.html 22 | pip install torch-cluster==1.5.9 -f https://pytorch-geometric.com/whl/torch-1.7.0+cu102.html 23 | pip install torch-spline-conv==1.2.0 -f https://pytorch-geometric.com/whl/torch-1.7.0+cu102.html 24 | pip install torch-geometric==1.6.3 25 | ``` 26 | 27 | ### Newtonian many-body system 28 | 29 | * This task is inspired by (Kipf et al., 2018; Fuchs et al., 2020; Satorras et al., 2021b), where a 5-body charged system is controlled by the electrostatic force field. Note that the force direction between any two particles is always along the radial direction in the original setting. To validate the effectiveness of ClofNet on modeling arbitrary force directions, we also impose two external force fields into the original system, a gravity field and a Lorentz-like dynamical force field, which can provide more complex and dynamical force directions. 30 | * The original source code for generating trajectories comes from [Kipf et al., 2018](https://github.com/ethanfetaya/NRI) and is modified by [EGNN](https://github.com/vgsatorras/egnn). We further extend the version of EGNN to three new settings, as described in Section 7.1 of our [paper](https://arxiv.org/abs/2110.14811). We sincerely thank the solid contribution of these two works. 31 | 32 | #### Create Many-body dataset 33 | ``` 34 | cd newtonian/dataset 35 | bash script.sh 36 | ``` 37 | 38 | #### Run experiments 39 | * for the ES(5) setting, run 40 | ``` 41 | python -u main_newtonian.py --max_training_samples 3000 --norm_diff True --LR_decay True --lr 0.01 --outf saved/newtonian \ 42 | --data_mode small --decay 0.9 --epochs 400 --exp_name clof_vel_small_5body --model clof_vel --n_layers 4 --data_root 43 | ``` 44 | * for the ES(20) setting, run 45 | ``` 46 | python -u main_newtonian.py --max_training_samples 3000 --norm_diff True --LR_decay True --lr 0.01 --outf saved/newtonian \ 47 | --data_mode small_20body --decay 0.9 --epochs 600 --exp_name clof_vel_small_20body --model clof_vel --n_layers 4 --data_root 48 | ``` 49 | * for the G+ES(20) setting, run 50 | ``` 51 | python -u main_newtonian.py --max_training_samples 3000 --norm_diff True --LR_decay True --lr 0.01 --outf saved/newtonian \ 52 | --data_mode static_20body --decay 0.9 --epochs 200 --exp_name clof_vel_static_20body --model clof_vel --n_layers 4 --data_root 53 | ``` 54 | * for the L+ES(20) setting, run 55 | ``` 56 | python -u main_newtonian.py --max_training_samples 3000 --norm_diff True --LR_decay True --decay 0.9 --lr 0.01 --outf saved/newtonian \ 57 | --data_mode dynamic_20body --epochs 600 --exp_name clof_vel_dynamic_20body --model clof_vel --n_layers 4 --data_root 58 | ``` 59 | 60 | ### Conformation Generation 61 | Equilibrium conformation generation targets on predicting stable 3D structures from 2D molecular graphs. Following [ConfGF](https://arxiv.org/abs/2105.03902), we evaluate the proposed ClofNet on the GEOM-QM9 and GEOM-Drugs datasets ([Axelrod & Gomez-Bombarelli, 2020](https://arxiv.org/abs/2006.05531)) as well as the ISO17 dataset ([Sch¨utt et al., 2017](https://proceedings.neurips.cc/paper/2017/hash/303ed4c69846ab36c2904d3ba8573050-Abstract.html)). For the score-based generation framework, we build our algorithm based on the public codebase of [ConfGF](https://github.com/DeepGraphLearning/ConfGF). We sincerely thank their solid contribution for this field. 62 | 63 | #### Dataset 64 | * **Offical Dataset**: The offical raw GEOM dataset is avaiable [[here]](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/JNGTDF). 65 | 66 | * **Preprocessed dataset**: We use the preprocessed datasets (GEOM, ISO17) published by ConfGF([[google drive folder]](https://drive.google.com/drive/folders/10dWaj5lyMY0VY4Zl0zDPCa69cuQUGb-6?usp=sharing)). 67 | 68 | #### Train 69 | ``` 70 | cd confgen 71 | python -u script/train.py --config_path ./config/qm9_clofnet.yml 72 | ``` 73 | #### Generation 74 | ``` 75 | python -u script/gen.py --config_path ./config/qm9_clofnet.yml --generator EquiGF --eval_epoch [epoch] --start 0 --end 1 76 | ``` 77 | #### Evaluation 78 | ``` 79 | python -u script/get_task1_results.py --input /root/to/generation --core 10 --threshold 0.5 80 | python -u script/get_task1_results.py --input /casp/v-hezha1/workspace/EquiNODE/code_publish/ClofNet/confgen/generation/clofnet4qm9/EquiGF_s0e1epoch398min_sig0.000repeat2.pkl --core 10 --threshold 0.5 81 | ``` 82 | ## Cite 83 | Please cite our paper if you use the model or this code in your own work: 84 | ``` 85 | @inproceedings{weitao_clofnet_2021, 86 | title = {{SE(3)} Equivariant Graph Neural Networks with Complete Local Frames}, 87 | author = {Weitao Du and 88 | He Zhang and 89 | Yuanqi Du and 90 | Qi Meng and 91 | Wei Chen and 92 | Nanning Zheng and 93 | Bin Shao and 94 | Tie{-}Yan Liu}, 95 | booktitle={International Conference on Machine Learning, {ICML} 2022, 17-23 July 96 | 2022, Baltimore, Maryland, {USA}}, 97 | year = {2021} 98 | } 99 | ``` -------------------------------------------------------------------------------- /clof_environment.yml: -------------------------------------------------------------------------------- 1 | name: clof 2 | channels: 3 | - rdkit 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - backcall=0.2.0=pyhd3eb1b0_0 11 | - blas=1.0=mkl 12 | - bottleneck=1.3.5=py37h7deecbd_0 13 | - brotli=1.0.9=h5eee18b_7 14 | - brotli-bin=1.0.9=h5eee18b_7 15 | - bzip2=1.0.8=h7b6447c_0 16 | - ca-certificates=2023.11.17=hbcca054_0 17 | - cairo=1.16.0=hb05425b_5 18 | - certifi=2023.11.17=pyhd8ed1ab_0 19 | - cudatoolkit=10.2.89=hfd86e86_1 20 | - cycler=0.11.0=pyhd3eb1b0_0 21 | - cyrus-sasl=2.1.28=h9c0eb46_1 22 | - dbus=1.13.18=hb2f20db_0 23 | - decorator=5.1.1=pyhd3eb1b0_0 24 | - easydict=1.9=py_0 25 | - expat=2.5.0=h6a678d5_0 26 | - fftw=3.3.9=h27cfd23_1 27 | - fontconfig=2.14.1=h4c34cd2_2 28 | - fonttools=4.25.0=pyhd3eb1b0_0 29 | - freetype=2.12.1=h4a9f257_0 30 | - giflib=5.2.1=h5eee18b_3 31 | - glib=2.69.1=he621ea3_2 32 | - gst-plugins-base=1.14.1=h6a678d5_1 33 | - gstreamer=1.14.1=h5eee18b_1 34 | - icu=58.2=he6710b0_3 35 | - intel-openmp=2021.4.0=h06a4308_3561 36 | - ipython=7.31.1=py37h06a4308_1 37 | - jedi=0.18.1=py37h06a4308_1 38 | - joblib=1.1.1=py37h06a4308_0 39 | - jpeg=9e=h5eee18b_1 40 | - kiwisolver=1.4.4=py37h6a678d5_0 41 | - krb5=1.20.1=h568e23c_1 42 | - lcms2=2.12=h3be6417_0 43 | - ld_impl_linux-64=2.38=h1181459_1 44 | - lerc=3.0=h295c915_0 45 | - libboost=1.67.0=h46d08c1_4 46 | - libbrotlicommon=1.0.9=h5eee18b_7 47 | - libbrotlidec=1.0.9=h5eee18b_7 48 | - libbrotlienc=1.0.9=h5eee18b_7 49 | - libclang=14.0.6=default_hc6dbbc7_1 50 | - libclang13=14.0.6=default_he11475f_1 51 | - libcups=2.4.2=ha637b67_0 52 | - libdeflate=1.17=h5eee18b_1 53 | - libedit=3.1.20230828=h5eee18b_0 54 | - libevent=2.1.12=h8f2d780_0 55 | - libffi=3.4.4=h6a678d5_0 56 | - libgcc-ng=11.2.0=h1234567_1 57 | - libgfortran-ng=11.2.0=h00389a5_1 58 | - libgfortran5=11.2.0=h1234567_1 59 | - libgomp=11.2.0=h1234567_1 60 | - libllvm14=14.0.6=hdb19cb5_3 61 | - libpng=1.6.39=h5eee18b_0 62 | - libpq=12.15=h37d81fd_1 63 | - libstdcxx-ng=11.2.0=h1234567_1 64 | - libtiff=4.5.1=h6a678d5_0 65 | - libuuid=1.41.5=h5eee18b_0 66 | - libuv=1.44.2=h5eee18b_0 67 | - libwebp=1.2.4=h11a3e52_1 68 | - libwebp-base=1.2.4=h5eee18b_1 69 | - libxcb=1.15=h7f8727e_0 70 | - libxkbcommon=1.0.1=h5eee18b_1 71 | - libxml2=2.10.4=hcbfbd50_0 72 | - libxslt=1.1.37=h2085143_0 73 | - lz4-c=1.9.4=h6a678d5_0 74 | - matplotlib=3.5.3=py37h06a4308_0 75 | - matplotlib-base=3.5.3=py37hf590b9c_0 76 | - matplotlib-inline=0.1.6=py37h06a4308_0 77 | - mkl=2021.4.0=h06a4308_640 78 | - mkl-service=2.4.0=py37h7f8727e_0 79 | - mkl_fft=1.3.1=py37hd3c417c_0 80 | - mkl_random=1.2.2=py37h51133e4_0 81 | - munkres=1.1.4=py_0 82 | - mysql=5.7.24=he378463_2 83 | - ncurses=6.4=h6a678d5_0 84 | - networkx=2.6.3=pyhd3eb1b0_0 85 | - ninja=1.10.2=h06a4308_5 86 | - ninja-base=1.10.2=hd09550d_5 87 | - nspr=4.35=h6a678d5_0 88 | - nss=3.89.1=h6a678d5_0 89 | - numexpr=2.8.4=py37he184ba9_0 90 | - numpy=1.21.5=py37h6c91a56_3 91 | - numpy-base=1.21.5=py37ha15fc14_3 92 | - openssl=1.1.1w=h7f8727e_0 93 | - packaging=22.0=py37h06a4308_0 94 | - pandas=1.3.5=py37h8c16a72_0 95 | - parso=0.8.3=pyhd3eb1b0_0 96 | - pcre=8.45=h295c915_0 97 | - pexpect=4.8.0=pyhd3eb1b0_3 98 | - pickleshare=0.7.5=pyhd3eb1b0_1003 99 | - pillow=9.4.0=py37h6a678d5_0 100 | - pip=22.3.1=py37h06a4308_0 101 | - pixman=0.40.0=h7f8727e_1 102 | - ply=3.11=py37_0 103 | - prompt-toolkit=3.0.36=py37h06a4308_0 104 | - ptyprocess=0.7.0=pyhd3eb1b0_2 105 | - py-boost=1.67.0=py37h04863e7_4 106 | - pygments=2.11.2=pyhd3eb1b0_0 107 | - pyparsing=3.0.9=py37h06a4308_0 108 | - pyqt=5.15.7=py37h6a678d5_1 109 | - pyqt5-sip=12.11.0=py37h6a678d5_1 110 | - python=3.7.16=h7a1cb2a_0 111 | - python-dateutil=2.8.2=pyhd3eb1b0_0 112 | - pytorch=1.7.0=py3.7_cuda10.2.89_cudnn7.6.5_0 113 | - pytz=2022.7=py37h06a4308_0 114 | - qt-main=5.15.2=h5b8104b_9 115 | - qt-webengine=5.15.9=h9ab4d14_7 116 | - qtwebkit=5.212=h3fafdc1_5 117 | - rdkit=2020.03.2.0=py37hc20afe1_1 118 | - readline=8.2=h5eee18b_0 119 | - scikit-learn=1.0.2=py37h51133e4_1 120 | - scipy=1.7.3=py37h6c91a56_2 121 | - setuptools=65.6.3=py37h06a4308_0 122 | - sip=6.6.2=py37h6a678d5_0 123 | - six=1.16.0=pyhd3eb1b0_1 124 | - sqlite=3.41.2=h5eee18b_0 125 | - threadpoolctl=2.2.0=pyh0d69192_0 126 | - tk=8.6.12=h1ccaba5_0 127 | - toml=0.10.2=pyhd3eb1b0_0 128 | - torchaudio=0.7.0=py37 129 | - torchvision=0.8.1=py37_cu102 130 | - tornado=6.2=py37h5eee18b_0 131 | - tqdm=4.64.1=py37h06a4308_0 132 | - traitlets=5.7.1=py37h06a4308_0 133 | - typing_extensions=4.3.0=py37h06a4308_0 134 | - wcwidth=0.2.5=pyhd3eb1b0_0 135 | - wheel=0.38.4=py37h06a4308_0 136 | - xz=5.4.5=h5eee18b_0 137 | - zlib=1.2.13=h5eee18b_0 138 | - zstd=1.5.5=hc292b87_0 139 | - pip: 140 | - appdirs==1.4.4 141 | - ase==3.22.1 142 | - charset-normalizer==3.3.2 143 | - click==8.1.7 144 | - dataclasses==0.6 145 | - docker-pycreds==0.4.0 146 | - future==0.18.3 147 | - gitdb==4.0.11 148 | - gitpython==3.1.41 149 | - googledrivedownloader==0.4 150 | - h5py==3.8.0 151 | - idna==3.6 152 | - importlib-metadata==4.13.0 153 | - isodate==0.6.1 154 | - jinja2==3.1.3 155 | - llvmlite==0.39.1 156 | - markupsafe==2.1.3 157 | - numba==0.56.4 158 | - protobuf==4.24.4 159 | - psutil==5.9.7 160 | - python-louvain==0.16 161 | - pyyaml==6.0.1 162 | - rdflib==6.3.2 163 | - requests==2.31.0 164 | - sentry-sdk==1.39.2 165 | - setproctitle==1.3.3 166 | - smmap==5.0.1 167 | - torch-cluster==1.5.9 168 | - torch-geometric==1.6.3 169 | - torch-scatter==2.0.6 170 | - torch-sparse==0.6.8 171 | - torch-spline-conv==1.2.0 172 | - urllib3==2.0.7 173 | - wandb==0.16.2 174 | - zipp==3.15.0 175 | -------------------------------------------------------------------------------- /confgen/confgf/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" -------------------------------------------------------------------------------- /confgen/confgf/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import GEOMDataset, GEOMDataset_PackedConf, \ 2 | rdmol_to_data, smiles_to_data, preprocess_GEOM_dataset, get_GEOM_testset, preprocess_iso17_dataset 3 | 4 | 5 | __all__ = ["GEOMDataset", 6 | "GEOMDataset_PackedConf", 7 | "rdmol_to_data", 8 | "smiles_to_data", 9 | "preprocess_GEOM_dataset", 10 | "get_GEOM_testset", 11 | "preprocess_iso17_dataset" 12 | ] -------------------------------------------------------------------------------- /confgen/confgf/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import copy 4 | import json 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import random 9 | 10 | import torch 11 | from torch_geometric.data import Data, Dataset 12 | from torch_geometric.transforms import Compose 13 | from torch_geometric.utils import to_networkx 14 | from torch_scatter import scatter 15 | #from torch.utils.data import Dataset 16 | 17 | import rdkit 18 | from rdkit import Chem 19 | from rdkit.Chem.rdchem import Mol, HybridizationType, BondType 20 | from rdkit import RDLogger 21 | import networkx as nx 22 | from tqdm import tqdm 23 | RDLogger.DisableLog('rdApp.*') 24 | 25 | from confgf import utils 26 | 27 | 28 | def rdmol_to_data(mol:Mol, smiles=None): 29 | assert mol.GetNumConformers() == 1 30 | N = mol.GetNumAtoms() 31 | 32 | pos = torch.tensor(mol.GetConformer(0).GetPositions(), dtype=torch.float32) 33 | 34 | atomic_number = [] 35 | aromatic = [] 36 | sp = [] 37 | sp2 = [] 38 | sp3 = [] 39 | num_hs = [] 40 | for atom in mol.GetAtoms(): 41 | atomic_number.append(atom.GetAtomicNum()) 42 | aromatic.append(1 if atom.GetIsAromatic() else 0) 43 | hybridization = atom.GetHybridization() 44 | sp.append(1 if hybridization == HybridizationType.SP else 0) 45 | sp2.append(1 if hybridization == HybridizationType.SP2 else 0) 46 | sp3.append(1 if hybridization == HybridizationType.SP3 else 0) 47 | 48 | z = torch.tensor(atomic_number, dtype=torch.long) 49 | 50 | row, col, edge_type = [], [], [] 51 | for bond in mol.GetBonds(): 52 | start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() 53 | row += [start, end] 54 | col += [end, start] 55 | edge_type += 2 * [utils.BOND_TYPES[bond.GetBondType()]] 56 | 57 | edge_index = torch.tensor([row, col], dtype=torch.long) 58 | edge_type = torch.tensor(edge_type) 59 | 60 | perm = (edge_index[0] * N + edge_index[1]).argsort() 61 | edge_index = edge_index[:, perm] 62 | edge_type = edge_type[perm] 63 | 64 | row, col = edge_index 65 | hs = (z == 1).to(torch.float32) 66 | 67 | num_hs = scatter(hs[row], col, dim_size=N, reduce='sum').tolist() 68 | 69 | if smiles is None: 70 | smiles = Chem.MolToSmiles(mol) 71 | 72 | data = Data(atom_type=z, pos=pos, edge_index=edge_index, edge_type=edge_type, 73 | rdmol=copy.deepcopy(mol), smiles=smiles) 74 | #data.nx = to_networkx(data, to_undirected=True) 75 | 76 | return data 77 | 78 | def smiles_to_data(smiles): 79 | """ 80 | Convert a SMILES to a pyg object that can be fed into ConfGF for generation 81 | """ 82 | try: 83 | mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) 84 | except: 85 | return None 86 | 87 | N = mol.GetNumAtoms() 88 | pos = torch.rand((N, 3), dtype=torch.float32) 89 | 90 | atomic_number = [] 91 | aromatic = [] 92 | 93 | for atom in mol.GetAtoms(): 94 | atomic_number.append(atom.GetAtomicNum()) 95 | aromatic.append(1 if atom.GetIsAromatic() else 0) 96 | 97 | z = torch.tensor(atomic_number, dtype=torch.long) 98 | 99 | row, col, edge_type = [], [], [] 100 | for bond in mol.GetBonds(): 101 | start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() 102 | row += [start, end] 103 | col += [end, start] 104 | edge_type += 2 * [utils.BOND_TYPES[bond.GetBondType()]] 105 | 106 | edge_index = torch.tensor([row, col], dtype=torch.long) 107 | edge_type = torch.tensor(edge_type) 108 | 109 | perm = (edge_index[0] * N + edge_index[1]).argsort() 110 | edge_index = edge_index[:, perm] 111 | edge_type = edge_type[perm] 112 | 113 | row, col = edge_index 114 | 115 | data = Data(atom_type=z, pos=pos, edge_index=edge_index, edge_type=edge_type, 116 | rdmol=copy.deepcopy(mol), smiles=smiles) 117 | 118 | transform = Compose([ 119 | utils.AddHigherOrderEdges(order=3), 120 | utils.AddEdgeLength(), 121 | utils.AddPlaceHolder(), 122 | utils.AddEdgeName() 123 | ]) 124 | 125 | return transform(data) 126 | 127 | 128 | def preprocess_iso17_dataset(base_path): 129 | train_path = os.path.join(base_path, 'iso17_split-0_train.pkl') 130 | test_path = os.path.join(base_path, 'iso17_split-0_test.pkl') 131 | with open(train_path, 'rb') as fin: 132 | raw_train = pickle.load(fin) 133 | with open(test_path, 'rb') as fin: 134 | raw_test = pickle.load(fin) 135 | 136 | smiles_list_train = [utils.mol_to_smiles(mol) for mol in raw_train] 137 | smiles_set_train = list(set(smiles_list_train)) 138 | smiles_list_test = [utils.mol_to_smiles(mol) for mol in raw_test] 139 | smiles_set_test = list(set(smiles_list_test)) 140 | 141 | print('preprocess train...') 142 | all_train = [] 143 | for i in tqdm(range(len(raw_train))): 144 | smiles = smiles_list_train[i] 145 | data = rdmol_to_data(raw_train[i], smiles=smiles) 146 | all_train.append(data) 147 | 148 | print('Train | find %d molecules with %d confs' % (len(smiles_set_train), len(all_train))) 149 | 150 | print('preprocess test...') 151 | all_test = [] 152 | for i in tqdm(range(len(raw_test))): 153 | smiles = smiles_list_test[i] 154 | data = rdmol_to_data(raw_test[i], smiles=smiles) 155 | all_test.append(data) 156 | 157 | print('Test | find %d molecules with %d confs' % (len(smiles_set_test), len(all_test))) 158 | 159 | return all_train, all_test 160 | 161 | 162 | 163 | 164 | 165 | def preprocess_GEOM_dataset(base_path, dataset_name, conf_per_mol=5, train_size=0.8, tot_mol_size=50000, seed=None): 166 | """ 167 | base_path: directory that contains GEOM dataset 168 | dataset_name: dataset name in [qm9, drugs] 169 | conf_per_mol: keep mol that has at least conf_per_mol confs, and sampling the most probable conf_per_mol confs 170 | train_size ratio, val = test = (1-train_size) / 2 171 | tot_mol_size: max num of mols. The total number of final confs should be tot_mol_size * conf_per_mol 172 | seed: rand seed for RNG 173 | """ 174 | 175 | # set random seed 176 | if seed is None: 177 | seed = 2021 178 | np.random.seed(seed) 179 | random.seed(seed) 180 | 181 | 182 | # read summary file 183 | assert dataset_name in ['qm9', 'drugs'] 184 | summary_path = os.path.join(base_path, 'summary_%s.json' % dataset_name) 185 | with open(summary_path, 'r') as f: 186 | summ = json.load(f) 187 | 188 | # filter valid pickle path 189 | smiles_list = [] 190 | pickle_path_list = [] 191 | num_mols = 0 192 | num_confs = 0 193 | for smiles, meta_mol in tqdm(summ.items()): 194 | u_conf = meta_mol.get('uniqueconfs') 195 | if u_conf is None: 196 | continue 197 | pickle_path = meta_mol.get('pickle_path') 198 | if pickle_path is None: 199 | continue 200 | if u_conf < conf_per_mol: 201 | continue 202 | num_mols += 1 203 | num_confs += conf_per_mol 204 | smiles_list.append(smiles) 205 | pickle_path_list.append(pickle_path) 206 | 207 | random.shuffle(pickle_path_list) 208 | assert len(pickle_path_list) >= tot_mol_size, 'the length of all available mols is %d, which is smaller than tot mol size %d' % (len(pickle_path_list), tot_mol_size) 209 | 210 | pickle_path_list = pickle_path_list[:tot_mol_size] 211 | 212 | print('pre-filter: find %d molecules with %d confs, use %d molecules with %d confs' % (num_mols, num_confs, tot_mol_size, tot_mol_size*conf_per_mol)) 213 | 214 | 215 | # 1. select the most probable 'conf_per_mol' confs of each 2D molecule 216 | # 2. split the dataset based on 2D structure, i.e., test on unseen graphs 217 | train_data, val_data, test_data = [], [], [] 218 | val_size = test_size = (1. - train_size) / 2 219 | 220 | # generate train, val, test split indexes 221 | split_indexes = list(range(tot_mol_size)) 222 | random.shuffle(split_indexes) 223 | index2split = {} 224 | for i in range(0, int(len(split_indexes) * train_size)): 225 | index2split[split_indexes[i]] = 'train' 226 | for i in range(int(len(split_indexes) * train_size), int(len(split_indexes) * (train_size + val_size))): 227 | index2split[split_indexes[i]] = 'val' 228 | for i in range(int(len(split_indexes) * (train_size + val_size)), len(split_indexes)): 229 | index2split[split_indexes[i]] = 'test' 230 | 231 | 232 | num_mols = np.zeros(4, dtype=int) # (tot, train, val, test) 233 | num_confs = np.zeros(4, dtype=int) # (tot, train, val, test) 234 | 235 | 236 | bad_case = 0 237 | 238 | for i in tqdm(range(len(pickle_path_list))): 239 | 240 | with open(os.path.join(base_path, pickle_path_list[i]), 'rb') as fin: 241 | mol = pickle.load(fin) 242 | 243 | if mol.get('uniqueconfs') > len(mol.get('conformers')): 244 | bad_case += 1 245 | continue 246 | if mol.get('uniqueconfs') <= 0: 247 | bad_case += 1 248 | continue 249 | 250 | datas = [] 251 | smiles = mol.get('smiles') 252 | 253 | if mol.get('uniqueconfs') == conf_per_mol: 254 | # use all confs 255 | conf_ids = np.arange(mol.get('uniqueconfs')) 256 | else: 257 | # filter the most probable 'conf_per_mol' confs 258 | all_weights = np.array([_.get('boltzmannweight', -1.) for _ in mol.get('conformers')]) 259 | descend_conf_id = (-all_weights).argsort() 260 | conf_ids = descend_conf_id[:conf_per_mol] 261 | 262 | for conf_id in conf_ids: 263 | conf_meta = mol.get('conformers')[conf_id] 264 | data = rdmol_to_data(conf_meta.get('rd_mol'), smiles=smiles) 265 | labels = { 266 | 'totalenergy': conf_meta['totalenergy'], 267 | 'boltzmannweight': conf_meta['boltzmannweight'], 268 | } 269 | for k, v in labels.items(): 270 | data[k] = torch.tensor([v], dtype=torch.float32) 271 | data['idx'] = torch.tensor([i], dtype=torch.long) 272 | datas.append(data) 273 | assert len(datas) == conf_per_mol 274 | 275 | if index2split[i] == 'train': 276 | train_data.extend(datas) 277 | num_mols += [1, 1, 0, 0] 278 | num_confs += [len(datas), len(datas), 0, 0] 279 | elif index2split[i] == 'val': 280 | val_data.extend(datas) 281 | num_mols += [1, 0, 1, 0] 282 | num_confs += [len(datas), 0, len(datas), 0] 283 | elif index2split[i] == 'test': 284 | test_data.extend(datas) 285 | num_mols += [1, 0, 0, 1] 286 | num_confs += [len(datas), 0, 0, len(datas)] 287 | else: 288 | raise ValueError('unknown index2split value.') 289 | 290 | print('post-filter: find %d molecules with %d confs' % (num_mols[0], num_confs[0])) 291 | print('train size: %d molecules with %d confs' % (num_mols[1], num_confs[1])) 292 | print('val size: %d molecules with %d confs' % (num_mols[2], num_confs[2])) 293 | print('test size: %d molecules with %d confs' % (num_mols[3], num_confs[3])) 294 | print('bad case: %d' % bad_case) 295 | print('done!') 296 | 297 | return train_data, val_data, test_data, index2split 298 | 299 | 300 | def get_GEOM_testset(base_path, dataset_name, block, tot_mol_size=200, seed=None, confmin=50, confmax=500): 301 | """ 302 | base_path: directory that contains GEOM dataset 303 | dataset_name: dataset name, should be in [qm9, drugs] 304 | block: block the training and validation set 305 | tot_mol_size: size of the test set 306 | seed: rand seed for RNG 307 | confmin and confmax: range of the number of conformations 308 | """ 309 | 310 | #block smiles in train / val 311 | block_smiles = defaultdict(int) 312 | for block_ in block: 313 | for i in range(len(block_)): 314 | block_smiles[block_[i].smiles] = 1 315 | 316 | # set random seed 317 | if seed is None: 318 | seed = 2021 319 | np.random.seed(seed) 320 | random.seed(seed) 321 | 322 | 323 | # read summary file 324 | assert dataset_name in ['qm9', 'drugs'] 325 | summary_path = os.path.join(base_path, 'summary_%s.json' % dataset_name) 326 | with open(summary_path, 'r') as f: 327 | summ = json.load(f) 328 | 329 | # filter valid pickle path 330 | smiles_list = [] 331 | pickle_path_list = [] 332 | num_mols = 0 333 | num_confs = 0 334 | for smiles, meta_mol in tqdm(summ.items()): 335 | u_conf = meta_mol.get('uniqueconfs') 336 | if u_conf is None: 337 | continue 338 | pickle_path = meta_mol.get('pickle_path') 339 | if pickle_path is None: 340 | continue 341 | if u_conf < confmin or u_conf > confmax: 342 | continue 343 | if block_smiles[smiles] == 1: 344 | continue 345 | 346 | num_mols += 1 347 | num_confs += u_conf 348 | smiles_list.append(smiles) 349 | pickle_path_list.append(pickle_path) 350 | 351 | 352 | random.shuffle(pickle_path_list) 353 | assert len(pickle_path_list) >= tot_mol_size, 'the length of all available mols is %d, which is smaller than tot mol size %d' % (len(pickle_path_list), tot_mol_size) 354 | 355 | pickle_path_list = pickle_path_list[:tot_mol_size] 356 | 357 | print('pre-filter: find %d molecules with %d confs' % (num_mols, num_confs)) 358 | 359 | 360 | bad_case = 0 361 | all_test_data = [] 362 | num_valid_mol = 0 363 | num_valid_conf = 0 364 | 365 | for i in tqdm(range(len(pickle_path_list))): 366 | 367 | with open(os.path.join(base_path, pickle_path_list[i]), 'rb') as fin: 368 | mol = pickle.load(fin) 369 | 370 | if mol.get('uniqueconfs') > len(mol.get('conformers')): 371 | bad_case += 1 372 | continue 373 | if mol.get('uniqueconfs') <= 0: 374 | bad_case += 1 375 | continue 376 | 377 | datas = [] 378 | smiles = mol.get('smiles') 379 | 380 | conf_ids = np.arange(mol.get('uniqueconfs')) 381 | 382 | for conf_id in conf_ids: 383 | conf_meta = mol.get('conformers')[conf_id] 384 | data = rdmol_to_data(conf_meta.get('rd_mol'), smiles=smiles) 385 | labels = { 386 | 'totalenergy': conf_meta['totalenergy'], 387 | 'boltzmannweight': conf_meta['boltzmannweight'], 388 | } 389 | for k, v in labels.items(): 390 | data[k] = torch.tensor([v], dtype=torch.float32) 391 | data['idx'] = torch.tensor([i], dtype=torch.long) 392 | datas.append(data) 393 | 394 | 395 | all_test_data.extend(datas) 396 | num_valid_mol += 1 397 | num_valid_conf += len(datas) 398 | 399 | print('poster-filter: find %d molecules with %d confs' % (num_valid_mol, num_valid_conf)) 400 | 401 | 402 | return all_test_data 403 | 404 | 405 | 406 | class GEOMDataset(Dataset): 407 | 408 | def __init__(self, data=None, transform=None): 409 | super().__init__() 410 | self.data = data 411 | self.transform = transform 412 | self.atom_types = self._atom_types() 413 | self.edge_types = self._edge_types() 414 | 415 | def __getitem__(self, idx): 416 | 417 | data = self.data[idx].clone() 418 | pos_center = data.pos.mean(dim=0) 419 | data.pos = data.pos - pos_center 420 | if self.transform is not None: 421 | data = self.transform(data) 422 | return data 423 | 424 | def __len__(self): 425 | return len(self.data) 426 | 427 | 428 | def _atom_types(self): 429 | """All atom types.""" 430 | atom_types = set() 431 | for graph in self.data: 432 | atom_types.update(graph.atom_type.tolist()) 433 | return sorted(atom_types) 434 | 435 | def _edge_types(self): 436 | """All edge types.""" 437 | edge_types = set() 438 | for graph in self.data: 439 | edge_types.update(graph.edge_type.tolist()) 440 | return sorted(edge_types) 441 | 442 | 443 | 444 | 445 | class GEOMDataset_PackedConf(GEOMDataset): 446 | 447 | def __init__(self, data=None, transform=None): 448 | super(GEOMDataset_PackedConf, self).__init__(data, transform) 449 | self._pack_data_by_mol() 450 | 451 | def _pack_data_by_mol(self): 452 | """ 453 | pack confs with same mol into a single data object 454 | """ 455 | self._packed_data = defaultdict(list) 456 | if hasattr(self.data, 'idx'): 457 | for i in range(len(self.data)): 458 | self._packed_data[self.data[i].idx.item()].append(self.data[i]) 459 | else: 460 | for i in range(len(self.data)): 461 | self._packed_data[self.data[i].smiles].append(self.data[i]) 462 | print('got %d molecules with %d confs' % (len(self._packed_data), len(self.data))) 463 | 464 | new_data = [] 465 | # logic 466 | # save graph structure for each mol once, but store all confs 467 | cnt = 0 468 | for k, v in self._packed_data.items(): 469 | data = copy.deepcopy(v[0]) 470 | all_pos = [] 471 | for i in range(len(v)): 472 | pos_center = v[i].pos.mean(dim=0) 473 | pos = v[i].pos - pos_center 474 | all_pos.append(pos) 475 | data.pos_ref = torch.cat(all_pos, 0) # (num_conf*num_node, 3) 476 | data.num_pos_ref = torch.tensor([len(all_pos)], dtype=torch.long) 477 | #del data.pos 478 | 479 | if hasattr(data, 'totalenergy'): 480 | del data.totalenergy 481 | if hasattr(data, 'boltzmannweight'): 482 | del data.boltzmannweight 483 | new_data.append(data) 484 | self.new_data = new_data 485 | 486 | 487 | def __getitem__(self, idx): 488 | 489 | data = self.new_data[idx].clone() 490 | if self.transform is not None: 491 | data = self.transform(data) 492 | return data 493 | 494 | def __len__(self): 495 | return len(self.new_data) 496 | 497 | 498 | 499 | if __name__ == '__main__': 500 | pass -------------------------------------------------------------------------------- /confgen/confgf/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import MultiLayerPerceptron 2 | from .gin import GraphIsomorphismNetwork 3 | from .clofnet import GradientGCN 4 | 5 | __all__ = ["MultiLayerPerceptron", "GraphIsomorphismNetwork", "GradientGCN"] 6 | -------------------------------------------------------------------------------- /confgen/confgf/layers/clofnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from typing import Callable, Union 3 | from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size 4 | 5 | import torch 6 | from torch import Tensor 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch_sparse import SparseTensor, matmul 10 | from torch_geometric.nn.conv import MessagePassing 11 | 12 | from .common import MeanReadout, SumReadout, MultiLayerPerceptron 13 | from .gin import GINEConv 14 | from .gat import Transformer_layer 15 | 16 | class EquiLayer(MessagePassing): 17 | 18 | def __init__(self, eps: float = 0., train_eps: bool = False, 19 | activation="softplus", **kwargs): 20 | super(EquiLayer, self).__init__(aggr='add', **kwargs) 21 | self.initial_eps = eps 22 | 23 | if isinstance(activation, str): 24 | self.activation = getattr(F, activation) 25 | else: 26 | self.activation = None 27 | 28 | if train_eps: 29 | self.eps = torch.nn.Parameter(torch.Tensor([eps])) 30 | else: 31 | self.register_buffer('eps', torch.Tensor([eps])) 32 | 33 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, 34 | edge_attr: OptTensor = None, size: Size = None) -> Tensor: 35 | """""" 36 | if isinstance(x, Tensor): 37 | x: OptPairTensor = (x, x) 38 | 39 | # Node and edge feature dimensionalites need to match. 40 | if isinstance(edge_index, Tensor): 41 | assert edge_attr is not None 42 | # assert x[0].size(-1) == edge_attr.size(-1) 43 | elif isinstance(edge_index, SparseTensor): 44 | assert x[0].size(-1) == edge_index.size(-1) 45 | 46 | # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) 47 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) 48 | return out 49 | 50 | def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor: 51 | if self.activation: 52 | return self.activation(x_j + edge_attr) 53 | else: 54 | # return x_j + edge_attr 55 | return edge_attr 56 | 57 | def __repr__(self): 58 | return '{}(nn={})'.format(self.__class__.__name__, self.nn) 59 | 60 | 61 | class GradientGCN(torch.nn.Module): 62 | 63 | def __init__(self, hidden_dim, hidden_coff_dim=64, num_convs=3, activation="softplus", readout="sum", short_cut=False, concat_hidden=False): 64 | super(GradientGCN, self).__init__() 65 | 66 | self.hidden_dim = hidden_dim 67 | # self.num_convs = num_convs 68 | self.num_layers = 2 69 | self.num_convs = 2 70 | self.short_cut = short_cut 71 | self.num_head = 8 72 | self.dropout = 0.1 73 | self.concat_hidden = concat_hidden 74 | self.hidden_coff_dim = hidden_coff_dim 75 | 76 | if isinstance(activation, str): 77 | self.activation = getattr(F, activation) 78 | else: 79 | self.activation = None 80 | 81 | # self.conv_modules = nn.ModuleList() 82 | self.transformers = nn.ModuleList() 83 | self.equi_modules = nn.ModuleList() 84 | self.dynamic_mlp_modules = nn.ModuleList() 85 | for _ in range(self.num_layers): 86 | trans_convs = nn.ModuleList() 87 | for i in range(self.num_convs): 88 | trans_convs.append( 89 | Transformer_layer(self.num_head, self.hidden_dim, dropout=self.dropout, activation=activation) 90 | ) 91 | # self.conv_modules.append(convs) 92 | self.transformers.append(trans_convs) 93 | self.equi_modules.append(EquiLayer(activation=False)) 94 | self.dynamic_mlp_modules.append( 95 | nn.Sequential( 96 | nn.Linear(2 * self.hidden_dim, self.hidden_coff_dim), 97 | nn.Softplus(), 98 | nn.Linear(self.hidden_coff_dim, 3)) 99 | ) 100 | 101 | 102 | def coord2basis(self, data): 103 | coord_diff = data.pert_pos[data.edge_index[0]] - data.pert_pos[data.edge_index[1]] 104 | radial = torch.sum((coord_diff)**2, 1).unsqueeze(1) 105 | coord_cross = torch.cross(data.pert_pos[data.edge_index[0]], data.pert_pos[data.edge_index[1]]) 106 | 107 | norm = torch.sqrt(radial) + 1 108 | coord_diff = coord_diff / norm 109 | cross_norm = torch.sqrt(torch.sum((coord_cross)**2, 1).unsqueeze(1)) + 1 110 | coord_cross = coord_cross / cross_norm 111 | 112 | coord_vertical = torch.cross(coord_diff, coord_cross) 113 | 114 | return coord_diff, coord_cross, coord_vertical 115 | 116 | 117 | def forward(self, data, node_attr, edge_attr): 118 | """ 119 | Input: 120 | data: (torch_geometric.data.Data): batched graph 121 | node_attr: node feature tensor with shape (num_node, hidden) 122 | edge_attr: edge feature tensor with shape (num_edge, hidden) 123 | Output: 124 | node_attr 125 | graph feature 126 | """ 127 | 128 | hiddens = [] 129 | conv_input = node_attr # (num_node, hidden) 130 | 131 | for module_idx, convs in enumerate(self.transformers): 132 | for conv_idx, conv in enumerate(convs): 133 | hidden = conv(data.edge_index, conv_input, edge_attr) 134 | if conv_idx < len(convs) - 1 and self.activation is not None: 135 | hidden = self.activation(hidden) 136 | assert hidden.shape == conv_input.shape 137 | if self.short_cut and hidden.shape == conv_input.shape: 138 | hidden += conv_input 139 | 140 | hiddens.append(hidden) 141 | conv_input = hidden 142 | 143 | if self.concat_hidden: 144 | node_feature = torch.cat(hiddens, dim=-1) 145 | else: 146 | node_feature = hiddens[-1] 147 | 148 | h_row, h_col = node_feature[data.edge_index[0]], node_feature[data.edge_index[1]] # (num_edge, hidden) 149 | edge_feature = torch.cat([h_row*h_col, edge_attr], dim=-1) # (num_edge, 2 * hidden) 150 | ## generate gradient 151 | dynamic_coff = self.dynamic_mlp_modules[module_idx](edge_feature) 152 | coord_diff, coord_cross, coord_vertical = self.coord2basis(data) 153 | basis_mix = dynamic_coff[:, :1] * coord_diff + dynamic_coff[:, 1:2] * coord_cross + dynamic_coff[:, 2:3] * coord_vertical 154 | 155 | if module_idx == 0: 156 | gradient = self.equi_modules[module_idx](node_feature, data.edge_index, basis_mix) 157 | else: 158 | gradient += self.equi_modules[module_idx](node_feature, data.edge_index, basis_mix) 159 | 160 | return { 161 | "node_feature": node_feature, 162 | "gradient": gradient 163 | } 164 | 165 | -------------------------------------------------------------------------------- /confgen/confgf/layers/common.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from typing import Callable, Union 3 | from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch_scatter import scatter_mean, scatter_add 9 | 10 | 11 | class MeanReadout(nn.Module): 12 | """Mean readout operator over graphs with variadic sizes.""" 13 | 14 | def forward(self, data, input): 15 | """ 16 | Perform readout over the graph(s). 17 | 18 | Parameters: 19 | data (torch_geometric.data.Data): batched graph 20 | input (Tensor): node representations 21 | 22 | Returns: 23 | Tensor: graph representations 24 | """ 25 | output = scatter_mean(input, data.batch, dim=0, dim_size=data.num_graphs) 26 | return output 27 | 28 | 29 | class SumReadout(nn.Module): 30 | """Sum readout operator over graphs with variadic sizes.""" 31 | 32 | def forward(self, data, input): 33 | """ 34 | Perform readout over the graph(s). 35 | 36 | Parameters: 37 | data (torch_geometric.data.Data): batched graph 38 | input (Tensor): node representations 39 | 40 | Returns: 41 | Tensor: graph representations 42 | """ 43 | output = scatter_add(input, data.batch, dim=0, dim_size=data.num_graphs) 44 | return output 45 | 46 | 47 | 48 | class MultiLayerPerceptron(nn.Module): 49 | """ 50 | Multi-layer Perceptron. 51 | 52 | Note there is no activation or dropout in the last layer. 53 | 54 | Parameters: 55 | input_dim (int): input dimension 56 | hidden_dim (list of int): hidden dimensions 57 | activation (str or function, optional): activation function 58 | dropout (float, optional): dropout rate 59 | """ 60 | 61 | def __init__(self, input_dim, hidden_dims, activation="relu", dropout=0): 62 | super(MultiLayerPerceptron, self).__init__() 63 | 64 | self.dims = [input_dim] + hidden_dims 65 | if isinstance(activation, str): 66 | self.activation = getattr(F, activation) 67 | else: 68 | self.activation = None 69 | if dropout: 70 | self.dropout = nn.Dropout(dropout) 71 | else: 72 | self.dropout = None 73 | 74 | self.layers = nn.ModuleList() 75 | for i in range(len(self.dims) - 1): 76 | self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1])) 77 | 78 | self.reset_parameters() 79 | 80 | def reset_parameters(self): 81 | for i, layer in enumerate(self.layers): 82 | nn.init.xavier_uniform_(layer.weight) 83 | nn.init.constant_(layer.bias, 0.) 84 | 85 | def forward(self, input): 86 | """""" 87 | x = input 88 | for i, layer in enumerate(self.layers): 89 | x = layer(x) 90 | if i < len(self.layers) - 1: 91 | if self.activation: 92 | x = self.activation(x) 93 | if self.dropout: 94 | x = self.dropout(x) 95 | return x 96 | 97 | 98 | -------------------------------------------------------------------------------- /confgen/confgf/layers/gat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import TransformerConv 5 | 6 | 7 | class Transformer_layer(nn.Module): 8 | def __init__( 9 | self, n_head, hidden_dim, dropout=0.2, activation="softplus" 10 | ): 11 | super(Transformer_layer, self).__init__() 12 | 13 | if isinstance(activation, str): 14 | self.activation = getattr(F, activation) 15 | else: 16 | self.activation = None 17 | assert hidden_dim % n_head == 0 18 | self.MHA = TransformerConv( 19 | in_channels=hidden_dim, 20 | out_channels=int(hidden_dim // n_head), 21 | heads=n_head, 22 | dropout=dropout, 23 | edge_dim=hidden_dim, 24 | ) 25 | self.FFN = nn.Sequential( 26 | nn.Linear(hidden_dim, hidden_dim), 27 | nn.SiLU(), 28 | nn.Dropout(dropout), 29 | nn.Linear(hidden_dim, hidden_dim) 30 | ) 31 | self.norm1 = nn.LayerNorm(hidden_dim) 32 | self.norm2 = nn.LayerNorm(hidden_dim) 33 | 34 | def forward(self, edge_index, node_attr, edge_attr): 35 | x = self.MHA(node_attr, edge_index, edge_attr) 36 | # e_index, attn_weights = tuple_attn 37 | node_attr = node_attr + self.norm1(x) 38 | x = self.FFN(node_attr) 39 | node_attr = node_attr + self.norm2(x) 40 | 41 | return node_attr 42 | 43 | -------------------------------------------------------------------------------- /confgen/confgf/layers/gin.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from typing import Callable, Union 3 | from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size 4 | 5 | import torch 6 | from torch import Tensor 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch_sparse import SparseTensor, matmul 10 | from torch_geometric.nn.conv import MessagePassing 11 | 12 | from .common import MeanReadout, SumReadout, MultiLayerPerceptron 13 | 14 | 15 | class GINEConv(MessagePassing): 16 | 17 | def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False, 18 | activation="softplus", **kwargs): 19 | super(GINEConv, self).__init__(aggr='add', **kwargs) 20 | self.nn = nn 21 | self.initial_eps = eps 22 | 23 | if isinstance(activation, str): 24 | self.activation = getattr(F, activation) 25 | else: 26 | self.activation = None 27 | 28 | if train_eps: 29 | self.eps = torch.nn.Parameter(torch.Tensor([eps])) 30 | else: 31 | self.register_buffer('eps', torch.Tensor([eps])) 32 | 33 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, 34 | edge_attr: OptTensor = None, size: Size = None) -> Tensor: 35 | """""" 36 | if isinstance(x, Tensor): 37 | x: OptPairTensor = (x, x) 38 | 39 | # Node and edge feature dimensionalites need to match. 40 | if isinstance(edge_index, Tensor): 41 | assert edge_attr is not None 42 | assert x[0].size(-1) == edge_attr.size(-1) 43 | elif isinstance(edge_index, SparseTensor): 44 | assert x[0].size(-1) == edge_index.size(-1) 45 | 46 | # propagate_type: (x: OptPairTensor, edge_attr: OptTensor) 47 | out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size) 48 | 49 | x_r = x[1] 50 | if x_r is not None: 51 | out += (1 + self.eps) * x_r 52 | 53 | return self.nn(out) 54 | 55 | def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor: 56 | if self.activation: 57 | return self.activation(x_j + edge_attr) 58 | else: 59 | return x_j + edge_attr 60 | 61 | def __repr__(self): 62 | return '{}(nn={})'.format(self.__class__.__name__, self.nn) 63 | 64 | 65 | class GraphIsomorphismNetwork(torch.nn.Module): 66 | 67 | def __init__(self, hidden_dim, num_convs=3, activation="softplus", readout="sum", short_cut=False, concat_hidden=False): 68 | super(GraphIsomorphismNetwork, self).__init__() 69 | 70 | self.hidden_dim = hidden_dim 71 | self.num_convs = num_convs 72 | self.short_cut = short_cut 73 | self.concat_hidden = concat_hidden 74 | 75 | if isinstance(activation, str): 76 | self.activation = getattr(F, activation) 77 | else: 78 | self.activation = None 79 | 80 | 81 | 82 | self.convs = nn.ModuleList() 83 | for i in range(self.num_convs): 84 | self.convs.append(GINEConv(MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], \ 85 | activation=activation), activation=activation)) 86 | 87 | if readout == "sum": 88 | self.readout = SumReadout() 89 | elif readout == "mean": 90 | self.readout = MeanReadout() 91 | else: 92 | raise ValueError("Unknown readout `%s`" % readout) 93 | 94 | 95 | 96 | def forward(self, data, node_attr, edge_attr): 97 | """ 98 | Input: 99 | data: (torch_geometric.data.Data): batched graph 100 | node_attr: node feature tensor with shape (num_node, hidden) 101 | edge_attr: edge feature tensor with shape (num_edge, hidden) 102 | Output: 103 | node_attr 104 | graph feature 105 | """ 106 | 107 | hiddens = [] 108 | conv_input = node_attr # (num_node, hidden) 109 | 110 | for conv_idx, conv in enumerate(self.convs): 111 | hidden = conv(conv_input, data.edge_index, edge_attr) 112 | if conv_idx < len(self.convs) - 1 and self.activation is not None: 113 | hidden = self.activation(hidden) 114 | assert hidden.shape == conv_input.shape 115 | if self.short_cut and hidden.shape == conv_input.shape: 116 | hidden += conv_input 117 | 118 | hiddens.append(hidden) 119 | conv_input = hidden 120 | 121 | if self.concat_hidden: 122 | node_feature = torch.cat(hiddens, dim=-1) 123 | else: 124 | node_feature = hiddens[-1] 125 | 126 | graph_feature = self.readout(data, node_feature) 127 | 128 | return { 129 | "graph_feature": graph_feature, 130 | "node_feature": node_feature 131 | } 132 | 133 | -------------------------------------------------------------------------------- /confgen/confgf/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .scorenet import EquiDistanceScoreMatch 2 | __all__ = ["DistanceScoreMatch", "EquiDistanceScoreMatch", "NCEquiDistanceScoreMatch"] 3 | -------------------------------------------------------------------------------- /confgen/confgf/models/scorenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch_scatter import scatter_add, scatter_mean 7 | from torch_sparse import coalesce 8 | from torch_geometric.data import Data 9 | from torch_geometric.utils import to_dense_adj, dense_to_sparse 10 | from confgf import utils, layers 11 | 12 | 13 | class GaussianFourierProjection(nn.Module): 14 | """Gaussian Fourier embeddings for noise levels.""" 15 | 16 | def __init__(self, embedding_size=256, scale=1.0): 17 | super().__init__() 18 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 19 | 20 | def forward(self, x): 21 | x_proj = x * self.W[None, :] * 2 * np.pi 22 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 23 | 24 | 25 | class EquiDistanceScoreMatch(torch.nn.Module): 26 | 27 | def __init__(self, config): 28 | super(EquiDistanceScoreMatch, self).__init__() 29 | self.config = config 30 | self.anneal_power = self.config.train.anneal_power 31 | self.hidden_dim = self.config.model.hidden_dim 32 | self.order = self.config.model.order 33 | self.noise_type = self.config.model.noise_type 34 | 35 | self.node_emb = torch.nn.Embedding(100, self.hidden_dim) 36 | self.edge_emb = torch.nn.Embedding(100, self.hidden_dim) 37 | self.dist_gaussian_fourier = GaussianFourierProjection(embedding_size=self.hidden_dim, scale=1) 38 | self.input_mlp = layers.MultiLayerPerceptron(2 * self.hidden_dim, [self.hidden_dim], activation=self.config.model.mlp_act) 39 | self.coff_gaussian_fourier = GaussianFourierProjection(embedding_size=self.hidden_dim, scale=1) 40 | self.coff_mlp = nn.Linear(4 * self.hidden_dim, self. hidden_dim) 41 | self.project = layers.MultiLayerPerceptron(2 * self.hidden_dim + 2, [self.hidden_dim, self.hidden_dim], activation=self.config.model.mlp_act) 42 | 43 | self.model = layers.GradientGCN(hidden_dim=self.hidden_dim, hidden_coff_dim=128, \ 44 | num_convs=self.config.model.num_convs, \ 45 | activation=self.config.model.gnn_act, \ 46 | readout="sum", short_cut=self.config.model.short_cut, \ 47 | concat_hidden=self.config.model.concat_hidden) 48 | sigmas = torch.tensor( 49 | np.exp(np.linspace(np.log(self.config.model.sigma_begin), np.log(self.config.model.sigma_end), 50 | self.config.model.num_noise_level)), dtype=torch.float32) 51 | self.sigmas = nn.Parameter(sigmas, requires_grad=False) # (num_noise_level) 52 | """ 53 | Techniques from "Improved Techniques for Training Score-Based Generative Models" 54 | 1. Choose sigma1 to be as large as the maximum Euclidean distance between all pairs of training data points. 55 | 2. Choose sigmas as a geometric progression with common ratio gamma, where a specific equation of CDF is satisfied. 56 | 3. Parameterize the Noise Conditional Score Networks with f_theta_sigma(x) = f_theta(x) / sigma 57 | """ 58 | 59 | 60 | @torch.no_grad() 61 | # extend the edge on the fly, second order: angle, third order: dihedral 62 | def extend_graph(self, data: Data, order=3): 63 | 64 | def binarize(x): 65 | return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)) 66 | 67 | def get_higher_order_adj_matrix(adj, order): 68 | """ 69 | Args: 70 | adj: (N, N) 71 | type_mat: (N, N) 72 | """ 73 | adj_mats = [torch.eye(adj.size(0), dtype=torch.long, device=adj.device), \ 74 | binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device))] 75 | 76 | for i in range(2, order+1): 77 | adj_mats.append(binarize(adj_mats[i-1] @ adj_mats[1])) 78 | order_mat = torch.zeros_like(adj) 79 | 80 | for i in range(1, order+1): 81 | order_mat += (adj_mats[i] - adj_mats[i-1]) * i 82 | 83 | return order_mat 84 | 85 | num_types = len(utils.BOND_TYPES) 86 | 87 | N = data.num_nodes 88 | adj = to_dense_adj(data.edge_index).squeeze(0) 89 | adj_order = get_higher_order_adj_matrix(adj, order) # (N, N) 90 | 91 | type_mat = to_dense_adj(data.edge_index, edge_attr=data.edge_type).squeeze(0) # (N, N) 92 | type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order)) 93 | assert (type_mat * type_highorder == 0).all() 94 | type_new = type_mat + type_highorder 95 | 96 | new_edge_index, new_edge_type = dense_to_sparse(type_new) 97 | _, edge_order = dense_to_sparse(adj_order) 98 | 99 | data.bond_edge_index = data.edge_index # Save original edges 100 | data.edge_index, data.edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data 101 | edge_index_1, data.edge_order = coalesce(new_edge_index, edge_order.long(), N, N) # modify data 102 | data.is_bond = (data.edge_type < num_types) 103 | assert (data.edge_index == edge_index_1).all() 104 | 105 | return data 106 | 107 | # @torch.no_grad() 108 | def get_distance(self, data: Data): 109 | pos = data.pos 110 | row, col = data.edge_index 111 | d = (pos[row] - pos[col]).norm(dim=-1).unsqueeze(-1) # (num_edge, 1) 112 | data.edge_length = d 113 | return data 114 | 115 | # @torch.no_grad() 116 | def get_perturb_distance(self, data: Data, p_pos): 117 | pos = p_pos 118 | row, col = data.edge_index 119 | d = (pos[row] - pos[col]).norm(dim=-1).unsqueeze(-1) # (num_edge, 1) 120 | return d 121 | 122 | def get_pred_distance(self, data: Data, p_pos): 123 | pos = p_pos 124 | row, col = data.edge_index 125 | d = torch.sqrt(torch.sum((pos[row] - pos[col])**2, dim=-1) + 0.0001) 126 | # d = (pos[row] - pos[col]).norm(dim=-1).unsqueeze(-1) # (num_edge, 1) 127 | return d 128 | 129 | def coord2basis(self, data): 130 | coord_diff = data.pert_pos[data.edge_index[0]] - data.pert_pos[data.edge_index[1]] 131 | radial = torch.sum((coord_diff)**2, 1).unsqueeze(1) 132 | coord_cross = torch.cross(data.pert_pos[data.edge_index[0]], data.pert_pos[data.edge_index[1]]) 133 | 134 | norm = torch.sqrt(radial) + 1 135 | coord_diff = coord_diff / norm 136 | cross_norm = torch.sqrt(torch.sum((coord_cross)**2, 1).unsqueeze(1)) + 1 137 | coord_cross = coord_cross / cross_norm 138 | 139 | coord_vertical = torch.cross(coord_diff, coord_cross) 140 | 141 | return coord_diff, coord_cross, coord_vertical 142 | 143 | @torch.no_grad() 144 | def get_angle(self, data: Data, p_pos): 145 | pos = p_pos 146 | row, col = data.edge_index 147 | pos_normal = pos.clone().detach() 148 | pos_normal_norm = pos_normal.norm(dim=-1).unsqueeze(-1) 149 | pos_normal = pos_normal / (pos_normal_norm + 1e-5) 150 | cos_theta = torch.sum(pos_normal[row] * pos_normal[col], dim=-1, keepdim=True) 151 | sin_theta = torch.sqrt(1 - cos_theta ** 2) 152 | node_angles = torch.cat([cos_theta, sin_theta], dim=-1) 153 | return node_angles 154 | 155 | @torch.no_grad() 156 | def get_score(self, data: Data, d, sigma): 157 | """ 158 | Input: 159 | data: torch geometric batched data object 160 | d: edge distance, shape (num_edge, 1) 161 | sigma: noise level, tensor (,) 162 | Output: 163 | log-likelihood gradient of distance, tensor with shape (num_edge, 1) 164 | """ 165 | # generate common features 166 | node_attr = self.node_emb(data.atom_type) # (num_node, hidden) 167 | edge_attr = self.edge_emb(data.edge_type) # (num_edge, hidden) 168 | d_emb = self.dist_gaussian_fourier(d) 169 | d_emb = self.input_mlp(d_emb) # (num_edge, hidden) 170 | edge_attr = d_emb * edge_attr # (num_edge, hidden) 171 | 172 | # construct geometric features 173 | row, col = data.edge_index[0], data.edge_index[1] # check if roe and col is right? 174 | coord_diff, coord_cross, coord_vertical = self.coord2basis(data) # [E, 3] 175 | edge_basis = torch.cat([coord_diff.unsqueeze(1), coord_cross.unsqueeze(1), coord_vertical.unsqueeze(1)], dim=1) # [E, 3] 176 | r_i, r_j = data.pert_pos[row], data.pert_pos[col] # [E, 3] 177 | # [E, 3, 3] x [E, 3, 1] 178 | coff_i = torch.matmul(edge_basis, r_i.unsqueeze(-1)).squeeze(-1) # [E, 3] 179 | coff_j = torch.matmul(edge_basis, r_j.unsqueeze(-1)).squeeze(-1) # [E, 3] 180 | coff_mul = coff_i * coff_j # [E, 3] 181 | coff_i_norm = coff_i.norm(dim=-1, keepdim=True) 182 | coff_j_norm = coff_j.norm(dim=-1, keepdim=True) 183 | pesudo_cos = coff_mul.sum(dim=-1, keepdim=True) / (coff_i_norm + 1e-5) / (coff_j_norm + 1e-5) 184 | pesudo_sin = torch.sqrt(1 - pesudo_cos**2) 185 | psudo_angle = torch.cat([pesudo_sin, pesudo_cos], dim=-1) 186 | embed_i = self.get_embedding(coff_i) # [E, C] 187 | embed_j = self.get_embedding(coff_j) # [E, C] 188 | edge_embed = torch.cat([psudo_angle, embed_i, embed_j], dim=-1) 189 | edge_embed = self.project(edge_embed) 190 | 191 | edge_attr = edge_attr + edge_embed 192 | 193 | output = self.model(data, node_attr, edge_attr) 194 | scores = output["gradient"] * (1. / sigma) # f_theta_sigma(x) = f_theta(x) / sigma, (num_edge, 1) 195 | return scores 196 | 197 | def get_embedding(self, coff_index): 198 | coff_embeds = [] 199 | for i in [0, 2]: # if i=1, then x=0 200 | coff_embeds.append(self.coff_gaussian_fourier(coff_index[:, i:i+1])) #[E, 2C] 201 | coff_embeds = torch.cat(coff_embeds, dim=-1) # [E, 6C] 202 | coff_embeds = self.coff_mlp(coff_embeds) 203 | 204 | return coff_embeds 205 | 206 | def forward(self, data): 207 | """ 208 | Input: 209 | data: torch geometric batched data object 210 | Output: 211 | loss 212 | """ 213 | # a workaround to get the current device, we assume all tensors in a model are on the same device. 214 | self.device = self.sigmas.device 215 | data = self.extend_graph(data, self.order) 216 | ## enable input gradient 217 | input_x = data.pos 218 | input_x.requires_grad = True 219 | 220 | data = self.get_distance(data) 221 | 222 | assert data.edge_index.size(1) == data.edge_length.size(0) 223 | node2graph = data.batch 224 | edge2graph = node2graph[data.edge_index[0]] 225 | 226 | # sample noise level 227 | noise_level = torch.randint(0, self.sigmas.size(0), (data.num_graphs,), device=self.device) # (num_graph) 228 | used_sigmas = self.sigmas[noise_level] # (num_graph) 229 | used_sigmas = used_sigmas[node2graph].unsqueeze(-1) # (num_nodes, 1) 230 | 231 | if self.noise_type == 'rand': 232 | coord_noise = torch.randn_like(data.pos) 233 | else: 234 | raise NotImplementedError('noise type must in [distance_symm, distance_rand]') 235 | 236 | assert coord_noise.shape == data.pos.shape 237 | perturbed_pos = data.pos + coord_noise * used_sigmas 238 | data.pert_pos = perturbed_pos 239 | perturbed_d = self.get_perturb_distance(data, perturbed_pos) 240 | target = -1 / (used_sigmas ** 2) * (perturbed_pos - data.pos) 241 | 242 | # generate common features 243 | node_attr = self.node_emb(data.atom_type) # (num_node, hidden) 244 | edge_attr = self.edge_emb(data.edge_type) # (num_edge, hidden) 245 | d_emb = self.dist_gaussian_fourier(perturbed_d) 246 | d_emb = self.input_mlp(d_emb) # (num_edge, hidden) 247 | edge_attr = d_emb * edge_attr # (num_edge, hidden) 248 | 249 | # construct geometric features 250 | row, col = data.edge_index[0], data.edge_index[1] # check if roe and col is right? 251 | coord_diff, coord_cross, coord_vertical = self.coord2basis(data) # [E, 3] 252 | edge_basis = torch.cat([coord_diff.unsqueeze(1), coord_cross.unsqueeze(1), coord_vertical.unsqueeze(1)], dim=1) # [E, 3] 253 | r_i, r_j = data.pert_pos[row], data.pert_pos[col] # [E, 3] 254 | coff_i = torch.matmul(edge_basis, r_i.unsqueeze(-1)).squeeze(-1) # [E, 3] 255 | coff_j = torch.matmul(edge_basis, r_j.unsqueeze(-1)).squeeze(-1) # [E, 3] 256 | coff_mul = coff_i * coff_j # [E, 3] 257 | coff_i_norm = coff_i.norm(dim=-1, keepdim=True) 258 | coff_j_norm = coff_j.norm(dim=-1, keepdim=True) 259 | pesudo_cos = coff_mul.sum(dim=-1, keepdim=True) / (coff_i_norm + 1e-5) / (coff_j_norm + 1e-5) 260 | pesudo_sin = torch.sqrt(1 - pesudo_cos**2) 261 | psudo_angle = torch.cat([pesudo_sin, pesudo_cos], dim=-1) 262 | embed_i = self.get_embedding(coff_i) # [E, C] 263 | embed_j = self.get_embedding(coff_j) # [E, C] 264 | edge_embed = torch.cat([psudo_angle, embed_i, embed_j], dim=-1) 265 | edge_embed = self.project(edge_embed) 266 | edge_attr = edge_attr + edge_embed 267 | 268 | # estimate scores 269 | output = self.model(data, node_attr, edge_attr) 270 | scores = output["gradient"] * (1. / used_sigmas) 271 | loss_pos = 0.5 * torch.sum((scores - target) ** 2, -1) * (used_sigmas.squeeze(-1) ** self.anneal_power) # (num_edge) 272 | loss_pos = scatter_mean(loss_pos, node2graph) # (num_graph) 273 | 274 | loss_dict = { 275 | 'position': loss_pos.mean(), 276 | 'distance': torch.Tensor([0]).to(loss_pos.device), 277 | } 278 | return loss_dict 279 | -------------------------------------------------------------------------------- /confgen/confgf/runner/__init__.py: -------------------------------------------------------------------------------- 1 | from .clofnet_runner import EquiRunner 2 | 3 | __all__ = ["EquiRunner"] -------------------------------------------------------------------------------- /confgen/confgf/runner/clofnet_runner.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from time import time 3 | from tqdm import tqdm 4 | import os 5 | import numpy as np 6 | import pickle 7 | import copy 8 | 9 | import rdkit 10 | from rdkit import Chem 11 | from scipy import integrate 12 | import torch 13 | from torch_geometric.data import DataLoader 14 | from torch_scatter import scatter_add 15 | from confgf import utils, dataset 16 | import wandb 17 | 18 | 19 | class EquiRunner(object): 20 | def __init__( 21 | self, train_set, val_set, test_set, model, optimizer, scheduler, gpus, config 22 | ): 23 | self.train_set = train_set 24 | self.val_set = val_set 25 | self.test_set = test_set 26 | self.gpus = gpus 27 | self.device = torch.device(gpus[0]) if len(gpus) > 0 else torch.device("cpu") 28 | self.config = config 29 | 30 | self.batch_size = self.config.train.batch_size 31 | 32 | self._model = model 33 | self._optimizer = optimizer 34 | self._scheduler = scheduler 35 | 36 | self.best_loss = 1000.0 37 | self.start_epoch = 0 38 | 39 | if self.device.type == "cuda": 40 | self._model = self._model.cuda(self.device) 41 | 42 | if self.config.train.wandb.Enable: 43 | self.init_wandb(self.config.train.wandb.Project, self.config.train.Name) 44 | 45 | def init_wandb(self, project, name): 46 | wandb.init(project=project, name=name, entity="ClofNet") 47 | 48 | def save(self, checkpoint, epoch=None, var_list={}): 49 | 50 | state = { 51 | **var_list, 52 | "model": self._model.state_dict(), 53 | "optimizer": self._optimizer.state_dict(), 54 | "scheduler": self._scheduler.state_dict(), 55 | "config": self.config, 56 | } 57 | epoch = str(epoch) if epoch is not None else "" 58 | checkpoint = os.path.join(checkpoint, "checkpoint%s" % epoch) 59 | torch.save(state, checkpoint) 60 | 61 | def load(self, checkpoint, epoch=None, load_optimizer=False, load_scheduler=False): 62 | 63 | epoch = str(epoch) if epoch is not None else "" 64 | checkpoint = os.path.join(checkpoint, "checkpoint%s" % epoch) 65 | print("Load checkpoint from %s" % checkpoint) 66 | 67 | state = torch.load(checkpoint, map_location=self.device) 68 | self._model.load_state_dict(state["model"]) 69 | # self._model.load_state_dict(state["model"], strict=False) 70 | self.best_loss = state["best_loss"] 71 | self.start_epoch = state["cur_epoch"] + 1 72 | 73 | if load_optimizer: 74 | self._optimizer.load_state_dict(state["optimizer"]) 75 | if self.device.type == "cuda": 76 | for state in self._optimizer.state.values(): 77 | for k, v in state.items(): 78 | if isinstance(v, torch.Tensor): 79 | state[k] = v.cuda(self.device) 80 | 81 | if load_scheduler: 82 | self._scheduler.load_state_dict(state["scheduler"]) 83 | 84 | @torch.no_grad() 85 | def evaluate(self, split, verbose=0): 86 | """ 87 | Evaluate the model. 88 | Parameters: 89 | split (str): split to evaluate. Can be ``train``, ``val`` or ``test``. 90 | """ 91 | if split not in ["train", "val", "test"]: 92 | raise ValueError("split should be either train, val, or test.") 93 | 94 | test_set = getattr(self, "%s_set" % split) 95 | dataloader = DataLoader( 96 | test_set, 97 | batch_size=self.config.train.batch_size, 98 | shuffle=False, 99 | num_workers=self.config.train.num_workers, 100 | ) 101 | model = self._model 102 | model.eval() 103 | # code here 104 | eval_start = time() 105 | eval_losses = [] 106 | for batch in dataloader: 107 | if self.device.type == "cuda": 108 | batch = batch.to(self.device) 109 | 110 | loss_dict = model(batch) 111 | eval_losses.append(loss_dict["position"].item()) 112 | average_loss = sum(eval_losses) / len(eval_losses) 113 | 114 | if verbose: 115 | print( 116 | "Evaluate %s Position Loss: %.5f | Time: %.5f" 117 | % (split, average_loss, time() - eval_start) 118 | ) 119 | return average_loss 120 | 121 | def train(self, verbose=1): 122 | train_start = time() 123 | 124 | num_epochs = self.config.train.epochs 125 | dataloader = DataLoader( 126 | self.train_set, 127 | batch_size=self.config.train.batch_size, 128 | shuffle=self.config.train.shuffle, 129 | num_workers=self.config.train.num_workers, 130 | ) 131 | 132 | model = self._model 133 | train_losses, train_losses_pos, train_losses_dist = [], [], [] 134 | val_losses = [] 135 | best_loss = self.best_loss 136 | start_epoch = self.start_epoch 137 | print("start training...") 138 | train_start_time = time() 139 | for epoch in range(num_epochs): 140 | # train 141 | model.train() 142 | batch_losses = [] 143 | batch_losses_pos = [] 144 | batch_losses_dist = [] 145 | # batch_losses_curl = [] 146 | batch_cnt = 0 147 | for batch in dataloader: 148 | batch_cnt += 1 149 | if self.device.type == "cuda": 150 | batch = batch.to(self.device) 151 | 152 | loss_dict = model(batch) 153 | p1 = self.config.train.loss.position 154 | p2 = self.config.train.loss.distance 155 | # p3 = self.config.train.loss.curl 156 | loss = loss_dict["position"] * p1 + loss_dict["distance"] * p2 #+ loss_dict["curl"] * p3 157 | if not loss.requires_grad: 158 | raise RuntimeError("loss doesn't require grad") 159 | self._optimizer.zero_grad() 160 | loss.backward() 161 | self._optimizer.step() 162 | batch_losses.append(loss.item()) 163 | batch_losses_pos.append(loss_dict["position"].item()) 164 | batch_losses_dist.append(loss_dict["distance"].item()) 165 | # batch_losses_curl.append(loss_dict["curl"].item()) 166 | if batch_cnt % self.config.train.log_interval == 0 or ( 167 | epoch == 0 and batch_cnt <= 10 168 | ): 169 | # if batch_cnt % self.config.train.log_interval == 0: 170 | print( 171 | "Epoch: %d | Step: %d | loss: %.3f | loss_pos: %.3f | loss_dist: %.3f | Lr: %.5f" 172 | % ( 173 | epoch + start_epoch, 174 | batch_cnt, 175 | batch_losses[-1], 176 | batch_losses_pos[-1], 177 | batch_losses_dist[-1], 178 | self._optimizer.param_groups[0]["lr"], 179 | ) 180 | ) 181 | 182 | train_losses.append(sum(batch_losses) / len(batch_losses)) 183 | train_losses_pos.append(sum(batch_losses_pos) / len(batch_losses_pos)) 184 | train_losses_dist.append(sum(batch_losses_dist) / len(batch_losses_dist)) 185 | 186 | if verbose: 187 | print( 188 | "Epoch: %d | Train Loss: %.5f | Passed Time: %.3f h" 189 | % ( 190 | epoch + start_epoch, 191 | train_losses[-1], 192 | (time() - train_start_time) / 3600, 193 | ) 194 | ) 195 | 196 | # evaluate 197 | if self.config.train.eval: 198 | average_eval_loss = self.evaluate("val", verbose=1) 199 | val_losses.append(average_eval_loss) 200 | else: 201 | # use train loss as surrogate loss 202 | average_eval_loss = train_losses[-1] 203 | val_losses.append(train_losses[-1]) 204 | 205 | if self.config.train.wandb.Enable: 206 | wandb.log( 207 | { 208 | "Train Loss": train_losses[-1], 209 | "Train Loss Pos": train_losses_pos[-1], 210 | "Train Loss Dist": train_losses_dist[-1], 211 | "Validate Loss Pos": average_eval_loss, 212 | "LR": self._optimizer.param_groups[0]["lr"], 213 | }, 214 | step=epoch, 215 | commit=False, 216 | ) 217 | 218 | if self.config.train.scheduler.type == "plateau": 219 | self._scheduler.step(average_eval_loss) 220 | else: 221 | self._scheduler.step() 222 | 223 | if val_losses[-1] < best_loss: 224 | best_loss = val_losses[-1] 225 | if self.config.train.save: 226 | val_list = { 227 | "cur_epoch": epoch + start_epoch, 228 | "best_loss": best_loss, 229 | } 230 | self.save( 231 | self.config.train.save_path, epoch + start_epoch, val_list 232 | ) 233 | self.best_loss = best_loss 234 | self.start_epoch = start_epoch + num_epochs 235 | print("optimization finished.") 236 | print("Total time elapsed: %.5fs" % (time() - train_start)) 237 | 238 | @torch.no_grad() 239 | def convert_score_d(self, score_d, pos, edge_index, edge_length): 240 | dd_dr = (1.0 / edge_length) * ( 241 | pos[edge_index[0]] - pos[edge_index[1]] 242 | ) # (num_edge, 3) 243 | score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0) 244 | 245 | return score_pos 246 | 247 | @torch.no_grad() 248 | def position_Langevin_Dynamics( 249 | self, 250 | data, 251 | pos_init, 252 | scorenet, 253 | sigmas, 254 | n_steps_each=100, 255 | step_lr=0.00002, 256 | clip=1000, 257 | min_sigma=0, 258 | ): 259 | """ 260 | # 1. initial pos: (N, 3) 261 | # 2. get d: (num_edge, 1) 262 | # 3. get score of d: score_d = self.get_grad(d).view(-1) (num_edge) 263 | # 4. get score of pos: 264 | # dd_dr = (1/d) * (pos[edge_index[0]] - pos[edge_index[1]]) (num_edge, 3) 265 | # edge2node = edge_index[0] (num_edge) 266 | # score_pos = scatter_add(dd_dr * score_d, edge2node) (num_node, 3) 267 | # 5. update pos: 268 | # pos = pos + step_size * score_pos + noise 269 | """ 270 | scorenet.eval() 271 | pos_vecs = [] 272 | pos = pos_init 273 | cnt_sigma = 0 274 | for i, sigma in tqdm( 275 | enumerate(sigmas), total=sigmas.size(0), desc="Sampling positions" 276 | ): 277 | if sigma < min_sigma: 278 | break 279 | cnt_sigma += 1 280 | step_size = step_lr * (sigma / sigmas[-1]) ** 2 281 | for step in range(n_steps_each): 282 | data.pert_pos = pos 283 | d = utils.get_d_from_pos(pos, data.edge_index).unsqueeze( 284 | -1 285 | ) # (num_edge, 1) 286 | 287 | noise = torch.randn_like(pos) * torch.sqrt(step_size * 2) 288 | score_pos = scorenet.get_score(data, d, sigma) # (num_edge, 1) 289 | # score_pos = self.convert_score_d(score_d, pos, data.edge_index, d) 290 | score_pos = utils.clip_norm(score_pos, limit=clip) 291 | if score_pos.max().item() > 100: 292 | dd = score_pos.max().item() 293 | pos = pos + step_size * score_pos + noise # (num_node, 3) 294 | pos_vecs.append(pos) 295 | 296 | pos_vecs = torch.stack(pos_vecs, dim=0).view( 297 | cnt_sigma, n_steps_each, -1, 3 298 | ) # (sigams, 100, num_node, 3) 299 | 300 | return data, pos_vecs 301 | 302 | def EquiGF_generator(self, data, config, pos_init=None): 303 | 304 | """ 305 | The ConfGF generator that generates conformations using the score of atomic coordinates 306 | Return: 307 | The generated conformation (pos_gen) 308 | Distance of the generated conformation (d_recover) 309 | """ 310 | 311 | if pos_init is None: 312 | pos_init = torch.randn(data.num_nodes, 3).to(data.pos) 313 | 314 | data, pos_traj = self.position_Langevin_Dynamics( 315 | data, 316 | pos_init, 317 | self._model, 318 | self._model.sigmas.data.clone(), 319 | n_steps_each=config.steps_pos, 320 | step_lr=config.step_lr_pos, 321 | clip=config.clip, 322 | min_sigma=config.min_sigma, 323 | ) 324 | pos_gen = pos_traj[-1, -1] # (num_node, 3) fetch the lastest pos 325 | d_recover = utils.get_d_from_pos(pos_gen, data.edge_index) # (num_edges) 326 | 327 | data.pos_gen = pos_gen.to(data.pos) 328 | data.d_recover = d_recover.view(-1, 1).to(data.edge_length) 329 | return pos_gen, d_recover.view(-1), data, pos_traj 330 | 331 | @torch.no_grad() 332 | def position_pc_generation( 333 | self, 334 | data, 335 | pos_init, 336 | scorenet, 337 | sigmas, 338 | n_steps_each=100, 339 | step_lr=0.00002, 340 | clip=1000, 341 | min_sigma=0, 342 | ): 343 | """ 344 | # 1. initial pos: (N, 3) 345 | # 2. get d: (num_edge, 1) 346 | # 3. get score of d: score_d = self.get_grad(d).view(-1) (num_edge) 347 | # 4. get score of pos: 348 | # dd_dr = (1/d) * (pos[edge_index[0]] - pos[edge_index[1]]) (num_edge, 3) 349 | # edge2node = edge_index[0] (num_edge) 350 | # score_pos = scatter_add(dd_dr * score_d, edge2node) (num_node, 3) 351 | # 5. update pos: 352 | # pos = pos + step_size * score_pos + noise 353 | """ 354 | scorenet.eval() 355 | pos_vecs = [] 356 | pos = pos_init 357 | 358 | cnt_sigma = 0 359 | for i, sigma in tqdm( 360 | enumerate(sigmas), total=sigmas.size(0), desc="Sampling positions" 361 | ): 362 | if sigma < min_sigma: 363 | break 364 | step_size = step_lr * (sigma / sigmas[-1]) ** 2 365 | cnt_sigma += 1 366 | # corrector 367 | for step in range(n_steps_each): 368 | data.pert_pos = pos 369 | d = utils.get_d_from_pos(pos, data.edge_index).unsqueeze( 370 | -1 371 | ) # (num_edge, 1) 372 | 373 | noise = torch.randn_like(pos) * torch.sqrt(step_size * 2) 374 | score_pos = scorenet.get_score(data, d, sigma) # (num_edge, 1) 375 | score_pos = utils.clip_norm(score_pos, limit=clip) 376 | if score_pos.max().item() > 100: 377 | dd = score_pos.max().item() 378 | pos = pos + step_size * score_pos + noise # (num_node, 3) 379 | 380 | # predictor 381 | if cnt_sigma < sigmas.size(0): 382 | data.pert_pos = pos 383 | d = utils.get_d_from_pos(pos, data.edge_index).unsqueeze(-1) # (num_edge, 1) 384 | score_pos = scorenet.get_score(data, d, sigma) 385 | vec_sigma = torch.ones(pos.shape[0], device=pos.device) * sigma 386 | vec_adjacent_sigma = torch.ones(pos.shape[0], device=pos.device) * sigmas[cnt_sigma] 387 | f = torch.zeros_like(pos) 388 | G = torch.sqrt(vec_sigma ** 2 - vec_adjacent_sigma ** 2) 389 | 390 | rev_f = f - G[:, None] ** 2 * score_pos * 0.5 391 | rev_G = torch.zeros_like(G) 392 | z = torch.randn_like(pos) 393 | x_mean = pos - rev_f 394 | pos = x_mean + rev_G[:, None] * z 395 | 396 | return data, pos 397 | 398 | def PC_generator(self, data, config, pos_init=None): 399 | 400 | """ 401 | The ConfGF generator that generates conformations using the score of atomic coordinates 402 | Return: 403 | The generated conformation (pos_gen) 404 | Distance of the generated conformation (d_recover) 405 | """ 406 | 407 | if pos_init is None: 408 | pos_init = torch.randn(data.num_nodes, 3).to(data.pos) 409 | 410 | data, pos_traj = self.position_pc_generation( 411 | data, 412 | pos_init, 413 | self._model, 414 | self._model.sigmas.data.clone(), 415 | n_steps_each=config.steps_pos, 416 | step_lr=config.step_lr_pos, 417 | clip=config.clip, 418 | min_sigma=config.min_sigma, 419 | ) 420 | pos_gen = pos_traj 421 | d_recover = utils.get_d_from_pos(pos_gen, data.edge_index) # (num_edges) 422 | 423 | data.pos_gen = pos_gen.to(data.pos) 424 | data.d_recover = d_recover.view(-1, 1).to(data.edge_length) 425 | return pos_gen, d_recover.view(-1), data, pos_traj 426 | 427 | 428 | def generate_samples_from_smiles( 429 | self, smiles, generator, num_repeat=1, keep_traj=False, out_path=None 430 | ): 431 | 432 | if keep_traj: 433 | assert ( 434 | num_repeat == 1 435 | ), "to generate the trajectory of conformations, you must set num_repeat to 1" 436 | 437 | data = dataset.smiles_to_data(smiles) 438 | 439 | if data is None: 440 | raise ValueError("invalid smiles: %s" % smiles) 441 | 442 | return_data = copy.deepcopy(data) 443 | batch = utils.repeat_data(data, num_repeat).to(self.device) 444 | 445 | if generator == "EquiGF": 446 | _, _, batch, _ = self.EquiGF_generator(batch, self.config.test.gen) 447 | elif generator == "ODE": 448 | _, _, batch, _ = self.PC_generator(batch, self.config.test.gen) 449 | else: 450 | raise NotImplementedError 451 | 452 | batch = batch.to("cpu").to_data_list() 453 | pos_traj = pos_traj.view(-1, 3).to("cpu") 454 | pos_traj_step = pos_traj.size(0) // return_data.num_nodes 455 | 456 | all_pos = [] 457 | for i in range(len(batch)): 458 | all_pos.append(batch[i].pos_gen) 459 | return_data.pos_gen = torch.cat(all_pos, 0) # (num_repeat * num_node, 3) 460 | return_data.num_pos_gen = torch.tensor([len(all_pos)], dtype=torch.long) 461 | if keep_traj: 462 | return_data.pos_traj = pos_traj 463 | return_data.num_pos_traj = torch.tensor([pos_traj_step], dtype=torch.long) 464 | 465 | if out_path is not None: 466 | with open( 467 | os.path.join(out_path, "%s_%s.pkl" % (generator, return_data.smiles)), 468 | "wb", 469 | ) as fout: 470 | pickle.dump(return_data, fout) 471 | print("save generated %s samples to %s done!" % (generator, out_path)) 472 | 473 | print("pos generation of %s done" % return_data.smiles) 474 | 475 | return return_data 476 | 477 | def generate_samples_from_testset( 478 | self, start, end, eval_epoch, generator, num_repeat=None, out_path=None 479 | ): 480 | 481 | test_set = self.test_set 482 | generate_start = time() 483 | 484 | all_data_list = [] 485 | print("len of all data: %d" % len(test_set)) 486 | 487 | for i in tqdm(range(len(test_set))): 488 | if i < start or i >= end: 489 | continue 490 | return_data = copy.deepcopy(test_set[i]) 491 | num_repeat_ = ( 492 | num_repeat 493 | if num_repeat is not None 494 | else self.config.test.gen.repeat * test_set[i].num_pos_ref.item() 495 | ) 496 | batch = utils.repeat_data(test_set[i], num_repeat_).to(self.device) 497 | 498 | if generator == "EquiGF": 499 | _, _, batch, _ = self.EquiGF_generator(batch, self.config.test.gen) 500 | elif generator == "EquiPCGF": 501 | _, _, batch, _ = self.PC_generator(batch, self.config.test.gen) 502 | else: 503 | raise NotImplementedError 504 | 505 | batch = batch.to("cpu").to_data_list() 506 | 507 | all_pos = [] 508 | for i in range(len(batch)): 509 | all_pos.append(batch[i].pos_gen) 510 | return_data.pos_gen = torch.cat(all_pos, 0) # (num_repeat * num_node, 3) 511 | return_data.num_pos_gen = torch.tensor([len(all_pos)], dtype=torch.long) 512 | all_data_list.append(return_data) 513 | 514 | if out_path is not None: 515 | with open( 516 | os.path.join( 517 | out_path, 518 | "%s_s%de%depoch%dmin_sig%.3frepeat%d.pkl" 519 | % ( 520 | generator, 521 | start, 522 | end, 523 | eval_epoch, 524 | self.config.test.gen.min_sigma, 525 | self.config.test.gen.repeat, 526 | ), 527 | ), 528 | "wb", 529 | ) as fout: 530 | pickle.dump(all_data_list, fout) 531 | print("save generated %s samples to %s done!" % (generator, out_path)) 532 | print( 533 | "pos generation[%d-%d] done | Time: %.5f" 534 | % (start, end, time() - generate_start) 535 | ) 536 | 537 | return all_data_list 538 | 539 | 540 | def to_flattened_numpy(x): 541 | """Flatten a torch tensor `x` and convert it to numpy.""" 542 | return x.detach().cpu().numpy().reshape((-1,)) 543 | 544 | 545 | def from_flattened_numpy(x, shape): 546 | """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" 547 | return torch.from_numpy(x.reshape(shape)) 548 | 549 | -------------------------------------------------------------------------------- /confgen/confgf/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .chem import BOND_TYPES, BOND_NAMES, set_conformer_positions, draw_mol_image, update_data_rdmol_positions, \ 2 | update_data_pos_from_rdmol, set_rdmol_positions, set_rdmol_positions_, get_atom_symbol, mol_to_smiles, \ 3 | remove_duplicate_mols, get_atoms_in_ring, get_2D_mol, draw_mol_svg, GetBestRMSD 4 | from .distgeom import Embed3D, get_d_from_pos 5 | from .transforms import AddHigherOrderEdges, AddEdgeLength, AddPlaceHolder, AddEdgeName, AddAngleDihedral, CountNodesPerGraph 6 | from .torch import ExponentialLR_with_minLr, repeat_batch, repeat_data, get_optimizer, get_scheduler, clip_norm 7 | from .evaluation import evaluate_conf, evaluate_distance, get_rmsd_confusion_matrix, evaluate_conf_extend, evaluate_conf_prec 8 | 9 | __all__ = ["BOND_TYPES", "BOND_NAMES", "set_conformer_positions", "draw_mol_image", 10 | "update_data_rdmol_positions", "update_data_pos_from_rdmol", "set_rdmol_positions", 11 | "set_rdmol_positions_", "get_atom_symbol", "mol_to_smiles", "remove_duplicate_mols", 12 | "get_atoms_in_ring", "get_2D_mol", "draw_mol_svg", "GetBestRMSD", 13 | "Embed3D", "get_d_from_pos", 14 | "AddHigherOrderEdges", "AddEdgeLength", "AddPlaceHolder", "AddEdgeName", 15 | "AddAngleDihedral", "CountNodesPerGraph", 16 | "ExponentialLR_with_minLr", 17 | "repeat_batch", "repeat_data", 18 | "get_optimizer", "get_scheduler", "clip_norm", 19 | "evaluate_conf", "evaluate_distance", "get_rmsd_confusion_matrix"] 20 | -------------------------------------------------------------------------------- /confgen/confgf/utils/chem.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torchvision.transforms.functional import to_tensor 4 | 5 | import rdkit 6 | import rdkit.Chem.Draw 7 | from rdkit import Chem 8 | from rdkit.Chem import rdDepictor as DP 9 | from rdkit.Chem import PeriodicTable as PT 10 | from rdkit.Chem import rdMolAlign as MA 11 | from rdkit.Chem.rdchem import BondType as BT 12 | from rdkit.Chem.rdchem import Mol, GetPeriodicTable 13 | from rdkit.Chem.Draw import rdMolDraw2D as MD2 14 | from rdkit.Chem.rdmolops import RemoveHs 15 | from typing import List, Tuple 16 | 17 | 18 | 19 | BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())} 20 | BOND_NAMES = {i: t for i, t in enumerate(BT.names.keys())} 21 | 22 | 23 | def set_conformer_positions(conf, pos): 24 | for i in range(pos.shape[0]): 25 | conf.SetAtomPosition(i, pos[i].tolist()) 26 | return conf 27 | 28 | 29 | def draw_mol_image(rdkit_mol, tensor=False): 30 | rdkit_mol.UpdatePropertyCache() 31 | img = rdkit.Chem.Draw.MolToImage(rdkit_mol, kekulize=False) 32 | if tensor: 33 | return to_tensor(img) 34 | else: 35 | return img 36 | 37 | 38 | def update_data_rdmol_positions(data): 39 | for i in range(data.pos.size(0)): 40 | data.rdmol.GetConformer(0).SetAtomPosition(i, data.pos[i].tolist()) 41 | return data 42 | 43 | 44 | def update_data_pos_from_rdmol(data): 45 | new_pos = torch.FloatTensor(data.rdmol.GetConformer(0).GetPositions()).to(data.pos) 46 | data.pos = new_pos 47 | return data 48 | 49 | 50 | def set_rdmol_positions(rdkit_mol, pos): 51 | """ 52 | Args: 53 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object. 54 | pos: (N_atoms, 3) 55 | """ 56 | mol = copy.deepcopy(rdkit_mol) 57 | set_rdmol_positions_(mol, pos) 58 | return mol 59 | 60 | 61 | def set_rdmol_positions_(mol, pos): 62 | """ 63 | Args: 64 | rdkit_mol: An `rdkit.Chem.rdchem.Mol` object. 65 | pos: (N_atoms, 3) 66 | """ 67 | for i in range(pos.shape[0]): 68 | mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist()) 69 | return mol 70 | 71 | 72 | def get_atom_symbol(atomic_number): 73 | return PT.GetElementSymbol(GetPeriodicTable(), atomic_number) 74 | 75 | 76 | def mol_to_smiles(mol: Mol) -> str: 77 | return Chem.MolToSmiles(mol, allHsExplicit=True) 78 | 79 | 80 | def mol_to_smiles_without_Hs(mol: Mol) -> str: 81 | return Chem.MolToSmiles(Chem.RemoveHs(mol)) 82 | 83 | 84 | def remove_duplicate_mols(molecules: List[Mol]) -> List[Mol]: 85 | unique_tuples: List[Tuple[str, Mol]] = [] 86 | 87 | for molecule in molecules: 88 | duplicate = False 89 | smiles = mol_to_smiles(molecule) 90 | for unique_smiles, _ in unique_tuples: 91 | if smiles == unique_smiles: 92 | duplicate = True 93 | break 94 | 95 | if not duplicate: 96 | unique_tuples.append((smiles, molecule)) 97 | 98 | return [mol for smiles, mol in unique_tuples] 99 | 100 | 101 | def get_atoms_in_ring(mol): 102 | atoms = set() 103 | for ring in mol.GetRingInfo().AtomRings(): 104 | for a in ring: 105 | atoms.add(a) 106 | return atoms 107 | 108 | 109 | def get_2D_mol(mol): 110 | mol = copy.deepcopy(mol) 111 | DP.Compute2DCoords(mol) 112 | return mol 113 | 114 | 115 | def draw_mol_svg(mol,molSize=(450,150),kekulize=False): 116 | mc = Chem.Mol(mol.ToBinary()) 117 | if kekulize: 118 | try: 119 | Chem.Kekulize(mc) 120 | except: 121 | mc = Chem.Mol(mol.ToBinary()) 122 | if not mc.GetNumConformers(): 123 | DP.Compute2DCoords(mc) 124 | drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1]) 125 | drawer.DrawMolecule(mc) 126 | drawer.FinishDrawing() 127 | svg = drawer.GetDrawingText() 128 | # It seems that the svg renderer used doesn't quite hit the spec. 129 | # Here are some fixes to make it work in the notebook, although I think 130 | # the underlying issue needs to be resolved at the generation step 131 | # return svg.replace('svg:','') 132 | return svg 133 | 134 | 135 | def GetBestRMSD(probe, ref): 136 | probe = RemoveHs(probe) 137 | ref = RemoveHs(ref) 138 | rmsd = MA.GetBestRMS(probe, ref) 139 | return rmsd 140 | -------------------------------------------------------------------------------- /confgen/confgf/utils/distgeom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def embed_3D(d_target, edge_index, init_pos, edge_order=None, alpha=0.5, mu=0, step_size=None, num_steps=None, verbose=0): 5 | assert torch.is_grad_enabled, '`embed_3D` requires gradients' 6 | step_size = 8.0 if step_size is None else step_size 7 | num_steps = 1000 if num_steps is None else num_steps 8 | pos_vecs = [] 9 | 10 | d_target = d_target.view(-1) 11 | pos = init_pos.clone().requires_grad_(True) 12 | optimizer = torch.optim.Adam([pos], lr=step_size) 13 | 14 | if edge_order is not None: 15 | coef = alpha ** (edge_order.view(-1).float() - 1) 16 | else: 17 | coef = 1.0 18 | 19 | if mu > 0: 20 | noise = torch.randn_like(coef) * coef * mu + coef 21 | coef = torch.clamp_min(coef + noise, min=0) 22 | 23 | for i in range(num_steps): 24 | optimizer.zero_grad() 25 | d_new = torch.norm(pos[edge_index[0]] - pos[edge_index[1]], dim=1) 26 | #print(d_new) 27 | loss = (coef * ((d_target - d_new) ** 2)).sum() 28 | loss.backward() 29 | optimizer.step() 30 | pos_vecs.append(pos.detach()) 31 | 32 | pos_vecs = torch.stack(pos_vecs, dim=0) # (num_steps, num_node, 3) 33 | 34 | if verbose: 35 | print('Embed 3D: AvgLoss %.6f' % (loss.item() / d_target.size(0))) 36 | 37 | return pos_vecs, loss.detach() / d_target.size(0) 38 | 39 | 40 | class Embed3D(object): 41 | 42 | def __init__(self, alpha=0.5, mu=0, step_size=8.0, num_steps=1000, verbose=0): 43 | super().__init__() 44 | self.alpha = alpha 45 | self.mu = mu 46 | self.step_size = step_size 47 | self.num_steps = num_steps 48 | self.verbose = verbose 49 | 50 | def __call__(self, d_target, edge_index, init_pos, edge_order=None): 51 | return embed_3D( 52 | d_target, edge_index, init_pos, edge_order, 53 | alpha=self.alpha, 54 | mu=self.mu, 55 | step_size=self.step_size, 56 | num_steps=self.num_steps, 57 | verbose=self.verbose 58 | ) 59 | 60 | def get_d_from_pos(pos, edge_index): 61 | return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1) # (num_edge) 62 | -------------------------------------------------------------------------------- /confgen/confgf/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm.auto import tqdm 3 | 4 | import torch 5 | from torch_geometric.data import Data 6 | from rdkit import Chem 7 | from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule 8 | 9 | from confgf import utils 10 | 11 | 12 | def get_rmsd_confusion_matrix(data: Data, useFF=False, reverse=False): 13 | if not reverse: 14 | pos_ref = data.pos_ref.view(-1, data.num_nodes, 3) 15 | pos_gen = data.pos_gen.view(-1, data.num_nodes, 3) 16 | num_gen = pos_gen.size(0) 17 | num_ref = pos_ref.size(0) 18 | assert num_gen == data.num_pos_gen.item() 19 | assert num_ref == data.num_pos_ref.item() 20 | else: 21 | pos_ref = data.pos_gen.view(-1, data.num_nodes, 3) 22 | pos_gen = data.pos_ref.view(-1, data.num_nodes, 3) 23 | num_gen = pos_gen.size(0) 24 | num_ref = pos_ref.size(0) 25 | assert num_gen == data.num_pos_ref.item() 26 | assert num_ref == data.num_pos_gen.item() 27 | 28 | rmsd_confusion_mat = -1 * np.ones([num_ref, num_gen],dtype=np.float) 29 | 30 | for i in range(num_gen): 31 | gen_mol = utils.set_rdmol_positions(data.rdmol, pos_gen[i]) 32 | if useFF: 33 | #print('Applying FF on generated molecules...') 34 | MMFFOptimizeMolecule(gen_mol) 35 | for j in range(num_ref): 36 | ref_mol = utils.set_rdmol_positions(data.rdmol, pos_ref[j]) 37 | 38 | rmsd_confusion_mat[j,i] = utils.GetBestRMSD(gen_mol, ref_mol) 39 | 40 | return rmsd_confusion_mat 41 | 42 | 43 | def evaluate_conf(data: Data, useFF=False, threshold=0.5): 44 | rmsd_confusion_mat = get_rmsd_confusion_matrix(data, useFF=useFF) 45 | rmsd_ref_min = rmsd_confusion_mat.min(-1) 46 | return (rmsd_ref_min<=threshold).mean(), rmsd_ref_min.mean() 47 | 48 | def evaluate_conf_prec(data: Data, useFF=False, threshold=0.5): 49 | rmsd_confusion_mat = get_rmsd_confusion_matrix(data, useFF=useFF, reverse=True) 50 | rmsd_ref_min = rmsd_confusion_mat.min(-1) 51 | return (rmsd_ref_min<=threshold).mean(), rmsd_ref_min.mean() 52 | 53 | def evaluate_conf_extend(data: Data, useFF=False, threshold=0.5): 54 | rmsd_confusion_mat = get_rmsd_confusion_matrix(data, useFF=useFF) 55 | rmsd_ref_min = rmsd_confusion_mat.min(-1) 56 | rmsd_gen_min = rmsd_confusion_mat.min(0) 57 | recall_cov, recall_mat = (rmsd_ref_min<=threshold).mean(), rmsd_ref_min.mean() 58 | prec_cov, prec_mat = (rmsd_gen_min<=threshold).mean(), rmsd_gen_min.mean() 59 | return recall_cov, recall_mat, prec_cov, prec_mat 60 | 61 | def evaluate_distance(data: Data, ignore_H=True): 62 | data.pos_ref = data.pos_ref.view(-1, data.num_nodes, 3) # (N, num_node, 3) 63 | data.pos_gen = data.pos_gen.view(-1, data.num_nodes, 3) # (M, num_node, 3) 64 | num_ref = data.pos_ref.size(0) # N 65 | num_gen = data.pos_gen.size(0) # M 66 | assert num_gen == data.num_pos_gen.item() 67 | assert num_ref == data.num_pos_ref.item() 68 | smiles = data.smiles 69 | 70 | edge_index = data.edge_index 71 | atom_type = data.atom_type 72 | 73 | # compute generated length and ref length 74 | ref_lengths = (data.pos_ref[:, edge_index[0]] - data.pos_ref[:, edge_index[1]]).norm(dim=-1) # (N, num_edge) 75 | gen_lengths = (data.pos_gen[:, edge_index[0]] - data.pos_gen[:, edge_index[1]]).norm(dim=-1) # (M, num_edge) 76 | # print(ref_lengths.size(), gen_lengths.size()) 77 | #print(ref_lengths.size()) 78 | #print(gen_lengths.size()) 79 | 80 | stats_single = [] 81 | first = 1 82 | for i, (row, col) in enumerate(tqdm(edge_index.t())): 83 | if row >= col: 84 | continue 85 | if ignore_H and 1 in (atom_type[row].item(), atom_type[col].item()): 86 | continue 87 | gen_l = gen_lengths[:, i] 88 | ref_l = ref_lengths[:, i] 89 | if first: 90 | print(gen_l.size(), ref_l.size()) 91 | first = 0 92 | mmd = compute_mmd(gen_l.view(-1, 1).cuda(), ref_l.view(-1, 1).cuda()).item() 93 | stats_single.append({ 94 | 'edge_id': i, 95 | 'elems': '%s - %s' % (utils.get_atom_symbol(atom_type[row].item()), utils.get_atom_symbol(atom_type[col].item())), 96 | 'nodes': (row.item(), col.item()), 97 | 'gen_lengths': gen_l.cpu(), 98 | 'ref_lengths': ref_l.cpu(), 99 | 'mmd': mmd 100 | }) 101 | 102 | first = 1 103 | stats_pair = [] 104 | for i, (row_i, col_i) in enumerate(tqdm(edge_index.t())): 105 | if row_i >= col_i: 106 | continue 107 | if ignore_H and 1 in (atom_type[row_i].item(), atom_type[col_i].item()): 108 | continue 109 | for j, (row_j, col_j) in enumerate(edge_index.t()): 110 | if (row_i >= row_j) or (row_j >= col_j): 111 | continue 112 | if ignore_H and 1 in (atom_type[row_j].item(), atom_type[col_j].item()): 113 | continue 114 | 115 | gen_L = gen_lengths[:, (i,j)] # (N, 2) 116 | ref_L = ref_lengths[:, (i,j)] # (M, 2) 117 | if first: 118 | # print(gen_L.size(), ref_L.size()) 119 | first = 0 120 | mmd = compute_mmd(gen_L.cuda(), ref_L.cuda()).item() 121 | 122 | stats_pair.append({ 123 | 'edge_id': (i, j), 124 | 'elems': ( 125 | '%s - %s' % (utils.get_atom_symbol(atom_type[row_i].item()), utils.get_atom_symbol(atom_type[col_i].item())), 126 | '%s - %s' % (utils.get_atom_symbol(atom_type[row_j].item()), utils.get_atom_symbol(atom_type[col_j].item())), 127 | ), 128 | 'nodes': ( 129 | (row_i.item(), col_i.item()), 130 | (row_j.item(), col_j.item()), 131 | ), 132 | 'gen_lengths': gen_L.cpu(), 133 | 'ref_lengths': ref_L.cpu(), 134 | 'mmd': mmd 135 | }) 136 | 137 | edge_filter = edge_index[0] < edge_index[1] 138 | if ignore_H: 139 | for i, (row, col) in enumerate(edge_index.t()): 140 | if 1 in (atom_type[row].item(), atom_type[col].item()): 141 | edge_filter[i] = False 142 | 143 | gen_L = gen_lengths[:, edge_filter] # (N, Ef) 144 | ref_L = ref_lengths[:, edge_filter] # (M, Ef) 145 | # print(gen_L.size(), ref_L.size()) 146 | mmd = compute_mmd(gen_L.cuda(), ref_L.cuda()).item() 147 | 148 | stats_all = { 149 | 'gen_lengths': gen_L.cpu(), 150 | 'ref_lengths': ref_L.cpu(), 151 | 'mmd': mmd 152 | } 153 | return stats_single, stats_pair, stats_all 154 | 155 | def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 156 | ''' 157 | Params: 158 | source: n * len(x) 159 | target: m * len(y) 160 | Return: 161 | sum(kernel_val): Sum of various kernel matrices 162 | ''' 163 | n_samples = int(source.size()[0])+int(target.size()[0]) 164 | total = torch.cat([source, target], dim=0) 165 | total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 166 | total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 167 | 168 | L2_distance = ((total0-total1)**2).sum(2) 169 | 170 | if fix_sigma: 171 | bandwidth = fix_sigma 172 | else: 173 | bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples) 174 | 175 | bandwidth /= kernel_mul ** (kernel_num // 2) 176 | bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)] 177 | 178 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] 179 | 180 | return sum(kernel_val)#/len(kernel_val) 181 | 182 | def compute_mmd(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 183 | ''' 184 | Params: 185 | source: (N, D) 186 | target: (M, D) 187 | Return: 188 | loss: MMD loss 189 | ''' 190 | batch_size = int(source.size()[0]) 191 | kernels = guassian_kernel(source, target, 192 | kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) 193 | 194 | XX = kernels[:batch_size, :batch_size] 195 | YY = kernels[batch_size:, batch_size:] 196 | XY = kernels[:batch_size, batch_size:] 197 | YX = kernels[batch_size:, :batch_size] 198 | loss = torch.mean(XX) + torch.mean(YY) - torch.mean(XY) - torch.mean(YX) 199 | 200 | return loss 201 | 202 | 203 | """ 204 | Another implementation: 205 | https://github.com/martinepalazzo/kernel_methods/blob/master/kernel_methods.py 206 | """ -------------------------------------------------------------------------------- /confgen/confgf/utils/torch.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import warnings 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch_geometric.data import Data, Batch 8 | 9 | def clip_norm(vec, limit, p=2): 10 | norm = torch.norm(vec, dim=-1, p=2, keepdim=True) 11 | denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm)) 12 | return vec * denom 13 | 14 | 15 | def repeat_data(data: Data, num_repeat) -> Batch: 16 | datas = [copy.deepcopy(data) for i in range(num_repeat)] 17 | return Batch.from_data_list(datas) 18 | 19 | 20 | def repeat_batch(batch: Batch, num_repeat) -> Batch: 21 | datas = batch.to_data_list() 22 | new_data = [] 23 | for i in range(num_repeat): 24 | new_data += copy.deepcopy(datas) 25 | return Batch.from_data_list(new_data) 26 | 27 | 28 | #customize exp lr scheduler with min lr 29 | class ExponentialLR_with_minLr(torch.optim.lr_scheduler.ExponentialLR): 30 | def __init__(self, optimizer, gamma, min_lr=1e-4, last_epoch=-1, verbose=False): 31 | self.gamma = gamma 32 | self.min_lr = min_lr 33 | super(ExponentialLR_with_minLr, self).__init__(optimizer, gamma, last_epoch, verbose) 34 | 35 | def get_lr(self): 36 | if not self._get_lr_called_within_step: 37 | warnings.warn("To get the last learning rate computed by the scheduler, " 38 | "please use `get_last_lr()`.", UserWarning) 39 | 40 | if self.last_epoch == 0: 41 | return self.base_lrs 42 | 43 | return [max(group['lr'] * self.gamma, self.min_lr) 44 | for group in self.optimizer.param_groups] 45 | 46 | def _get_closed_form_lr(self): 47 | return [max(base_lr * self.gamma ** self.last_epoch, self.min_lr) 48 | for base_lr in self.base_lrs] 49 | 50 | 51 | def get_optimizer(config, model): 52 | if config.type == "Adam": 53 | return torch.optim.Adam( 54 | filter(lambda p: p.requires_grad, model.parameters()), 55 | lr=config.lr, 56 | weight_decay=config.weight_decay) 57 | else: 58 | raise NotImplementedError('Optimizer not supported: %s' % config.type) 59 | 60 | 61 | 62 | 63 | def get_scheduler(config, optimizer): 64 | if config.type == 'plateau': 65 | return torch.optim.lr_scheduler.ReduceLROnPlateau( 66 | optimizer, 67 | factor=config.factor, 68 | patience=config.patience, 69 | ) 70 | elif config.type == 'expmin': 71 | return ExponentialLR_with_minLr( 72 | optimizer, 73 | gamma=config.factor, 74 | min_lr=config.min_lr, 75 | ) 76 | else: 77 | raise NotImplementedError('Scheduler not supported: %s' % config.type) 78 | -------------------------------------------------------------------------------- /confgen/confgf/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch_geometric.data import Data 4 | from torch_geometric.utils import to_dense_adj, dense_to_sparse 5 | from torch_sparse import coalesce 6 | from confgf import utils 7 | 8 | 9 | class AddHigherOrderEdges(object): 10 | 11 | def __init__(self, order, num_types=len(utils.BOND_TYPES)): 12 | super().__init__() 13 | self.order = order 14 | self.num_types = num_types 15 | 16 | def binarize(self, x): 17 | return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x)) 18 | 19 | def get_higher_order_adj_matrix(self, adj, order): 20 | """ 21 | Args: 22 | adj: (N, N) 23 | type_mat: (N, N) 24 | """ 25 | adj_mats = [torch.eye(adj.size(0), dtype=torch.long, device=adj.device), \ 26 | self.binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device))] 27 | 28 | for i in range(2, order+1): 29 | adj_mats.append(self.binarize(adj_mats[i-1] @ adj_mats[1])) 30 | order_mat = torch.zeros_like(adj) 31 | 32 | for i in range(1, order+1): 33 | order_mat += (adj_mats[i] - adj_mats[i-1]) * i 34 | 35 | return order_mat 36 | 37 | def __call__(self, data: Data): 38 | 39 | 40 | N = data.num_nodes 41 | adj = to_dense_adj(data.edge_index).squeeze(0) 42 | adj_order = self.get_higher_order_adj_matrix(adj, self.order) # (N, N) 43 | 44 | type_mat = to_dense_adj(data.edge_index, edge_attr=data.edge_type).squeeze(0) # (N, N) 45 | type_highorder = torch.where(adj_order > 1, self.num_types + adj_order - 1, torch.zeros_like(adj_order)) 46 | assert (type_mat * type_highorder == 0).all() 47 | type_new = type_mat + type_highorder 48 | 49 | new_edge_index, new_edge_type = dense_to_sparse(type_new) 50 | _, edge_order = dense_to_sparse(adj_order) 51 | 52 | data.bond_edge_index = data.edge_index # Save original edges 53 | data.edge_index, data.edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data 54 | edge_index_1, data.edge_order = coalesce(new_edge_index, edge_order.long(), N, N) # modify data 55 | data.is_bond = (data.edge_type < self.num_types) 56 | assert (data.edge_index == edge_index_1).all() 57 | 58 | return data 59 | 60 | class AddEdgeLength(object): 61 | 62 | def __call__(self, data: Data): 63 | 64 | pos = data.pos 65 | row, col = data.edge_index 66 | d = (pos[row] - pos[col]).norm(dim=-1).unsqueeze(-1) # (num_edge, 1) 67 | data.edge_length = d 68 | return data 69 | 70 | 71 | # Add attribute placeholder for data object, so that we can use batch.to_data_list 72 | class AddPlaceHolder(object): 73 | def __call__(self, data: Data): 74 | data.pos_gen = -1. * torch.ones_like(data.pos) 75 | data.d_gen = -1. * torch.ones_like(data.edge_length) 76 | data.d_recover = -1. * torch.ones_like(data.edge_length) 77 | return data 78 | 79 | 80 | class AddEdgeName(object): 81 | 82 | def __init__(self, asymmetric=True): 83 | super().__init__() 84 | self.bonds = copy.deepcopy(utils.BOND_NAMES) 85 | self.bonds[len(utils.BOND_NAMES) + 1] = 'Angle' 86 | self.bonds[len(utils.BOND_NAMES) + 2] = 'Dihedral' 87 | self.asymmetric = asymmetric 88 | 89 | def __call__(self, data:Data): 90 | data.edge_name = [] 91 | for i in range(data.edge_index.size(1)): 92 | tail = data.edge_index[0, i] 93 | head = data.edge_index[1, i] 94 | if self.asymmetric and tail >= head: 95 | data.edge_name.append('') 96 | continue 97 | tail_name = utils.get_atom_symbol(data.atom_type[tail].item()) 98 | head_name = utils.get_atom_symbol(data.atom_type[head].item()) 99 | name = '%s_%s_%s_%d_%d' % ( 100 | self.bonds[data.edge_type[i].item()] if data.edge_type[i].item() in self.bonds else 'E'+str(data.edge_type[i].item()), 101 | tail_name, 102 | head_name, 103 | tail, 104 | head, 105 | ) 106 | if hasattr(data, 'edge_length'): 107 | name += '_%.3f' % (data.edge_length[i].item()) 108 | data.edge_name.append(name) 109 | return data 110 | 111 | 112 | 113 | class AddAngleDihedral(object): 114 | 115 | def __init__(self): 116 | super().__init__() 117 | 118 | @staticmethod 119 | def iter_angle_triplet(bond_mat): 120 | n_atoms = bond_mat.size(0) 121 | for j in range(n_atoms): 122 | for k in range(n_atoms): 123 | for l in range(n_atoms): 124 | if bond_mat[j, k].item() == 0 or bond_mat[k, l].item() == 0: continue 125 | if (j == k) or (k == l) or (j >= l): continue 126 | yield(j, k, l) 127 | 128 | @staticmethod 129 | def iter_dihedral_quartet(bond_mat): 130 | n_atoms = bond_mat.size(0) 131 | for i in range(n_atoms): 132 | for j in range(n_atoms): 133 | if i >= j: continue 134 | if bond_mat[i,j].item() == 0:continue 135 | for k in range(n_atoms): 136 | for l in range(n_atoms): 137 | if (k in (i,j)) or (l in (i,j)): continue 138 | if bond_mat[k,i].item() == 0 or bond_mat[l,j].item() == 0: continue 139 | yield(k, i, j, l) 140 | 141 | def __call__(self, data:Data): 142 | N = data.num_nodes 143 | if 'is_bond' in data: 144 | bond_mat = to_dense_adj(data.edge_index, edge_attr=data.is_bond).long().squeeze(0) > 0 145 | else: 146 | bond_mat = to_dense_adj(data.edge_index, edge_attr=data.edge_type).long().squeeze(0) > 0 147 | 148 | # Note: if the name of attribute contains `index`, it will automatically 149 | # increases during batching. 150 | data.angle_index = torch.LongTensor(list(self.iter_angle_triplet(bond_mat))).t() 151 | data.dihedral_index = torch.LongTensor(list(self.iter_dihedral_quartet(bond_mat))).t() 152 | 153 | return data 154 | 155 | 156 | class CountNodesPerGraph(object): 157 | 158 | def __init__(self) -> None: 159 | super().__init__() 160 | 161 | def __call__(self, data): 162 | data.num_nodes_per_graph = torch.LongTensor([data.num_nodes]) 163 | return data -------------------------------------------------------------------------------- /confgen/config/drugs_clofnet.yml: -------------------------------------------------------------------------------- 1 | train: 2 | batch_size: 128 3 | seed: 2021 4 | epochs: 400 5 | shuffle: true 6 | resume_train: false 7 | eval: true 8 | num_workers: 3 9 | gpus: 10 | - 0 11 | - null 12 | - null 13 | - null 14 | anneal_power: 2.0 15 | save: true 16 | save_path: root/to/save 17 | resume_checkpoint: null 18 | resume_epoch: null 19 | log_interval: 400 20 | optimizer: 21 | type: Adam 22 | lr: 0.001 23 | weight_decay: 0.0000 24 | dropout: 0.0 25 | scheduler: 26 | type: plateau 27 | factor: 0.6 28 | patience: 10 29 | min_lr: 1e-4 30 | loss: 31 | position: 1 32 | distance: 0 33 | curl: 0 34 | wandb: 35 | Enable: True 36 | Project: Molecular-Generation 37 | Name: clofnet4drugs 38 | 39 | 40 | test: 41 | init_checkpoint: root/to/checkpoint 42 | output_path: root/to/generation_files 43 | gen: 44 | steps_pos: 100 45 | step_lr_pos: 0.000002 46 | clip: 1000 47 | min_sigma: 0.0 48 | verbose: 1 49 | repeat: 2 50 | 51 | 52 | data: 53 | base_path: root/to/dataset 54 | dataset: drugs 55 | train_set: train_data_39k.pkl 56 | val_set: val_data_5k.pkl 57 | test_set: test_data_200.pkl 58 | 59 | 60 | model: 61 | hidden_dim: 288 62 | num_convs: 4 63 | sigma_begin: 10 64 | sigma_end: 0.01 65 | num_noise_level: 50 66 | order: 3 67 | mlp_act: relu 68 | gnn_act: relu 69 | cutoff: 10.0 70 | short_cut: true 71 | concat_hidden: false 72 | noise_type: rand 73 | edge_encoder: mlp 74 | 75 | 76 | -------------------------------------------------------------------------------- /confgen/config/qm9_clofnet.yml: -------------------------------------------------------------------------------- 1 | train: 2 | batch_size: 128 3 | seed: 2021 4 | epochs: 400 5 | shuffle: true 6 | resume_train: false 7 | eval: true 8 | num_workers: 3 9 | gpus: 10 | - 0 11 | - null 12 | - null 13 | - null 14 | anneal_power: 2.0 15 | save: true 16 | save_path: root/to/save 17 | resume_checkpoint: null 18 | resume_epoch: null 19 | log_interval: 400 20 | optimizer: 21 | type: Adam 22 | lr: 0.001 23 | weight_decay: 0.0000 24 | dropout: 0.0 25 | scheduler: 26 | type: plateau 27 | factor: 0.6 28 | patience: 10 29 | min_lr: 1e-4 30 | loss: 31 | position: 1 32 | distance: 0 33 | curl: 0 34 | wandb: 35 | Enable: False 36 | Project: Molecular-Generation 37 | Name: clofnet4qm9 38 | 39 | 40 | test: 41 | init_checkpoint: root/to/checkpoint 42 | output_path: root/to/generation_files 43 | gen: 44 | steps_pos: 100 45 | step_lr_pos: 0.000002 46 | clip: 1000 47 | min_sigma: 0.0 48 | verbose: 1 49 | repeat: 2 50 | 51 | 52 | data: 53 | base_path: root/to/dataset 54 | dataset: qm9 55 | train_set: train_data_40k.pkl 56 | val_set: val_data_5k.pkl 57 | test_set: test_data_200.pkl 58 | 59 | 60 | model: 61 | hidden_dim: 288 62 | num_convs: 4 63 | sigma_begin: 10 64 | sigma_end: 0.01 65 | num_noise_level: 50 66 | order: 3 67 | mlp_act: relu 68 | gnn_act: relu 69 | cutoff: 10.0 70 | short_cut: true 71 | concat_hidden: false 72 | noise_type: rand 73 | edge_encoder: mlp -------------------------------------------------------------------------------- /confgen/script/gen.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | import os 3 | import sys 4 | project_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) 5 | print('project path is {}'.format(project_path)) 6 | sys.path.append(project_path) 7 | import argparse 8 | import numpy as np 9 | import random 10 | import pickle 11 | import yaml 12 | from easydict import EasyDict 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | from torch.autograd import Variable 19 | from torch_geometric.data import Data, Dataset 20 | from torch_geometric.transforms import Compose 21 | 22 | from confgf import models, dataset, runner, utils 23 | 24 | 25 | if __name__ == '__main__': 26 | 27 | parser = argparse.ArgumentParser(description='clofnet') 28 | parser.add_argument('--config_path', type=str, help='path of dataset', required=True) 29 | parser.add_argument('--generator', type=str, help='type of generator [EquiGF, EquiPCGF]', required=True) 30 | parser.add_argument('--num_repeat', type=int, default=None, help='end idx of test generation') 31 | parser.add_argument('--eval_epoch', type=int, default=None, help='evaluation epoch') 32 | parser.add_argument('--start', type=int, default=-1, help='start idx of test generation') 33 | parser.add_argument('--end', type=int, default=-1, help='end idx of test generation') 34 | parser.add_argument('--smiles', type=str, default=None, help='smiles for generation') 35 | parser.add_argument('--seed', type=int, default=2021, help='overwrite config seed') 36 | 37 | args = parser.parse_args() 38 | with open(args.config_path, 'r') as f: 39 | config = yaml.safe_load(f) 40 | config = EasyDict(config) 41 | 42 | if args.seed != 2021: 43 | config.train.seed = args.seed 44 | 45 | if config.test.output_path is not None: 46 | config.test.output_path = os.path.join(config.test.output_path, config.train.Name) 47 | if not os.path.exists(config.test.output_path): 48 | os.makedirs(config.test.output_path) 49 | 50 | # check device 51 | gpus = list(filter(lambda x: x is not None, config.train.gpus)) 52 | assert torch.cuda.device_count() >= len(gpus), 'do you set the gpus in config correctly?' 53 | device = torch.device(gpus[0]) if len(gpus) > 0 else torch.device('cpu') 54 | print("Let's use", len(gpus), "GPUs!") 55 | print("Using device %s as main device" % device) 56 | config.train.device = device 57 | config.train.gpus = gpus 58 | config.train.wandb.Enable = False 59 | 60 | print(config) 61 | 62 | # set random seed 63 | np.random.seed(config.train.seed) 64 | random.seed(config.train.seed) 65 | torch.manual_seed(config.train.seed) 66 | if torch.cuda.is_available(): 67 | torch.cuda.manual_seed(config.train.seed) 68 | torch.cuda.manual_seed_all(config.train.seed) 69 | torch.backends.cudnn.benchmark = True 70 | print('set seed for random, numpy and torch') 71 | 72 | 73 | load_path = os.path.join(config.data.base_path, '%s_processed' % config.data.dataset) 74 | print('loading data from %s' % load_path) 75 | 76 | train_data = [] 77 | val_data = [] 78 | test_data = [] 79 | 80 | if config.data.test_set is not None: 81 | with open(os.path.join(load_path, config.data.test_set), "rb") as fin: 82 | test_data = pickle.load(fin) 83 | else: 84 | raise ValueError("do you set the test data ?") 85 | 86 | print('train size : %d || val size: %d || test size: %d ' % (len(train_data), len(val_data), len(test_data))) 87 | print('loading data done!') 88 | 89 | transform = Compose([ 90 | utils.AddHigherOrderEdges(order=config.model.order), 91 | utils.AddEdgeLength(), 92 | utils.AddPlaceHolder(), 93 | utils.AddEdgeName() 94 | ]) 95 | train_data = dataset.GEOMDataset(data=train_data, transform=transform) 96 | val_data = dataset.GEOMDataset(data=val_data, transform=transform) 97 | test_data = dataset.GEOMDataset_PackedConf(data=test_data, transform=transform) 98 | print('len of test data: %d' % len(test_data)) 99 | 100 | model = models.EquiDistanceScoreMatch(config) 101 | 102 | optimizer = None 103 | scheduler = None 104 | 105 | solver = runner.EquiRunner(train_data, val_data, test_data, model, optimizer, scheduler, gpus, config) 106 | 107 | assert config.test.init_checkpoint is not None 108 | init_checkpoint = os.path.join(config.test.init_checkpoint, config.train.Name) 109 | solver.load(init_checkpoint, epoch=args.eval_epoch) 110 | 111 | if args.smiles is not None: 112 | solver.generate_samples_from_smiles(args.smiles, args.generator, \ 113 | num_repeat=1, keep_traj=True, 114 | out_path=config.test.output_path) 115 | 116 | if args.start != -1 and args.end != -1: 117 | solver.generate_samples_from_testset(args.start, args.end, args.eval_epoch, \ 118 | args.generator, num_repeat=args.num_repeat, \ 119 | out_path=config.test.output_path) 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /confgen/script/get_rdkit_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import torch 5 | import pickle 6 | 7 | import copy 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | import rdkit 12 | from rdkit.Chem import AllChem 13 | from rdkit import Chem 14 | 15 | from confgf import utils, dataset 16 | 17 | import multiprocessing 18 | from functools import partial 19 | 20 | def generate_conformers(mol, num_confs): 21 | mol = copy.deepcopy(mol) 22 | mol.RemoveAllConformers() 23 | assert mol.GetNumConformers() == 0 24 | 25 | AllChem.EmbedMultipleConfs( 26 | mol, 27 | numConfs=num_confs, 28 | maxAttempts=0, 29 | ignoreSmoothingFailures=True, 30 | ) 31 | if mol.GetNumConformers() != num_confs: 32 | print('Warning: Failure cases occured, generated: %d , expected: %d.' % (mol.GetNumConformers(), num_confs, )) 33 | 34 | return mol 35 | 36 | 37 | 38 | if __name__ == '__main__': 39 | 40 | parser = argparse.ArgumentParser(description='confgf') 41 | parser.add_argument('--input', type=str, required=True) 42 | parser.add_argument('--output', type=str, required=True) 43 | parser.add_argument('--start_idx', type=int, default=0) 44 | parser.add_argument('--num_samples', type=int, default=50) 45 | 46 | parser.add_argument('--eval', action='store_true', default=False) 47 | parser.add_argument('--core', type=int, default=6) 48 | parser.add_argument('--FF', action='store_true') 49 | parser.add_argument('--threshold', type=float, default=0.5) 50 | args = parser.parse_args() 51 | print(args) 52 | 53 | 54 | 55 | with open(args.input, 'rb') as f: 56 | data_raw = pickle.load(f) 57 | if 'pos_ref' in data_raw[0]: 58 | data_list = data_raw 59 | else: 60 | data_list = dataset.GEOMDataset_PackedConf(data_raw) 61 | 62 | 63 | generated_data_list = [] 64 | for i in tqdm(range(args.start_idx, len(data_list))): 65 | return_data = copy.deepcopy(data_list[i]) 66 | 67 | if args.num_samples > 0: 68 | num_confs = args.num_samples 69 | else: 70 | num_confs = -args.num_samples*return_data.num_pos_ref.item() 71 | mol = generate_conformers(return_data.rdmol, num_confs=num_confs) 72 | num_pos_gen = mol.GetNumConformers() 73 | all_pos = [] 74 | 75 | if num_pos_gen == 0: 76 | continue 77 | 78 | for j in range(num_pos_gen): 79 | all_pos.append(torch.tensor(mol.GetConformer(j).GetPositions(), dtype=torch.float32)) 80 | 81 | return_data.pos_gen = torch.cat(all_pos, 0) # (num_pos_gen * num_node, 3) 82 | return_data.num_pos_gen = torch.tensor([len(all_pos)], dtype=torch.long) 83 | generated_data_list.append(return_data) 84 | 85 | with open(args.output, "wb") as fout: 86 | pickle.dump(generated_data_list, fout) 87 | print('save generated conf to %s done!' % args.output) 88 | 89 | 90 | 91 | 92 | if args.eval: 93 | print('start getting results!') 94 | 95 | with open(args.output, 'rb') as fin: 96 | data_list = pickle.load(fin) 97 | bad_case = 0 98 | 99 | 100 | filtered_data_list = [] 101 | for i in tqdm(range(len(data_list))): 102 | if '.' in data_list[i].smiles: 103 | bad_case += 1 104 | continue 105 | filtered_data_list.append(data_list[i]) 106 | 107 | cnt_conf = 0 108 | for i in range(len(filtered_data_list)): 109 | cnt_conf += filtered_data_list[i].num_pos_ref 110 | print('%d bad cases, use %d mols with total %d confs' % (bad_case, len(filtered_data_list), cnt_conf)) 111 | 112 | 113 | pool = multiprocessing.Pool(args.core) 114 | 115 | func = partial(utils.evaluate_conf, useFF=args.FF, threshold=args.threshold) 116 | 117 | 118 | covs = [] 119 | mats = [] 120 | for result in tqdm(pool.imap(func, filtered_data_list), total=len(filtered_data_list)): 121 | covs.append(result[0]) 122 | mats.append(result[1]) 123 | covs = np.array(covs) 124 | mats = np.array(mats) 125 | 126 | print('Coverage Mean: %.4f | Coverage Median: %.4f | Match Mean: %.4f | Match Median: %.4f' % \ 127 | (covs.mean(), np.median(covs), mats.mean(), np.median(mats))) 128 | pool.close() 129 | pool.join() 130 | 131 | 132 | -------------------------------------------------------------------------------- /confgen/script/get_task1_results.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | import os 3 | import sys 4 | project_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) 5 | print('project path is {}'.format(project_path)) 6 | sys.path.append(project_path) 7 | from time import time 8 | from tqdm import tqdm 9 | import argparse 10 | import numpy as np 11 | import pickle 12 | import pandas as pd 13 | from confgf import utils 14 | 15 | import multiprocessing 16 | from functools import partial 17 | 18 | 19 | if __name__ == '__main__': 20 | 21 | parser = argparse.ArgumentParser(description='confgf') 22 | parser.add_argument('--input', type=str) 23 | parser.add_argument('--core', type=int, default=6) 24 | parser.add_argument('--threshold', type=float, default=0.5, help='threshold of COV score') 25 | parser.add_argument('--FF', action='store_true', help='only for rdkit') 26 | 27 | args = parser.parse_args() 28 | print(args) 29 | 30 | with open(args.input, 'rb') as fin: 31 | data_list = pickle.load(fin) 32 | # assert len(data_list) == 200 33 | 34 | bad_case = 0 35 | filtered_data_list = [] 36 | for i in tqdm(range(len(data_list))): 37 | if '.' in data_list[i].smiles: 38 | bad_case += 1 39 | continue 40 | filtered_data_list.append(data_list[i]) 41 | 42 | cnt_conf = 0 43 | for i in range(len(filtered_data_list)): 44 | cnt_conf += filtered_data_list[i].num_pos_ref.item() 45 | print('%d bad cases, use %d mols with total %d confs' % (bad_case, len(filtered_data_list), cnt_conf)) 46 | 47 | pool = multiprocessing.Pool(args.core) 48 | func = partial(utils.evaluate_conf, useFF=args.FF, threshold=args.threshold) 49 | 50 | covs = [] 51 | mats = [] 52 | for result in tqdm(pool.imap(func, filtered_data_list), total=len(filtered_data_list)): 53 | covs.append(result[0]) 54 | mats.append(result[1]) 55 | covs = np.array(covs) 56 | mats = np.array(mats) 57 | 58 | print('Coverage Mean: %.4f | Coverage Median: %.4f | Match Mean: %.4f | Match Median: %.4f' % \ 59 | (covs.mean(), np.median(covs), mats.mean(), np.median(mats))) 60 | pool.close() 61 | pool.join() 62 | 63 | results = pd.DataFrame(columns=['ID', 'coverage', 'match']) 64 | results['ID'] = range(len(covs)) 65 | results['coverage'] = covs 66 | results['match'] = mats 67 | root_id = '/'.join(args.input.split('/')[:-1]) 68 | exp_id = args.input.split('/')[-1] 69 | exp_id = exp_id.split('.')[0] 70 | save_path = os.path.join(root_id, exp_id + '.csv') 71 | results.to_csv(save_path) 72 | -------------------------------------------------------------------------------- /confgen/script/get_task2_results.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | import os 3 | import sys 4 | project_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) 5 | print('project path is {}'.format(project_path)) 6 | sys.path.append(project_path) 7 | from time import time 8 | from tqdm import tqdm 9 | import argparse 10 | import numpy as np 11 | import random 12 | import math 13 | import json 14 | import pickle 15 | import yaml 16 | from easydict import EasyDict 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import torch.optim as optim 22 | from torch.autograd import Variable 23 | from torch_geometric.data import Data, Dataset 24 | from torch_geometric.transforms import Compose 25 | from confgf import utils, dataset 26 | 27 | import multiprocessing 28 | from functools import partial 29 | 30 | if __name__ == '__main__': 31 | 32 | 33 | multiprocessing.set_start_method('spawn', force=True) 34 | 35 | parser = argparse.ArgumentParser(description='confgf') 36 | parser.add_argument('--input', type=str) 37 | parser.add_argument('--core', type=int, default=1, help='path of dataset') 38 | args = parser.parse_args() 39 | print(args) 40 | 41 | with open(args.input, 'rb') as fin: 42 | data_list = pickle.load(fin) 43 | print(len(data_list)) 44 | 45 | 46 | bad_case = 0 47 | 48 | filtered_data_list = [] 49 | for i in tqdm(range(len(data_list))): 50 | if '.' in data_list[i].smiles: 51 | bad_case += 1 52 | continue 53 | # filter corrupted mols with #confs less than 1000 54 | if data_list[i].num_pos_ref < 1000: 55 | bad_case += 1 56 | continue 57 | filtered_data_list.append(data_list[i]) 58 | 59 | cnt_conf = 0 60 | for i in range(len(filtered_data_list)): 61 | cnt_conf += filtered_data_list[i].num_pos_ref 62 | print('%d bad cases, use %d mols with total %d confs' % (bad_case, len(filtered_data_list), cnt_conf)) 63 | 64 | pool = multiprocessing.Pool(args.core) 65 | func = partial(utils.evaluate_distance, ignore_H=True) 66 | 67 | 68 | s_mmd_all = [] 69 | p_mmd_all = [] 70 | a_mmd_all = [] 71 | 72 | 73 | for result in tqdm(pool.imap(func, filtered_data_list), total=len(filtered_data_list)): 74 | stats_single, stats_pair, stats_all = result 75 | s_mmd_all += [e['mmd'] for e in stats_single] 76 | p_mmd_all += [e['mmd'] for e in stats_pair] 77 | a_mmd_all.append(stats_all['mmd']) 78 | 79 | print('SingleDist | Mean: %.4f | Median: %.4f | Min: %.4f | Max: %.4f' % \ 80 | (np.mean(s_mmd_all), np.median(s_mmd_all), np.min(s_mmd_all), np.max(s_mmd_all))) 81 | print('PairDist | Mean: %.4f | Median: %.4f | Min: %.4f | Max: %.4f' % \ 82 | (np.mean(p_mmd_all), np.median(p_mmd_all), np.min(p_mmd_all), np.max(p_mmd_all))) 83 | print('AllDist | Mean: %.4f | Median: %.4f | Min: %.4f | Max: %.4f' % \ 84 | (np.mean(a_mmd_all), np.median(a_mmd_all), np.min(a_mmd_all), np.max(a_mmd_all))) 85 | 86 | 87 | pool.close() 88 | pool.join() 89 | -------------------------------------------------------------------------------- /confgen/script/process_GEOM_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pickle 4 | 5 | from confgf import dataset 6 | 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--base_path', type=str) 11 | parser.add_argument('--dataset_name', type=str, choices=['qm9', 'drugs']) 12 | parser.add_argument('--tot_mol_size', type=int, default=50000) 13 | parser.add_argument('--conf_per_mol', type=int, default=5) 14 | parser.add_argument('--train_size', type=float, default=0.8) 15 | parser.add_argument('--test_mol_size', type=int, default=200) 16 | parser.add_argument('--confmin', type=int, default=50) 17 | parser.add_argument('--confmax', type=int, default=500) 18 | args = parser.parse_args() 19 | rdkit_folder_path = os.path.join(args.base_path, 'rdkit_folder') 20 | 21 | train_data, val_data, test_data, index2split = dataset.preprocess_GEOM_dataset(rdkit_folder_path, args.dataset_name, \ 22 | conf_per_mol=args.conf_per_mol, train_size=args.train_size, \ 23 | tot_mol_size=args.tot_mol_size, seed=2021) 24 | 25 | processed_data_path = os.path.join(args.base_path, '%s_processed' % args.dataset_name) 26 | os.makedirs(processed_data_path, exist_ok=True) 27 | 28 | # save train and val data 29 | with open(os.path.join(processed_data_path, 'train_data_%dk.pkl' % ((len(train_data) // args.conf_per_mol) // 1000)), "wb") as fout: 30 | pickle.dump(train_data, fout) 31 | print('save train %dk done' % ((len(train_data) // args.conf_per_mol) // 1000)) 32 | 33 | with open(os.path.join(processed_data_path, 'val_data_%dk.pkl' % ((len(val_data) // args.conf_per_mol) // 1000)), "wb") as fout: 34 | pickle.dump(val_data, fout) 35 | print('save val %dk done' % ((len(val_data) // args.conf_per_mol) // 1000)) 36 | del test_data 37 | 38 | # filter test data 39 | test_data = dataset.get_GEOM_testset(rdkit_folder_path, args.dataset_name, block=[train_data, val_data], \ 40 | tot_mol_size=args.test_mol_size, seed=2021, \ 41 | confmin=args.confmin, confmax=args.confmax) 42 | with open(os.path.join(processed_data_path, 'test_data_%d.pkl' % (args.test_mol_size)), "wb") as fout: 43 | pickle.dump(test_data, fout) 44 | print('save test %d done' % (args.test_mol_size)) 45 | 46 | 47 | -------------------------------------------------------------------------------- /confgen/script/process_iso17_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pickle 4 | 5 | from confgf import dataset 6 | 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--input', type=str) 11 | args = parser.parse_args() 12 | base_path = args.input 13 | 14 | train_data, test_data = dataset.preprocess_iso17_dataset(base_path) 15 | 16 | with open(os.path.join(base_path, 'iso17_split-0_train_processed.pkl'), "wb") as fout: 17 | pickle.dump(train_data, fout) 18 | print('save train done') 19 | 20 | with open(os.path.join(base_path, 'iso17_split-0_test_processed.pkl'), "wb") as fout: 21 | pickle.dump(test_data, fout) 22 | print('save test done') 23 | 24 | 25 | -------------------------------------------------------------------------------- /confgen/script/train.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | import os 3 | import sys 4 | project_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) 5 | print('project path is {}'.format(project_path)) 6 | sys.path.append(project_path) 7 | import argparse 8 | import numpy as np 9 | import random 10 | import pickle 11 | import yaml 12 | from easydict import EasyDict 13 | 14 | import torch 15 | from confgf import models, dataset, runner, utils 16 | 17 | 18 | if __name__ == '__main__': 19 | 20 | parser = argparse.ArgumentParser(description='clofnet') 21 | parser.add_argument('--config_path', type=str, help='path of dataset', required=True) 22 | parser.add_argument('--seed', type=int, default=2021, help='overwrite config seed') 23 | 24 | args = parser.parse_args() 25 | with open(args.config_path, 'r') as f: 26 | config = yaml.safe_load(f) 27 | config = EasyDict(config) 28 | 29 | if args.seed != 2021: 30 | config.train.seed = args.seed 31 | 32 | if config.train.save and config.train.save_path is not None: 33 | config.train.save_path = os.path.join(config.train.save_path, config.train.Name) 34 | if not os.path.exists(config.train.save_path): 35 | os.makedirs(config.train.save_path) 36 | 37 | 38 | # check device 39 | gpus = list(filter(lambda x: x is not None, config.train.gpus)) 40 | assert torch.cuda.device_count() >= len(gpus), 'do you set the gpus in config correctly?' 41 | device = torch.device(gpus[0]) if len(gpus) > 0 else torch.device('cpu') 42 | print("Let's use", len(gpus), "GPUs!") 43 | print("Using device %s as main device" % device) 44 | config.train.device = device 45 | config.train.gpus = gpus 46 | 47 | print(config) 48 | 49 | # set random seed 50 | np.random.seed(config.train.seed) 51 | random.seed(config.train.seed) 52 | torch.manual_seed(config.train.seed) 53 | if torch.cuda.is_available(): 54 | torch.cuda.manual_seed(config.train.seed) 55 | torch.cuda.manual_seed_all(config.train.seed) 56 | torch.backends.cudnn.benchmark = True 57 | print('set seed for random, numpy and torch') 58 | 59 | 60 | load_path = os.path.join(config.data.base_path, '%s_processed' % config.data.dataset) 61 | print('loading data from %s' % load_path) 62 | 63 | train_data = [] 64 | val_data = [] 65 | test_data = [] 66 | 67 | if config.data.train_set is not None: 68 | with open(os.path.join(load_path, config.data.train_set), "rb") as fin: 69 | train_data = pickle.load(fin) 70 | if config.data.val_set is not None: 71 | with open(os.path.join(load_path, config.data.val_set), "rb") as fin: 72 | val_data = pickle.load(fin) 73 | print('train size : %d || val size: %d || test size: %d ' % (len(train_data), len(val_data), len(test_data))) 74 | print('loading data done!') 75 | 76 | transform = None 77 | train_data = dataset.GEOMDataset(data=train_data, transform=transform) 78 | val_data = dataset.GEOMDataset(data=val_data, transform=transform) 79 | test_data = dataset.GEOMDataset_PackedConf(data=test_data, transform=transform) 80 | 81 | model = models.EquiDistanceScoreMatch(config) 82 | optimizer = utils.get_optimizer(config.train.optimizer, model) 83 | scheduler = utils.get_scheduler(config.train.scheduler, optimizer) 84 | 85 | solver = runner.EquiRunner(train_data, val_data, test_data, model, optimizer, scheduler, gpus, config) 86 | if config.train.resume_train: 87 | solver.load(config.train.resume_checkpoint, epoch=config.train.resume_epoch, load_optimizer=True, load_scheduler=True) 88 | solver.train() 89 | 90 | 91 | -------------------------------------------------------------------------------- /main_newtonian.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import logging 5 | import argparse 6 | 7 | import torch 8 | from torch import nn, optim 9 | from newtonian.dataset4newton import NBodyDataset 10 | from newtonian.gnn import GNN, RF_vel 11 | from newtonian.egnn import EGNN, EGNN_vel 12 | from newtonian.clof import ClofNet, ClofNet_vel, ClofNet_vel_gbf 13 | 14 | parser = argparse.ArgumentParser(description='VAE MNIST Example') 15 | parser.add_argument('--exp_name', type=str, default='exp_1', metavar='N', help='experiment_name') 16 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', 17 | help='input batch size for training (default: 128)') 18 | parser.add_argument('--epochs', type=int, default=10000, metavar='N', 19 | help='number of epochs to train (default: 10)') 20 | parser.add_argument('--no-cuda', action='store_true', default=False, 21 | help='enables CUDA training') 22 | parser.add_argument('--seed', type=int, default=1, metavar='S', 23 | help='random seed (default: 1)') 24 | parser.add_argument('--log_interval', type=int, default=1, metavar='N', 25 | help='how many batches to wait before logging training status') 26 | parser.add_argument('--test_interval', type=int, default=5, metavar='N', 27 | help='how many epochs to wait before logging test') 28 | parser.add_argument('--outf', type=str, default='saved', metavar='N', 29 | help='folder to output') 30 | parser.add_argument('--data_mode', type=str, default='small', metavar='N', 31 | help='folder to dataset') 32 | parser.add_argument('--data_root', type=str, default='dataset/clofnet_dataset', metavar='N', 33 | help='folder to dataset root') 34 | parser.add_argument('--lr', type=float, default=5e-4, metavar='N', 35 | help='learning rate') 36 | parser.add_argument('--nf', type=int, default=64, metavar='N', 37 | help='learning rate') 38 | parser.add_argument('--model', type=str, default='egnn_vel', metavar='N', 39 | help='available models: gnn, baseline, linear, linear_vel, se3_transformer, egnn_vel, rf_vel, tfn') 40 | parser.add_argument('--attention', type=int, default=0, metavar='N', 41 | help='attention in the ae model') 42 | parser.add_argument('--n_layers', type=int, default=4, metavar='N', 43 | help='number of layers for the autoencoder') 44 | parser.add_argument('--degree', type=int, default=2, metavar='N', 45 | help='degree of the TFN and SE3') 46 | parser.add_argument('--max_training_samples', type=int, default=3000, metavar='N', 47 | help='maximum amount of training samples') 48 | parser.add_argument('--sweep_training', type=int, default=0, metavar='N', 49 | help='0 nor sweep, 1 sweep, 2 sweep small') 50 | parser.add_argument('--time_exp', type=int, default=0, metavar='N', 51 | help='timing experiment') 52 | parser.add_argument('--weight_decay', type=float, default=1e-12, metavar='N', 53 | help='timing experiment') 54 | parser.add_argument('--div', type=float, default=1, metavar='N', 55 | help='timing experiment') 56 | parser.add_argument('--norm_diff', type=eval, default=False, metavar='N', 57 | help='normalize_diff') 58 | parser.add_argument('--tanh', type=eval, default=False, metavar='N', 59 | help='use tanh') 60 | parser.add_argument('--LR_decay', type=eval, default=False, metavar='N', 61 | help='LR_decay') 62 | parser.add_argument('--decay', type=float, default=0.1, metavar='N', 63 | help='learning rate decay') 64 | 65 | time_exp_dic = {'time': 0, 'counter': 0} 66 | args = parser.parse_args() 67 | args.cuda = not args.no_cuda and torch.cuda.is_available() 68 | 69 | device = torch.device("cuda" if args.cuda else "cpu") 70 | loss_mse = nn.MSELoss() 71 | 72 | print(args) 73 | try: 74 | os.makedirs(args.outf) 75 | except OSError: 76 | pass 77 | 78 | try: 79 | os.makedirs(args.outf + "/" + args.exp_name) 80 | except OSError: 81 | pass 82 | 83 | # prepare data root and save path for checkpoint 84 | data_root = os.path.join(args.data_root, args.data_mode) 85 | checkpoint_path = os.path.join(args.outf, args.exp_name, 'checkpoint') 86 | os.makedirs(checkpoint_path, exist_ok=True) 87 | 88 | def get_velocity_attr(loc, vel, rows, cols): 89 | 90 | diff = loc[cols] - loc[rows] 91 | norm = torch.norm(diff, p=2, dim=1).unsqueeze(1) 92 | u = diff/norm 93 | va, vb = vel[rows] * u, vel[cols] * u 94 | va, vb = torch.sum(va, dim=1).unsqueeze(1), torch.sum(vb, dim=1).unsqueeze(1) 95 | return va 96 | 97 | def main(): 98 | logging.basicConfig( 99 | filename=os.path.join(args.outf, args.exp_name, "training.log"), 100 | format="%(asctime)s - %(levelname)s - %(message)s", 101 | filemode='w', 102 | level=logging.INFO, 103 | ) 104 | logging.info(f'load data from {data_root}') 105 | logging.info(f'save checkpoints to {checkpoint_path}') 106 | 107 | dataset_train = NBodyDataset(partition='train', max_samples=args.max_training_samples, data_root=data_root, data_mode=args.data_mode) 108 | loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, drop_last=True) 109 | 110 | dataset_val = NBodyDataset(partition='valid', data_root=data_root, data_mode=args.data_mode) 111 | loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=args.batch_size, shuffle=False, drop_last=False) 112 | 113 | dataset_test = NBodyDataset(partition='test', data_root=data_root, data_mode=args.data_mode) 114 | loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False, drop_last=False) 115 | 116 | 117 | if args.model == 'gnn': 118 | model = GNN(input_dim=6, hidden_nf=args.nf, n_layers=args.n_layers, device=device, recurrent=True) 119 | elif args.model == 'egnn': 120 | model = EGNN(in_node_nf=1, in_edge_nf=2, hidden_nf=args.nf, device='cpu', n_layers=args.n_layers) 121 | elif args.model == 'egnn_vel': 122 | model = EGNN_vel(in_node_nf=1, in_edge_nf=2, hidden_nf=args.nf, device=device, n_layers=args.n_layers, recurrent=True, norm_diff=args.norm_diff, tanh=args.tanh) 123 | elif args.model == 'rf_vel': 124 | model = RF_vel(hidden_nf=args.nf, edge_attr_nf=2, device=device, act_fn=nn.SiLU(), n_layers=args.n_layers) 125 | elif args.model == 'clof': 126 | model = ClofNet(in_node_nf=1, in_edge_nf=2, hidden_nf=args.nf, n_layers=args.n_layers, device=device, recurrent=True, norm_diff=args.norm_diff, tanh=args.tanh) 127 | elif args.model == 'clof_vel': 128 | model = ClofNet_vel(in_node_nf=1, in_edge_nf=2, hidden_nf=args.nf, n_layers=args.n_layers, device=device, recurrent=True, norm_diff=args.norm_diff, tanh=args.tanh) 129 | elif args.model == 'clof_vel_gbf': 130 | model = ClofNet_vel_gbf(in_node_nf=1, in_edge_nf=2, hidden_nf=args.nf, n_layers=args.n_layers, device=device, recurrent=True, norm_diff=args.norm_diff, tanh=args.tanh) 131 | else: 132 | raise Exception("Wrong model specified") 133 | 134 | logging.info(args) 135 | print(model) 136 | logging.info(model) 137 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 138 | step_size = int(args.epochs // 8) 139 | if args.LR_decay: 140 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=args.decay, last_epoch=-1) 141 | 142 | results = {'epochs': [], 'losess': []} 143 | best_val_loss = 1e8 144 | best_test_loss = 1e8 145 | best_epoch = 0 146 | for epoch in range(0, args.epochs): 147 | loss = train(model, optimizer, epoch, loader_train) 148 | if args.LR_decay: 149 | scheduler.step() 150 | if epoch % args.test_interval == 0: 151 | val_loss = train(model, optimizer, epoch, loader_val, backprop=False) 152 | test_loss = train(model, optimizer, epoch, loader_test, backprop=False) 153 | results['epochs'].append(epoch) 154 | results['losess'].append(test_loss) 155 | 156 | if val_loss < best_val_loss: 157 | best_val_loss = val_loss 158 | best_test_loss = test_loss 159 | best_epoch = epoch 160 | torch.save(model, os.path.join(checkpoint_path, 'best_model.pt')) 161 | print("*** Best Val Loss: %.5f \t Best Test Loss: %.5f \t Best epoch %d" % (best_val_loss, best_test_loss, best_epoch)) 162 | logging.info("*** Best Val Loss: %.5f \t Best Test Loss: %.5f \t Best epoch %d" % (best_val_loss, best_test_loss, best_epoch)) 163 | 164 | json_object = json.dumps(results, indent=4) 165 | with open(args.outf + "/" + args.exp_name + "/losess.json", "w") as outfile: 166 | outfile.write(json_object) 167 | return best_val_loss, best_test_loss, best_epoch 168 | 169 | 170 | def train(model, optimizer, epoch, loader, backprop=True): 171 | if backprop: 172 | model.train() 173 | else: 174 | model.eval() 175 | 176 | res = {'epoch': epoch, 'loss': 0, 'coord_reg': 0, 'counter': 0} 177 | 178 | for batch_idx, data in enumerate(loader): 179 | batch_size, n_nodes, _ = data[0].size() 180 | data = [d.to(device) for d in data] 181 | data = [d.view(-1, d.size(2)) for d in data] 182 | loc, vel, edge_attr, charges, loc_end = data 183 | 184 | edges = loader.dataset.get_edges(batch_size, n_nodes) 185 | edges = [edges[0].to(device), edges[1].to(device)] 186 | 187 | optimizer.zero_grad() 188 | 189 | if args.time_exp: 190 | torch.cuda.synchronize() 191 | t1 = time.time() 192 | 193 | if args.model == 'gnn': 194 | nodes = torch.cat([loc, vel], dim=1) 195 | loc_pred = model(nodes, edges, edge_attr) 196 | elif args.model == 'egnn': 197 | nodes = torch.ones(loc.size(0), 1).to(device) # all input nodes are set to 1 198 | rows, cols = edges 199 | loc_dist = torch.sum((loc[rows] - loc[cols])**2, 1).unsqueeze(1) # relative distances among locations 200 | vel_attr = get_velocity_attr(loc, vel, rows, cols).detach() 201 | edge_attr = torch.cat([edge_attr, loc_dist, vel_attr], 1).detach() # concatenate all edge properties 202 | loc_pred = model(nodes, loc.detach(), edges, edge_attr) 203 | elif args.model == 'egnn_vel': 204 | nodes = torch.sqrt(torch.sum(vel ** 2, dim=1)).unsqueeze(1).detach() 205 | rows, cols = edges 206 | loc_dist = torch.sum((loc[rows] - loc[cols])**2, 1).unsqueeze(1) # relative distances among locations 207 | edge_attr = torch.cat([edge_attr, loc_dist], 1).detach() # concatenate all edge properties 208 | loc_pred = model(nodes, loc.detach(), edges, vel, edge_attr) 209 | elif args.model == 'rf_vel': 210 | rows, cols = edges 211 | vel_norm = torch.sqrt(torch.sum(vel ** 2, dim=1).unsqueeze(1)).detach() 212 | loc_dist = torch.sum((loc[rows] - loc[cols]) ** 2, 1).unsqueeze(1) 213 | edge_attr = torch.cat([edge_attr, loc_dist], 1).detach() 214 | loc_pred = model(vel_norm, loc.detach(), edges, vel, edge_attr) 215 | elif args.model in ['clof', 'clof_vel', 'clof_vel_gbf']: 216 | nodes = torch.sqrt(torch.sum(vel ** 2, dim=1)).unsqueeze(1).detach() 217 | rows, cols = edges 218 | loc_dist = torch.sum((loc[rows] - loc[cols])**2, 1).unsqueeze(1) # relative distances among locations 219 | edge_attr = torch.cat([edge_attr, loc_dist], 1).detach() # concatenate all edge properties 220 | loc_pred = model(nodes, loc.detach(), edges, vel, edge_attr, n_nodes=n_nodes) 221 | else: 222 | raise Exception("Wrong model") 223 | 224 | if args.time_exp: 225 | torch.cuda.synchronize() 226 | t2 = time.time() 227 | time_exp_dic['time'] += t2 - t1 228 | time_exp_dic['counter'] += 1 229 | 230 | print("Forward average time: %.6f" % (time_exp_dic['time'] / time_exp_dic['counter'])) 231 | logging.info("Forward average time: %.6f" % (time_exp_dic['time'] / time_exp_dic['counter'])) 232 | loss = loss_mse(loc_pred, loc_end) 233 | if backprop: 234 | loss.backward() 235 | optimizer.step() 236 | res['loss'] += loss.item()*batch_size 237 | res['counter'] += batch_size 238 | if batch_idx % args.log_interval == 0 and (args.model == "se3_transformer" or args.model == "tfn"): 239 | print('===> {} Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t'.format(loader.dataset.partition, 240 | epoch, batch_idx * batch_size, len(loader.dataset), 241 | 100. * batch_idx / len(loader), 242 | loss.item())) 243 | logging.info('===> {} Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t'.format(loader.dataset.partition, 244 | epoch, batch_idx * batch_size, len(loader.dataset), 245 | 100. * batch_idx / len(loader), 246 | loss.item())) 247 | if not backprop: 248 | prefix = "==> " 249 | else: 250 | prefix = "" 251 | print('%s epoch %d avg loss: %.5f LR: %.6f' % (prefix+loader.dataset.partition, epoch, res['loss'] / res['counter'], optimizer.param_groups[0]['lr'])) 252 | logging.info('%s epoch %d avg loss: %.5f LR: %.6f' % (prefix+loader.dataset.partition, epoch, res['loss'] / res['counter'], optimizer.param_groups[0]['lr'])) 253 | 254 | return res['loss'] / res['counter'] 255 | 256 | 257 | def main_sweep(): 258 | training_samples = [200, 400, 800, 1600, 3200, 6400, 12800, 25000, 50000] 259 | n_epochs = [200, 200, 200, 200, 500, 500, 600, 600, 600] 260 | 261 | if args.sweep_training == 2: 262 | training_samples = training_samples[0:5] 263 | n_epochs = n_epochs[0:5] 264 | elif args.sweep_training == 3: 265 | training_samples = training_samples[6:] 266 | n_epochs = n_epochs[6:] 267 | elif args.sweep_training == 4: 268 | training_samples = training_samples[8:] 269 | n_epochs = n_epochs[8:] 270 | 271 | 272 | results = {'tr_samples': [], 'test_loss': [], 'best_epochs': []} 273 | for epochs, tr_samples in zip(n_epochs, training_samples): 274 | args.epochs = epochs 275 | args.max_training_samples = tr_samples 276 | args.test_interval = max(int(10000/tr_samples), 1) 277 | best_val_loss, best_test_loss, best_epoch = main() 278 | results['tr_samples'].append(tr_samples) 279 | results['best_epochs'].append(best_epoch) 280 | results['test_loss'].append(best_test_loss) 281 | print("\n####### Results #######") 282 | print(results) 283 | print("Results for %d epochs and %d # training samples \n" % (epochs, tr_samples)) 284 | logging.info("\n####### Results #######") 285 | logging.info(results) 286 | logging.info("Results for %d epochs and %d # training samples \n" % (epochs, tr_samples)) 287 | 288 | 289 | if __name__ == "__main__": 290 | if args.sweep_training: 291 | main_sweep() 292 | else: 293 | main() 294 | 295 | 296 | 297 | 298 | -------------------------------------------------------------------------------- /models/gcl.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Most functions are copied from [EGNN](https://github.com/vgsatorras/egnn). 3 | ''' 4 | 5 | from torch import nn 6 | import torch 7 | 8 | class MLP(nn.Module): 9 | """ a simple 4-layer MLP """ 10 | 11 | def __init__(self, nin, nout, nh): 12 | super().__init__() 13 | self.net = nn.Sequential( 14 | nn.Linear(nin, nh), 15 | nn.LeakyReLU(0.2), 16 | nn.Linear(nh, nh), 17 | nn.LeakyReLU(0.2), 18 | nn.Linear(nh, nh), 19 | nn.LeakyReLU(0.2), 20 | nn.Linear(nh, nout), 21 | ) 22 | 23 | def forward(self, x): 24 | return self.net(x) 25 | 26 | 27 | class GCL_basic(nn.Module): 28 | """Graph Neural Net with global state and fixed number of nodes per graph. 29 | Args: 30 | hidden_dim: Number of hidden units. 31 | num_nodes: Maximum number of nodes (for self-attentive pooling). 32 | global_agg: Global aggregation function ('attn' or 'sum'). 33 | temp: Softmax temperature. 34 | """ 35 | 36 | def __init__(self): 37 | super(GCL_basic, self).__init__() 38 | 39 | 40 | def edge_model(self, source, target, edge_attr): 41 | pass 42 | 43 | def node_model(self, h, edge_index, edge_attr): 44 | pass 45 | 46 | def forward(self, x, edge_index, edge_attr=None): 47 | row, col = edge_index 48 | edge_feat = self.edge_model(x[row], x[col], edge_attr) 49 | x = self.node_model(x, edge_index, edge_feat) 50 | return x, edge_feat 51 | 52 | 53 | class GCL(GCL_basic): 54 | """Graph Neural Net with global state and fixed number of nodes per graph. 55 | Args: 56 | hidden_dim: Number of hidden units. 57 | num_nodes: Maximum number of nodes (for self-attentive pooling). 58 | global_agg: Global aggregation function ('attn' or 'sum'). 59 | temp: Softmax temperature. 60 | """ 61 | 62 | def __init__(self, input_nf, output_nf, hidden_nf, edges_in_nf=0, act_fn=nn.ReLU(), bias=True, attention=False, t_eq=False, recurrent=True): 63 | super(GCL, self).__init__() 64 | self.attention = attention 65 | self.t_eq=t_eq 66 | self.recurrent = recurrent 67 | input_edge_nf = input_nf * 2 68 | self.edge_mlp = nn.Sequential( 69 | nn.Linear(input_edge_nf + edges_in_nf, hidden_nf, bias=bias), 70 | act_fn, 71 | nn.Linear(hidden_nf, hidden_nf, bias=bias), 72 | act_fn) 73 | if self.attention: 74 | self.att_mlp = nn.Sequential( 75 | nn.Linear(input_nf, hidden_nf, bias=bias), 76 | act_fn, 77 | nn.Linear(hidden_nf, 1, bias=bias), 78 | nn.Sigmoid()) 79 | 80 | 81 | self.node_mlp = nn.Sequential( 82 | nn.Linear(hidden_nf + input_nf, hidden_nf, bias=bias), 83 | act_fn, 84 | nn.Linear(hidden_nf, output_nf, bias=bias)) 85 | 86 | #if recurrent: 87 | #self.gru = nn.GRUCell(hidden_nf, hidden_nf) 88 | 89 | 90 | def edge_model(self, source, target, edge_attr): 91 | edge_in = torch.cat([source, target], dim=1) 92 | if edge_attr is not None: 93 | edge_in = torch.cat([edge_in, edge_attr], dim=1) 94 | out = self.edge_mlp(edge_in) 95 | if self.attention: 96 | att = self.att_mlp(torch.abs(source - target)) 97 | out = out * att 98 | return out 99 | 100 | def node_model(self, h, edge_index, edge_attr): 101 | row, col = edge_index 102 | agg = unsorted_segment_sum(edge_attr, row, num_segments=h.size(0)) 103 | out = torch.cat([h, agg], dim=1) 104 | out = self.node_mlp(out) 105 | if self.recurrent: 106 | out = out + h 107 | #out = self.gru(out, h) 108 | return out 109 | 110 | 111 | class GCL_rf(GCL_basic): 112 | """Graph Neural Net with global state and fixed number of nodes per graph. 113 | Args: 114 | hidden_dim: Number of hidden units. 115 | num_nodes: Maximum number of nodes (for self-attentive pooling). 116 | global_agg: Global aggregation function ('attn' or 'sum'). 117 | temp: Softmax temperature. 118 | """ 119 | 120 | def __init__(self, nf=64, edge_attr_nf=0, reg=0, act_fn=nn.LeakyReLU(0.2), clamp=False): 121 | super(GCL_rf, self).__init__() 122 | 123 | self.clamp = clamp 124 | layer = nn.Linear(nf, 1, bias=False) 125 | torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) 126 | self.phi = nn.Sequential(nn.Linear(edge_attr_nf + 1, nf), 127 | act_fn, 128 | layer) 129 | self.reg = reg 130 | 131 | def edge_model(self, source, target, edge_attr): 132 | x_diff = source - target 133 | radial = torch.sqrt(torch.sum(x_diff ** 2, dim=1)).unsqueeze(1) 134 | e_input = torch.cat([radial, edge_attr], dim=1) 135 | e_out = self.phi(e_input) 136 | m_ij = x_diff * e_out 137 | if self.clamp: 138 | m_ij = torch.clamp(m_ij, min=-100, max=100) 139 | return m_ij 140 | 141 | def node_model(self, x, edge_index, edge_attr): 142 | row, col = edge_index 143 | agg = unsorted_segment_mean(edge_attr, row, num_segments=x.size(0)) 144 | x_out = x + agg - x*self.reg 145 | return x_out 146 | 147 | 148 | class GCL_rf_vel(nn.Module): 149 | """Graph Neural Net with global state and fixed number of nodes per graph. 150 | Args: 151 | hidden_dim: Number of hidden units. 152 | num_nodes: Maximum number of nodes (for self-attentive pooling). 153 | global_agg: Global aggregation function ('attn' or 'sum'). 154 | temp: Softmax temperature. 155 | """ 156 | def __init__(self, nf=64, edge_attr_nf=0, act_fn=nn.LeakyReLU(0.2), coords_weight=1.0): 157 | super(GCL_rf_vel, self).__init__() 158 | self.coords_weight = coords_weight 159 | self.coord_mlp_vel = nn.Sequential( 160 | nn.Linear(1, nf), 161 | act_fn, 162 | nn.Linear(nf, 1)) 163 | 164 | layer = nn.Linear(nf, 1, bias=False) 165 | torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) 166 | #layer.weight.uniform_(-0.1, 0.1) 167 | self.phi = nn.Sequential(nn.Linear(1 + edge_attr_nf, nf), 168 | act_fn, 169 | layer, 170 | nn.Tanh()) #we had to add the tanh to keep this method stable 171 | 172 | def forward(self, x, vel_norm, vel, edge_index, edge_attr=None): 173 | row, col = edge_index 174 | edge_m = self.edge_model(x[row], x[col], edge_attr) 175 | x = self.node_model(x, edge_index, edge_m) 176 | x += vel * self.coord_mlp_vel(vel_norm) 177 | return x, edge_attr 178 | 179 | def edge_model(self, source, target, edge_attr): 180 | x_diff = source - target 181 | radial = torch.sqrt(torch.sum(x_diff ** 2, dim=1)).unsqueeze(1) 182 | e_input = torch.cat([radial, edge_attr], dim=1) 183 | e_out = self.phi(e_input) 184 | m_ij = x_diff * e_out 185 | return m_ij 186 | 187 | def node_model(self, x, edge_index, edge_m): 188 | row, col = edge_index 189 | agg = unsorted_segment_mean(edge_m, row, num_segments=x.size(0)) 190 | x_out = x + agg * self.coords_weight 191 | return x_out 192 | 193 | 194 | class E_GCL(nn.Module): 195 | """Graph Neural Net with global state and fixed number of nodes per graph. 196 | Args: 197 | hidden_dim: Number of hidden units. 198 | num_nodes: Maximum number of nodes (for self-attentive pooling). 199 | global_agg: Global aggregation function ('attn' or 'sum'). 200 | temp: Softmax temperature. 201 | """ 202 | 203 | def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_att_dim=0, act_fn=nn.ReLU(), recurrent=True, coords_weight=1.0, attention=False, clamp=False, norm_diff=False, tanh=False, out_basis_dim=1): 204 | super(E_GCL, self).__init__() 205 | input_edge = input_nf * 2 206 | self.coords_weight = coords_weight 207 | self.recurrent = recurrent 208 | self.attention = attention 209 | self.norm_diff = norm_diff 210 | self.tanh = tanh 211 | edge_coords_nf = 1 212 | 213 | 214 | self.edge_mlp = nn.Sequential( 215 | nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf), 216 | act_fn, 217 | nn.Linear(hidden_nf, hidden_nf), 218 | act_fn) 219 | 220 | self.node_mlp = nn.Sequential( 221 | nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf), 222 | act_fn, 223 | nn.Linear(hidden_nf, output_nf)) 224 | 225 | layer = nn.Linear(hidden_nf, out_basis_dim, bias=False) 226 | torch.nn.init.xavier_uniform_(layer.weight, gain=0.001) 227 | 228 | self.clamp = clamp 229 | coord_mlp = [] 230 | coord_mlp.append(nn.Linear(hidden_nf, hidden_nf)) 231 | coord_mlp.append(act_fn) 232 | coord_mlp.append(layer) 233 | if self.tanh: 234 | coord_mlp.append(nn.Tanh()) 235 | self.coords_range = nn.Parameter(torch.ones(1))*3 236 | self.coord_mlp = nn.Sequential(*coord_mlp) 237 | 238 | 239 | if self.attention: 240 | self.att_mlp = nn.Sequential( 241 | nn.Linear(hidden_nf, 1), 242 | nn.Sigmoid()) 243 | 244 | #if recurrent: 245 | # self.gru = nn.GRUCell(hidden_nf, hidden_nf) 246 | 247 | 248 | def edge_model(self, source, target, radial, edge_attr): 249 | if edge_attr is None: # Unused. 250 | out = torch.cat([source, target, radial], dim=1) 251 | else: 252 | out = torch.cat([source, target, radial, edge_attr], dim=1) 253 | out = self.edge_mlp(out) 254 | if self.attention: 255 | att_val = self.att_mlp(out) 256 | out = out * att_val 257 | return out 258 | 259 | def node_model(self, x, edge_index, edge_attr, node_attr): 260 | row, col = edge_index 261 | agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0)) 262 | if node_attr is not None: 263 | agg = torch.cat([x, agg, node_attr], dim=1) 264 | else: 265 | agg = torch.cat([x, agg], dim=1) 266 | out = self.node_mlp(agg) 267 | if self.recurrent: 268 | out = x + out 269 | return out, agg 270 | 271 | def coord_model(self, coord, edge_index, coord_diff, edge_feat): 272 | row, col = edge_index 273 | trans = coord_diff * self.coord_mlp(edge_feat) 274 | trans = torch.clamp(trans, min=-100, max=100) #This is never activated but just in case it case it explosed it may save the train 275 | agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0)) 276 | coord += agg*self.coords_weight 277 | return coord 278 | 279 | 280 | def coord2radial(self, edge_index, coord): 281 | row, col = edge_index 282 | coord_diff = coord[row] - coord[col] 283 | radial = torch.sum((coord_diff)**2, 1).unsqueeze(1) 284 | 285 | if self.norm_diff: 286 | norm = torch.sqrt(radial) + 1 287 | coord_diff = coord_diff/(norm) 288 | 289 | return radial, coord_diff 290 | 291 | def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None): 292 | row, col = edge_index 293 | radial, coord_diff = self.coord2radial(edge_index, coord) 294 | 295 | edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) 296 | coord = self.coord_model(coord, edge_index, coord_diff, edge_feat) 297 | h, agg = self.node_model(h, edge_index, edge_feat, node_attr) 298 | # coord = self.node_coord_model(h, coord) 299 | # x = self.node_model(x, edge_index, x[col], u, batch) # GCN 300 | return h, coord, edge_attr 301 | 302 | 303 | class E_GCL_vel(E_GCL): 304 | """Graph Neural Net with global state and fixed number of nodes per graph. 305 | Args: 306 | hidden_dim: Number of hidden units. 307 | num_nodes: Maximum number of nodes (for self-attentive pooling). 308 | global_agg: Global aggregation function ('attn' or 'sum'). 309 | temp: Softmax temperature. 310 | """ 311 | 312 | 313 | def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_att_dim=0, act_fn=nn.ReLU(), recurrent=True, coords_weight=1.0, attention=False, norm_diff=False, tanh=False): 314 | E_GCL.__init__(self, input_nf, output_nf, hidden_nf, edges_in_d=edges_in_d, nodes_att_dim=nodes_att_dim, act_fn=act_fn, recurrent=recurrent, coords_weight=coords_weight, attention=attention, norm_diff=norm_diff, tanh=tanh) 315 | self.norm_diff = norm_diff 316 | self.coord_mlp_vel = nn.Sequential( 317 | nn.Linear(input_nf, hidden_nf), 318 | act_fn, 319 | nn.Linear(hidden_nf, 1)) 320 | 321 | def forward(self, h, edge_index, coord, vel, edge_attr=None, node_attr=None): 322 | row, col = edge_index 323 | radial, coord_diff = self.coord2radial(edge_index, coord) 324 | 325 | edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) 326 | coord = self.coord_model(coord, edge_index, coord_diff, edge_feat) 327 | 328 | 329 | coord += self.coord_mlp_vel(h) * vel 330 | h, agg = self.node_model(h, edge_index, edge_feat, node_attr) 331 | # coord = self.node_coord_model(h, coord) 332 | # x = self.node_model(x, edge_index, x[col], u, batch) # GCN 333 | return h, coord, edge_attr 334 | 335 | 336 | class Clof_GCL(E_GCL): 337 | """ 338 | Basic message passing module of ClofNet. 339 | """ 340 | def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_att_dim=0, act_fn=nn.ReLU(), recurrent=True, coords_weight=1.0, attention=False, norm_diff=False, tanh=False): 341 | E_GCL.__init__(self, input_nf, output_nf, hidden_nf, edges_in_d=edges_in_d, nodes_att_dim=nodes_att_dim, act_fn=act_fn, recurrent=recurrent, coords_weight=coords_weight, attention=attention, norm_diff=norm_diff, tanh=tanh, out_basis_dim=3) 342 | self.norm_diff = norm_diff 343 | self.coord_mlp_vel = nn.Sequential( 344 | nn.Linear(input_nf, hidden_nf), 345 | act_fn, 346 | nn.Linear(hidden_nf, 1)) 347 | 348 | self.edge_mlp = nn.Sequential( 349 | nn.Linear(input_nf * 2 + 1 + edges_in_d, hidden_nf), 350 | act_fn, 351 | nn.Linear(hidden_nf, hidden_nf), 352 | act_fn, 353 | nn.Linear(hidden_nf, hidden_nf), 354 | act_fn) 355 | self.layer_norm = nn.LayerNorm(hidden_nf) 356 | 357 | def coord2localframe(self, edge_index, coord): 358 | row, col = edge_index 359 | coord_diff = coord[row] - coord[col] 360 | radial = torch.sum((coord_diff)**2, 1).unsqueeze(1) 361 | coord_cross = torch.cross(coord[row], coord[col]) 362 | if self.norm_diff: 363 | norm = torch.sqrt(radial) + 1 364 | coord_diff = coord_diff / norm 365 | cross_norm = ( 366 | torch.sqrt(torch.sum((coord_cross)**2, 1).unsqueeze(1))) + 1 367 | coord_cross = coord_cross / cross_norm 368 | 369 | coord_vertical = torch.cross(coord_diff, coord_cross) 370 | 371 | return radial, coord_diff, coord_cross, coord_vertical 372 | 373 | def coord_model(self, coord, edge_index, coord_diff, coord_cross, coord_vertical, edge_feat): 374 | row, col = edge_index 375 | coff = self.coord_mlp(edge_feat) 376 | trans = coord_diff * coff[:, :1] + coord_cross * coff[:, 1:2] + coord_vertical * coff[:, 2:3] 377 | trans = torch.clamp(trans, min=-100, max=100) #This is never activated but just in case it case it explosed it may save the train 378 | agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0)) 379 | coord += agg*self.coords_weight 380 | return coord 381 | 382 | def forward(self, h, edge_index, coord, vel, edge_attr=None, node_attr=None): 383 | row, col = edge_index 384 | residue = h 385 | # h = self.layer_norm(h) 386 | radial, coord_diff, coord_cross, coord_vertical = self.coord2localframe(edge_index, coord) 387 | edge_feat = self.edge_model(h[row], h[col], radial, edge_attr) 388 | coord = self.coord_model(coord, edge_index, coord_diff, coord_cross, coord_vertical, edge_feat) 389 | 390 | coord += self.coord_mlp_vel(h) * vel 391 | h, agg = self.node_model(h, edge_index, edge_feat, node_attr) 392 | h = residue + h 393 | h = self.layer_norm(h) 394 | return h, coord, edge_attr 395 | 396 | def unsorted_segment_sum(data, segment_ids, num_segments): 397 | """Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.""" 398 | result_shape = (num_segments, data.size(1)) 399 | result = data.new_full(result_shape, 0) # Init empty result tensor. 400 | segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) 401 | result.scatter_add_(0, segment_ids, data) 402 | return result 403 | 404 | 405 | def unsorted_segment_mean(data, segment_ids, num_segments): 406 | result_shape = (num_segments, data.size(1)) 407 | segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) 408 | result = data.new_full(result_shape, 0) # Init empty result tensor. 409 | count = data.new_full(result_shape, 0) 410 | result.scatter_add_(0, segment_ids, data) 411 | count.scatter_add_(0, segment_ids, torch.ones_like(data)) 412 | return result / count.clamp(min=1) 413 | 414 | -------------------------------------------------------------------------------- /newtonian/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mouthful/ClofNet/6bf71ff8cf0165b05e7635b80e7d217bdad57a64/newtonian/__init__.py -------------------------------------------------------------------------------- /newtonian/clof.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import logging 5 | from models.gcl import Clof_GCL 6 | from .layers import GaussianLayer 7 | 8 | class ClofNet(nn.Module): 9 | def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, 10 | coords_weight=1.0, recurrent=True, norm_diff=True, tanh=False, 11 | ): 12 | super(ClofNet, self).__init__() 13 | self.hidden_nf = hidden_nf 14 | self.device = device 15 | self.n_layers = n_layers 16 | self.embedding_node = nn.Linear(in_node_nf, self.hidden_nf) 17 | self.embedding_edge = nn.Sequential(nn.Linear(in_edge_nf, 8), act_fn) 18 | 19 | edge_embed_dim = 10 20 | self.fuse_edge = nn.Sequential( 21 | nn.Linear(edge_embed_dim, self.hidden_nf // 2), act_fn, 22 | nn.Linear(self.hidden_nf // 2, self.hidden_nf // 2), act_fn) 23 | 24 | self.norm_diff = norm_diff 25 | for i in range(0, self.n_layers): 26 | self.add_module( 27 | "gcl_%d" % i, 28 | Clof_GCL( 29 | input_nf=self.hidden_nf, 30 | output_nf=self.hidden_nf, 31 | hidden_nf=self.hidden_nf, 32 | edges_in_d=self.hidden_nf // 2, 33 | act_fn=act_fn, 34 | recurrent=recurrent, 35 | coords_weight=coords_weight, 36 | norm_diff=norm_diff, 37 | tanh=tanh, 38 | ), 39 | ) 40 | self.to(self.device) 41 | self.params = self.__str__() 42 | 43 | def __str__(self): 44 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 45 | params = sum([np.prod(p.size()) for p in model_parameters]) 46 | print('Network Size', params) 47 | logging.info('Network Size {}'.format(params)) 48 | return str(params) 49 | 50 | def coord2localframe(self, edge_index, coord): 51 | row, col = edge_index 52 | coord_diff = coord[row] - coord[col] 53 | radial = torch.sum((coord_diff)**2, 1).unsqueeze(1) 54 | coord_cross = torch.cross(coord[row], coord[col]) 55 | if self.norm_diff: 56 | norm = torch.sqrt(radial) + 1 57 | coord_diff = coord_diff / norm 58 | cross_norm = (torch.sqrt( 59 | torch.sum((coord_cross)**2, 1).unsqueeze(1))) + 1 60 | coord_cross = coord_cross / cross_norm 61 | coord_vertical = torch.cross(coord_diff, coord_cross) 62 | return coord_diff.unsqueeze(1), coord_cross.unsqueeze(1), coord_vertical.unsqueeze(1) 63 | 64 | def scalarization(self, edges, x): 65 | coord_diff, coord_cross, coord_vertical = self.coord2localframe(edges, x) 66 | # Geometric Vectors Scalarization 67 | row, col = edges 68 | edge_basis = torch.cat([coord_diff, coord_cross, coord_vertical], dim=1) 69 | r_i = x[row] 70 | r_j = x[col] 71 | coff_i = torch.matmul(edge_basis, r_i.unsqueeze(-1)).squeeze(-1) 72 | coff_j = torch.matmul(edge_basis, r_j.unsqueeze(-1)).squeeze(-1) 73 | # Calculate angle information in local frames 74 | coff_mul = coff_i * coff_j # [E, 3] 75 | coff_i_norm = coff_i.norm(dim=-1, keepdim=True) + 1e-5 76 | coff_j_norm = coff_j.norm(dim=-1, keepdim=True) + 1e-5 77 | pesudo_cos = coff_mul.sum(dim=-1, keepdim=True) / coff_i_norm / coff_j_norm 78 | pesudo_sin = torch.sqrt(1 - pesudo_cos**2) 79 | pesudo_angle = torch.cat([pesudo_sin, pesudo_cos], dim=-1) 80 | coff_feat = torch.cat([pesudo_angle, coff_i, coff_j], dim=-1) 81 | return coff_feat 82 | 83 | def forward(self, h, x, edges, vel, edge_attr, node_attr=None, n_nodes=5): 84 | h = self.embedding_node(h) 85 | x = x.reshape(-1, n_nodes, 3) 86 | centroid = torch.mean(x, dim=1, keepdim=True) 87 | x_center = (x - centroid).reshape(-1, 3) 88 | coff_feat = self.scalarization(edges, x_center) 89 | edge_feat = torch.cat([edge_attr, coff_feat], dim=-1) 90 | edge_feat = self.fuse_edge(edge_feat) 91 | 92 | for i in range(0, self.n_layers): 93 | h, x_center, _ = self._modules["gcl_%d" % i]( 94 | h, edges, x_center, vel, edge_attr=edge_feat, node_attr=node_attr) 95 | 96 | x = x_center.reshape(-1, n_nodes, 3) + centroid 97 | x = x.reshape(-1, 3) 98 | return x 99 | 100 | 101 | class ClofNet_vel(nn.Module): 102 | def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, 103 | coords_weight=1.0, recurrent=True, norm_diff=True, tanh=False, 104 | ): 105 | super(ClofNet_vel, self).__init__() 106 | self.hidden_nf = hidden_nf 107 | self.device = device 108 | self.n_layers = n_layers 109 | self.embedding_node = nn.Linear(in_node_nf, self.hidden_nf) 110 | 111 | edge_embed_dim = 16 112 | self.fuse_edge = nn.Sequential( 113 | nn.Linear(edge_embed_dim, self.hidden_nf // 2), act_fn, 114 | nn.Linear(self.hidden_nf // 2, self.hidden_nf // 2), act_fn) 115 | 116 | self.norm_diff = True 117 | for i in range(0, self.n_layers): 118 | self.add_module( 119 | "gcl_%d" % i, 120 | Clof_GCL( 121 | input_nf=self.hidden_nf, 122 | output_nf=self.hidden_nf, 123 | hidden_nf=self.hidden_nf, 124 | edges_in_d=self.hidden_nf // 2, 125 | act_fn=act_fn, 126 | recurrent=recurrent, 127 | coords_weight=coords_weight, 128 | norm_diff=norm_diff, 129 | tanh=tanh, 130 | ), 131 | ) 132 | self.to(self.device) 133 | self.params = self.__str__() 134 | 135 | def __str__(self): 136 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 137 | params = sum([np.prod(p.size()) for p in model_parameters]) 138 | print('Network Size', params) 139 | logging.info('Network Size {}'.format(params)) 140 | return str(params) 141 | 142 | def coord2localframe(self, edge_index, coord): 143 | row, col = edge_index 144 | coord_diff = coord[row] - coord[col] 145 | radial = torch.sum((coord_diff)**2, 1).unsqueeze(1) 146 | coord_cross = torch.cross(coord[row], coord[col]) 147 | if self.norm_diff: 148 | norm = torch.sqrt(radial) + 1 149 | coord_diff = coord_diff / norm 150 | cross_norm = (torch.sqrt( 151 | torch.sum((coord_cross)**2, 1).unsqueeze(1))) + 1 152 | coord_cross = coord_cross / cross_norm 153 | coord_vertical = torch.cross(coord_diff, coord_cross) 154 | return coord_diff.unsqueeze(1), coord_cross.unsqueeze(1), coord_vertical.unsqueeze(1) 155 | 156 | def scalarization(self, edges, x, vel): 157 | coord_diff, coord_cross, coord_vertical = self.coord2localframe(edges, x) 158 | # Geometric Vectors Scalarization 159 | row, col = edges 160 | edge_basis = torch.cat([coord_diff, coord_cross, coord_vertical], dim=1) 161 | r_i = x[row] 162 | r_j = x[col] 163 | v_i = vel[row] 164 | v_j = vel[col] 165 | coff_i = torch.matmul(edge_basis, 166 | r_i.unsqueeze(-1)).squeeze(-1) 167 | coff_j = torch.matmul(edge_basis, 168 | r_j.unsqueeze(-1)).squeeze(-1) 169 | vel_i = torch.matmul(edge_basis, 170 | v_i.unsqueeze(-1)).squeeze(-1) 171 | vel_j = torch.matmul(edge_basis, 172 | v_j.unsqueeze(-1)).squeeze(-1) 173 | # Calculate angle information in local frames 174 | coff_mul = coff_i * coff_j # [E, 3] 175 | coff_i_norm = coff_i.norm(dim=-1, keepdim=True) 176 | coff_j_norm = coff_j.norm(dim=-1, keepdim=True) 177 | pesudo_cos = coff_mul.sum( 178 | dim=-1, keepdim=True) / (coff_i_norm + 1e-5) / (coff_j_norm + 1e-5) 179 | pesudo_sin = torch.sqrt(1 - pesudo_cos**2) 180 | pesudo_angle = torch.cat([pesudo_sin, pesudo_cos], dim=-1) 181 | coff_feat = torch.cat([pesudo_angle, coff_i, coff_j, vel_i, vel_j], 182 | dim=-1) #[E, 14] 183 | return coff_feat 184 | 185 | def forward(self, h, x, edges, vel, edge_attr, node_attr=None, n_nodes=5): 186 | h = self.embedding_node(h) 187 | x = x.reshape(-1, n_nodes, 3) 188 | centroid = torch.mean(x, dim=1, keepdim=True) 189 | x_center = (x - centroid).reshape(-1, 3) 190 | 191 | coff_feat = self.scalarization(edges, x_center, vel) 192 | edge_feat = torch.cat([edge_attr, coff_feat], dim=-1) 193 | edge_feat = self.fuse_edge(edge_feat) 194 | 195 | for i in range(0, self.n_layers): 196 | h, x_center, _ = self._modules["gcl_%d" % i]( 197 | h, edges, x_center, vel, edge_attr=edge_feat, node_attr=node_attr) 198 | 199 | x = x_center.reshape(-1, n_nodes, 3) + centroid 200 | x = x.reshape(-1, 3) 201 | return x 202 | 203 | 204 | class ClofNet_vel_gbf(nn.Module): 205 | def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, 206 | coords_weight=1.0, recurrent=True, norm_diff=True, tanh=False, 207 | ): 208 | super(ClofNet_vel_gbf, self).__init__() 209 | self.hidden_nf = hidden_nf 210 | self.device = device 211 | self.n_layers = n_layers 212 | self.embedding_node = nn.Linear(in_node_nf, self.hidden_nf) 213 | self.gbf = GaussianLayer(K=self.hidden_nf // 2, edge_types=8) 214 | edge_embed_dim = 14 215 | self.fuse_edge = nn.Sequential( 216 | nn.Linear(edge_embed_dim, self.hidden_nf // 2), act_fn, 217 | nn.Linear(self.hidden_nf // 2, self.hidden_nf // 2), act_fn) 218 | 219 | self.norm_diff = True 220 | for i in range(0, self.n_layers): 221 | self.add_module( 222 | "gcl_%d" % i, 223 | Clof_GCL( 224 | input_nf=self.hidden_nf, 225 | output_nf=self.hidden_nf, 226 | hidden_nf=self.hidden_nf, 227 | edges_in_d=self.hidden_nf // 2, 228 | act_fn=act_fn, 229 | recurrent=recurrent, 230 | coords_weight=coords_weight, 231 | norm_diff=norm_diff, 232 | tanh=tanh, 233 | ), 234 | ) 235 | self.to(self.device) 236 | self.params = self.__str__() 237 | 238 | def __str__(self): 239 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 240 | params = sum([np.prod(p.size()) for p in model_parameters]) 241 | print('Network Size', params) 242 | logging.info('Network Size {}'.format(params)) 243 | return str(params) 244 | 245 | def coord2localframe(self, edge_index, coord): 246 | row, col = edge_index 247 | coord_diff = coord[row] - coord[col] 248 | radial = torch.sum((coord_diff)**2, 1).unsqueeze(1) 249 | coord_cross = torch.cross(coord[row], coord[col]) 250 | if self.norm_diff: 251 | norm = torch.sqrt(radial) + 1 252 | coord_diff = coord_diff / norm 253 | cross_norm = (torch.sqrt( 254 | torch.sum((coord_cross)**2, 1).unsqueeze(1))) + 1 255 | coord_cross = coord_cross / cross_norm 256 | coord_vertical = torch.cross(coord_diff, coord_cross) 257 | return coord_diff.unsqueeze(1), coord_cross.unsqueeze(1), coord_vertical.unsqueeze(1) 258 | 259 | def embed_edge(self, edge_types, dist): 260 | edge_types = edge_types * 0.5 + 0.5 261 | return self.gbf(dist, edge_types.long()) 262 | 263 | def scalarization(self, edges, x, vel): 264 | coord_diff, coord_cross, coord_vertical = self.coord2localframe(edges, x) 265 | # Geometric Vectors Scalarization 266 | row, col = edges 267 | edge_basis = torch.cat([coord_diff, coord_cross, coord_vertical], dim=1) 268 | r_i = x[row] 269 | r_j = x[col] 270 | v_i = vel[row] 271 | v_j = vel[col] 272 | coff_i = torch.matmul(edge_basis, 273 | r_i.unsqueeze(-1)).squeeze(-1) 274 | coff_j = torch.matmul(edge_basis, 275 | r_j.unsqueeze(-1)).squeeze(-1) 276 | vel_i = torch.matmul(edge_basis, 277 | v_i.unsqueeze(-1)).squeeze(-1) 278 | vel_j = torch.matmul(edge_basis, 279 | v_j.unsqueeze(-1)).squeeze(-1) 280 | # Calculate angle information in local frames 281 | coff_mul = coff_i * coff_j # [E, 3] 282 | coff_i_norm = coff_i.norm(dim=-1, keepdim=True) 283 | coff_j_norm = coff_j.norm(dim=-1, keepdim=True) 284 | pesudo_cos = coff_mul.sum( 285 | dim=-1, keepdim=True) / (coff_i_norm + 1e-5) / (coff_j_norm + 1e-5) 286 | pesudo_sin = torch.sqrt(1 - pesudo_cos**2) 287 | pesudo_angle = torch.cat([pesudo_sin, pesudo_cos], dim=-1) 288 | coff_feat = torch.cat([pesudo_angle, coff_i, coff_j, vel_i, vel_j], 289 | dim=-1) #[E, 14] 290 | return coff_feat 291 | 292 | def forward(self, h, x, edges, vel, edge_attr, node_attr=None, n_nodes=5): 293 | 294 | h = self.embedding_node(h) 295 | x = x.reshape(-1, n_nodes, 3) 296 | centroid = torch.mean(x, dim=1, keepdim=True) 297 | x_center = (x - centroid).reshape(-1, 3) 298 | 299 | coff_feat = self.scalarization(edges, x_center, vel) 300 | # edge_feat = torch.cat([edge_attr, coff_feat], dim=-1) 301 | edge_embed = self.embed_edge(edge_attr[:, 0], edge_attr[:, 1]) 302 | edge_feat = self.fuse_edge(coff_feat) 303 | edge_feat = edge_feat + edge_embed 304 | for i in range(0, self.n_layers): 305 | h, x_center, _ = self._modules["gcl_%d" % i]( 306 | h, edges, x_center, vel, edge_attr=edge_feat, node_attr=node_attr) 307 | 308 | x = x_center.reshape(-1, n_nodes, 3) + centroid 309 | x = x.reshape(-1, 3) 310 | return x 311 | 312 | -------------------------------------------------------------------------------- /newtonian/dataloader.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class Dataloader(): 6 | def __init__(self, dataset, batch_size=1, slice=[0, 1e8], shuffle=True): 7 | self.dataset = dataset 8 | self.batch_size = batch_size 9 | self.n_nodes = self.dataset.get_n_nodes() 10 | self.edges = self.expand_edges(dataset.edges, batch_size, self.n_nodes) 11 | self.idxs_permuted = list(range(len(self.dataset))) 12 | self.shuffle = shuffle 13 | self.slice = slice 14 | if self.shuffle: 15 | random.shuffle(self.idxs_permuted) 16 | self.idx = 0 17 | 18 | def __iter__(self): 19 | return self 20 | 21 | def expand_edges(self, edges, batch_size, n_nodes): 22 | edges = [torch.LongTensor(edges[0]), torch.LongTensor(edges[1])] 23 | if batch_size == 1: 24 | return edges 25 | elif batch_size > 1: 26 | rows, cols = [], [] 27 | for i in range(batch_size): 28 | rows.append(edges[0] + n_nodes*i) 29 | cols.append(edges[1] + n_nodes*i) 30 | edges = [torch.cat(rows), torch.cat(cols)] 31 | return edges 32 | 33 | def __next__(self): 34 | if self.idx > len(self.dataset) - self.batch_size: 35 | self.idx = 0 36 | #random.shuffle(self.dataset.graphs) 37 | raise StopIteration # Done iterating. 38 | else: 39 | loc, vel, edge_attr, charges = self.dataset.data 40 | idx_permuted = self.idxs_permuted[self.idx:self.idx + self.batch_size] 41 | batched_data = loc[idx_permuted], vel[idx_permuted], edge_attr[idx_permuted], charges[idx_permuted] 42 | [loc_batch, vel_batch, edge_attr_batch, loc_end_batch, charges_batch] = self.cast_batch(list(batched_data)) 43 | 44 | self.idx += self.batch_size 45 | return loc_batch, vel_batch, edge_attr_batch, loc_end_batch, charges_batch 46 | 47 | def cast_batch(self, batched_data): 48 | #loc_batch, vel_batch, edges_batch, loc_end_batch = batched_data 49 | #if self.batch_size > 1: 50 | # raise Exception("To implement") 51 | batched_data = [d.contiguous().view(-1, d.size(2)) for d in batched_data] 52 | 53 | return batched_data 54 | #else: 55 | # return loc_batch[0], vel_batch[0], edges_batch[0], loc_end_batch[0] 56 | 57 | def __len__(self): 58 | return len(self.dataset) 59 | 60 | def partition(self): 61 | return self.dataset.partition 62 | 63 | 64 | if __name__ == "__main__": 65 | pass 66 | 67 | 68 | -------------------------------------------------------------------------------- /newtonian/dataset/generate_dataset.py: -------------------------------------------------------------------------------- 1 | from synthetic_sim import ChargedParticlesSim, DynamicSim, GravitySim, SpringSim, FixCharge 2 | import time 3 | import numpy as np 4 | import argparse 5 | from multiprocessing import Pool 6 | from tqdm import tqdm 7 | import os 8 | """ 9 | nbody: python -u generate_dataset.py --num-train 50000 --sample-freq 500 2>&1 | tee log_generating_100000.log & 10 | 11 | nbody_small: python -u generate_dataset.py --num-train 10000 --seed 43 --sufix small 2>&1 | tee log_generating_10000_small.log & 12 | 13 | """ 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--simulation', type=str, default='charged', 17 | help='What simulation to generate.') 18 | parser.add_argument('--num-train', type=int, default=10000, 19 | help='Number of training simulations to generate.') 20 | parser.add_argument('--num-valid', type=int, default=2000, 21 | help='Number of validation simulations to generate.') 22 | parser.add_argument('--num-test', type=int, default=2000, 23 | help='Number of test simulations to generate.') 24 | parser.add_argument('--length', type=int, default=5000, 25 | help='Length of trajectory.') 26 | parser.add_argument('--length_test', type=int, default=5000, 27 | help='Length of test set trajectory.') 28 | parser.add_argument('--sample-freq', type=int, default=100, 29 | help='How often to sample the trajectory.') 30 | parser.add_argument('--n_balls', type=int, default=5, 31 | help='Number of balls in the simulation.') 32 | parser.add_argument('--seed', type=int, default=42, 33 | help='Random seed.') 34 | parser.add_argument('--initial_vel', type=int, default=1, 35 | help='consider initial velocity') 36 | parser.add_argument('--sufix', type=str, default="", 37 | help='add a sufix to the name') 38 | parser.add_argument('--saved_dir', type=str, default="", 39 | help='add a directory to save') 40 | 41 | args = parser.parse_args() 42 | print(args) 43 | 44 | initial_vel_norm = 0.5 45 | if not args.initial_vel: 46 | initial_vel_norm = 1e-16 47 | 48 | if args.simulation == 'springs': 49 | sim = SpringSim(noise_var=0.0, n_balls=args.n_balls) 50 | suffix = '_springs' 51 | elif args.simulation == 'charged': 52 | sim = ChargedParticlesSim(noise_var=0.0, n_balls=args.n_balls, vel_norm=initial_vel_norm) 53 | suffix = '_charged' 54 | elif args.simulation == 'static': 55 | sim = GravitySim(noise_var=0.0, n_balls=args.n_balls, vel_norm=initial_vel_norm) 56 | suffix = '_static' 57 | elif args.simulation == 'dynamic': 58 | sim = DynamicSim(noise_var=0.0, n_balls=args.n_balls, vel_norm=initial_vel_norm) 59 | suffix = '_dynamic' 60 | elif args.simulation == 'fixcharge': 61 | sim = FixCharge(noise_var=0.0, n_balls=args.n_balls, vel_norm=initial_vel_norm) 62 | suffix = '_fixcharge' 63 | else: 64 | raise ValueError('Simulation {} not implemented'.format(args.simulation)) 65 | 66 | suffix += str(args.n_balls) + "_initvel%d" % args.initial_vel + args.sufix 67 | np.random.seed(args.seed) 68 | 69 | print(suffix) 70 | 71 | def generate_dataset(num_sims, length, sample_freq, multiprocess=False, num_workers=10): 72 | loc_all = list() 73 | vel_all = list() 74 | edges_all = list() 75 | charges_all = list() 76 | t = time.time() 77 | def collect(s): 78 | loc_all.append(s[0]) 79 | vel_all.append(s[1]) 80 | edges_all.append(s[2]) 81 | charges_all.append(s[3]) 82 | if len(loc_all) % 100 == 0: 83 | print("Iter: {}, Simulation time: {}".format(len(loc_all), time.time() - t)) 84 | 85 | if multiprocess: 86 | pool = Pool(num_workers) 87 | for i in range(num_sims): 88 | pool.apply_async(sim.sample_trajectory, (np.random.choice(list(range(1000000))), length, sample_freq), callback=collect) 89 | pool.close() 90 | pool.join() 91 | else: 92 | for i in tqdm(range(num_sims)): 93 | res = sim.sample_trajectory(np.random.choice(list(range(1000000))), T=length, sample_freq=sample_freq) 94 | collect(res) 95 | 96 | charges_all = np.stack(charges_all) 97 | loc_all = np.stack(loc_all) 98 | vel_all = np.stack(vel_all) 99 | edges_all = np.stack(edges_all) 100 | 101 | return loc_all, vel_all, edges_all, charges_all 102 | 103 | if __name__ == "__main__": 104 | multiprocess = True 105 | args.saved_dir = os.path.join(args.saved_dir, args.sufix) 106 | os.makedirs(args.saved_dir, exist_ok=True) 107 | if not args.saved_dir[-1] == '/': 108 | args.saved_dir = args.saved_dir + '/' 109 | 110 | print("Generating {} validation simulations".format(args.num_valid)) 111 | loc_valid, vel_valid, edges_valid, charges_valid = generate_dataset(args.num_valid, 112 | args.length, 113 | args.sample_freq, 114 | multiprocess=multiprocess) 115 | np.save(args.saved_dir + 'loc_valid' + suffix + '.npy', loc_valid) 116 | np.save(args.saved_dir + 'vel_valid' + suffix + '.npy', vel_valid) 117 | np.save(args.saved_dir + 'edges_valid' + suffix + '.npy', edges_valid) 118 | np.save(args.saved_dir + 'charges_valid' + suffix + '.npy', charges_valid) 119 | 120 | print("Generating {} test simulations".format(args.num_test)) 121 | loc_test, vel_test, edges_test, charges_test = generate_dataset(args.num_test, 122 | args.length_test, 123 | args.sample_freq, 124 | multiprocess=multiprocess) 125 | np.save(args.saved_dir + 'loc_test' + suffix + '.npy', loc_test) 126 | np.save(args.saved_dir + 'vel_test' + suffix + '.npy', vel_test) 127 | np.save(args.saved_dir + 'edges_test' + suffix + '.npy', edges_test) 128 | np.save(args.saved_dir + 'charges_test' + suffix + '.npy', charges_test) 129 | 130 | 131 | print("Generating {} training simulations".format(args.num_train)) 132 | loc_train, vel_train, edges_train, charges_train = generate_dataset(args.num_train, 133 | args.length, 134 | args.sample_freq, 135 | multiprocess=multiprocess) 136 | 137 | np.save(args.saved_dir + 'loc_train' + suffix + '.npy', loc_train) 138 | np.save(args.saved_dir + 'vel_train' + suffix + '.npy', vel_train) 139 | np.save(args.saved_dir + 'edges_train' + suffix + '.npy', edges_train) 140 | np.save(args.saved_dir + 'charges_train' + suffix + '.npy', charges_train) 141 | 142 | 143 | -------------------------------------------------------------------------------- /newtonian/dataset/script.sh: -------------------------------------------------------------------------------- 1 | saved_dir="data" # the path of output directory 2 | 3 | ''' 4 | suffix: the suffix of generated dataset (e.g., small_20body) 5 | simulation: the mode of force field 6 | 1. charged: electronstatic system (ES) 7 | 2. static: gravity + electronstatic force (G+ES) 8 | 3. dynamic: lorentz + electronstatic force (L+ES) 9 | n_balls: system size 10 | num_train: number of training trajectories 11 | other parameters: see generate_dataset.py 12 | ''' 13 | suffix=small 14 | simulation=charged 15 | n_balls=5 16 | python -u generate_dataset.py --num-train 3000 --seed 43 --sufix ${suffix} --saved_dir ${saved_dir} --simulation ${simulation} --n_balls ${n_balls} 17 | 18 | suffix=small_20body 19 | simulation=charged 20 | n_balls=20 21 | python -u generate_dataset.py --num-train 3000 --seed 43 --sufix ${suffix} --saved_dir ${saved_dir} --simulation ${simulation} --n_balls ${n_balls} 22 | 23 | suffix=static_20body 24 | simulation=static 25 | python -u generate_dataset.py --num-train 3000 --seed 43 --sufix ${suffix} --saved_dir ${saved_dir} --simulation ${simulation} --n_balls ${n_balls} 26 | 27 | suffix=dynamic_20body 28 | simulation=dynamic 29 | python -u generate_dataset.py --num-train 3000 --seed 43 --sufix ${suffix} --saved_dir ${saved_dir} --simulation ${simulation} --n_balls ${n_balls} -------------------------------------------------------------------------------- /newtonian/dataset4newton.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import os 5 | 6 | class NBodyDataset(): 7 | """ 8 | NBodyDataset: 9 | { 10 | small: ES 11 | static: G+ES 12 | dynamic: L+ES 13 | } 14 | """ 15 | def __init__(self, partition='train', max_samples=1e8, data_root=None, data_mode='small'): 16 | self.partition = partition 17 | self.sufix = partition 18 | if data_mode == 'small': 19 | self.sufix += "_charged5_initvel1small" 20 | elif (data_mode == 'static') or (data_mode == 'dynamic'): 21 | self.sufix += f"_{data_mode}5_initvel1{data_mode}" 22 | elif data_mode == "small_20body": 23 | self.sufix += f"_charged20_initvel1{data_mode}" 24 | else: 25 | self.sufix += f"_{data_mode[:-7]}20_initvel1{data_mode}" 26 | 27 | self.data_root = data_root 28 | self.max_samples = int(max_samples) 29 | self.data, self.edges = self.load() 30 | self.frame_0 = 30 31 | self.frame_T = 40 32 | 33 | def load(self): 34 | loc = np.load(os.path.join(self.data_root, 'loc_' + self.sufix + '.npy')) 35 | vel = np.load(os.path.join(self.data_root, 'vel_' + self.sufix + '.npy')) 36 | edges = np.load(os.path.join(self.data_root, 'edges_' + self.sufix + '.npy')) 37 | charges = np.load(os.path.join(self.data_root, 'charges_' + self.sufix + '.npy')) 38 | 39 | loc, vel, edge_attr, edges, charges = self.preprocess(loc, vel, edges, charges) 40 | return (loc, vel, edge_attr, charges), edges 41 | 42 | 43 | def preprocess(self, loc, vel, edges, charges): 44 | # cast to torch and swap n_nodes <--> n_features dimensions 45 | loc, vel = torch.Tensor(loc).transpose(2, 3), torch.Tensor(vel).transpose(2, 3) 46 | n_nodes = loc.size(2) 47 | loc = loc[0:self.max_samples, :, :, :] # limit number of samples 48 | vel = vel[0:self.max_samples, :, :, :] # speed when starting the trajectory 49 | charges = charges[0:self.max_samples] 50 | edge_attr = [] 51 | 52 | #Initialize edges and edge_attributes 53 | rows, cols = [], [] 54 | for i in range(n_nodes): 55 | for j in range(n_nodes): 56 | if i != j: 57 | edge_attr.append(edges[:, i, j]) 58 | rows.append(i) 59 | cols.append(j) 60 | edges = [rows, cols] 61 | edge_attr = torch.Tensor(edge_attr).transpose(0, 1).unsqueeze(2) # swap n_nodes <--> batch_size and add nf dimension 62 | 63 | return torch.Tensor(loc), torch.Tensor(vel), torch.Tensor(edge_attr), edges, torch.Tensor(charges) 64 | 65 | def set_max_samples(self, max_samples): 66 | self.max_samples = int(max_samples) 67 | self.data, self.edges = self.load() 68 | 69 | def get_n_nodes(self): 70 | return self.data[0].size(1) 71 | 72 | def __getitem__(self, i): 73 | loc, vel, edge_attr, charges = self.data 74 | loc, vel, edge_attr, charges = loc[i], vel[i], edge_attr[i], charges[i] 75 | return loc[self.frame_0], vel[self.frame_0], edge_attr, charges, loc[self.frame_T] 76 | 77 | def __len__(self): 78 | return len(self.data[0]) 79 | 80 | def get_edges(self, batch_size, n_nodes): 81 | edges = [torch.LongTensor(self.edges[0]), torch.LongTensor(self.edges[1])] 82 | if batch_size == 1: 83 | return edges 84 | elif batch_size > 1: 85 | rows, cols = [], [] 86 | for i in range(batch_size): 87 | rows.append(edges[0] + n_nodes * i) 88 | cols.append(edges[1] + n_nodes * i) 89 | edges = [torch.cat(rows), torch.cat(cols)] 90 | return edges 91 | 92 | 93 | if __name__ == "__main__": 94 | NBodyDataset() -------------------------------------------------------------------------------- /newtonian/egnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from models.gcl import E_GCL, E_GCL_vel 4 | import numpy as np 5 | import logging 6 | 7 | 8 | class EGNN(nn.Module): 9 | def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.LeakyReLU(0.2), n_layers=4, coords_weight=1.0): 10 | super(EGNN, self).__init__() 11 | self.hidden_nf = hidden_nf 12 | self.device = device 13 | self.n_layers = n_layers 14 | self.embedding = nn.Linear(in_node_nf, self.hidden_nf) 15 | for i in range(0, n_layers): 16 | self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, recurrent=True, coords_weight=coords_weight)) 17 | self.to(self.device) 18 | self.params = self.__str__() 19 | 20 | def __str__(self): 21 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 22 | params = sum([np.prod(p.size()) for p in model_parameters]) 23 | print('Network Size', params) 24 | logging.info('Network Size {}'.format(params)) 25 | return str(params) 26 | 27 | def forward(self, h, x, edges, edge_attr, vel=None): 28 | h = self.embedding(h) 29 | for i in range(0, self.n_layers): 30 | h, x, _ = self._modules["gcl_%d" % i](h, edges, x, edge_attr=edge_attr) 31 | return x 32 | 33 | 34 | class EGNN_vel(nn.Module): 35 | def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, coords_weight=1.0, recurrent=False, norm_diff=False, tanh=False): 36 | super(EGNN_vel, self).__init__() 37 | self.hidden_nf = hidden_nf 38 | self.device = device 39 | self.n_layers = n_layers 40 | self.embedding = nn.Linear(in_node_nf, self.hidden_nf) 41 | for i in range(0, n_layers): 42 | self.add_module("gcl_%d" % i, E_GCL_vel(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, coords_weight=coords_weight, recurrent=recurrent, norm_diff=norm_diff, tanh=tanh)) 43 | self.to(self.device) 44 | self.params = self.__str__() 45 | 46 | def __str__(self): 47 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 48 | params = sum([np.prod(p.size()) for p in model_parameters]) 49 | print('Network Size', params) 50 | logging.info('Network Size {}'.format(params)) 51 | return str(params) 52 | 53 | def forward(self, h, x, edges, vel, edge_attr): 54 | h = self.embedding(h) 55 | for i in range(0, self.n_layers): 56 | h, x, _ = self._modules["gcl_%d" % i](h, edges, x, vel, edge_attr=edge_attr) 57 | return x -------------------------------------------------------------------------------- /newtonian/gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from models.gcl import GCL, GCL_rf_vel 4 | import numpy as np 5 | import logging 6 | 7 | class GNN(nn.Module): 8 | def __init__(self, input_dim, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, attention=0, recurrent=False): 9 | super(GNN, self).__init__() 10 | self.hidden_nf = hidden_nf 11 | self.device = device 12 | self.n_layers = n_layers 13 | for i in range(0, n_layers): 14 | self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_nf=1, act_fn=act_fn, attention=attention, recurrent=recurrent)) 15 | 16 | self.decoder = nn.Sequential(nn.Linear(hidden_nf, hidden_nf), 17 | act_fn, 18 | nn.Linear(hidden_nf, 3)) 19 | self.embedding = nn.Sequential(nn.Linear(input_dim, hidden_nf)) 20 | self.to(self.device) 21 | self.params = self.__str__() 22 | 23 | def __str__(self): 24 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 25 | params = sum([np.prod(p.size()) for p in model_parameters]) 26 | print('Network Size', params) 27 | logging.info('Network Size {}'.format(params)) 28 | return str(params) 29 | 30 | def forward(self, nodes, edges, edge_attr=None): 31 | h = self.embedding(nodes) 32 | for i in range(0, self.n_layers): 33 | h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr) 34 | return self.decoder(h) 35 | 36 | 37 | def get_velocity_attr(loc, vel, rows, cols): 38 | diff = loc[cols] - loc[rows] 39 | norm = torch.norm(diff, p=2, dim=1).unsqueeze(1) 40 | u = diff/norm 41 | va, vb = vel[rows] * u, vel[cols] * u 42 | va, vb = torch.sum(va, dim=1).unsqueeze(1), torch.sum(vb, dim=1).unsqueeze(1) 43 | return va 44 | 45 | 46 | 47 | class RF_vel(nn.Module): 48 | def __init__(self, hidden_nf, edge_attr_nf=0, device='cpu', act_fn=nn.SiLU(), n_layers=4): 49 | super(RF_vel, self).__init__() 50 | self.hidden_nf = hidden_nf 51 | self.device = device 52 | self.n_layers = n_layers 53 | for i in range(0, n_layers): 54 | self.add_module("gcl_%d" % i, GCL_rf_vel(nf=hidden_nf, edge_attr_nf=edge_attr_nf, act_fn=act_fn)) 55 | self.to(self.device) 56 | 57 | self.params = self.__str__() 58 | 59 | def __str__(self): 60 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 61 | params = sum([np.prod(p.size()) for p in model_parameters]) 62 | print('Network Size', params) 63 | logging.info('Network Size {}'.format(params)) 64 | return str(params) 65 | 66 | def forward(self, vel_norm, x, edges, vel, edge_attr): 67 | for i in range(0, self.n_layers): 68 | x, _ = self._modules["gcl_%d" % i](x, vel_norm, vel, edges, edge_attr) 69 | return x -------------------------------------------------------------------------------- /newtonian/layers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import pickle 4 | import numpy as np 5 | import math 6 | import torch.nn.functional as F 7 | 8 | def gaussian(x, mean, std): 9 | pi = 3.14159 10 | a = (2*pi) ** 0.5 11 | return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std) 12 | 13 | class GaussianLayer(nn.Module): 14 | def __init__(self, K=128, edge_types=1024): 15 | super().__init__() 16 | self.K = K 17 | self.means = nn.Embedding(1, K) 18 | self.stds = nn.Embedding(1, K) 19 | self.mul = nn.Embedding(edge_types, 1) 20 | self.bias = nn.Embedding(edge_types, 1) 21 | nn.init.uniform_(self.means.weight, 0, 3) 22 | nn.init.uniform_(self.stds.weight, 0, 3) 23 | nn.init.constant_(self.bias.weight, 0) 24 | nn.init.constant_(self.mul.weight, 1) 25 | 26 | def forward(self, x, edge_types): 27 | mul = self.mul(edge_types) 28 | bias = self.bias(edge_types) 29 | x = mul * x.unsqueeze(-1) + bias 30 | x = x.expand(-1, self.K) 31 | mean = self.means.weight.float().view(-1) 32 | std = self.stds.weight.float().view(-1).abs() + 1e-5 33 | return gaussian(x.float(), mean, std).type_as(self.means.weight) 34 | 35 | class E_GCL_vel(nn.Module): 36 | """Graph Neural Net with global state and fixed number of nodes per graph. 37 | Args: 38 | hidden_dim: Number of hidden units. 39 | num_nodes: Maximum number of nodes (for self-attentive pooling). 40 | global_agg: Global aggregation function ('attn' or 'sum'). 41 | temp: Softmax temperature. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | input_node_nf, 47 | input_edge_nf, 48 | output_nf, 49 | hidden_nf, 50 | nodes_att_dim=0, 51 | act_fn=nn.Softplus(), 52 | recurrent=True, 53 | coords_weight=1.0, 54 | attention=False, 55 | norm_diff=False, 56 | tanh=False, 57 | nhead=3, 58 | n_points=5, 59 | ): 60 | super(E_GCL_vel, self).__init__() 61 | self.coords_weight = coords_weight 62 | self.recurrent = recurrent 63 | self.attention = attention 64 | self.norm_diff = norm_diff 65 | self.tanh = tanh 66 | self.norm_diff = norm_diff 67 | self.nhead = nhead 68 | self.n_points = n_points 69 | assert (hidden_nf % nhead) == 0 70 | 71 | self.edge_mlp = nn.Sequential( 72 | nn.Linear(input_node_nf * 2 + input_edge_nf, hidden_nf), 73 | act_fn, 74 | nn.Linear(hidden_nf, hidden_nf), 75 | act_fn, 76 | ) 77 | 78 | self.edge_mlp_enhance = nn.Sequential( 79 | nn.Linear(hidden_nf * 2 + hidden_nf, hidden_nf), 80 | act_fn, 81 | nn.Linear(hidden_nf, hidden_nf), 82 | ) 83 | 84 | self.node_mlp = nn.Sequential( 85 | nn.Linear(input_node_nf + hidden_nf + nodes_att_dim, hidden_nf), 86 | act_fn, 87 | nn.Linear(hidden_nf, output_nf), 88 | ) 89 | 90 | self.coord_mlp_inner = nn.Sequential( 91 | nn.Linear(hidden_nf, hidden_nf), 92 | act_fn, 93 | nn.Linear(hidden_nf, hidden_nf), 94 | act_fn, 95 | nn.Linear(hidden_nf, 3), 96 | ) 97 | 98 | if self.attention: 99 | self.att_mlp = nn.Sequential(nn.Linear(hidden_nf, 1), nn.Sigmoid()) 100 | 101 | self.fc_query = nn.Linear(hidden_nf, hidden_nf) 102 | self.fc_key = nn.Linear(2 * hidden_nf, hidden_nf) 103 | self.fc_value = nn.Linear(hidden_nf, hidden_nf) 104 | self.norm1 = nn.LayerNorm(hidden_nf) 105 | self.norm2 = nn.LayerNorm(hidden_nf) 106 | self.constant = 1 107 | 108 | def edge_model(self, source, target, edge_attr): 109 | if edge_attr is None: 110 | out = torch.cat([source, target], dim=1) 111 | else: 112 | out = torch.cat([source, target, edge_attr], dim=1) 113 | out = self.edge_mlp(out) 114 | if self.attention: 115 | att_val = self.att_mlp(out) 116 | out = out * att_val 117 | return out 118 | 119 | def edge_model_enhance(self, source, target, edge_attr): 120 | if edge_attr is None: 121 | out = torch.cat([source, target], dim=1) 122 | else: 123 | out = torch.cat([source, target, edge_attr], dim=1) 124 | out = self.edge_mlp_enhance(out) 125 | return out 126 | 127 | def node_model(self, x, edge_index, edge_attr, node_attr): 128 | row, col = edge_index 129 | agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0)) 130 | agg = self.norm1(agg) 131 | if node_attr is not None: 132 | agg = torch.cat([x, agg, node_attr], dim=1) 133 | else: 134 | agg = torch.cat([x, agg], dim=1) 135 | out = self.node_mlp(agg) 136 | if self.recurrent: 137 | out = x + out 138 | out = self.norm2(out) 139 | 140 | return out, agg 141 | 142 | def coord2radial(self, edge_index, coord): 143 | row, col = edge_index 144 | coord_diff = coord[row] - coord[col] 145 | radial = torch.sum((coord_diff) ** 2, 1).unsqueeze(1) 146 | coord_cross = torch.cross(coord[row], coord[col]) 147 | if self.norm_diff: 148 | norm = torch.sqrt(radial) + self.constant 149 | coord_diff = coord_diff / norm 150 | cross_norm = ( 151 | torch.sqrt(torch.sum((coord_cross) ** 2, 1).unsqueeze(1)) + self.constant 152 | ) 153 | coord_cross = coord_cross / cross_norm 154 | 155 | coord_vertical = torch.cross(coord_diff, coord_cross) 156 | 157 | return coord_diff, coord_cross, coord_vertical 158 | 159 | 160 | def acc_model_inner( 161 | self, coord, edge_index, coord_diff, coord_cross, coord_vertical, edge_feat 162 | ): 163 | """ 164 | inner force field 165 | """ 166 | row, col = edge_index 167 | basis_coff = self.coord_mlp_inner(edge_feat) 168 | trans = ( 169 | coord_diff * basis_coff[:, :1] 170 | + coord_cross * basis_coff[:, 1:2] 171 | + coord_vertical * basis_coff[:, 2:3] 172 | ) 173 | trans = torch.clamp(trans, min=-100, max=100) 174 | agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0)) 175 | acc = agg * self.coords_weight 176 | 177 | return acc 178 | 179 | def transpose_for_scores(self, x): 180 | """ 181 | x has shape (B, N, C) 182 | return shape (B, nhead, N, C/nhead) 183 | """ 184 | new_shape = x.shape[:-1] + (self.nhead, -1) 185 | x = x.view(*new_shape) 186 | return x.transpose(-3, -2) 187 | 188 | def forward( 189 | self, 190 | h, 191 | edge_index, 192 | coord, 193 | coord_pre, 194 | vel, 195 | edge_attr=None, 196 | node_attr=None, 197 | short_cut=False, 198 | ): 199 | # Message enhancement 200 | row, col = edge_index 201 | edge_feat = self.edge_model(h[row], h[col], edge_attr) 202 | 203 | # Attention-based message passing, inspired by SE3-Fold 204 | B, N, C = int(h.shape[0]//self.n_points), self.n_points, edge_feat.shape[-1] 205 | query = h.reshape(B, N, C) # [B, N, C] 206 | m = edge_feat.reshape(B, N, -1, C) # [B, N, N-1, C] 207 | h_m = torch.cat([query.unsqueeze(2).repeat(1, 1, N - 1, 1), m], dim=-1).reshape( 208 | B, N * (N - 1), -1 209 | ) # [B, N*N-1, 2C] 210 | 211 | query = self.transpose_for_scores( 212 | self.fc_query(query) 213 | ) # [B, N, C] -> [B, A, N, C/A] 214 | key = self.transpose_for_scores( 215 | self.fc_key(h_m) 216 | ) # [B, N*(N-1), C] -> [B, A, N*(N-1), C/A] 217 | value = self.transpose_for_scores( 218 | self.fc_value(edge_feat.reshape(B, N * (N - 1), C)) 219 | ) # [B, N*(N-1), C] -> [B, A, N*(N-1), C/A] 220 | 221 | key = key.reshape(B, self.nhead, N, N - 1, -1) # [B, A, N, (N-1), C/A] 222 | attention_scores = torch.matmul( 223 | query.unsqueeze(-2), key.transpose(-1, -2) 224 | ).squeeze( 225 | -2 226 | ) # [B, A, N, N-1] 227 | attention_scores = attention_scores / math.sqrt(C / self.nhead) 228 | attention_weights = F.softmax(attention_scores, dim=-1) # [B, A, N, N-1] 229 | m_update = ( 230 | attention_weights.reshape(B, self.nhead, -1).unsqueeze(-1) * value 231 | ) # [B, A, N*(N-1), C/A] 232 | att_edge_feat = m_update.transpose(-3, -2) # [B, A, N*(N-1), C/A] 233 | att_edge_feat = att_edge_feat.reshape( 234 | *att_edge_feat.shape[:-2], -1 235 | ) # [B, N*(N-1), C] 236 | att_edge_feat = att_edge_feat.reshape(-1, C) # [B*N*(N-1), C] 237 | h, agg = self.node_model(h, edge_index, att_edge_feat, node_attr) 238 | 239 | # Equivariant message aggregation 240 | edge_feat = edge_feat + self.edge_model_enhance(h[row], h[col], edge_attr) 241 | coord_diff, coord_cross, coord_vertical = self.coord2radial( 242 | edge_index, coord_pre 243 | ) 244 | acc1 = self.acc_model_inner( 245 | coord, edge_index, coord_diff, coord_cross, coord_vertical, edge_feat 246 | ) 247 | 248 | if short_cut: 249 | ACC = coord + acc1 250 | else: 251 | ACC = acc1 252 | 253 | return ACC, h, edge_feat 254 | 255 | 256 | def unsorted_segment_sum(data, segment_ids, num_segments): 257 | """Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`.""" 258 | result_shape = (num_segments, data.size(1)) 259 | result = data.new_full(result_shape, 0) # Init empty result tensor. 260 | segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) 261 | result.scatter_add_(0, segment_ids, data) 262 | return result 263 | 264 | 265 | def unsorted_segment_mean(data, segment_ids, num_segments): 266 | result_shape = (num_segments, data.size(1)) 267 | segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) 268 | result = data.new_full(result_shape, 0) # Init empty result tensor. 269 | count = data.new_full(result_shape, 0) 270 | result.scatter_add_(0, segment_ids, data) 271 | count.scatter_add_(0, segment_ids, torch.ones_like(data)) 272 | return result / count.clamp(min=1) 273 | --------------------------------------------------------------------------------