├── .gitignore ├── LICENSE ├── README.md ├── assets ├── res_plot.png └── res_plot_unrot.png ├── atom_info ├── README.md ├── crystal.json └── qm9.json ├── configs ├── README.md ├── md.yml ├── pbc.yml └── qm9.yml ├── datasets ├── __init__.py ├── _base.py ├── density.py └── small_density.py ├── generate_dataset.py ├── inference.ipynb ├── main.py ├── models ├── __init__.py ├── _base.py ├── infgcn.py ├── orbital.py └── utils.py ├── requirements.txt ├── utils.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | data/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Chaoran Cheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # InfGCN for Electron Density Estimation 2 | 3 | By Chaoran Cheng, Oct 1, 2023 4 | 5 | [OpenReview](https://openreview.net/forum?id=EjiA3uWpnc), [ArXiv](https://arxiv.org/abs/2311.10908) 6 | 7 | Official implementation of the NeurIPS 23 spotlight paper *Equivariant Neural Operator Learning with Graphon 8 | Convolution* for 9 | modeling operators on continuous data. 10 | 11 | UPDATE: The pretrained model is available [here](https://uofi.box.com/s/8nfosxts1i8g643f8etdqqbtqlg5clhk). 12 | 13 | ## Requirements 14 | 15 | All codes are run with Python 3.9.15 and CUDA 11.6. Similar environment should also work, as this project does not rely 16 | on some rapidly changing packages. Other required packages are listed in `requirements.txt`. 17 | 18 | ## Datasets 19 | 20 | ### QM9 21 | 22 | The QM9 dataset contains 133885 small molecules consisting of C, H, O, N, and F. The QM9 electron density dataset was 23 | built by Jørgensen et al. ([paper](https://www.nature.com/articles/s41524-022-00863-y)) and was publicly available 24 | via [Figshare](https://data.dtu.dk/articles/dataset/QM9_Charge_Densities_and_Energies_Calculated_with_VASP/16794500). 25 | Each tarball needs to be extracted, but the inner lz4 compression should be kept. We provided code to read the 26 | compressed lz4 file. 27 | 28 | ### Cubic 29 | 30 | The Cubic dataset contains electron charge density for 16421 (after filtering) cubic crystal system cells. The dataset 31 | was built by Wang et al. ([paper](https://www.nature.com/articles/s41597-022-01158-z)) and was publicly available 32 | via [Figshare](https://springernature.figshare.com/collections/Large_scale_dataset_of_real_space_electronic_charge_density_of_cubic_inorganic_materials_from_density_functional_theory_DFT_calculations/5368343). 33 | Each tarball needs to be extracted, but the inner xz compression should be kept. We provided code to read the compressed 34 | xz file. 35 | 36 | **WARNING:** A considerable proportion of the samples uses the rhombohedral lattice system (i.e., primitive rhomhedral 37 | cell instead of unit cubic cell). Some visualization tools (including `plotly`) may not be able to handle this. 38 | 39 | ### MD 40 | 41 | The MD dataset contains 6 small molecules (ethanol, benzene, phenol, resorcinol, ethane, malonaldehyde) with different 42 | geometries sampled from molecular dynamics (MD). The dataset was curated 43 | from [here](https://www.nature.com/articles/s41467-020-19093-1) by Bogojeski et al. 44 | and [here](https://arxiv.org/abs/1609.02815) by Brockherde et al. The dataset is publicly available at the Quantum 45 | Machine [website](http://www.quantum-machine.org/datasets/). 46 | 47 | We assume the data is stored in the `//_/` directory, where `mol_name` should be 48 | one of the molecules mentioned above and split should be either `train` or `test`. The directory should contain the 49 | following files: 50 | 51 | - `structures.npy` contains the coordinates of the atoms. 52 | - `dft_densities.npy` contains the voxelized electron charge density data. 53 | 54 | This is the format for the latter four molecules (you can safely ignore other files). For the former two 55 | molecules, run `python generate_dataset.py` to generate the correctly formatted data. You can also specify the data 56 | directory with `--root` and the output directory with `--out`. 57 | 58 | All MD datasets assume a cubic box with side length of 20 Bohr and 50 grids per side. The densities are store as Fourier 59 | coefficients, and we provided code to convert them. 60 | 61 | ## Running the code 62 | 63 | Most hyperparameters are specified in the config files. More parameters in the YAML file is self-explanatory. 64 | See [this readme](configs/README.md) for more details on modifying the config files. Free feel to modify the config 65 | files to suit your needs or to add new models. The pretrained model together with a sample electron density file is 66 | available [here](https://uofi.box.com/s/8nfosxts1i8g643f8etdqqbtqlg5clhk). 67 | 68 | ### Training 69 | 70 | To train the model, run 71 | 72 | ```bash 73 | python main.py configs/qm9.yml --savename test 74 | ``` 75 | 76 | ### Evaluation 77 | 78 | To evaluate the model, run 79 | 80 | ```bash 81 | python main.py configs/qm9.yml --savename test --mode inf --resume 82 | ``` 83 | 84 | ### Inference 85 | 86 | To see the visualization of the predicted density, run [inference.ipynb](inference.ipynb) with JupyterLab or Jupyter 87 | Notebook. 88 | 89 | ### Extending to other models 90 | 91 | To utilize the code for other (GNN-based) models, you need to register the model class in using 92 | the `models.register_model` decorator. Your model's `forward` function should take same arguments as our InfGCN, but the 93 | initialization arguments can be different (see the [instructions](configs/README.md) on modifying the config file). 94 | 95 | ## Result 96 | 97 | The below figures demonstrate the normalized mean absolute error (NMAE) vs the model size of our model and all the 98 | baseline model on the QM9 dataset. Here, `s0` to `s6` refer to the maximum degree of spherical harmonics used in the 99 | model (InfGCN is `s7`). `no-res` refers to the model without residual connection and `fc` refers to the model without 100 | fully-connected tensor product. The pink points are interpolation GNNs and oranges points are neural operators. 101 | 102 |
103 | QM9 Rotated 104 | QM9 Unrotated 105 |
106 | 107 | ## Citation 108 | 109 | If you find this code useful, please cite our paper 110 | 111 | ```bibtex 112 | @InProceedings{Cheng2023infgcn, 113 | title={Equivariant Neural Operator Learning with Graphon Convolution}, 114 | author={Chaoran Cheng and Jian Peng}, 115 | booktitle={Advances in Neural Information Processing Systems 37: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, December 10-16, 2023}, 116 | month={December}, 117 | year={2023}, 118 | } 119 | ``` 120 | -------------------------------------------------------------------------------- /assets/res_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccr-cheng/InfGCN-pytorch/9aa4f8d4e4b4c21f30bf6378f5b2892a37a6feeb/assets/res_plot.png -------------------------------------------------------------------------------- /assets/res_plot_unrot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ccr-cheng/InfGCN-pytorch/9aa4f8d4e4b4c21f30bf6378f5b2892a37a6feeb/assets/res_plot_unrot.png -------------------------------------------------------------------------------- /atom_info/README.md: -------------------------------------------------------------------------------- 1 | # Getting Atom Information 2 | 3 | By Chaoran Cheng, Oct 1, 2023 4 | 5 | This folder contains the atom information for the datasets. 6 | 7 | Some baselines like CNN and GKN need atom information for building the initial input feature function. To let these models capture the different atom types, we applied atom-specific Gaussian parameters. The atom information are store in the JSON file with fields of `name` (chemical symbol), `atom_num` (atomic number) and `radius` (covalent radius). The information can be obtained with `RDKit` package with the following example code: 8 | 9 | ```python 10 | from rdkit import Chem 11 | 12 | pt = Chem.GetPeriodicTable() 13 | atoms = ['C', 'H', 'N', 'O', 'F'] 14 | atom_info = [ 15 | { 16 | 'name': a, 17 | 'atom_num': pt.GetAtomicNumber(a), 18 | 'radius': round(pt.GetRcovalent(a) / 0.529177, 5) # convert to Bohr 19 | } for a in atoms 20 | ] 21 | ``` 22 | -------------------------------------------------------------------------------- /atom_info/crystal.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "O", 4 | "atom_num": 8, 5 | "radius": 1.28501 6 | }, 7 | { 8 | "name": "Li", 9 | "atom_num": 3, 10 | "radius": 1.28501 11 | }, 12 | { 13 | "name": "K", 14 | "atom_num": 19, 15 | "radius": 2.51334 16 | }, 17 | { 18 | "name": "Mg", 19 | "atom_num": 12, 20 | "radius": 2.0787 21 | }, 22 | { 23 | "name": "Al", 24 | "atom_num": 13, 25 | "radius": 2.55113 26 | }, 27 | { 28 | "name": "F", 29 | "atom_num": 9, 30 | "radius": 1.20943 31 | }, 32 | { 33 | "name": "Na", 34 | "atom_num": 11, 35 | "radius": 1.83304 36 | }, 37 | { 38 | "name": "Rb", 39 | "atom_num": 37, 40 | "radius": 2.7779 41 | }, 42 | { 43 | "name": "Cu", 44 | "atom_num": 29, 45 | "radius": 2.87238 46 | }, 47 | { 48 | "name": "In", 49 | "atom_num": 49, 50 | "radius": 3.08025 51 | }, 52 | { 53 | "name": "Cs", 54 | "atom_num": 55, 55 | "radius": 3.15584 56 | }, 57 | { 58 | "name": "Cl", 59 | "atom_num": 17, 60 | "radius": 1.87083 61 | }, 62 | { 63 | "name": "Zn", 64 | "atom_num": 30, 65 | "radius": 2.7401 66 | }, 67 | { 68 | "name": "Ni", 69 | "atom_num": 28, 70 | "radius": 2.83459 71 | }, 72 | { 73 | "name": "N", 74 | "atom_num": 7, 75 | "radius": 1.28501 76 | }, 77 | { 78 | "name": "Ag", 79 | "atom_num": 47, 80 | "radius": 3.00467 81 | }, 82 | { 83 | "name": "Tl", 84 | "atom_num": 81, 85 | "radius": 2.92908 86 | }, 87 | { 88 | "name": "Au", 89 | "atom_num": 79, 90 | "radius": 2.83459 91 | }, 92 | { 93 | "name": "Ga", 94 | "atom_num": 31, 95 | "radius": 2.30547 96 | }, 97 | { 98 | "name": "Rh", 99 | "atom_num": 45, 100 | "radius": 2.7401 101 | }, 102 | { 103 | "name": "Hg", 104 | "atom_num": 80, 105 | "radius": 3.21254 106 | }, 107 | { 108 | "name": "Cd", 109 | "atom_num": 48, 110 | "radius": 3.19364 111 | }, 112 | { 113 | "name": "Pd", 114 | "atom_num": 46, 115 | "radius": 2.83459 116 | }, 117 | { 118 | "name": "Sb", 119 | "atom_num": 51, 120 | "radius": 2.759 121 | }, 122 | { 123 | "name": "Co", 124 | "atom_num": 27, 125 | "radius": 2.51334 126 | }, 127 | { 128 | "name": "Mn", 129 | "atom_num": 25, 130 | "radius": 2.55113 131 | }, 132 | { 133 | "name": "Fe", 134 | "atom_num": 26, 135 | "radius": 2.53223 136 | }, 137 | { 138 | "name": "Sn", 139 | "atom_num": 50, 140 | "radius": 2.759 141 | }, 142 | { 143 | "name": "Si", 144 | "atom_num": 14, 145 | "radius": 2.26767 146 | }, 147 | { 148 | "name": "Y", 149 | "atom_num": 39, 150 | "radius": 3.36371 151 | }, 152 | { 153 | "name": "Ru", 154 | "atom_num": 44, 155 | "radius": 2.64562 156 | }, 157 | { 158 | "name": "Ca", 159 | "atom_num": 20, 160 | "radius": 1.87083 161 | }, 162 | { 163 | "name": "Sc", 164 | "atom_num": 21, 165 | "radius": 2.72121 166 | }, 167 | { 168 | "name": "Ir", 169 | "atom_num": 77, 170 | "radius": 2.49444 171 | }, 172 | { 173 | "name": "Ti", 174 | "atom_num": 22, 175 | "radius": 2.7779 176 | }, 177 | { 178 | "name": "Ba", 179 | "atom_num": 56, 180 | "radius": 2.53223 181 | }, 182 | { 183 | "name": "Br", 184 | "atom_num": 35, 185 | "radius": 2.28657 186 | }, 187 | { 188 | "name": "Bi", 189 | "atom_num": 83, 190 | "radius": 2.91018 191 | }, 192 | { 193 | "name": "Ge", 194 | "atom_num": 32, 195 | "radius": 2.21098 196 | }, 197 | { 198 | "name": "Pt", 199 | "atom_num": 78, 200 | "radius": 2.83459 201 | }, 202 | { 203 | "name": "Yb", 204 | "atom_num": 70, 205 | "radius": 3.66607 206 | }, 207 | { 208 | "name": "S", 209 | "atom_num": 16, 210 | "radius": 1.92752 211 | }, 212 | { 213 | "name": "Sr", 214 | "atom_num": 38, 215 | "radius": 2.11649 216 | }, 217 | { 218 | "name": "La", 219 | "atom_num": 57, 220 | "radius": 3.53379 221 | }, 222 | { 223 | "name": "C", 224 | "atom_num": 6, 225 | "radius": 1.28501 226 | }, 227 | { 228 | "name": "Nd", 229 | "atom_num": 60, 230 | "radius": 3.42041 231 | }, 232 | { 233 | "name": "Tm", 234 | "atom_num": 69, 235 | "radius": 3.25033 236 | }, 237 | { 238 | "name": "Lu", 239 | "atom_num": 71, 240 | "radius": 3.25033 241 | }, 242 | { 243 | "name": "Er", 244 | "atom_num": 68, 245 | "radius": 3.26923 246 | }, 247 | { 248 | "name": "Ho", 249 | "atom_num": 67, 250 | "radius": 3.28812 251 | }, 252 | { 253 | "name": "Ce", 254 | "atom_num": 58, 255 | "radius": 3.4582 256 | }, 257 | { 258 | "name": "V", 259 | "atom_num": 23, 260 | "radius": 2.51334 261 | }, 262 | { 263 | "name": "As", 264 | "atom_num": 33, 265 | "radius": 2.28657 266 | }, 267 | { 268 | "name": "Zr", 269 | "atom_num": 40, 270 | "radius": 2.94797 271 | }, 272 | { 273 | "name": "Sm", 274 | "atom_num": 62, 275 | "radius": 3.40151 276 | }, 277 | { 278 | "name": "Pr", 279 | "atom_num": 59, 280 | "radius": 3.4393 281 | }, 282 | { 283 | "name": "Dy", 284 | "atom_num": 66, 285 | "radius": 3.30702 286 | }, 287 | { 288 | "name": "I", 289 | "atom_num": 53, 290 | "radius": 2.64562 291 | }, 292 | { 293 | "name": "B", 294 | "atom_num": 5, 295 | "radius": 1.56847 296 | }, 297 | { 298 | "name": "Cr", 299 | "atom_num": 24, 300 | "radius": 2.55113 301 | }, 302 | { 303 | "name": "P", 304 | "atom_num": 15, 305 | "radius": 1.4173 306 | }, 307 | { 308 | "name": "Pb", 309 | "atom_num": 82, 310 | "radius": 2.91018 311 | }, 312 | { 313 | "name": "Ta", 314 | "atom_num": 73, 315 | "radius": 2.70231 316 | }, 317 | { 318 | "name": "Os", 319 | "atom_num": 76, 320 | "radius": 2.58893 321 | }, 322 | { 323 | "name": "Nb", 324 | "atom_num": 41, 325 | "radius": 2.7968 326 | }, 327 | { 328 | "name": "Pm", 329 | "atom_num": 61, 330 | "radius": 3.40151 331 | }, 332 | { 333 | "name": "Se", 334 | "atom_num": 34, 335 | "radius": 2.30547 336 | }, 337 | { 338 | "name": "H", 339 | "atom_num": 1, 340 | "radius": 0.43464 341 | }, 342 | { 343 | "name": "Hf", 344 | "atom_num": 72, 345 | "radius": 2.96687 346 | }, 347 | { 348 | "name": "Be", 349 | "atom_num": 4, 350 | "radius": 0.6614 351 | }, 352 | { 353 | "name": "Eu", 354 | "atom_num": 63, 355 | "radius": 3.76056 356 | }, 357 | { 358 | "name": "Mo", 359 | "atom_num": 42, 360 | "radius": 2.7779 361 | }, 362 | { 363 | "name": "Tb", 364 | "atom_num": 65, 365 | "radius": 3.32592 366 | }, 367 | { 368 | "name": "Gd", 369 | "atom_num": 64, 370 | "radius": 3.38261 371 | }, 372 | { 373 | "name": "Te", 374 | "atom_num": 52, 375 | "radius": 2.7779 376 | }, 377 | { 378 | "name": "Tc", 379 | "atom_num": 43, 380 | "radius": 2.55113 381 | }, 382 | { 383 | "name": "W", 384 | "atom_num": 74, 385 | "radius": 2.58893 386 | }, 387 | { 388 | "name": "Re", 389 | "atom_num": 75, 390 | "radius": 2.55113 391 | }, 392 | { 393 | "name": "U", 394 | "atom_num": 92, 395 | "radius": 2.98577 396 | }, 397 | { 398 | "name": "Ac", 399 | "atom_num": 89, 400 | "radius": 3.55269 401 | }, 402 | { 403 | "name": "Th", 404 | "atom_num": 90, 405 | "radius": 3.38261 406 | }, 407 | { 408 | "name": "Pa", 409 | "atom_num": 91, 410 | "radius": 3.04246 411 | }, 412 | { 413 | "name": "Pu", 414 | "atom_num": 94, 415 | "radius": 2.89128 416 | }, 417 | { 418 | "name": "Np", 419 | "atom_num": 93, 420 | "radius": 2.92908 421 | } 422 | ] -------------------------------------------------------------------------------- /atom_info/qm9.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "C", 4 | "atom_num": 6, 5 | "radius": 1.28501 6 | }, 7 | { 8 | "name": "H", 9 | "atom_num": 1, 10 | "radius": 0.43464 11 | }, 12 | { 13 | "name": "O", 14 | "atom_num": 8, 15 | "radius": 1.28501 16 | }, 17 | { 18 | "name": "N", 19 | "atom_num": 7, 20 | "radius": 1.28501 21 | }, 22 | { 23 | "name": "F", 24 | "atom_num": 9, 25 | "radius": 1.20943 26 | } 27 | ] -------------------------------------------------------------------------------- /configs/README.md: -------------------------------------------------------------------------------- 1 | # Modifying the config files 2 | 3 | By Chaoran Cheng, Oct 1, 2023 4 | 5 | 6 | 7 | This folder contains the configuration files in YAML format. 8 | 9 | Most arguments in the YAML files are self-explanatory. There are a few more things to note: 10 | 11 | - In the `datasets` field, `type` field specifies the dataset type. All other fields are passed to the dataset class. Note that all train, validation, and test datasets are constructed. You may also pass a `validation` or `test` field to overwrite the arguments for the corresponding datasets (like adding rotation during inference). If not provided, the arguments will be the same as the train dataset. 12 | - In the `model` field, `type` field specifies the model type. All other fields are passed to the model class. See each model class docstring for more details of each argument. 13 | - To expand the model class, make sure the fields under the `model` field in the YAML file match the arguments of your model class's `__init__` function (except for the `type` field). 14 | 15 | -------------------------------------------------------------------------------- /configs/md.yml: -------------------------------------------------------------------------------- 1 | train: 2 | seed: 42 3 | train_samples: 1024 4 | val_samples: 2048 5 | max_iter: 2000 6 | batch_size: 64 7 | log_freq: 10000 8 | val_freq: 20 9 | save_freq: 100 10 | max_grad_norm: 100. 11 | optimizer: 12 | type: adam 13 | lr: 5.e-3 14 | weight_decay: 0. 15 | beta1: 0.9 16 | beta2: 0.999 17 | scheduler: 18 | type: plateau 19 | factor: 0.5 20 | patience: 5 21 | min_lr: 1.e-5 22 | 23 | test: 24 | batch_size: 16 25 | inf_samples: 4096 26 | num_infer: null 27 | num_vis: 2 28 | 29 | datasets: 30 | type: small_density 31 | root: ./data/small_ecd 32 | mol_name: malonaldehyde 33 | 34 | model: 35 | type: infgcn 36 | n_atom_type: 3 37 | num_radial: 16 38 | num_spherical: 7 39 | radial_embed_size: 64 40 | radial_hidden_size: 128 41 | num_radial_layer: 2 42 | num_gcn_layer: 3 43 | cutoff: 3. 44 | grid_cutoff: 3. 45 | is_fc: false 46 | gauss_start: 0.5 47 | gauss_end: 5. 48 | -------------------------------------------------------------------------------- /configs/pbc.yml: -------------------------------------------------------------------------------- 1 | train: 2 | seed: 42 3 | train_samples: 1024 4 | val_samples: 2048 5 | max_iter: 10000 6 | batch_size: 32 7 | log_freq: 10 8 | val_freq: 100 9 | save_freq: 500 10 | max_grad_norm: 100. 11 | optimizer: 12 | type: adam 13 | lr: 5.e-3 14 | weight_decay: 0. 15 | beta1: 0.9 16 | beta2: 0.999 17 | scheduler: 18 | type: plateau 19 | factor: 0.5 20 | patience: 5 21 | min_lr: 1.e-5 22 | 23 | test: 24 | batch_size: 16 25 | inf_samples: 4096 26 | num_infer: null 27 | num_vis: 2 28 | 29 | datasets: 30 | type: density 31 | root: ./data/crystal 32 | split_file: data_split.json 33 | atom_file: ./atom_info/crystal.json 34 | extension: json 35 | compression: xz 36 | pbc: false # We do not need to tile the atoms when loading the data. 37 | # Instead, we will put constraint during basis expansion. 38 | 39 | model: 40 | type: infgcn 41 | n_atom_type: 84 42 | num_radial: 16 43 | num_spherical: 7 44 | radial_embed_size: 64 45 | radial_hidden_size: 128 46 | num_radial_layer: 2 47 | num_gcn_layer: 3 48 | cutoff: 5. 49 | grid_cutoff: 5. 50 | is_fc: false 51 | gauss_start: 0.5 52 | gauss_end: 5. 53 | residual: false 54 | pbc: true 55 | -------------------------------------------------------------------------------- /configs/qm9.yml: -------------------------------------------------------------------------------- 1 | train: 2 | seed: 42 3 | train_samples: 1024 4 | val_samples: 1024 5 | max_iter: 40000 6 | batch_size: 64 7 | log_freq: 20 8 | val_freq: 200 9 | save_freq: 2000 10 | max_grad_norm: 100. 11 | optimizer: 12 | type: adam 13 | lr: 1.e-3 14 | weight_decay: 0. 15 | beta1: 0.9 16 | beta2: 0.999 17 | scheduler: 18 | type: plateau 19 | factor: 0.5 20 | patience: 10 21 | min_lr: 1.e-5 22 | 23 | test: 24 | batch_size: 16 25 | inf_samples: 4096 26 | num_infer: 100 27 | num_vis: 2 28 | 29 | datasets: 30 | type: density 31 | root: ./data/QM9 32 | split_file: data_split.json 33 | atom_file: ./atom_info/qm9.json 34 | extension: CHGCAR 35 | compression: lz4 36 | pbc: false 37 | test: 38 | rotate: true 39 | 40 | model: 41 | type: infgcn 42 | n_atom_type: 5 43 | num_radial: 16 44 | num_spherical: 7 45 | radial_embed_size: 64 46 | radial_hidden_size: 128 47 | num_radial_layer: 2 48 | num_gcn_layer: 3 49 | cutoff: 3. 50 | grid_cutoff: 3. 51 | is_fc: false 52 | gauss_start: 0.5 53 | gauss_end: 5. 54 | residual: true 55 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn.utils.rnn import pad_sequence 6 | import numpy as np 7 | from torch_geometric.data import Batch 8 | 9 | from ._base import get_dataset, register_dataset 10 | from .density import DensityDataset 11 | from .small_density import SmallDensityDataset 12 | 13 | 14 | class DensityCollator: 15 | def __init__(self, n_samples=None): 16 | self.n_samples = n_samples 17 | 18 | def __call__(self, batch): 19 | g, densities, grid_coord, infos = zip(*batch) 20 | g = Batch.from_data_list(g) 21 | 22 | if self.n_samples is None: 23 | densities = pad_sequence(densities, batch_first=True, padding_value=-1) 24 | grid_coord = pad_sequence(grid_coord, batch_first=True, padding_value=0.) 25 | return g, densities, grid_coord, infos 26 | 27 | sampled_density, sampled_grid = [], [] 28 | for d, coord in zip(densities, grid_coord): 29 | idx = random.sample(range(d.size(0)), self.n_samples) 30 | sampled_density.append(d[idx]) 31 | sampled_grid.append(coord[idx]) 32 | sampled_density = torch.stack(sampled_density, dim=0) 33 | sampled_grid = torch.stack(sampled_grid, dim=0) 34 | return g, sampled_density, sampled_grid, infos 35 | 36 | 37 | class DensityVoxelCollator: 38 | def __call__(self, batch): 39 | g, densities, grid_coord, infos = zip(*batch) 40 | g = Batch.from_data_list(g) 41 | shapes = [info['shape'] for info in infos] 42 | max_shape = np.array(shapes).max(0) 43 | 44 | padded_density, padded_grid = [], [] 45 | for den, grid, shape in zip(densities, grid_coord, shapes): 46 | padded_density.append(F.pad(den.view(*shape), ( 47 | 0, max_shape[2] - shape[2], 48 | 0, max_shape[1] - shape[1], 49 | 0, max_shape[0] - shape[0] 50 | ), value=-1)) 51 | padded_grid.append(F.pad(grid.view(*shape, 3), ( 52 | 0, 0, 53 | 0, max_shape[2] - shape[2], 54 | 0, max_shape[1] - shape[1], 55 | 0, max_shape[0] - shape[0] 56 | ), value=0.)) 57 | densities = torch.stack(padded_density, dim=0) 58 | grid_coord = torch.stack(padded_grid, dim=0) 59 | return g, densities, grid_coord, infos 60 | 61 | 62 | __all__ = [ 63 | 'get_dataset', 'register_dataset', 64 | 'DensityDataset', 'SmallDensityDataset', 65 | 'DensityCollator', 'DensityVoxelCollator' 66 | ] 67 | -------------------------------------------------------------------------------- /datasets/_base.py: -------------------------------------------------------------------------------- 1 | _DATASET_DICT = {} 2 | 3 | 4 | def register_dataset(name): 5 | def decorator(cls): 6 | _DATASET_DICT[name] = cls 7 | return cls 8 | 9 | return decorator 10 | 11 | 12 | def get_dataset(cfg): 13 | d_cfg = cfg.copy() 14 | d_type = d_cfg.pop('type') 15 | train_cfg = d_cfg.pop('train', {}) 16 | val_cfg = d_cfg.pop('validation', {}) 17 | test_cfg = d_cfg.pop('test', {}) 18 | return ( 19 | _DATASET_DICT[d_type](split='train', **train_cfg, **d_cfg), 20 | _DATASET_DICT[d_type](split='validation', **val_cfg, **d_cfg), 21 | _DATASET_DICT[d_type](split='test', **test_cfg, **d_cfg), 22 | ) 23 | -------------------------------------------------------------------------------- /datasets/density.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | from e3nn import o3 8 | import torch.nn.functional as F 9 | from torch.utils.data import Dataset 10 | from torch_geometric.data import Data 11 | 12 | from ._base import register_dataset 13 | 14 | Bohr = 0.529177 # Bohr radius in angstrom 15 | 16 | 17 | def pbc_expand(atom_type, atom_coord): 18 | """ 19 | Expand the atoms by periodic boundary condition to eight directions in the neighboring cells. 20 | :param atom_type: atom types, tensor of shape (n_atom,) 21 | :param atom_coord: atom coordinates, tensor of shape (n_atom, 3) 22 | :return: expanded atom types and coordinates 23 | """ 24 | exp_type, exp_coord = [], [] 25 | exp_direction = torch.FloatTensor([ 26 | [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], 27 | [0, 1, 1], [1, 0, 1], [1, 1, 0], [1, 1, 1] 28 | ]) 29 | for a_type, a_coord in zip(atom_type, atom_coord): 30 | for direction in exp_direction: 31 | new_coord = a_coord + direction 32 | if (new_coord <= 1).all(): 33 | exp_type.append(a_type) 34 | exp_coord.append(new_coord) 35 | return torch.LongTensor(exp_type), torch.stack(exp_coord, dim=0) 36 | 37 | 38 | def rotate_voxel(shape, cell, density, rotated_grid): 39 | """ 40 | Rotate the volumetric data using trilinear interpolation. 41 | :param shape: voxel shape, tensor of shape (3,) 42 | :param cell: cell vectors, tensor of shape (3, 3) 43 | :param density: original density, tensor of shape (n_grid,) 44 | :param rotated_grid: rotated grid coordinates, tensor of shape (n_grid, 3) 45 | :return: rotated density, tensor of shape (n_grid,) 46 | """ 47 | density = density.view(1, 1, *shape) 48 | rotated_grid = rotated_grid.view(1, *shape, 3) 49 | shape = torch.FloatTensor(shape) 50 | grid_cell = cell / shape.view(3, 1) 51 | normalized_grid = (2 * rotated_grid @ torch.linalg.inv(grid_cell) - shape + 1) / (shape - 1) 52 | return F.grid_sample(density, torch.flip(normalized_grid, [-1]), 53 | mode='bilinear', align_corners=False).view(-1) 54 | 55 | 56 | @register_dataset('density') 57 | class DensityDataset(Dataset): 58 | def __init__(self, root, split, split_file, atom_file, extension='CHGCAR', 59 | compression='lz4', rotate=False, pbc=False): 60 | """ 61 | The density dataset contains volumetric data of molecules. 62 | :param root: data root 63 | :param split: data split, can be 'train', 'validation', 'test' 64 | :param split_file: the data split file containing file names of the split 65 | :param atom_file: atom information file 66 | :param extension: raw data file extension, can be 'CHGCAR', 'cube', 'json' 67 | :param compression: raw data compression, can be 'lz4', 'xz', or None (no compression) 68 | :param rotate: whether to rotate the molecule and the volumetric data 69 | :param pbc: whether the data satisfy the periodic boundary condition 70 | """ 71 | super(DensityDataset, self).__init__() 72 | self.root = root 73 | self.split = split 74 | self.extension = extension 75 | self.compression = compression 76 | self.rotate = rotate 77 | self.pbc = pbc 78 | 79 | self.file_pattern = f'.{extension}' 80 | if compression is not None: 81 | self.file_pattern += f'.{compression}' 82 | with open(os.path.join(root, split_file)) as f: 83 | # reverse the order so that larger molecules are tested first 84 | self.file_list = list(reversed(json.load(f)[split])) 85 | with open(atom_file) as f: 86 | atom_info = json.load(f) 87 | atom_list = [info['name'] for info in atom_info] 88 | self.atom_name2idx = {name: idx for idx, name in enumerate(atom_list)} 89 | self.atom_name2idx.update({name.encode(): idx for idx, name in enumerate(atom_list)}) 90 | self.atom_num2idx = {info['atom_num']: idx for idx, info in enumerate(atom_info)} 91 | self.idx2atom_num = {idx: info['atom_num'] for idx, info in enumerate(atom_info)} 92 | 93 | if extension == 'CHGCAR': 94 | self.read_func = self.read_chgcar 95 | elif extension == 'cube': 96 | self.read_func = self.read_cube 97 | elif extension == 'json': 98 | self.read_func = self.read_json 99 | else: 100 | raise TypeError(f'Unknown extension {extension}') 101 | 102 | if compression == 'lz4': 103 | import lz4.frame 104 | self.open = lz4.frame.open 105 | elif compression == 'xz': 106 | import lzma 107 | self.open = lzma.open 108 | else: 109 | self.open = open 110 | 111 | def __getitem__(self, item): 112 | file_name = f'{(self.file_list[item])}{self.file_pattern}' 113 | with self.open(os.path.join(self.root, file_name)) as f: 114 | g, density, grid_coord, info = self.read_func(f) 115 | info['file_name'] = file_name 116 | 117 | if self.rotate: 118 | rot = o3.rand_matrix() 119 | center = info['cell'].sum(dim=0) / 2 120 | g.pos = (g.pos - center) @ rot.t() + center 121 | rotated_grid = (grid_coord - center) @ rot + center 122 | density = rotate_voxel(info['shape'], info['cell'], density, rotated_grid) 123 | info['rot'] = rot 124 | return g, density, grid_coord, info 125 | 126 | def __len__(self): 127 | return len(self.file_list) 128 | 129 | def read_cube(self, fileobj): 130 | """Read atoms and data from CUBE file.""" 131 | if self.pbc: 132 | raise NotImplementedError('PBC not implemented for cube files') 133 | 134 | readline = fileobj.readline 135 | readline() # the first comment line 136 | readline() # the second comment line 137 | 138 | # Third line contains actual system information: 139 | line = readline().split() 140 | n_atom = int(line[0]) 141 | 142 | # Origin around which the volumetric data is centered 143 | # (at least in FHI aims): 144 | origin = torch.FloatTensor([float(x) for x in line[1::]]) 145 | 146 | shape = [] 147 | cell = torch.empty(3, 3, dtype=torch.float) 148 | # the upcoming three lines contain the cell information 149 | for i in range(3): 150 | n, x, y, z = [float(s) for s in readline().split()] 151 | shape.append(int(n)) 152 | cell[i] = torch.FloatTensor([x, y, z]) 153 | x_coord = torch.arange(shape[0]).unsqueeze(-1) * cell[0] 154 | y_coord = torch.arange(shape[1]).unsqueeze(-1) * cell[1] 155 | z_coord = torch.arange(shape[2]).unsqueeze(-1) * cell[2] 156 | grid_coord = x_coord.view(-1, 1, 1, 3) + y_coord.view(1, -1, 1, 3) + z_coord.view(1, 1, -1, 3) 157 | grid_coord = grid_coord.view(-1, 3) - origin 158 | 159 | atom_type = torch.empty(n_atom, dtype=torch.long) 160 | atom_coord = torch.empty(n_atom, 3, dtype=torch.float) 161 | for i in range(n_atom): 162 | line = readline().split() 163 | atom_type[i] = self.atom_num2idx[int(line[0])] 164 | atom_coord[i] = torch.FloatTensor([float(s) for s in line[2:]]) 165 | 166 | g = Data(x=atom_type, pos=atom_coord) 167 | density = torch.FloatTensor([float(s) for s in fileobj.read().split()]) 168 | return g, density, grid_coord, {'shape': shape, 'cell': cell, 'origin': origin} 169 | 170 | def read_chgcar(self, fileobj): 171 | """Read atoms and data from CHGCAR file.""" 172 | readline = fileobj.readline 173 | readline() # the first comment line 174 | scale = float(readline()) # the scaling factor (lattice constant) 175 | 176 | # the upcoming three lines contain the cell information 177 | cell = torch.empty(3, 3, dtype=torch.float) 178 | for i in range(3): 179 | cell[i] = torch.FloatTensor([float(s) for s in readline().split()]) 180 | cell = cell * scale 181 | 182 | # the sixth line specifies the constituting elements 183 | elements = readline().split() 184 | # the seventh line supplies the number of atoms per atomic species 185 | n_atoms = [int(s) for s in readline().split()] 186 | # the eighth line is always "Direct" in our application 187 | readline() 188 | 189 | tot_atoms = sum(n_atoms) 190 | atom_type = torch.empty(tot_atoms, dtype=torch.long) 191 | atom_coord = torch.empty(tot_atoms, 3, dtype=torch.float) 192 | # the upcoming lines contains the atomic positions in fractional coordinates 193 | idx = 0 194 | for elem, n in zip(elements, n_atoms): 195 | atom_type[idx:idx + n] = self.atom_name2idx[elem] 196 | for _ in range(n): 197 | atom_coord[idx] = torch.FloatTensor([float(s) for s in readline().split()]) 198 | idx += 1 199 | if self.pbc: 200 | atom_type, atom_coord = pbc_expand(atom_type, atom_coord) 201 | # the coordinates are fractional, convert them to cartesian 202 | atom_coord = atom_coord @ cell 203 | g = Data(x=atom_type, pos=atom_coord) 204 | 205 | readline() # an empty line 206 | shape = [int(s) for s in readline().split()] # grid size 207 | n_grid = shape[0] * shape[1] * shape[2] 208 | # the grids are corner-aligned 209 | x_coord = torch.linspace(0, shape[0] - 1, shape[0]).unsqueeze(-1) / shape[0] * cell[0] 210 | y_coord = torch.linspace(0, shape[1] - 1, shape[1]).unsqueeze(-1) / shape[1] * cell[1] 211 | z_coord = torch.linspace(0, shape[2] - 1, shape[2]).unsqueeze(-1) / shape[2] * cell[2] 212 | grid_coord = x_coord.view(-1, 1, 1, 3) + y_coord.view(1, -1, 1, 3) + z_coord.view(1, 1, -1, 3) 213 | grid_coord = grid_coord.view(-1, 3) 214 | 215 | # the augmented occupancies are ignored 216 | density = torch.FloatTensor([float(s) for s in fileobj.read().split()[:n_grid]]) 217 | # the value stored is the charge within a grid instead of the charge density 218 | # divide the charge by the grid volume to get the density 219 | volume = torch.linalg.det(cell).abs() 220 | density = density / volume 221 | # CHGCAR file stores the density as Z-Y-X, convert them to X-Y-Z 222 | density = density.view(shape[2], shape[1], shape[0]).transpose(0, 2).contiguous().view(-1) 223 | return g, density, grid_coord, {'shape': shape, 'cell': cell} 224 | 225 | def read_json(self, fileobj): 226 | """Read atoms and data from JSON file.""" 227 | 228 | def read_2d_tensor(s): 229 | return torch.FloatTensor([[float(x) for x in line] for line in s]) 230 | 231 | data = json.load(fileobj) 232 | scale = float(data['vector'][0][0]) 233 | cell = read_2d_tensor(data['lattice'][0]) * scale 234 | elements = data['elements'][0] 235 | n_atoms = [int(s) for s in data['elements_number'][0]] 236 | 237 | tot_atoms = sum(n_atoms) 238 | atom_coord = read_2d_tensor(data['coordinates'][0]) 239 | atom_type = torch.empty(tot_atoms, dtype=torch.long) 240 | idx = 0 241 | for elem, n in zip(elements, n_atoms): 242 | atom_type[idx:idx + n] = self.atom_name2idx[elem] 243 | idx += n 244 | if self.pbc: 245 | atom_type, atom_coord = pbc_expand(atom_type, atom_coord) 246 | atom_coord = atom_coord @ cell 247 | g = Data(x=atom_type, pos=atom_coord) 248 | 249 | shape = [int(s) for s in data['FFTgrid'][0]] 250 | x_coord = torch.linspace(0, shape[0] - 1, shape[0]).unsqueeze(-1) / shape[0] * cell[0] 251 | y_coord = torch.linspace(0, shape[1] - 1, shape[1]).unsqueeze(-1) / shape[1] * cell[1] 252 | z_coord = torch.linspace(0, shape[2] - 1, shape[2]).unsqueeze(-1) / shape[2] * cell[2] 253 | grid_coord = x_coord.view(-1, 1, 1, 3) + y_coord.view(1, -1, 1, 3) + z_coord.view(1, 1, -1, 3) 254 | grid_coord = grid_coord.view(-1, 3) 255 | 256 | n_grid = shape[0] * shape[1] * shape[2] 257 | n_line = (n_grid + 9) // 10 258 | density = torch.FloatTensor([ 259 | float(s) if not s.startswith('*') else 0. 260 | for line in data['chargedensity'][0][:n_line] 261 | for s in line 262 | ]).view(-1)[:n_grid] 263 | volume = torch.linalg.det(cell).abs() 264 | density = density / volume 265 | density = density.view(shape[2], shape[1], shape[0]).transpose(0, 2).contiguous().view(-1) 266 | return g, density, grid_coord, {'shape': shape, 'cell': cell} 267 | 268 | # TODO: cube files are in unit of Bohr 269 | def write_cube(self, fileobj, atom_type, atom_coord, density, info): 270 | """Write a cube file.""" 271 | fileobj.write('Cube file written on ' + time.strftime('%c')) 272 | fileobj.write("\nOUTER LOOP: X, MIDDLE LOOP: Y, INNER LOOP: Z\n") 273 | 274 | cell = info['cell'] 275 | shape = info['shape'] 276 | origin = info.get('origin', np.zeros(3)) 277 | fileobj.write('{0:5}{1:12.6f}{2:12.6f}{3:12.6f}\n'.format(len(atom_type), *origin)) 278 | 279 | for s, c in zip(shape, cell): 280 | d = c / s 281 | fileobj.write('{0:5}{1:12.6f}{2:12.6f}{3:12.6f}\n'.format(s, *d)) 282 | 283 | for Z, (x, y, z) in zip(atom_type, atom_coord): 284 | Z = self.idx2atom_num[Z] 285 | fileobj.write( 286 | '{0:5}{1:12.6f}{2:12.6f}{3:12.6f}{4:12.6f}\n'.format(Z, Z, x, y, z) 287 | ) 288 | density.tofile(fileobj, sep='\n', format='%e') 289 | -------------------------------------------------------------------------------- /datasets/small_density.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from torch_geometric.data import Data 6 | import numpy as np 7 | 8 | from ._base import register_dataset 9 | 10 | ATOM_TYPES = { 11 | 'benzene': torch.LongTensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), 12 | 'ethanol': torch.LongTensor([0, 0, 2, 1, 1, 1, 1, 1, 1]), 13 | 'phenol': torch.LongTensor([0, 0, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1]), 14 | 'resorcinol': torch.LongTensor([0, 0, 0, 0, 0, 0, 2, 1, 2, 1, 1, 1, 1, 1]), 15 | 'ethane': torch.LongTensor([0, 0, 1, 1, 1, 1, 1, 1]), 16 | 'malonaldehyde': torch.LongTensor([2, 0, 0, 0, 2, 1, 1, 1, 1]), 17 | } 18 | 19 | 20 | @register_dataset('small_density') 21 | class SmallDensityDataset(Dataset): 22 | def __init__(self, root, mol_name, split): 23 | """ 24 | Density dataset for small molecules in the MD datasets. 25 | Note that the validation and test splits are the same. 26 | :param root: data root 27 | :param mol_name: name of the molecule 28 | :param split: data split, can be 'train', 'validation', 'test' 29 | """ 30 | super(SmallDensityDataset, self).__init__() 31 | assert mol_name in ('benzene', 'ethanol', 'phenol', 'resorcinol', 'ethane', 'malonaldehyde') 32 | self.root = root 33 | self.mol_name = mol_name 34 | self.split = split 35 | if split == 'validation': 36 | split = 'test' 37 | 38 | self.n_grid = 50 # number of grid points along each dimension 39 | self.grid_size = 20. # box size in Bohr 40 | self.data_path = os.path.join(root, mol_name, f'{mol_name}_{split}') 41 | 42 | self.atom_type = ATOM_TYPES[mol_name] 43 | self.atom_coords = torch.FloatTensor(np.load(os.path.join(self.data_path, 'structures.npy'))) 44 | self.densities = self._convert_fft(np.load(os.path.join(self.data_path, 'dft_densities.npy'))) 45 | self.grid_coord = self._generate_grid() 46 | 47 | def _convert_fft(self, fft_coeff): 48 | # The raw data are stored in Fourier basis, we need to convert them back. 49 | print(f'Precomputing {self.split} density from FFT coefficients ...') 50 | fft_coeff = torch.FloatTensor(fft_coeff).to(torch.complex64) 51 | d = fft_coeff.view(-1, self.n_grid, self.n_grid, self.n_grid) 52 | hf = self.n_grid // 2 53 | # first dimension 54 | d[:, :hf] = (d[:, :hf] - d[:, hf:] * 1j) / 2 55 | d[:, hf:] = torch.flip(d[:, 1:hf + 1], [1]).conj() 56 | d = torch.fft.ifft(d, dim=1) 57 | # second dimension 58 | d[:, :, :hf] = (d[:, :, :hf] - d[:, :, hf:] * 1j) / 2 59 | d[:, :, hf:] = torch.flip(d[:, :, 1:hf + 1], [2]).conj() 60 | d = torch.fft.ifft(d, dim=2) 61 | # third dimension 62 | d[..., :hf] = (d[..., :hf] - d[..., hf:] * 1j) / 2 63 | d[..., hf:] = torch.flip(d[..., 1:hf + 1], [3]).conj() 64 | d = torch.fft.ifft(d, dim=3) 65 | return torch.flip(d.real.view(-1, self.n_grid ** 3), [-1]).detach() 66 | 67 | def _generate_grid(self): 68 | x = torch.linspace(self.grid_size / self.n_grid, self.grid_size, self.n_grid) 69 | return torch.stack(torch.meshgrid(x, x, x, indexing='ij'), dim=-1).view(-1, 3).detach() 70 | 71 | def __getitem__(self, item): 72 | info = { 73 | 'cell': torch.eye(3) * self.grid_size, 74 | 'shape': [self.n_grid, self.n_grid, self.n_grid] 75 | } 76 | return ( 77 | Data(x=self.atom_type, pos=self.atom_coords[item]), 78 | self.densities[item], self.grid_coord, info 79 | ) 80 | 81 | def __len__(self): 82 | return self.atom_coords.shape[0] 83 | -------------------------------------------------------------------------------- /generate_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | 7 | 8 | def read_xyz(file): 9 | all_coords = [] 10 | try: 11 | while True: 12 | n_atom = int(file.readline()) 13 | file.readline() 14 | coords = [] 15 | for _ in range(n_atom): 16 | coords.append([float(x) for x in file.readline().split()[1:4]]) 17 | all_coords.append(coords) 18 | except (StopIteration, ValueError): 19 | all_coords = np.array(all_coords, dtype=float) 20 | print(all_coords.shape) 21 | return all_coords 22 | 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--root', type=str, default='./data') 27 | parser.add_argument('--out', type=str, default='./data') 28 | args = parser.parse_args() 29 | 30 | root = Path(args.root) 31 | out = Path(args.out) 32 | for mol in ['ethane', 'malonaldehyde']: 33 | den = np.loadtxt(root / f'{mol}_300K/densities.txt') 34 | train_dir = root / f'{mol}/{mol}_train/' 35 | os.makedirs(train_dir, exist_ok=True) 36 | np.save(train_dir / 'dft_densities.npy', den) 37 | with open(root / f'{mol}_300K/structures.xyz') as f: 38 | np.save(train_dir / 'structures.npy', read_xyz(f)) 39 | 40 | den = np.loadtxt(root / f'{mol}_300K-test/densities.txt') 41 | test_dir = root / f'{mol}/{mol}_test/' 42 | os.makedirs(test_dir, exist_ok=True) 43 | np.save(test_dir / 'dft_densities.npy', den) 44 | with open(root / f'{mol}_300K-test/structures.xyz') as f: 45 | np.save(test_dir / 'structures.npy', read_xyz(f)) 46 | print('Done') 47 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.data import DataLoader 9 | import torch.utils.tensorboard 10 | from torch.nn.utils import clip_grad_norm_ 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | torch.multiprocessing.set_sharing_strategy('file_system') 14 | 15 | from datasets import get_dataset, DensityCollator, DensityVoxelCollator 16 | from models import get_model 17 | from utils import load_config, seed_all, get_optimizer, get_scheduler, count_parameters 18 | from visualize import draw_stack 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser(description='InfGCN Training/Inference') 22 | parser.add_argument('config', type=str, help='config file path') 23 | parser.add_argument('--mode', type=str, choices=['train', 'inf'], default='train', 24 | help='running mode: train or inf') 25 | parser.add_argument('--device', type=str, default='cuda', help='cuda or cpu') 26 | parser.add_argument('--logdir', type=str, default='./logs', help='log directory') 27 | parser.add_argument('--savename', type=str, default='test', help='save name') 28 | parser.add_argument('--resume', type=str, default=None, help='checkpoint path to resume from') 29 | args = parser.parse_args() 30 | 31 | # Load configs 32 | config = load_config(args.config) 33 | seed_all(config.train.seed) 34 | print(config) 35 | logdir = os.path.join(args.logdir, args.savename) 36 | if not os.path.exists(logdir): 37 | os.makedirs(logdir, exist_ok=True) 38 | writer = SummaryWriter(logdir) 39 | 40 | # Data 41 | print('Loading datasets...') 42 | use_voxel = config.model.type == 'cnn' 43 | if use_voxel: 44 | train_collator = val_collator = inf_collator = DensityVoxelCollator() 45 | else: 46 | train_collator = DensityCollator(config.train.train_samples) 47 | val_collator = DensityCollator(config.train.val_samples) 48 | inf_collator = DensityCollator() 49 | train_set, val_set, test_set = get_dataset(config.datasets) 50 | train_loader = DataLoader(train_set, config.train.batch_size, shuffle=True, 51 | num_workers=32, collate_fn=train_collator) 52 | val_loader = DataLoader(val_set, config.train.batch_size, shuffle=False, 53 | num_workers=32, collate_fn=val_collator) 54 | inf_loader = DataLoader(val_set, 2, shuffle=True, num_workers=2, collate_fn=inf_collator) 55 | 56 | # Model 57 | print('Building model...') 58 | model = get_model(config.model).to(args.device) 59 | print(f'Number of parameters: {count_parameters(model)}') 60 | 61 | # Optimizer & Scheduler 62 | optimizer = get_optimizer(config.train.optimizer, model) 63 | scheduler = get_scheduler(config.train.scheduler, optimizer) 64 | criterion = nn.MSELoss().to(args.device) 65 | optimizer.zero_grad() 66 | 67 | # Resume 68 | if args.resume is not None: 69 | print(f'Resuming from checkpoint: {args.resume}') 70 | ckpt = torch.load(args.resume, map_location=args.device) 71 | model.load_state_dict(ckpt['model']) 72 | if 'optimizer' in ckpt: 73 | print('Resuming optimizer states...') 74 | optimizer.load_state_dict(ckpt['optimizer']) 75 | if 'scheduler' in ckpt: 76 | print('Resuming scheduler states...') 77 | scheduler.load_state_dict(ckpt['scheduler']) 78 | 79 | global_step = 0 80 | 81 | 82 | def train(): 83 | global global_step 84 | 85 | epoch = 0 86 | while True: 87 | model.train() 88 | epoch_losses = [] 89 | for g, density, grid_coord, infos in train_loader: 90 | g = g.to(args.device) 91 | density, grid_coord = density.to(args.device), grid_coord.to(args.device) 92 | pred = model(g.x, g.pos, grid_coord, g.batch, infos) 93 | if use_voxel: 94 | mask = (density > 0).float() 95 | pred = pred * mask 96 | density = density * mask 97 | loss = criterion(pred, density) 98 | mae = torch.abs(pred.detach() - density).sum() / density.sum() 99 | epoch_losses.append(loss.item()) 100 | loss.backward() 101 | grad_norm = clip_grad_norm_(model.parameters(), config.train.max_grad_norm) 102 | optimizer.step() 103 | optimizer.zero_grad() 104 | 105 | # Logging 106 | writer.add_scalar('train/loss', loss.item(), global_step) 107 | writer.add_scalar('train/mae', mae.item(), global_step) 108 | writer.add_scalar('train/grad', grad_norm.item(), global_step) 109 | writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], global_step) 110 | if global_step % config.train.log_freq == 0: 111 | print(f'Epoch {epoch} Step {global_step} train loss {loss.item():.6f},' 112 | f' train mae {mae.item():.6f}') 113 | global_step += 1 114 | if global_step % config.train.val_freq == 0: 115 | avg_val_loss = validate(val_loader) 116 | inference(inf_loader, 1, config.test.num_vis, config.test.inf_samples) 117 | 118 | if config.train.scheduler.type == 'plateau': 119 | scheduler.step(avg_val_loss) 120 | else: 121 | scheduler.step() 122 | 123 | model.train() 124 | torch.save({ 125 | 'model': model.state_dict(), 126 | 'step': global_step, 127 | }, os.path.join(logdir, 'latest.pt')) 128 | if global_step % config.train.save_freq == 0: 129 | ckpt_path = os.path.join(logdir, f'{global_step}.pt') 130 | torch.save({ 131 | 'config': config, 132 | 'model': model.state_dict(), 133 | 'optimizer': optimizer.state_dict(), 134 | 'scheduler': scheduler.state_dict(), 135 | 'avg_val_loss': avg_val_loss, 136 | }, ckpt_path) 137 | if global_step >= config.train.max_iter: 138 | return 139 | 140 | epoch_loss = sum(epoch_losses) / len(epoch_losses) 141 | print(f'Epoch {epoch} train loss {epoch_loss:.6f}') 142 | epoch += 1 143 | 144 | 145 | def validate(dataloader, split='val'): 146 | with torch.no_grad(): 147 | model.eval() 148 | 149 | val_losses = [] 150 | val_mae, val_cnt = 0., 0. 151 | for g, density, grid_coord, infos in tqdm(dataloader, total=len(dataloader)): 152 | g = g.to(args.device) 153 | density, grid_coord = density.to(args.device), grid_coord.to(args.device) 154 | pred = model(g.x, g.pos, grid_coord, g.batch, infos) 155 | if use_voxel: 156 | mask = (density > 0).float() 157 | pred = pred * mask 158 | density = density * mask 159 | loss = criterion(pred, density) 160 | val_losses.append(loss.item()) 161 | val_mae += torch.abs(pred - density).sum().item() 162 | val_cnt += density.sum().item() 163 | val_loss = sum(val_losses) / len(val_losses) 164 | val_mae = val_mae / val_cnt 165 | 166 | writer.add_scalar(f'{split}/loss', val_loss, global_step) 167 | writer.add_scalar(f'{split}/mae', val_mae, global_step) 168 | print(f'Step {global_step} {split} loss {val_loss:.6f}, {split} mae {val_mae:.6f}') 169 | return val_loss 170 | 171 | 172 | def inference_batch(g, density, grid_coord, infos, grid_batch_size=None): 173 | with torch.no_grad(): 174 | model.eval() 175 | if grid_batch_size is None: 176 | preds = model(g.x, g.pos, grid_coord, g.batch, infos) 177 | else: 178 | preds = [] 179 | for grid in grid_coord.split(grid_batch_size, dim=1): 180 | preds.append(model(g.x, g.pos, grid.contiguous(), g.batch, infos)) 181 | preds = torch.cat(preds, dim=1) 182 | mask = (density > 0).float() 183 | preds = preds * mask 184 | density = density * mask 185 | diff = torch.abs(preds - density) 186 | sum_idx = tuple(range(1, density.dim())) 187 | loss = diff.pow(2).sum(sum_idx) / mask.sum(sum_idx) 188 | mae = diff.sum(sum_idx) / density.sum(sum_idx) 189 | return preds, loss, mae 190 | 191 | 192 | def inference(dataloader, num_infer=None, num_vis=2, samples=None): 193 | inf_loss, inf_mae = [], [] 194 | num_infer = num_infer or len(dataloader) 195 | for idx, (g, density, grid_coord, infos) in tqdm(enumerate(dataloader), total=num_infer): 196 | if idx >= num_infer: 197 | break 198 | 199 | g = g.to(args.device) 200 | density, grid_coord = density.to(args.device), grid_coord.to(args.device) 201 | pred, loss, mae = inference_batch(g, density, grid_coord, infos, samples) 202 | inf_loss.append(loss.detach().cpu().numpy()) 203 | inf_mae.append(mae.detach().cpu().numpy()) 204 | 205 | if idx == 0: 206 | for vis_idx, (p, d, info) in enumerate(zip(pred, density, infos)): 207 | if vis_idx >= num_vis: 208 | break 209 | 210 | shape = info['shape'] 211 | mask = g.batch == vis_idx 212 | atom_type, coord = g.x[mask], g.pos[mask] 213 | grid_cell = (info['cell'] / torch.FloatTensor(shape).view(3, 1)).to(args.device) 214 | coord = coord @ torch.linalg.inv(grid_cell) 215 | if use_voxel: 216 | d = d[:shape[0], :shape[1], :shape[2]] 217 | p = p[:shape[0], :shape[1], :shape[2]] 218 | else: 219 | num_voxel = shape[0] * shape[1] * shape[2] 220 | d, p = d[:num_voxel].view(*shape), p[:num_voxel].view(*shape) 221 | writer.add_image(f'inf/gt_{vis_idx}', draw_stack(d, atom_type, coord), global_step) 222 | writer.add_image(f'inf/pred_{vis_idx}', draw_stack(p, atom_type, coord), global_step) 223 | writer.add_image(f'inf/diff_{vis_idx}', draw_stack(d - p, atom_type, coord), global_step) 224 | inf_loss = np.concatenate(inf_loss, axis=0).mean() 225 | inf_mae = np.concatenate(inf_mae, axis=0).mean() 226 | writer.add_scalar('inf/loss', inf_loss, global_step) 227 | writer.add_scalar('inf/mae', inf_mae, global_step) 228 | print(f'Step {global_step} inference loss {inf_loss:.6f}, inference mae {inf_mae:.6f}') 229 | 230 | 231 | try: 232 | if args.mode == 'train': 233 | # inference(inf_loader, 1, config.test.num_vis, config.test.inf_samples) 234 | train() 235 | print('Training finished!') 236 | 237 | if args.mode == 'inf' and args.resume is None: 238 | print('[WARNING]: inference mode without loading a pretrained model') 239 | test_loader = DataLoader(test_set, config.test.batch_size, shuffle=False, 240 | num_workers=16, collate_fn=inf_collator) 241 | inference(test_loader, config.test.num_infer, config.test.num_vis, config.test.inf_samples) 242 | except KeyboardInterrupt: 243 | print('Terminating...') 244 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import get_model, register_model 2 | from .infgcn import InfGCN 3 | -------------------------------------------------------------------------------- /models/_base.py: -------------------------------------------------------------------------------- 1 | _MODEL_DICT = {} 2 | 3 | 4 | def register_model(name): 5 | def decorator(cls): 6 | _MODEL_DICT[name] = cls 7 | return cls 8 | 9 | return decorator 10 | 11 | 12 | def get_model(cfg): 13 | m_cfg = cfg.copy() 14 | m_type = m_cfg.pop('type') 15 | return _MODEL_DICT[m_type](**m_cfg) 16 | -------------------------------------------------------------------------------- /models/infgcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_scatter import scatter 4 | from torch_geometric.nn import radius_graph, radius 5 | from e3nn import o3 6 | from e3nn.math import soft_one_hot_linspace 7 | from e3nn.nn import FullyConnectedNet, Extract, Activation 8 | 9 | from .orbital import GaussianOrbital 10 | from ._base import register_model 11 | 12 | 13 | class ScalarActivation(nn.Module): 14 | """ 15 | Use the invariant scalar features to gate higher order equivariant features. 16 | Adapted from `e3nn.nn.Gate`. 17 | """ 18 | 19 | def __init__(self, irreps_in, act_scalars, act_gates): 20 | """ 21 | :param irreps_in: input representations 22 | :param act_scalars: scalar activation function 23 | :param act_gates: gate activation function (for higher order features) 24 | """ 25 | super(ScalarActivation, self).__init__() 26 | self.irreps_in = o3.Irreps(irreps_in) 27 | self.num_spherical = len(self.irreps_in) 28 | 29 | irreps_scalars = self.irreps_in[0:1] 30 | irreps_gates = irreps_scalars * (self.num_spherical - 1) 31 | irreps_gated = self.irreps_in[1:] 32 | self.act_scalars = Activation(irreps_scalars, [act_scalars]) 33 | self.act_gates = Activation(irreps_gates, [act_gates] * (self.num_spherical - 1)) 34 | self.extract = Extract( 35 | self.irreps_in, 36 | [irreps_scalars, irreps_gated], 37 | instructions=[(0,), tuple(range(1, self.irreps_in.lmax + 1))] 38 | ) 39 | self.mul = o3.ElementwiseTensorProduct(irreps_gates, irreps_gated) 40 | 41 | def forward(self, features): 42 | scalars, gated = self.extract(features) 43 | scalars_out = self.act_scalars(scalars) 44 | if gated.shape[-1]: 45 | gates = self.act_gates(scalars.repeat(1, self.num_spherical - 1)) 46 | gated_out = self.mul(gates, gated) 47 | features = torch.cat([scalars_out, gated_out], dim=-1) 48 | else: 49 | features = scalars_out 50 | return features 51 | 52 | 53 | class NormActivation(nn.Module): 54 | """ 55 | Use the norm of the higher order equivariant features to gate themselves. 56 | Idea from the TFN paper. 57 | """ 58 | 59 | def __init__(self, irreps_in, act_scalars=torch.nn.functional.silu, act_vectors=torch.sigmoid): 60 | """ 61 | :param irreps_in: input representations 62 | :param act_scalars: scalar activation function 63 | :param act_vectors: vector activation function (for the norm of higher order features) 64 | """ 65 | super(NormActivation, self).__init__() 66 | self.irreps_in = o3.Irreps(irreps_in) 67 | self.scalar_irreps = self.irreps_in[0:1] 68 | self.vector_irreps = self.irreps_in[1:] 69 | self.act_scalars = act_scalars 70 | self.act_vectors = act_vectors 71 | self.scalar_idx = self.irreps_in[0].mul 72 | 73 | inner_out = o3.Irreps([(mul, (0, 1)) for mul, _ in self.vector_irreps]) 74 | self.inner_prod = o3.TensorProduct( 75 | self.vector_irreps, self.vector_irreps, inner_out, [ 76 | (i, i, i, 'uuu', False) for i in range(len(self.vector_irreps)) 77 | ] 78 | ) 79 | self.mul = o3.ElementwiseTensorProduct(inner_out, self.vector_irreps) 80 | 81 | def forward(self, features): 82 | scalars = self.act_scalars(features[..., :self.scalar_idx]) 83 | vectors = features[..., self.scalar_idx:] 84 | norm = torch.sqrt(self.inner_prod(vectors, vectors) + 1e-8) 85 | act = self.act_vectors(norm) 86 | vectors_out = self.mul(act, vectors) 87 | return torch.cat([scalars, vectors_out], dim=-1) 88 | 89 | 90 | class GCNLayer(nn.Module): 91 | def __init__(self, irreps_in, irreps_out, irreps_edge, radial_embed_size, num_radial_layer, radial_hidden_size, 92 | is_fc=True, use_sc=True, irrep_normalization='component', path_normalization='element'): 93 | r""" 94 | A single InfGCN layer for Tensor Product-based message passing. 95 | If the tensor product is fully connected, we have (for every path) 96 | 97 | .. math:: 98 | z_w=\sum_{uv}w_{uvw}x_u\otimes y_v=\sum_{u}w_{uw}x_u \otimes y 99 | 100 | Else, we have 101 | 102 | .. math:: 103 | z_u=x_u\otimes \sum_v w_{uv}y_v=w_u (x_u\otimes y) 104 | 105 | Here, uvw are radial (channel) indices of the first input, second input, and output, respectively. 106 | Notice that in our model, the second input is always the spherical harmonics of the edge vector, 107 | so the index v can be safely ignored. 108 | 109 | :param irreps_in: irreducible representations of input node features 110 | :param irreps_out: irreducible representations of output node features 111 | :param irreps_edge: irreducible representations of edge features 112 | :param radial_embed_size: embedding size of the edge length 113 | :param num_radial_layer: number of hidden layers in the radial network 114 | :param radial_hidden_size: hidden size of the radial network 115 | :param is_fc: whether to use fully connected tensor product 116 | :param use_sc: whether to use self-connection 117 | :param irrep_normalization: representation normalization passed to the `o3.FullyConnectedTensorProduct` 118 | :param path_normalization: path normalization passed to the `o3.FullyConnectedTensorProduct` 119 | """ 120 | super(GCNLayer, self).__init__() 121 | self.irreps_in = o3.Irreps(irreps_in) 122 | self.irreps_out = o3.Irreps(irreps_out) 123 | self.irreps_edge = o3.Irreps(irreps_edge) 124 | self.radial_embed_size = radial_embed_size 125 | self.num_radial_layer = num_radial_layer 126 | self.radial_hidden_size = radial_hidden_size 127 | self.is_fc = is_fc 128 | self.use_sc = use_sc 129 | 130 | if self.is_fc: 131 | self.tp = o3.FullyConnectedTensorProduct( 132 | self.irreps_in, self.irreps_edge, self.irreps_out, 133 | internal_weights=False, shared_weights=False, 134 | irrep_normalization=irrep_normalization, 135 | path_normalization=path_normalization, 136 | ) 137 | else: 138 | instr = [ 139 | (i_1, i_2, i_out, 'uvu', True) 140 | for i_1, (_, ir_1) in enumerate(self.irreps_in) 141 | for i_2, (_, ir_edge) in enumerate(self.irreps_edge) 142 | for i_out, (_, ir_out) in enumerate(self.irreps_out) 143 | if ir_out in ir_1 * ir_edge 144 | ] 145 | self.tp = o3.TensorProduct( 146 | self.irreps_in, self.irreps_edge, self.irreps_out, instr, 147 | internal_weights=False, shared_weights=False, 148 | irrep_normalization=irrep_normalization, 149 | path_normalization=path_normalization, 150 | ) 151 | self.fc = FullyConnectedNet( 152 | [radial_embed_size] + num_radial_layer * [radial_hidden_size] + [self.tp.weight_numel], 153 | torch.nn.functional.silu 154 | ) 155 | self.sc = None 156 | if self.use_sc: 157 | self.sc = o3.Linear(self.irreps_in, self.irreps_out) 158 | 159 | def forward(self, edge_index, node_feat, edge_feat, edge_embed, dim_size=None): 160 | src, dst = edge_index 161 | weight = self.fc(edge_embed) 162 | out = self.tp(node_feat[src], edge_feat, weight=weight) 163 | out = scatter(out, dst, dim=0, dim_size=dim_size, reduce='sum') 164 | if self.use_sc: 165 | out = out + self.sc(node_feat) 166 | return out 167 | 168 | 169 | def pbc_vec(vec, cell): 170 | """ 171 | Apply periodic boundary condition to the vector 172 | :param vec: original vector of (N, K, 3) 173 | :param cell: cell frame of (N, 3, 3) 174 | :return: shortest vector of (N, K, 3) 175 | """ 176 | coord = vec @ torch.linalg.inv(cell) 177 | coord = coord - torch.round(coord) 178 | pbc_vec = coord @ cell 179 | return pbc_vec.detach() 180 | 181 | 182 | @register_model('infgcn') 183 | class InfGCN(nn.Module): 184 | def __init__(self, n_atom_type, num_radial, num_spherical, radial_embed_size, radial_hidden_size, 185 | num_radial_layer=2, num_gcn_layer=3, cutoff=3.0, grid_cutoff=3.0, is_fc=True, 186 | gauss_start=0.5, gauss_end=5.0, activation='norm', residual=True, pbc=False, **kwargs): 187 | """ 188 | Implement the InfGCN model for electron density estimation 189 | :param n_atom_type: number of atom types 190 | :param num_radial: number of radial basis 191 | :param num_spherical: maximum number of spherical harmonics for each radial basis, 192 | number of spherical basis will be (num_spherical + 1)^2 193 | :param radial_embed_size: embedding size of the edge length 194 | :param radial_hidden_size: hidden size of the radial network 195 | :param num_radial_layer: number of hidden layers in the radial network 196 | :param num_gcn_layer: number of InfGCN layers 197 | :param cutoff: cutoff distance for building the molecular graph 198 | :param grid_cutoff: cutoff distance for building the grid-atom graph 199 | :param is_fc: whether the InfGCN layer should use fully connected tensor product 200 | :param gauss_start: start coefficient of the Gaussian radial basis 201 | :param gauss_end: end coefficient of the Gaussian radial basis 202 | :param activation: activation type for the InfGCN layer, can be ['scalar', 'norm'] 203 | :param residual: whether to use the residue prediction layer 204 | :param pbc: whether the data satisfy the periodic boundary condition 205 | """ 206 | super(InfGCN, self).__init__() 207 | self.n_atom_type = n_atom_type 208 | self.num_radial = num_radial 209 | self.num_spherical = num_spherical 210 | self.radial_embed_size = radial_embed_size 211 | self.radial_hidden_size = radial_hidden_size 212 | self.num_radial_layer = num_radial_layer 213 | self.num_gcn_layer = num_gcn_layer 214 | self.cutoff = cutoff 215 | self.grid_cutoff = grid_cutoff 216 | self.is_fc = is_fc 217 | self.gauss_start = gauss_start 218 | self.gauss_end = gauss_end 219 | self.activation = activation 220 | self.residual = residual 221 | self.pbc = pbc 222 | 223 | assert activation in ['scalar', 'norm'] 224 | 225 | self.embedding = nn.Embedding(n_atom_type, num_radial) 226 | self.irreps_sh = o3.Irreps.spherical_harmonics(num_spherical, p=1) 227 | self.irreps_feat = (self.irreps_sh * num_radial).sort().irreps.simplify() 228 | self.gcns = nn.ModuleList([ 229 | GCNLayer( 230 | (f'{num_radial}x0e' if i == 0 else self.irreps_feat), self.irreps_feat, self.irreps_sh, 231 | radial_embed_size, num_radial_layer, radial_hidden_size, is_fc=is_fc, **kwargs 232 | ) for i in range(num_gcn_layer) 233 | ]) 234 | if self.activation == 'scalar': 235 | self.act = ScalarActivation(self.irreps_feat, torch.nn.functional.silu, torch.sigmoid) 236 | else: 237 | self.act = NormActivation(self.irreps_feat) 238 | self.residue = None 239 | if self.residual: 240 | self.residue = GCNLayer( 241 | self.irreps_feat, '0e', self.irreps_sh, 242 | radial_embed_size, num_radial_layer, radial_hidden_size, 243 | is_fc=True, use_sc=False, **kwargs 244 | ) 245 | self.orbital = GaussianOrbital(gauss_start, gauss_end, num_radial, num_spherical) 246 | 247 | def forward(self, atom_types, atom_coord, grid, batch, infos): 248 | """ 249 | Network forward 250 | :param atom_types: atom types of (N,) 251 | :param atom_coord: atom coordinates of (N, 3) 252 | :param grid: coordinates at grid points of (G, K, 3) 253 | :param batch: batch index for each node of (N,) 254 | :param infos: list of dictionary containing additional information 255 | :return: predicted value at each grid point of (G, K) 256 | """ 257 | # Embedding 258 | cell = torch.stack([info['cell'] for info in infos], dim=0).to(batch.device) 259 | feat = self.embedding(atom_types) 260 | edge_index = radius_graph(atom_coord, self.cutoff, batch, loop=False) 261 | src, dst = edge_index 262 | edge_vec = atom_coord[src] - atom_coord[dst] 263 | edge_len = edge_vec.norm(dim=-1) + 1e-8 264 | edge_feat = o3.spherical_harmonics( 265 | list(range(self.num_spherical + 1)), edge_vec / edge_len[..., None], 266 | normalize=False, normalization='integral' 267 | ) 268 | edge_embed = soft_one_hot_linspace( 269 | edge_len, start=0.0, end=self.cutoff, 270 | number=self.radial_embed_size, basis='gaussian', cutoff=False 271 | ).mul(self.radial_embed_size ** 0.5) 272 | 273 | # GCN 274 | for i, gcn in enumerate(self.gcns): 275 | feat = gcn(edge_index, feat, edge_feat, edge_embed, dim_size=atom_types.size(0)) 276 | if i != self.num_gcn_layer - 1: 277 | feat = self.act(feat) 278 | 279 | # Residue 280 | n_graph, n_sample = grid.size(0), grid.size(1) 281 | if self.residual: 282 | grid_flat = grid.view(-1, 3) 283 | grid_batch = torch.arange(n_graph, device=grid.device).repeat_interleave(n_sample) 284 | grid_dst, node_src = radius(atom_coord, grid_flat, self.grid_cutoff, batch, grid_batch) 285 | grid_edge = grid_flat[grid_dst] - atom_coord[node_src] 286 | grid_len = torch.norm(grid_edge, dim=-1) + 1e-8 287 | grid_edge_feat = o3.spherical_harmonics( 288 | list(range(self.num_spherical + 1)), grid_edge / (grid_len[..., None] + 1e-8), 289 | normalize=False, normalization='integral' 290 | ) 291 | grid_edge_embed = soft_one_hot_linspace( 292 | grid_len, start=0.0, end=self.grid_cutoff, 293 | number=self.radial_embed_size, basis='gaussian', cutoff=False 294 | ).mul(self.radial_embed_size ** 0.5) 295 | residue = self.residue( 296 | (node_src, grid_dst), feat, grid_edge_feat, grid_edge_embed, dim_size=grid_flat.size(0) 297 | ) 298 | else: 299 | residue = 0. 300 | 301 | # Orbital 302 | sample_vec = grid[batch] - atom_coord.unsqueeze(-2) 303 | if self.pbc: 304 | sample_vec = pbc_vec(sample_vec, cell[batch]) 305 | orbital = self.orbital(sample_vec) 306 | density = (orbital * feat.unsqueeze(1)).sum(dim=-1) 307 | density = scatter(density, batch, dim=0, reduce='sum') 308 | if self.residual: 309 | density = density + residue.view(*density.size()) 310 | return density 311 | -------------------------------------------------------------------------------- /models/orbital.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from e3nn import o3 6 | 7 | from .utils import BroadcastGTOTensor 8 | 9 | 10 | class GaussianOrbital(nn.Module): 11 | r""" 12 | Gaussian-type orbital 13 | 14 | .. math:: 15 | \psi_{n\ell m}(\mathbf{r})=\sqrt{\frac{2(2a_n)^{\ell+3/2}}{\Gamma(\ell+3/2)}} 16 | \exp(-a_n r^2) r^\ell Y_{\ell}^m(\hat{\mathbf{r}}) 17 | 18 | """ 19 | 20 | def __init__(self, gauss_start, gauss_end, num_gauss, lmax=7): 21 | super(GaussianOrbital, self).__init__() 22 | self.gauss_start = gauss_start 23 | self.gauss_end = gauss_end 24 | self.num_gauss = num_gauss 25 | self.lmax = lmax 26 | 27 | self.lc2lcm = BroadcastGTOTensor(lmax, num_gauss, src='lc', dst='lcm') 28 | self.m2lcm = BroadcastGTOTensor(lmax, num_gauss, src='m', dst='lcm') 29 | self.gauss: torch.Tensor 30 | self.lognorm: torch.Tensor 31 | 32 | self.register_buffer('gauss', torch.linspace(gauss_start, gauss_end, num_gauss)) 33 | self.register_buffer('lognorm', self._generate_lognorm()) 34 | 35 | def _generate_lognorm(self): 36 | power = (torch.arange(self.lmax + 1) + 1.5).unsqueeze(-1) # (l, 1) 37 | numerator = power * torch.log(2 * self.gauss).unsqueeze(0) + math.log(2) # (l, c) 38 | denominator = torch.special.gammaln(power) 39 | lognorm = (numerator - denominator) / 2 40 | return lognorm.view(-1) # (l * c) 41 | 42 | def forward(self, vec): 43 | """ 44 | Evaluate the basis functions 45 | :param vec: un-normalized vectors of (..., 3) 46 | :return: basis values of (..., (l+1)^2 * c) 47 | """ 48 | # spherical 49 | device = vec.device 50 | r = vec.norm(dim=-1) + 1e-8 51 | spherical = o3.spherical_harmonics( 52 | list(range(self.lmax + 1)), vec / r[..., None], 53 | normalize=False, normalization='integral' 54 | ) 55 | 56 | # radial 57 | r = r.unsqueeze(-1) 58 | lognorm = self.lognorm * torch.ones_like(r) # (..., l * c) 59 | exponent = -self.gauss * (r * r) # (..., c) 60 | poly = torch.arange(self.lmax + 1, dtype=torch.float, device=device) * torch.log(r) # (..., l) 61 | log = exponent.unsqueeze(-2) + poly.unsqueeze(-1) # (..., l, c) 62 | radial = torch.exp(log.view(*log.size()[:-2], -1) + lognorm) # (..., l * c) 63 | return self.lc2lcm(radial) * self.m2lcm(spherical) # (..., (l+1)^2 * c) 64 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BroadcastGTOTensor(nn.Module): 6 | r""" 7 | Broadcast between spherical tensors of the Gaussian Type Orbitals (GTOs): 8 | 9 | .. math:: 10 | \{a_{clm}, 1\le c\le c_{max}, 0\le\ell\le\ell_{max}, -\ell\le m\le\ell\} 11 | 12 | For efficiency reason, the feature tensor is indexed by l, c, m. 13 | For example, for lmax = 3, cmax = 2, we have a tensor of 1s2s 1p2p 1d2d 1f2f. 14 | Currently, we support the following broadcasting: 15 | lc -> lcm; 16 | m -> lcm. 17 | """ 18 | 19 | def __init__(self, lmax, cmax, src='lc', dst='lcm'): 20 | super(BroadcastGTOTensor, self).__init__() 21 | assert src in ['lc', 'm'] 22 | assert dst in ['lcm'] 23 | self.src = src 24 | self.dst = dst 25 | self.lmax = lmax 26 | self.cmax = cmax 27 | 28 | if src == 'lc': 29 | self.src_dim = (lmax + 1) * cmax 30 | else: 31 | self.src_dim = (lmax + 1) ** 2 32 | self.dst_dim = (lmax + 1) ** 2 * cmax 33 | 34 | if src == 'lc': 35 | indices = self._generate_lc2lcm_indices() 36 | else: 37 | indices = self._generate_m2lcm_indices() 38 | self.register_buffer('indices', indices) 39 | 40 | def _generate_lc2lcm_indices(self): 41 | r""" 42 | lc -> lcm 43 | .. math:: 44 | 1s2s 1p2p → 1s2s 1p_x1p_y1p_z2p_x2p_y2p_z 45 | [0, 1, 2, 2, 2, 3, 3, 3] 46 | 47 | :return: (lmax+1)^2 * cmax 48 | """ 49 | indices = [ 50 | l * self.cmax + c for l in range(self.lmax + 1) 51 | for c in range(self.cmax) 52 | for _ in range(2 * l + 1) 53 | ] 54 | return torch.LongTensor(indices) 55 | 56 | def _generate_m2lcm_indices(self): 57 | r""" 58 | m -> lcm 59 | .. math:: 60 | s p_x p_y p_z → 1s2s 1p_x1p_y1p_z2p_x2p_y2p_z 61 | [0, 0, 1, 2, 3, 1, 2, 3] 62 | 63 | :return: (lmax+1)^2 * cmax 64 | """ 65 | indices = [ 66 | l * l + m for l in range(self.lmax + 1) 67 | for _ in range(self.cmax) 68 | for m in range(2 * l + 1) 69 | ] 70 | return torch.LongTensor(indices) 71 | 72 | def forward(self, x): 73 | """ 74 | Apply broadcasting to x. 75 | :param x: (..., src_dim) 76 | :return: (..., dst_dim) 77 | """ 78 | assert x.size(-1) == self.src_dim, f'Input dimension mismatch! ' \ 79 | f'Should be {self.src_dim}, but got {x.size(-1)} instead!' 80 | if self.src == self.dst: 81 | return x 82 | return x[..., self.indices] 83 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # all codes are run with Python 3.9.15 and CUDA 11.6 2 | # ML 3 | torch==1.13.1 4 | torchvision==0.14.1 5 | e3nn==0.5.1 6 | torch-scatter==2.1.0 7 | torch-cluster==1.6.0 8 | torch-geometric==2.2.0 9 | 10 | # scientific 11 | numpy==1.23.4 12 | scipy==1.9.3 13 | Pillow==9.3.0 14 | matplotlib==3.6.2 15 | 16 | # utils 17 | tqdm==4.64.1 18 | lz4==4.0.2 19 | PyYAML==6.0 20 | easydict==1.10 21 | tensorboard==2.11.0 22 | 23 | # inference, uncomment for running the Jupyter notebooks 24 | jupyterlab==3.5.2 25 | plotly==5.11.0 26 | kaleido==0.2.1 27 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import yaml 4 | from easydict import EasyDict 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def load_config(config_path): 10 | """Load configuration file.""" 11 | with open(config_path, 'r') as f: 12 | config = yaml.safe_load(f) 13 | return EasyDict(config) 14 | 15 | 16 | def seed_all(seed): 17 | """Seed all random number generators.""" 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | np.random.seed(seed) 21 | random.seed(seed) 22 | 23 | 24 | def get_optimizer(cfg, model): 25 | """Get optimizer from config.""" 26 | if cfg.type == 'adam': 27 | return torch.optim.Adam( 28 | model.parameters(), 29 | lr=cfg.lr, 30 | weight_decay=cfg.weight_decay, 31 | betas=(cfg.beta1, cfg.beta2,) 32 | ) 33 | else: 34 | raise NotImplementedError(f'Optimizer not supported: {cfg.type}') 35 | 36 | 37 | def get_scheduler(cfg, optimizer): 38 | """Get scheduler from config.""" 39 | if cfg.type == 'plateau': 40 | return torch.optim.lr_scheduler.ReduceLROnPlateau( 41 | optimizer, 42 | factor=cfg.factor, 43 | patience=cfg.patience, 44 | min_lr=cfg.min_lr, 45 | ) 46 | elif cfg.type == 'step': 47 | return torch.optim.lr_scheduler.StepLR( 48 | optimizer, 49 | step_size=cfg.step_size, 50 | gamma=cfg.gamma, 51 | ) 52 | elif cfg.type == 'multistep': 53 | return torch.optim.lr_scheduler.MultiStepLR( 54 | optimizer, 55 | milestones=cfg.milestones, 56 | gamma=cfg.gamma, 57 | ) 58 | elif cfg.type == 'exp': 59 | return torch.optim.lr_scheduler.ExponentialLR( 60 | optimizer, 61 | gamma=cfg.gamma, 62 | ) 63 | else: 64 | raise NotImplementedError(f'Scheduler not supported: {cfg.type}') 65 | 66 | 67 | def count_parameters(model): 68 | """Count the number of parameters in a model.""" 69 | return sum(p.numel() for p in model.parameters()) 70 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import PIL 4 | from matplotlib import pyplot as plt 5 | from matplotlib.colors import ListedColormap 6 | from torchvision.transforms import ToTensor 7 | 8 | plt.switch_backend('agg') 9 | cmap = ListedColormap(['grey', 'white', 'red', 'blue', 'green', 'white']) 10 | 11 | 12 | def draw_stack(density, atom_type=None, atom_coord=None, dim=-1): 13 | """ 14 | Draw a 2D density map along specific axis. 15 | :param density: density data, tensor of shape (batch_size, nx, ny, nz) 16 | :param atom_type: atom types, tensor of shape (batch_size, n_atom) 17 | :param atom_coord: atom coordinates, tensor of shape (batch_size, n_atom, 3) 18 | :param dim: axis along which to sum 19 | :return: an image tensor 20 | """ 21 | plt.figure(figsize=(3, 3)) 22 | plt.imshow(density.sum(dim).detach().cpu().numpy(), cmap='viridis') 23 | plt.colorbar() 24 | if atom_type is not None: 25 | idx = [i for i in range(3) if i != dim % 3] 26 | coord = atom_coord.detach().cpu().numpy() 27 | color = cmap(atom_type.detach().cpu().numpy()) 28 | plt.scatter(coord[:, idx[1]], coord[:, idx[0]], c=color, alpha=0.8) 29 | 30 | buf = io.BytesIO() 31 | plt.savefig(buf, format='jpg') 32 | buf.seek(0) 33 | image = PIL.Image.open(buf) 34 | image = ToTensor()(image) 35 | plt.close() 36 | return image 37 | --------------------------------------------------------------------------------