├── .flake8 ├── .github └── workflows │ ├── cli.yaml │ ├── notebooks.yaml │ ├── pre-commit.yaml │ └── unittest.yaml ├── .gitignore ├── .gradient └── available_ipus.py ├── .isort.cfg ├── .pre-commit-config.yaml ├── .pytest.ini ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── gdb ├── sortgdb.py └── sortgdb9.ipynb ├── generate.sh ├── images └── visualize_DFT_numerics.gif ├── notebooks ├── DFT-dataset-generation.ipynb ├── ERI-visualisation-JK.ipynb ├── ERI-visualisation-tril-indices.ipynb ├── SPICE-nao_nr.ipynb ├── binom_factor_table.ipynb ├── gammanu.ipynb ├── gto_integrals.ipynb ├── nanoDFT-demo.ipynb ├── plot_utils.py └── sparse_grid_ao.ipynb ├── pyscf_ipu ├── dft.py ├── electron_repulsion │ ├── LICENSE │ ├── __init__.py │ ├── cpu_int2e_sph.cpp │ ├── direct.py │ ├── direct.sh │ ├── int2e_sph.cpp │ └── popcint │ │ ├── libcint.c │ │ ├── libcint.cpp │ │ ├── libcint.py │ │ ├── libcint.sh │ │ └── readme.MD ├── exchange_correlation │ ├── LICENSE │ ├── __init__.py │ ├── b3lyp.py │ ├── b88.py │ ├── lda.py │ ├── lyp.py │ └── vwn.py ├── experimental │ ├── basis.py │ ├── binom_factor_table.py │ ├── device.py │ ├── integrals.py │ ├── interop.py │ ├── mesh.py │ ├── numerics.py │ ├── orbital.py │ ├── plot.py │ ├── primitive.py │ ├── special.py │ ├── structure.py │ ├── types.py │ └── units.py ├── nanoDFT │ ├── README.md │ ├── __init__.py │ ├── compute_eri_utils.py │ ├── compute_indices.cpp │ ├── cpu_sparse_symmetric_ERI.py │ ├── h2o.gif │ ├── intor_int2e_sph.cpp │ ├── nanoDFT.py │ ├── sparse_ERI.py │ ├── sparse_symmetric_ERI.py │ ├── sparse_symmetric_intor_ERI.py │ ├── symmetric_ERI.py │ └── utils.py └── pyscf_utils │ ├── build_grid.py │ ├── build_mol.py │ └── minao.py ├── qm1b ├── README.md └── datasheet.pdf ├── requirements_core.txt ├── requirements_cpu.txt ├── requirements_ipu.txt ├── requirements_test.txt ├── schnet_9m ├── README.md ├── data.py ├── device.py ├── model.py ├── qm1b.yaml ├── qm1b_dataset.py ├── requirements.txt ├── scaling_qm1b.png └── train.py ├── setup.py ├── setup.sh └── test ├── test_integrals.py ├── test_integrals_ipu.py ├── test_interop.py └── test_special.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | extend-ignore = E203 4 | 5 | copyright-check = True 6 | copyright-author = Graphcore Ltd 7 | -------------------------------------------------------------------------------- /.github/workflows/cli.yaml: -------------------------------------------------------------------------------- 1 | name: nanoDFT CLI 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | nanoDFT-cli: 10 | runs-on: ubuntu-20.04 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | with: 15 | python-version: "3.8.10" 16 | 17 | - name: Install default requirements 18 | run: | 19 | pip install -U pip 20 | pip install -e "." 21 | 22 | - name: Log installed environment 23 | run: | 24 | python3 -m pip freeze 25 | 26 | - name: Test nanoDFT CLI on CPU 27 | run: | 28 | nanoDFT 29 | 30 | 31 | -------------------------------------------------------------------------------- /.github/workflows/notebooks.yaml: -------------------------------------------------------------------------------- 1 | name: pytest notebooks 2 | on: 3 | pull_request: 4 | push: 5 | branches: [main] 6 | 7 | jobs: 8 | pytest-container: 9 | runs-on: ubuntu-latest 10 | container: 11 | image: graphcore/pytorch:3.2.0-ubuntu-20.04 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | 16 | - name : Install package dependencies 17 | run: | 18 | apt update -y 19 | apt install git -y 20 | 21 | - name: Install requirements 22 | run: | 23 | pip install -U pip 24 | pip install -e ".[test,ipu]" 25 | 26 | - name: Log installed environment 27 | run: | 28 | python3 -m pip freeze 29 | 30 | - name: Test nanoDFT demo notebook 31 | env: 32 | JAX_IPU_USE_MODEL: 1 33 | JAX_IPU_MODEL_NUM_TILES: 46 34 | run: | 35 | pytest --nbmake --nbmake-timeout=3000 notebooks/nanoDFT-demo.ipynb 36 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yaml: -------------------------------------------------------------------------------- 1 | name: Pre-Commit Checks 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | - uses: pre-commit/action@v3.0.0 -------------------------------------------------------------------------------- /.github/workflows/unittest.yaml: -------------------------------------------------------------------------------- 1 | name: unit tests 2 | on: 3 | pull_request: 4 | push: 5 | branches: [main] 6 | 7 | jobs: 8 | pytest-container: 9 | runs-on: ubuntu-latest 10 | container: 11 | image: graphcore/pytorch:3.2.0-ubuntu-20.04 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | 16 | - name : Install package dependencies 17 | run: | 18 | apt update -y 19 | apt install git -y 20 | 21 | - name: Install requirements 22 | run: | 23 | pip install -U pip 24 | pip install -e ".[test,ipu]" 25 | 26 | - name: Log installed environment 27 | run: | 28 | python3 -m pip freeze 29 | 30 | - name: Run unit tests 31 | env: 32 | JAX_IPU_USE_MODEL: 1 33 | JAX_IPU_MODEL_NUM_TILES: 46 34 | JAX_PLATFORMS: cpu,ipu 35 | run: | 36 | pytest . 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | _tmp/ 2 | 3 | datasets/ 4 | .graph_profile/ 5 | .poptorch_cache/ 6 | .system_profile/ 7 | wandb/ 8 | .vscode/ 9 | slurm* 10 | *.popef 11 | 12 | .vscode/ 13 | data/generated/ 14 | cache/ 15 | 16 | *.csv 17 | *.pkl 18 | *.jpg 19 | *.npz 20 | *.smi 21 | _cache/ 22 | tmp/ 23 | 24 | # Byte-compiled / optimized / DLL files 25 | __pycache__/ 26 | *.py[cod] 27 | *$py.class 28 | 29 | # C extensions 30 | *.so 31 | 32 | # Distribution / packaging 33 | .Python 34 | build/ 35 | develop-eggs/ 36 | dist/ 37 | downloads/ 38 | eggs/ 39 | .eggs/ 40 | lib/ 41 | lib64/ 42 | parts/ 43 | sdist/ 44 | var/ 45 | wheels/ 46 | pip-wheel-metadata/ 47 | share/python-wheels/ 48 | *.egg-info/ 49 | .installed.cfg 50 | *.egg 51 | MANIFEST 52 | 53 | # PyInstaller 54 | # Usually these files are written by a python script from a template 55 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 56 | *.manifest 57 | *.spec 58 | 59 | # Installer logs 60 | pip-log.txt 61 | pip-delete-this-directory.txt 62 | 63 | # Unit test / coverage reports 64 | htmlcov/ 65 | .tox/ 66 | .nox/ 67 | .coverage 68 | .coverage.* 69 | .cache 70 | nosetests.xml 71 | coverage.xml 72 | *.cover 73 | *.py,cover 74 | .hypothesis/ 75 | .pytest_cache/ 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | local_settings.py 84 | db.sqlite3 85 | db.sqlite3-journal 86 | 87 | # Flask stuff: 88 | instance/ 89 | .webassets-cache 90 | 91 | # Scrapy stuff: 92 | .scrapy 93 | 94 | # Sphinx documentation 95 | docs/_build/ 96 | 97 | # PyBuilder 98 | target/ 99 | 100 | # Jupyter Notebook 101 | .ipynb_checkpoints 102 | 103 | # IPython 104 | profile_default/ 105 | ipython_config.py 106 | 107 | # pyenv 108 | .python-version 109 | 110 | # pipenv 111 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 112 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 113 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 114 | # install all needed dependencies. 115 | #Pipfile.lock 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | -------------------------------------------------------------------------------- /.gradient/available_ipus.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 2 | import json 3 | import subprocess 4 | 5 | j = subprocess.check_output(["gc-monitor", "-j"]) 6 | data = json.loads(j) 7 | num_ipuMs = len(data["cards"]) 8 | num_ipus = 4 * num_ipuMs 9 | 10 | # to be captured as a variable in the bash script that calls this python script 11 | print(num_ipus) 12 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile=black 3 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | files: (^pyscf_ipu/experimental/)|(^test/)|(\.(cfg|txt|flake8|ini)) 7 | - id: trailing-whitespace 8 | files: (^pyscf_ipu/experimental/)|(^test/) 9 | 10 | - repo: https://github.com/psf/black 11 | rev: 23.9.1 12 | hooks: 13 | - id: black-jupyter 14 | files: (^pyscf_ipu/experimental/)|(^test/) 15 | name: Format code 16 | 17 | - repo: https://github.com/pycqa/isort 18 | rev: 5.12.0 19 | hooks: 20 | - id: isort 21 | files: (^pyscf_ipu/experimental/)|(^test/) 22 | name: Sort imports 23 | 24 | - repo: https://github.com/PyCQA/flake8 25 | rev: 6.1.0 26 | hooks: 27 | - id: flake8 28 | files: (^pyscf_ipu/experimental/)|(^test/) 29 | name: Check PEP8 30 | 31 | - repo: https://github.com/PyCQA/flake8 32 | rev: 6.1.0 33 | hooks: 34 | - id: flake8 35 | args: [--select=C] 36 | additional_dependencies: [flake8-copyright] 37 | name: copyright check 38 | -------------------------------------------------------------------------------- /.pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = -s -v --durations=10 3 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to pyscf-ipu 2 | 3 | This project is still evolving but at the moment is focused around a high-performance and easily hackable implementation of Gaussian basis set DFT. 4 | We hope this is useful for the generation of large-scale datasets needed 5 | for training machine-learning models. We are interested in hearing any and 6 | all feedback so feel free to raise any questions, bugs encountered, or enhancement requests as [Issues](https://github.com/graphcore-research/pyscf-ipu/issues). 7 | 8 | ## Setting up a development environment 9 | We recommend using the conda package manager as this can automatically enable 10 | the Graphcore Poplar SDK. This is particularly useful in VS Code which can automatically 11 | activate the conda environment in a variety of scenarios: 12 | * visual debugging 13 | * running quick experiments in an interactive Jupyter window 14 | * using VS code for Jupyter notebook development. 15 | 16 | The following assumes that you have already set up an install of conda and that 17 | the conda command is available on your system path. Refer to your preferred conda 18 | installer: 19 | * [miniforge installation](https://github.com/conda-forge/miniforge#install) 20 | * [conda installation documentation](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html). 21 | 22 | 1. Create a new conda environment with the same python version as required by the Poplar SDK. 23 | For example, on ubuntu 20 use `python=3.8.10` 24 | ```bash 25 | conda create -n pyscf-ipu python=3.8.10 26 | ``` 27 | 28 | 2. Confirm that you have the Poplar SDK installed on your machine and store the location 29 | in a temporary shell variable. The following will test that the SDK is found and 30 | configured correctly: 31 | ```bash 32 | TMP_POPLAR_SDK=/path/to/sdk 33 | source $TMP_POPLAR_SDK/enable 34 | gc-monitor 35 | 36 | 3. Activate the environment and make POPLAR_SDK a persistent environment variable. 37 | ```bash 38 | conda activate pyscf-ipu 39 | conda env config vars set POPLAR_SDK=$TMP_POPLAR_SDK 40 | ``` 41 | 42 | 4. You have to reactivate the conda environment to use the `$POPLAR_SDK` 43 | variable in the environment. 44 | ```bash 45 | conda deactivate 46 | conda activate pyscf-ipu 47 | ``` 48 | 49 | 5. Setup the conda environment to automatically enable the Poplar SDK whenever 50 | the environment is activated. 51 | ```bash 52 | mkdir -p $CONDA_PREFIX/etc/conda/activate.d 53 | echo "source $POPLAR_SDK/enable" > $CONDA_PREFIX/etc/conda/activate.d/enable.sh 54 | ``` 55 | 56 | 6. Check that everything is working by once reactivating the pyscf-ipu 57 | environment _in a new shell_ and calling `gc-monitor`: 58 | ```bash 59 | conda deactivate 60 | conda activate pyscf-ipu 61 | gc-monitor 62 | 63 | 7. Install all required packages for developing JAX DFT: 64 | ```bash 65 | pip install -e ".[ipu,test]" 66 | ``` 67 | 68 | 8. Install the pre-commit hooks 69 | ```bash 70 | pre-commit install 71 | ``` 72 | 73 | 9. Create a feature branch, make changes, and when you commit them the 74 | pre-commit hooks will run. 75 | ```bash 76 | git checkout -b feature 77 | ... 78 | git push --set-upstream origin feature 79 | ``` 80 | The last command will prints a link that you can follow to open a PR. 81 | 82 | 83 | ## Testing 84 | Run all the tests using `pytest` 85 | ```bash 86 | pytest 87 | ``` 88 | We also use the nbmake package to check our notebooks work in the `IpuModel` environment. These checks can also be run on IPU hardware equiped machines e.g.: 89 | ```bash 90 | pytest --nbmake --nbmake-timeout=3000 notebooks/nanoDFT-demo.ipynb 91 | ``` 92 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | :red_circle: :warning: **Experimental and non-official Graphcore product** :warning: :red_circle: 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2311.01135-b31b1b.svg)](https://arxiv.org/abs/2311.01135) 4 | [![QM1B figshare+](https://img.shields.io/badge/figshare%2B-24459376-blue)](https://doi.org/10.25452/figshare.plus.24459376) 5 | [![notebook-tests](https://github.com/graphcore-research/pyscf-ipu/actions/workflows/notebooks.yaml/badge.svg)](https://github.com/graphcore-research/pyscf-ipu/actions/workflows/notebooks.yaml) 6 | [![nanoDFT CLI](https://github.com/graphcore-research/pyscf-ipu/actions/workflows/cli.yaml/badge.svg)](https://github.com/graphcore-research/pyscf-ipu/actions/workflows/cli.yaml) 7 | [![unit tests](https://github.com/graphcore-research/pyscf-ipu/actions/workflows/unittest.yaml/badge.svg)](https://github.com/graphcore-research/pyscf-ipu/actions/workflows/unittest.yaml) 8 | [![pre-commit checks](https://github.com/graphcore-research/pyscf-ipu/actions/workflows/pre-commit.yaml/badge.svg)](https://github.com/graphcore-research/pyscf-ipu/actions/workflows/pre-commit.yaml) 9 | 10 | [**Installation guide**](#installation) 11 | | [**Example DFT Computations**](#example-dft-computations) 12 | | [**Generating data**](#generating-new-datasets) 13 | | [**Training SchNet**](#training-schnet-on-qm1b) 14 | | [**QM1B dataset**](qm1b/README.md) 15 | 16 | 17 | # PySCF on IPU 18 | 19 | PySCF-IPU is built on top of the [PySCF](https://github.com/pyscf) package, porting some of the PySCF algorithms to the Graphcore [IPU](https://www.graphcore.ai/products/ipu). 20 | 21 | 22 | Only a small portion of PySCF is currently ported, specifically Restricted Kohn Sham DFT (based on [RKS](https://github.com/pyscf/pyscf/blob/6c815a62bc2e5eae1488a1d0dbe84556dd54b922/pyscf/dft/rks.py#L531), [KohnShamDFT](https://github.com/pyscf/pyscf/blob/6c815a62bc2e5eae1488a1d0dbe84556dd54b922/pyscf/dft/rks.py#L280) and [hf.RHF](https://github.com/pyscf/pyscf/blob/6c815a62bc2e5eae1488a1d0dbe84556dd54b922/pyscf/scf/hf.py#L2044)). 23 | 24 | The package is under active development, to broaden its scope and applicability. Current limitations are: 25 | - Number of atomic orbitals less than 70 `mol.nao_nr() <= 70`. 26 | - Larger numerical errors due to `np.float32` instead of `np.float64`. 27 | - Limited support for `jax.grad(.)` 28 | 29 | ## QuickStart 30 | 31 | ### For ML dataset generation (SynS & ML Workshop 2023) 32 | To generate datasets based on the paper __Repurposing Density Functional Theory to Suit Deep Learning__ [Link](https://icml.cc/virtual/2023/workshop/21476#wse-detail-28485) [PDF](https://syns-ml.github.io/2023/assets/papers/17.pdf) presented at the [Syns & ML Workshop, ICML 2023](https://syns-ml.github.io/2023/), the entry point is the notebook [DFT Dataset Generation](./notebooks/DFT-dataset-generation.ipynb), and the file [density_functional_theory.py](./density_functional_theory.py). 33 | 34 | 35 | ### For DFT teaching and learning: nanoDFT 36 | 37 | We also provide a lightweight implementation of the SCF algorithm, optimized for readability and hackability, in the [nanoDFT demo](notebooks/nanoDFT-demo.ipynb) notebook and in [nanodft](pyscf_ipu/nanoDFT/README.md) folder. 38 | 39 | 40 | 41 | Additional notebooks in [notebooks](notebooks) demonstrate other aspects of the computation. 42 | 43 | ## Installation 44 | 45 | PySCF on IPU requires Python 3.8, [JAX IPU experimental](https://github.com/graphcore-research/jax-experimental), [TessellateIPU library](https://github.com/graphcore-research/tessellate-ipu) and [Graphcore Poplar SDK 3.2](https://www.graphcore.ai/downloads). 46 | 47 | We recommend upgrading `pip` to the latest stable release to prepare your environment. 48 | ```bash 49 | pip install -U pip 50 | ``` 51 | 52 | This project is currently under active development. 53 | For CPU simulations, we recommend installing `pyscf-ipu` from latest `main` branch as: 54 | ```bash 55 | pip install pyscf-ipu[cpu]@git+https://github.com/graphcore-research/pyscf-ipu 56 | ``` 57 | 58 | and on IPU equipped machines: 59 | ```bash 60 | pip install pyscf-ipu[ipu]@git+https://github.com/graphcore-research/pyscf-ipu 61 | ``` 62 | 63 | ## Example DFT Computations 64 | The following commands may be useful to check the installation. Each command runs a test-case which compares PySCF against our DFT computation using different options. 65 | ``` 66 | python density_functional_theory.py -methane -backend cpu # defaults to float64 as used in PySCF 67 | python density_functional_theory.py -methane -backend cpu -float32 68 | python density_functional_theory.py -methane -backend ipu -float32 69 | ``` 70 | This will automatically compare our DFT against PySCF for methane `CH4` and report numerical errors. 71 | 72 | 73 | ## Generating New Datasets 74 | 75 | This section contains an example on how to generate a DFT dataset based on GDB. This is not needed if you just want to train on the QM1B dataset (to be released soon). 76 | 77 | Download the `gdb11.tgz` file from https://zenodo.org/record/5172018 and extract its content in `gdb/` directory: 78 | ```bash 79 | wget -p -O ./gdb/gdb11.tgz https://zenodo.org/record/5172018/files/gdb11.tgz\?download\=1 80 | tar -xvf ./gdb/gdb11.tgz --directory ./gdb/ 81 | ``` 82 | To utilize caching you need to sort the SMILES strings by the number of hydrogens RDKit adds to them. This means molecule `i` and `i+1` in most cases have the same number of hydrogens which allows our code to reuse/cache the computational graph for DFT. This can be done by running the following Python script: 83 | ``` 84 | python ./gdb/sortgdb.py ./gdb/gdb11_size09.smi 85 | ``` 86 | You can then start generating (locally on CPU) a dataset using the following command: 87 | ```bash 88 | python density_functional_theory.py -generate -save -fname dataset_name -level 0 -plevel 0 -gdb 9 -backend cpu -float32 89 | ``` 90 | 91 | You can speed up the generation by using IPUs. Please try the [DFT dataset generation notebook](https://ipu.dev/YX0jlK) 92 | 93 | 94 | ## Training SchNet on [QM1B](qm1b/README.md) 95 | 96 | We used PySCF on IPU to generate the [QM1B dataset](qm1b/README.md) with one billion training examples (to be released soon). 97 | See [Training SchNet on QM1B](./schnet_9m/README.md) for an example implementation of a neural network trained on this dataset. 98 | 99 | ## License 100 | 101 | Copyright (c) 2023 Graphcore Ltd. The project is licensed under the [**Apache License 2.0**](LICENSE), with the exception of the folders `electron_repulsion/` and `exchange_correlation/`. 102 | 103 | The library is built on top of the following main dependencies: 104 | 105 | | Component | Description | License | 106 | | --- | --- | --- | 107 | | [pyscf](https://github.com/pyscf/pyscf) | Python-based Simulations of Chemistry Framework | [Apache License 2.0](https://github.com/pyscf/pyscf/blob/master/LICENSE) | 108 | | [libcint](https://github.com/sunqm/libcint/) | Open source library for analytical Gaussian integrals | [BSD 2-Clause “Simplified” License](https://github.com/sunqm/libcint/blob/master/LICENSE) | 109 | | [xcauto](https://github.com/dftlibs/xcauto) | Arbitrary order exchange-correlation functional derivatives | [MPL-2.0 license](https://github.com/dftlibs/xcauto/blob/master/LICENSE) | 110 | 111 | 112 | ## Cite 113 | Please use the following citation for the pyscf-ipu project: 114 | 115 | ``` 116 | @inproceedings{mathiasen2023qm1b, 117 | title={Generating QM1B with PySCF $ \_ $\{$$\backslash$text $\{$IPU$\}$$\}$ $}, 118 | author={Mathiasen, Alexander and Helal, Hatem and Klaeser, Kerstin and Balanca, Paul and Dean, Josef and Luschi, Carlo and Beaini, Dominique and Fitzgibbon, Andrew William and Masters, Dominic}, 119 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems Datasets and Benchmarks Track}, 120 | year={2023} 121 | } 122 | ``` 123 | -------------------------------------------------------------------------------- /gdb/sortgdb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import argparse 3 | from tqdm import tqdm 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | from rdkit import Chem 9 | from rdkit.Chem import AllChem 10 | from rdkit import RDLogger 11 | 12 | 13 | lg = RDLogger.logger() 14 | lg.setLevel(RDLogger.CRITICAL) 15 | 16 | 17 | def sort_gdb(gdb_filename: str, keep_only_atoms_count: int = 9): 18 | """Sort GDB SMILES strings by number of hydrogens, after keeping 19 | only the molecules with a given count of heavy atoms. 20 | 21 | Returns: 22 | Pandas dataframe of SMILES strings. 23 | """ 24 | smiles = [a.split("\t")[0] for a in open(gdb_filename, "r").read().split("\n")] 25 | smiles_filtered = [] 26 | num_hs = [] 27 | for smile in tqdm(smiles): 28 | atoms = [a for a in list(smile.upper()) if a == "C" or a == "N" or a == "O" or a == "F"] 29 | if len(atoms) != keep_only_atoms_count: continue 30 | smiles_filtered.append(smile) 31 | b = Chem.MolFromSmiles(smile) 32 | b = Chem.AddHs(b) 33 | atoms = [atom.GetSymbol() for atom in b.GetAtoms()] 34 | num_hs.append( len([a for a in atoms if a.upper() == "H"])) 35 | 36 | # Sort by number of hydrogens. 37 | num_hs = np.array(num_hs) 38 | sorted_smiles = np.array(smiles_filtered)[np.argsort(num_hs)].tolist() 39 | df = pd.DataFrame(sorted_smiles[1:]) 40 | return df 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser( 45 | prog="Filter and sort GDB by number of atoms", epilog="Provide GDB .smi filename." 46 | ) 47 | parser.add_argument("filename") 48 | args = parser.parse_args() 49 | 50 | gdb_filename = args.filename 51 | assert gdb_filename.endswith(".smi") 52 | gdb_sorted = sort_gdb(gdb_filename) 53 | # Save output as csv. 54 | out_filename = gdb_filename.replace(".smi", "_sorted.csv") 55 | gdb_sorted.to_csv(out_filename, index=False, header=False) 56 | 57 | -------------------------------------------------------------------------------- /gdb/sortgdb9.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "100%|██████████| 444314/444314 [01:14<00:00, 5991.15it/s]\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "# Copyright (c) 2023 Graphcore Ltd. All rights reserved.\n", 18 | "from tqdm import tqdm \n", 19 | "from rdkit import Chem \n", 20 | "from rdkit.Chem import AllChem\n", 21 | "from rdkit import RDLogger\n", 22 | "lg = RDLogger.logger()\n", 23 | "lg.setLevel(RDLogger.CRITICAL)\n", 24 | "\n", 25 | "smiles = [a.split(\"\\t\")[0] for a in open(\"gdb11_size09.smi\", \"r\").read().split(\"\\n\")]\n", 26 | "\n", 27 | "smiles_9 = []\n", 28 | "\n", 29 | "num_hs = []\n", 30 | "for smile in tqdm(smiles):\n", 31 | " atoms = [a for a in list(smile.upper()) if a == \"C\" or a == \"N\" or a == \"O\" or a == \"F\"]\n", 32 | " if len(atoms) != 9: continue \n", 33 | " smiles_9.append(smile)\n", 34 | " b = Chem.MolFromSmiles(smile)\n", 35 | " b = Chem.AddHs(b) \n", 36 | " atoms = [atom.GetSymbol() for atom in b.GetAtoms()]\n", 37 | " num_hs.append( len([a for a in atoms if a.upper() == \"H\"]))" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 4, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "import numpy as np \n", 47 | "num_hs = np.array(num_hs)\n", 48 | "sorted_smiles = np.array(smiles_9)[np.argsort(num_hs)].tolist()" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 5, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "import pandas as pd\n", 58 | "df = pd.DataFrame(sorted_smiles[1:])\n", 59 | "df.to_csv('gdb11_size09_sorted.csv', index=False, header=False)" 60 | ] 61 | } 62 | ], 63 | "metadata": { 64 | "kernelspec": { 65 | "display_name": "dft", 66 | "language": "python", 67 | "name": "python3" 68 | }, 69 | "language_info": { 70 | "codemirror_mode": { 71 | "name": "ipython", 72 | "version": 3 73 | }, 74 | "file_extension": ".py", 75 | "mimetype": "text/x-python", 76 | "name": "python", 77 | "nbconvert_exporter": "python", 78 | "pygments_lexer": "ipython3", 79 | "version": "3.8.10" 80 | }, 81 | "orig_nbformat": 4 82 | }, 83 | "nbformat": 4, 84 | "nbformat_minor": 2 85 | } 86 | -------------------------------------------------------------------------------- /generate.sh: -------------------------------------------------------------------------------- 1 | # Install (takes 2-3 min) 2 | cd /notebooks/ 3 | pip install -q jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-research.github.io/jax-experimental/wheels.html 4 | pip install -q git+https://github.com/graphcore-research/tessellate-ipu.git@main 5 | pip install -r requirements.txt 6 | pip install -r requirements_ipu.txt 7 | apt update 8 | apt -y install tmux 9 | 10 | # Start data generation (takes 2-3 min to compile) 11 | tmux new-session -d -s 0 "XLA_IPU_PLATFORM_DEVICE_COUNT=1 python density_functional_theory.py -gdb 6 -randomSeed 0 -num_conformers 1000 -fname test -split 0 3 -threads 3 -threads_int 3 -generate -save -backend ipu -float32 -level 0 -plevel 0 -its 20 -skip_minao -choleskycpu -multv 2 -basis sto3g" 12 | tmux new-session -d -s 1 "XLA_IPU_PLATFORM_DEVICE_COUNT=1 python density_functional_theory.py -gdb 6 -randomSeed 1 -num_conformers 1000 -fname test -split 0 3 -threads 3 -threads_int 3 -generate -save -backend ipu -float32 -level 0 -plevel 0 -its 20 -skip_minao -choleskycpu -multv 2 -basis sto3g" 13 | tmux new-session -d -s 2 "XLA_IPU_PLATFORM_DEVICE_COUNT=1 python density_functional_theory.py -gdb 6 -randomSeed 2 -num_conformers 1000 -fname test -split 0 3 -threads 3 -threads_int 3 -generate -save -backend ipu -float32 -level 0 -plevel 0 -its 20 -skip_minao -choleskycpu -multv 2 -basis sto3g" 14 | tmux new-session -d -s 3 "XLA_IPU_PLATFORM_DEVICE_COUNT=1 python density_functional_theory.py -gdb 6 -randomSeed 3 -num_conformers 1000 -fname test -split 0 3 -threads 3 -threads_int 3 -generate -save -backend ipu -float32 -level 0 -plevel 0 -its 20 -skip_minao -choleskycpu -multv 2 -basis sto3g" 15 | 16 | tmux list-sessions 17 | 18 | echo "Files are stored in data/generated/test/..." 19 | echo "You can inspect the individual generation through 'tmux attach-session -t 0'. " 20 | -------------------------------------------------------------------------------- /images/visualize_DFT_numerics.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/pyscf-ipu/43a9bb343acbd296aec1d1f64c309bb24920d98d/images/visualize_DFT_numerics.gif -------------------------------------------------------------------------------- /notebooks/plot_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import numpy as np 3 | import matplotlib as mpl 4 | import matplotlib.pyplot as plt 5 | 6 | def plot4D(x, N, name=''): 7 | norm = mpl.colors.Normalize(vmin=np.min(x), vmax=np.max(x)) 8 | 9 | fig, axs = plt.subplots(N, N) 10 | for i in range(N): 11 | for j in range(N): 12 | axs[i, j].imshow(x[i,j], norm=norm) 13 | axs[i, j].set_ylabel(f'i={i}') 14 | axs[i, j].set_xlabel(f'j={j}') 15 | 16 | for ax in axs.flat: 17 | ax.label_outer() 18 | 19 | fig.suptitle(name+' 4D (N, N, N, N)') 20 | plt.show() 21 | 22 | def plot2D(x, N, name='', show_values=False): 23 | norm = mpl.colors.Normalize(vmin=np.min(x), vmax=np.max(x)) 24 | fig, ax = plt.subplots() 25 | plt.imshow(x, norm=norm) 26 | if show_values: 27 | ixs, iys = np.meshgrid(np.arange(0, N, 1), np.arange(0, N, 1)) 28 | for iy, ix in zip(iys.flatten(), ixs.flatten()): 29 | ax.text(iy, ix, x[ix, iy], va='center', ha='center') # indices swapped to match image 30 | 31 | fig.suptitle(name) 32 | plt.show() 33 | -------------------------------------------------------------------------------- /pyscf_ipu/electron_repulsion/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012, Qiming Sun 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /pyscf_ipu/electron_repulsion/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -------------------------------------------------------------------------------- /pyscf_ipu/electron_repulsion/direct.sh: -------------------------------------------------------------------------------- 1 | clear 2 | rm gen.so 3 | 4 | cc cpu_int2e_sph.cpp -shared -fpic -o gen.so -lpoplar -lpoputil -fpermissive 5 | echo "Done compiling" 6 | #echo "Calling from python"; 7 | 8 | export XLA_IPU_PLATFORM_DEVICE_COUNT=1 9 | export POPLAR_ENGINE_OPTIONS="{ 10 | \"autoReport.outputExecutionProfile\": \"true\", 11 | \"autoReport.directory\": \"profs/\" 12 | }" 13 | export TF_POPLAR_FLAGS=--show_progress_bar=true 14 | python direct.py $@ 15 | -------------------------------------------------------------------------------- /pyscf_ipu/electron_repulsion/popcint/libcint.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "poplar/TileConstants.hpp" 8 | #include 9 | 10 | using namespace poplar; 11 | 12 | #ifdef __IPU__ 13 | // Use the IPU intrinsics 14 | #include 15 | #include 16 | #define NAMESPACE ipu 17 | #else 18 | // Use the std functions 19 | #include 20 | #define NAMESPACE std 21 | #endif 22 | 23 | #include "libcint.c" 24 | 25 | 26 | class Grad : public Vertex { 27 | public: 28 | // TODO: Change InOut to Input. 29 | // Using InOut so it's float* instead of const float* (which would require changing 30k lines in libcint.c) 30 | InOut> mat; 31 | InOut> shls_slice; 32 | InOut> ao_loc; 33 | InOut> atm; 34 | InOut> bas; 35 | InOut> env; 36 | Input> natm; 37 | Input> nbas; 38 | Input> which_integral; 39 | Output> out; 40 | 41 | bool compute() { 42 | float * _env = env.data(); 43 | int *_bas = bas.data(); 44 | int *_atm = atm.data(); 45 | int *_shls_slice = shls_slice.data(); 46 | int *_ao_loc = ao_loc.data(); 47 | float * _mat = mat.data(); 48 | 49 | if (which_integral.data()[0] == INT1E_KIN){ 50 | GTOint2c( 51 | (int (*)(dtype *out, FINT *dims, FINT *shls, FINT *atm, FINT natm, FINT *bas, FINT nbas, dtype *env, CINTOpt *opt, dtype *cache)) 52 | int1e_kin_sph, out.data(), 1, 0, _shls_slice, _ao_loc, NULL, _atm, natm.data()[0], _bas, nbas.data()[0], _env); 53 | } 54 | else if (which_integral.data()[0] == INT1E_NUC){ 55 | GTOint2c( 56 | (int (*)(dtype *out, FINT *dims, FINT *shls, FINT *atm, FINT natm, FINT *bas, FINT nbas, dtype *env, CINTOpt *opt, dtype *cache)) 57 | int1e_nuc_sph, out.data(), 1, 0, _shls_slice, _ao_loc, NULL, _atm, natm.data()[0], _bas, nbas.data()[0], _env); 58 | } 59 | else if (which_integral.data()[0] == INT1E_OVLP){ 60 | GTOint2c( 61 | (int (*)(dtype *out, FINT *dims, FINT *shls, FINT *atm, FINT natm, FINT *bas, FINT nbas, dtype *env, CINTOpt *opt, dtype *cache)) 62 | int1e_ovlp_sph, out.data(), 1, 0, _shls_slice, _ao_loc, NULL, _atm, natm.data()[0], _bas, nbas.data()[0], _env); 63 | } 64 | if (which_integral.data()[0] == INT1E_OVLP_IP){ 65 | GTOint2c( 66 | (int (*)(dtype *out, FINT *dims, FINT *shls, FINT *atm, FINT natm, FINT *bas, FINT nbas, dtype *env, CINTOpt *opt, dtype *cache)) 67 | int1e_ipovlp_sph, out.data(), 3, 0, _shls_slice, _ao_loc, NULL, _atm, natm.data()[0], _bas, nbas.data()[0], _env); 68 | } 69 | else if (which_integral.data()[0] == INT1E_KIN_IP){ 70 | GTOint2c( 71 | (int (*)(dtype *out, FINT *dims, FINT *shls, FINT *atm, FINT natm, FINT *bas, FINT nbas, dtype *env, CINTOpt *opt, dtype *cache)) 72 | int1e_ipkin_sph, out.data(), 3, 0, _shls_slice, _ao_loc, NULL, _atm, natm.data()[0], _bas, nbas.data()[0], _env); 73 | } 74 | else if (which_integral.data()[0] == INT1E_NUC_IP){ 75 | GTOint2c( 76 | (int (*)(dtype *out, FINT *dims, FINT *shls, FINT *atm, FINT natm, FINT *bas, FINT nbas, dtype *env, CINTOpt *opt, dtype *cache)) 77 | int1e_ipnuc_sph, out.data(), 3, 0, _shls_slice, _ao_loc, NULL, _atm, natm.data()[0], _bas, nbas.data()[0], _env); 78 | } 79 | 80 | 81 | 82 | return true; 83 | } 84 | }; 85 | 86 | 87 | 88 | 89 | 90 | class Int2e : public Vertex { 91 | public: 92 | //"mat", "shls_slice", "ao_loc", "atm", "bas", "env" 93 | InOut> mat; 94 | InOut> shls_slice; 95 | InOut> ao_loc; 96 | InOut> atm; 97 | InOut> bas; 98 | InOut> env; 99 | Input> natm; 100 | Input> nbas; 101 | Input> which_integral; 102 | Input> comp; 103 | 104 | Output> out; 105 | 106 | bool compute() { 107 | float * _env = env.data(); 108 | int *_bas = bas.data(); 109 | int *_atm = atm.data(); 110 | int *_shls_slice = shls_slice.data(); 111 | int *_ao_loc = ao_loc.data(); 112 | float * _mat = mat.data(); 113 | 114 | GTOnr2e_fill_drv( 115 | (int (*)(...))int2e_sph, 116 | (void (*)(...))GTOnr2e_fill_s1, 117 | NULL, 118 | out.data(), comp.data()[0], _shls_slice, _ao_loc, NULL, 119 | _atm, natm.data()[0], 120 | _bas, nbas.data()[0], 121 | _env, which_integral.data()[0] 122 | ); 123 | 124 | 125 | return true; 126 | } 127 | }; 128 | -------------------------------------------------------------------------------- /pyscf_ipu/electron_repulsion/popcint/libcint.sh: -------------------------------------------------------------------------------- 1 | clear 2 | rm libcint.so 3 | 4 | g++ libcint.c -shared -fpic -o libcint.so -lpoplar -lpoputil -fpermissive 5 | echo "Done compiling. Calling C code from python. " 6 | 7 | XLA_IPU_PLATFORM_DEVICE_COUNT=1 TF_POPLAR_FLAGS=--show_progress_bar=true python libcint.py $@ -------------------------------------------------------------------------------- /pyscf_ipu/electron_repulsion/popcint/readme.MD: -------------------------------------------------------------------------------- 1 | # popcint 2 | Libcint (manually) compiled to IPU implementing 3 | 4 | ``` 5 | import pyscf 6 | mol = pyscf.gto.Mole([["H", (0,0,0)], ["H", (0,0,1)]], basis="sto3g") 7 | mol.build() 8 | mol.intor("int1e_nuc") # nuclear integral 9 | mol.intor("int1e_kin") # kinetic integral 10 | mol.intor("int1e_ovlp") # overlap integral 11 | 12 | mol.intor("int1e_ipnuc") # gradient of nuclear integral 13 | mol.intor("int1e_ipkin") # gradient of kinetic integral 14 | mol.intor("int1e_ipovlp") # gradient of overlap integral 15 | 16 | mol.intor("int2e_sph") # electron repulsion integral 17 | mol.intor("int2e_ip1_sph") # gradient (ip1) of electron repulsion integral 18 | ``` 19 | 20 | You can test all integrals with `./cpp_libcint.sh -all`. The C++ plumbing to run all integrals is in place and all (but kinetic) pass a simple H2 test-case in STO3G (this compiles and runs libcint.c both with CPU/G++ and IPU/tesselate). 21 | 22 | ``` 23 | > ./cpp_libcint.sh -all 24 | 25 | Compiling with C++ 26 | Done compiling. Calling C code from python. 27 | [N=2] 28 | 29 | [Nuclear Integral] 30 | CPU: 2.763163926555734e-07 31 | Compiling module jit_ipu_intor1e.0: 32 | [##################################################] 100% Compilation Finished [Elapsed: 00:00:15.6] 33 | IPU: 2.763163926555734e-07 34 | 35 | [Kinetic Integral] 36 | CPU: -1.8022852765753328e-08 37 | Compiling module jit_ipu_intor1e.1: 38 | [##################################################] 100% Compilation Finished [Elapsed: 00:00:15.9] 39 | IPU: -1.05722721688295e-08 40 | 41 | [Overlap Integral] 42 | CPU: -1.2445099606406274e-07 43 | Compiling module jit_ipu_intor1e.2: 44 | [##################################################] 100% Compilation Finished [Elapsed: 00:00:16.0] 45 | IPU: -6.484635128867211e-08 46 | 47 | [Grad Nuclear] 48 | CPU: 7.246001532124069e-08 49 | Compiling module jit_ipu_intor1e.3: 50 | [##################################################] 100% Compilation Finished [Elapsed: 00:00:15.8] 51 | IPU: 7.246001532124069e-08 52 | 53 | [Grad Kinetic] 54 | CPU: 0.22741775584087665 55 | [ 0. -0. -0. 0. 0. -0. 56 | -0. 0. 0. 0.19630939 -0.19630939 0. ] 57 | [-0.0000000e+00 5.6303712e-04 -5.6303712e-04 -0.0000000e+00 58 | -1.4645594e-01 2.4947241e-02 -2.2242269e-02 1.6426709e-01 59 | -0.0000000e+00 -3.1108368e-02 -1.6386819e-01 9.0011621e-01] 60 | Compiling module jit_ipu_intor1e.4: 61 | [##################################################] 100% Compilation Finished [Elapsed: 00:00:15.8] 62 | T[1.2]: inside CINTset_pairdata 63 | T[1.2]: inside CINTset_pairdata 64 | IPU: 0.1963094174861908 65 | [ 0. -0. -0. 0. 0. -0. 66 | -0. 0. 0. 0.19630939 -0.19630939 0. ] 67 | [-0. -0. -0. -0. -0. -0. 68 | -0. -0. -0.19630942 -0. -0. -0. ] 69 | 70 | [Grad Overlap] 71 | CPU: 6.077975783780332e-08 72 | Compiling module jit_ipu_intor1e.5: 73 | [##################################################] 100% Compilation Finished [Elapsed: 00:00:15.5] 74 | IPU: 6.077975783780332e-08 75 | 76 | [Electron Repulsion Integral] 77 | CPU: -4.443460513425812e-08 78 | Compiling module jit_ipu_getints4c.6: 79 | [##################################################] 100% Compilation Finished [Elapsed: 00:00:15.8] 80 | IPU: -2.953344391265489e-08 81 | 82 | [Grad of Electron Repulsion Integral] 83 | CPU: 1.341920186359591e-07 84 | Compiling module jit_ipu_getints4c.7: 85 | [##################################################] 100% Compilation Finished [Elapsed: 00:00:15.6] 86 | IPU: 1.1929085744211143e-07 87 | ``` 88 | -------------------------------------------------------------------------------- /pyscf_ipu/exchange_correlation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -------------------------------------------------------------------------------- /pyscf_ipu/exchange_correlation/b3lyp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import jax.numpy as jnp 3 | import jax 4 | 5 | from pyscf_ipu.exchange_correlation.lda import __lda 6 | from pyscf_ipu.exchange_correlation.lyp import __lyp 7 | from pyscf_ipu.exchange_correlation.b88 import __b88 8 | from pyscf_ipu.exchange_correlation.vwn import __vwn 9 | 10 | CLIP_RHO_MIN = 1e-9 11 | CLIP_RHO_MAX = 1e12 12 | 13 | def b3lyp(rho, EPSILON_B3LYP=0): 14 | 15 | rho = jnp.concatenate([jnp.clip(rho[:1], CLIP_RHO_MIN, CLIP_RHO_MAX), rho[1:4]*2]) 16 | 17 | rho0 = rho.T[:, 0] 18 | norms = jnp.linalg.norm(rho[1:], axis=0).T**2+EPSILON_B3LYP 19 | 20 | def lda(rho0): return jax.vmap(jax.value_and_grad(lambda x: __lda(x)*0.08)) (rho0) 21 | def vwn(rho0): return jax.vmap(jax.value_and_grad(lambda x: __vwn(x)*0.19)) (rho0) 22 | 23 | # disabled gradient checkpointing 24 | #def b88(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: jax.checkpoint(__b88)(rho0, norm)*0.72, (0, 1))) (rho0, norms) 25 | #def lyp(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: jax.checkpoint(__lyp)(rho0, norm)*0.810, (0, 1))) (rho0, norms) 26 | 27 | def b88(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: __b88(rho0, norm)*0.72, (0,1)))(rho0, norms) 28 | def lyp(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: __lyp(rho0, norm)*0.810, (0,1)))(rho0, norms) 29 | 30 | e_xc_lda, v_rho_lda = jax.jit(lda)(rho0) 31 | e_xc_vwn, v_rho_vwn = jax.jit(vwn)(rho0) 32 | e_xc_b88, (v_rho_b88, v_norm_b88) = jax.jit(b88)(rho0, norms) 33 | e_xc_lyp, (v_rho_lyp, v_norm_lyp) = jax.jit(lyp)(rho0, norms) 34 | 35 | e_xc = e_xc_lda + (e_xc_vwn + e_xc_b88 + e_xc_lyp) / rho0 36 | v_xc_rho = v_rho_lda*4*rho0 + v_rho_vwn + v_rho_b88 + v_rho_lyp 37 | v_xc_norms = v_norm_b88 + v_norm_lyp 38 | 39 | return e_xc, v_xc_rho, v_xc_norms 40 | 41 | 42 | 43 | @jax.jit 44 | def do_lda(rho, EPSILON_B3LYP=0): 45 | rho0 = rho.T[:, 0] 46 | norms = jnp.linalg.norm(rho[1:], axis=0).T**2+EPSILON_B3LYP 47 | 48 | # simple wrapper to get names in popvision; lambda doesn't give different names.. 49 | def lda(rho0): return jax.vmap(jax.value_and_grad(lambda x: __lda(x)*0.08)) (rho0) 50 | 51 | e_xc_lda, v_rho_lda = jax.jit(lda)(rho0) 52 | 53 | e_xc = e_xc_lda 54 | v_xc_rho = v_rho_lda*4*rho[0] 55 | v_xc_norms = jnp.zeros(rho[0].shape)# v_norm_b88 + v_norm_lyp 56 | 57 | return e_xc, v_xc_rho, v_xc_norms 58 | 59 | def plot(rho, b, a, g, grad, vnorm=None, name=""): # b is pyscf a is us 60 | 61 | import matplotlib.pyplot as plt 62 | import numpy as np 63 | 64 | fig, ax = plt.subplots(1, 3, figsize=(14, 4)) 65 | ax[0].plot(rho[0], -b, 'o', label="pyscf.eval_b3lyp", ms=7) 66 | ax[0].plot(rho[0], -a, 'x', label="jax_b3lyp", ms=2) 67 | 68 | print(np.max(np.abs(b)-np.abs(a))) 69 | 70 | ax[1].plot(rho[0], np.abs(a-b), 'x', label="absolute error") 71 | ax[1].set_yscale("log") 72 | ax[1].set_xscale("log") 73 | ax[1].set_xlabel("input to b3lyp") 74 | ax[1].set_ylabel("absolute error") 75 | ax[1].legend() 76 | 77 | ax[2].plot(rho[0], np.abs(a-b)/np.abs(b), 'x', label="relative error") 78 | ax[2].set_yscale("log") 79 | ax[2].set_xscale("log") 80 | ax[2].set_xlabel("input to b3lyp") 81 | ax[2].set_ylabel("relative absolute error") 82 | ax[2].legend() 83 | 84 | ax[0].set_yscale("log") 85 | ax[0].set_xscale("log") 86 | ax[0].set_ylabel("input to b3lyp") 87 | ax[0].set_xlabel("output of b3lyp") 88 | ax[0].legend() 89 | 90 | plt.tight_layout() 91 | ax[1].set_title("E_xc [%s]" % name) 92 | plt.savefig("%s_1.jpg"%name) 93 | 94 | fig, ax = plt.subplots(1,3, figsize=(14, 4)) 95 | 96 | ax[0].plot(rho[0], -g[0], 'o', label="pyscf grad", ms=7) 97 | 98 | if grad.ndim == 2: grad = grad[:, 0] 99 | 100 | ax[0].plot(rho[0], -grad, 'x', label="jax grad", ms=2) 101 | 102 | print(np.max(np.abs(g[0])-np.abs(grad))) 103 | 104 | ax[0].legend() 105 | 106 | ax[1].plot(rho[0], np.abs(g[0]-grad), 'x', label="absolute error") 107 | ax[1].legend() 108 | ax[1].set_xlabel("input") 109 | ax[1].set_ylabel("absolute gradient error") 110 | ax[1].set_title("d E_xc / d electron_density [%s]" % name) 111 | 112 | ax[2].plot(rho[0], np.abs(g[0]-grad)/np.abs(g[0]), 'x', label="relative error") 113 | ax[2].legend() 114 | ax[2].set_xlabel("input") 115 | ax[2].set_ylabel("relative gradient error") 116 | 117 | 118 | ax[0].set_yscale("log") 119 | ax[0].set_xscale("log") 120 | 121 | ax[1].set_yscale("log") 122 | ax[1].set_xscale("log") 123 | 124 | ax[2].set_yscale("log") 125 | ax[2].set_xscale("log") 126 | plt.tight_layout() 127 | plt.savefig("%s_2.jpg"%name) 128 | 129 | if vnorm is not None: 130 | fig, ax = plt.subplots(1,3, figsize=(14, 4)) 131 | 132 | ax[0].plot(rho[0], np.abs(g[1]), 'o', label="pyscf grad", ms=7) 133 | ax[0].plot(rho[0], np.abs(vnorm), 'x', label="jax grad", ms=2) 134 | ax[0].legend() 135 | 136 | print(np.max(np.abs(g[1])-np.abs(vnorm))) 137 | 138 | ax[1].plot(rho[0], np.abs(g[1]-vnorm), 'x', label="absolute error") 139 | ax[1].legend() 140 | ax[1].set_xlabel("input") 141 | ax[1].set_ylabel("absolute gradient error") 142 | 143 | ax[1].set_title("d E_xc / d norms [%s]" % name) 144 | 145 | ax[2].plot(rho[0], np.abs(g[1]-vnorm)/np.abs(g[1]), 'x', label="relative error") 146 | ax[2].legend() 147 | ax[2].set_xlabel("input") 148 | ax[2].set_ylabel("relative gradient error") 149 | 150 | ax[0].set_yscale("log") 151 | ax[0].set_xscale("log") 152 | 153 | ax[1].set_yscale("log") 154 | ax[1].set_xscale("log") 155 | 156 | ax[2].set_yscale("log") 157 | ax[2].set_xscale("log") 158 | plt.tight_layout() 159 | plt.savefig("%s_3.jpg"%name) 160 | -------------------------------------------------------------------------------- /pyscf_ipu/exchange_correlation/b88.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | # The functional definition in this file was ported to Python 3 | # from XCFun, which is Copyright Ulf Ekström and contributors 2009-2020 4 | # and provided under the Mozilla Public License (v2.0) 5 | # see also: 6 | # - https://github.com/dftlibs/xcfun 7 | # - https://github.com/dftlibs/xcfun/blob/master/LICENSE.md 8 | 9 | import jax.numpy as jnp 10 | import jax 11 | import numpy as np 12 | 13 | def __b88(a, gaa): 14 | # precompute 15 | c1 = (4.0 / 3.0) 16 | c2 = (-8.0 / 3.0) 17 | c3 = (-3.0 / 4.0) * (6.0 / np.pi) ** (1.0 / 3.0) * 2 18 | d = 0.0042 19 | d2 = d * 2. 20 | d12 = d *12. 21 | 22 | # actual compute 23 | log_a = jnp.log(a/2) 24 | na43 = jnp.exp(log_a * c1) 25 | chi2 = gaa / 4* jnp.exp(log_a * c2 ) 26 | chi = jnp.exp(jnp.log( chi2 ) / 2 ) 27 | b88 = -(d * na43 * chi2) / (1.0 + 6*d * chi * jnp.arcsinh(chi)) *2 28 | slaterx_a = c3 * na43 29 | return slaterx_a + b88 30 | -------------------------------------------------------------------------------- /pyscf_ipu/exchange_correlation/lda.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import jax.numpy as jnp 3 | import jax 4 | 5 | def __lda(rho): return -jnp.exp(1/3*jnp.log(rho) - 0.30305460484554375) 6 | -------------------------------------------------------------------------------- /pyscf_ipu/exchange_correlation/lyp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | # The functional definition in this file was ported to Python 3 | # from XCFun, which is Copyright Ulf Ekström and contributors 2009-2020 4 | # and provided under the Mozilla Public License (v2.0) 5 | # see also: 6 | # - https://github.com/dftlibs/xcfun 7 | # - https://github.com/dftlibs/xcfun/blob/master/LICENSE.md 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import numpy 12 | 13 | def __lyp(n, gnn): 14 | 15 | # precompute 16 | A = 0.04918 17 | B = 0.132 18 | C = 0.2533 19 | Dd = 0.349 20 | CF = 0.3 * (3.0 * numpy.pi * numpy.pi) ** (2.0 / 3.0) 21 | c0 = 2.0 ** (11.0 / 3.0) * (1/2)**(8/3) 22 | c1 = (1/3 + 1/8)*4 23 | 24 | # actual compute 25 | log_n = jnp.log(n) 26 | icbrtn = jnp.exp(log_n * (-1.0 / 3.0) ) 27 | 28 | P = 1.0 / (1.0 + Dd * icbrtn) 29 | omega = jnp.exp(-C * icbrtn) * P 30 | delta = icbrtn * (C + Dd * P) 31 | 32 | n_five_three = jnp.exp(log_n*(-5/3)) 33 | 34 | result = -A * ( 35 | n * P 36 | + B 37 | * omega 38 | * 1/ 4 *( 39 | 2 * CF * n * c0+ 40 | gnn * (60 - 14.0 * delta) /36 * n_five_three 41 | - gnn *c1 * n_five_three 42 | ) 43 | ) 44 | 45 | return result 46 | -------------------------------------------------------------------------------- /pyscf_ipu/exchange_correlation/vwn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | # The functional definition in this file was ported to Python 3 | # from XCFun, which is Copyright Ulf Ekström and contributors 2009-2020 4 | # and provided under the Mozilla Public License (v2.0) 5 | # see also: 6 | # - https://github.com/dftlibs/xcfun 7 | # - https://github.com/dftlibs/xcfun/blob/master/LICENSE.md 8 | 9 | import jax.numpy as jnp 10 | import jax 11 | import numpy as np 12 | 13 | def __vwn(n): 14 | # Precompute stuff in np.float64 15 | p = np.array( [-0.10498, 0.0621813817393097900698817274255, 3.72744, 12.9352]) 16 | f = p[0] * p[2] / (p[0] * p[0] + p[0] * p[2] + p[3]) - 1.0 17 | f_inv_p1 = 1/f+1 18 | f_2 = f * 0.5 19 | sqrt = np.sqrt(4.0 * p[3] - p[2] * p[2]) 20 | precompute = p[2] * ( 1.0 / sqrt 21 | - p[0] 22 | / ( 23 | (p[0] * p[0] + p[0] * p[2] + p[3]) 24 | * sqrt 25 | / (p[2] + 2.0 * p[0]) 26 | ) 27 | ) 28 | log_s_c = np.log( 3.0 /(4*np.pi) ) / 6 29 | 30 | # Below cast to same dtype as input (allow easier comparison between f32/f64). 31 | dtype = n.dtype 32 | p = p.astype(dtype) 33 | f = f.astype(dtype) 34 | f_inv_p1 = (f_inv_p1).astype(dtype) 35 | f_2 = f_2.astype(dtype) 36 | sqrt = sqrt.astype(dtype) 37 | precompute = precompute.astype(dtype) 38 | log_s_c =log_s_c.astype(dtype) 39 | 40 | # compute stuff that depends on n 41 | log_s = - jnp.log(n) / 6 + log_s_c 42 | s_2 = jnp.exp( log_s *2) 43 | s = jnp.exp( log_s ) 44 | z = sqrt / (2.0 * s + p[2]) 45 | 46 | result = n * p[1] * ( 47 | log_s 48 | #+ f * jnp.log( jnp.sqrt( s_2 + p[2] * s + p[3] ) / (s-p[0])**(1/f+1) ) # problem with float, 1/f+1 was done in np which automatically sticks to float64 49 | + f * jnp.log( jnp.sqrt( s_2 + p[2] * s + p[3] ) / (s-p[0])**(f_inv_p1) ) 50 | + precompute * jnp.arctan(z) 51 | 52 | ) 53 | 54 | return result 55 | -------------------------------------------------------------------------------- /pyscf_ipu/experimental/basis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from typing import Tuple 3 | 4 | import chex 5 | import jax.numpy as jnp 6 | import numpy as np 7 | 8 | from .orbital import Orbital 9 | from .structure import Structure 10 | from .types import FloatN, FloatNx3, FloatNxM 11 | 12 | 13 | @chex.dataclass 14 | class Basis: 15 | orbitals: Tuple[Orbital] 16 | structure: Structure 17 | 18 | @property 19 | def num_orbitals(self) -> int: 20 | return len(self.orbitals) 21 | 22 | @property 23 | def num_primitives(self) -> int: 24 | return sum(ao.num_primitives for ao in self.orbitals) 25 | 26 | @property 27 | def occupancy(self) -> FloatN: 28 | # Assumes uncharged systems in restricted Kohn-Sham 29 | occ = jnp.full(self.num_orbitals, 2.0) 30 | mask = occ.cumsum() > self.structure.num_electrons 31 | occ = occ.at[mask].set(0.0) 32 | return occ 33 | 34 | def __call__(self, pos: FloatNx3) -> FloatNxM: 35 | return jnp.hstack([o(pos) for o in self.orbitals]) 36 | 37 | 38 | def basisset(structure: Structure, basis_name: str = "sto-3g"): 39 | from basis_set_exchange import get_basis 40 | from basis_set_exchange.sort import sort_basis 41 | 42 | LMN_MAP = { 43 | 0: [(0, 0, 0)], 44 | 1: [(1, 0, 0), (0, 1, 0), (0, 0, 1)], 45 | 2: [(2, 0, 0), (1, 1, 0), (1, 0, 1), (0, 2, 0), (0, 1, 1), (0, 0, 2)], 46 | } 47 | 48 | bse_basis = get_basis( 49 | basis_name, 50 | elements=structure.atomic_symbol, 51 | uncontract_spdf=True, 52 | uncontract_general=False, 53 | ) 54 | bse_basis = sort_basis(bse_basis)["elements"] 55 | orbitals = [] 56 | 57 | for a in range(structure.num_atoms): 58 | center = structure.position[a, :] 59 | shells = bse_basis[str(structure.atomic_number[a])]["electron_shells"] 60 | 61 | for s in shells: 62 | for lmn in LMN_MAP[s["angular_momentum"][0]]: 63 | ao = Orbital.from_bse( 64 | center=center, 65 | alphas=np.array(s["exponents"], dtype=np.float32), 66 | lmn=np.array(lmn, dtype=np.int32), 67 | coefficients=np.array(s["coefficients"], dtype=np.float32), 68 | ) 69 | orbitals.append(ao) 70 | 71 | return Basis( 72 | orbitals=orbitals, 73 | structure=structure, 74 | ) 75 | -------------------------------------------------------------------------------- /pyscf_ipu/experimental/binom_factor_table.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | # AUTOGENERATED from notebooks/binom_factor_table.ipynb 3 | # fmt: off 4 | # flake8: noqa 5 | # isort: skip_file 6 | from numpy import array 7 | binom_factor_table = ((array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 8 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 9 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 10 | 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 11 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 12 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 13 | 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 14 | 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]), array([1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 15 | 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 1, 1, 1, 1, 1, 2, 2, 16 | 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 17 | 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 18 | 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 19 | 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 20 | 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 21 | 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]), array([0, 0, 1, 0, 1, 2, 0, 1, 2, 3, 0, 0, 1, 1, 0, 1, 1, 2, 2, 0, 1, 1, 22 | 2, 2, 3, 3, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 1, 0, 1, 1, 2, 2, 0, 1, 23 | 1, 2, 2, 2, 3, 3, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 0, 1, 1, 2, 2, 24 | 2, 3, 3, 3, 4, 4, 4, 0, 1, 2, 0, 1, 1, 2, 2, 3, 3, 0, 1, 1, 2, 2, 25 | 2, 3, 3, 3, 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 0, 1, 1, 26 | 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 0, 1, 2, 3, 0, 1, 1, 2, 2, 3, 3, 27 | 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 28 | 3, 3, 4, 4, 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4]), array([ 4, 21, 4, 11, 21, 4, 2, 11, 21, 4, 15, 3, 4, 15, 1, 3, 21, 29 | 4, 15, 17, 1, 11, 3, 21, 4, 15, 19, 2, 17, 1, 11, 3, 21, 4, 30 | 15, 22, 15, 10, 3, 22, 4, 15, 23, 1, 10, 3, 21, 22, 4, 15, 9, 31 | 17, 23, 1, 10, 11, 3, 21, 22, 4, 15, 12, 9, 19, 2, 17, 23, 1, 32 | 10, 11, 3, 21, 22, 13, 22, 15, 14, 10, 13, 3, 22, 4, 15, 5, 14, 33 | 23, 1, 10, 13, 3, 21, 22, 4, 15, 16, 5, 9, 14, 17, 23, 1, 10, 34 | 11, 13, 3, 21, 22, 18, 12, 16, 5, 9, 19, 2, 14, 17, 23, 1, 10, 35 | 11, 13, 7, 13, 22, 15, 24, 7, 14, 10, 13, 3, 22, 4, 15, 20, 5, 36 | 24, 7, 14, 23, 1, 10, 13, 3, 21, 22, 6, 16, 20, 5, 9, 24, 7, 37 | 14, 17, 23, 1, 10, 11, 13, 8, 6, 18, 12, 16, 20, 5, 9, 19, 24, 38 | 2, 7, 14, 17, 23])), array([ 1., 1., 2., 1., 3., 3., 1., 4., 6., 4., 1., 1., 1., 39 | 1., 1., 2., 1., 2., 1., 1., 3., 1., 3., 3., 3., 1., 40 | 1., 1., 4., 6., 4., 4., 6., 4., 1., 1., 2., 1., 2., 41 | 1., 1., 2., 1., 2., 2., 4., 1., 1., 2., 2., 1., 2., 42 | 3., 6., 3., 1., 6., 3., 1., 3., 2., 1., 4., 2., 1., 43 | 8., 6., 12., 4., 4., 8., 6., 1., 1., 3., 3., 1., 3., 44 | 1., 3., 3., 1., 3., 1., 2., 3., 3., 6., 1., 6., 1., 45 | 3., 2., 3., 1., 3., 3., 3., 3., 9., 9., 9., 1., 1., 46 | 9., 3., 3., 1., 3., 4., 6., 12., 3., 1., 4., 12., 18., 47 | 18., 12., 4., 1., 1., 4., 6., 4., 1., 1., 4., 6., 4., 48 | 4., 6., 1., 4., 1., 4., 2., 1., 8., 6., 4., 12., 4., 49 | 8., 1., 6., 1., 4., 3., 12., 6., 3., 1., 12., 4., 18., 50 | 12., 18., 1., 4., 1., 4., 4., 6., 16., 6., 24., 24., 4., 51 | 4., 1., 1., 16., 16., 36.])) 52 | -------------------------------------------------------------------------------- /pyscf_ipu/experimental/device.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from functools import partial, wraps 3 | 4 | import numpy as np 5 | from jax import devices, jit 6 | 7 | 8 | def has_ipu() -> bool: 9 | try: 10 | return len(devices("ipu")) > 0 11 | except RuntimeError: 12 | pass 13 | 14 | return False 15 | 16 | 17 | ipu_jit = partial(jit, backend="ipu") 18 | 19 | 20 | def ipu_func(func): 21 | @wraps(func) 22 | def wrapper(*args, **kwargs): 23 | outputs = ipu_jit(func)(*args, **kwargs) 24 | 25 | if not isinstance(outputs, tuple): 26 | return np.asarray(outputs) 27 | 28 | return [np.asarray(o) for o in outputs] 29 | 30 | return wrapper 31 | -------------------------------------------------------------------------------- /pyscf_ipu/experimental/integrals.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from dataclasses import asdict 3 | from functools import partial 4 | from itertools import product as cartesian_product 5 | from typing import Callable 6 | 7 | import jax.numpy as jnp 8 | import numpy as np 9 | from jax import jit, tree_map, vmap 10 | from jax.ops import segment_sum 11 | 12 | from .basis import Basis 13 | from .orbital import batch_orbitals 14 | from .primitive import Primitive, product 15 | from .special import binom, binom_factor, factorial, factorial2, gammanu 16 | from .types import Float3, FloatN, FloatNx3, FloatNxN 17 | from .units import LMAX 18 | 19 | """ 20 | JAX implementation for integrals over Gaussian basis functions. Based upon the 21 | closed-form expressions derived in 22 | 23 | Taketa, H., Huzinaga, S., & O-ohata, K. (1966). Gaussian-expansion methods for 24 | molecular integrals. Journal of the physical society of Japan, 21(11), 2313-2324. 25 | 26 | 27 | Hereafter referred to as the "THO paper" 28 | 29 | Related work: 30 | 31 | [1] Augspurger JD, Dykstra CE. General quantum mechanical operators. An 32 | open-ended approach for one-electron integrals with Gaussian bases. Journal of 33 | computational chemistry. 1990 Jan;11(1):105-11. 34 | 35 | 36 | [2] PyQuante: 37 | """ 38 | 39 | 40 | @partial(vmap, in_axes=(0, 0, 0, 0, None)) 41 | def overlap_axis(i: int, j: int, a: float, b: float, alpha: float) -> float: 42 | idx = [(s, t) for s in range(LMAX + 1) for t in range(2 * s + 1)] 43 | s, t = jnp.array(idx, dtype=jnp.uint32).T 44 | out = binom(i, 2 * s - t) * binom(j, t) 45 | out *= a ** (i - (2 * s - t)) * b ** (j - t) 46 | out *= factorial2(2 * s - 1) / (2 * alpha) ** s 47 | 48 | mask = (2 * s - i <= t) & (t <= j) 49 | out = jnp.where(mask, out, 0) 50 | return jnp.sum(out) 51 | 52 | 53 | def overlap_basis(b: Basis) -> FloatNxN: 54 | return integrate(b, vmap_overlap_primitives) 55 | 56 | 57 | def integrate(b: Basis, primitive_op: Callable) -> FloatNxN: 58 | def take_primitives(indices): 59 | p = tree_map(lambda x: jnp.take(x, indices, axis=0), primitives) 60 | c = jnp.take(coefficients, indices) 61 | return p, c 62 | 63 | primitives, coefficients, orbital_index = batch_orbitals(b.orbitals) 64 | ii, jj = jnp.triu_indices(b.num_primitives) 65 | lhs, cl = take_primitives(ii.reshape(-1)) 66 | rhs, cr = take_primitives(jj.reshape(-1)) 67 | aij = cl * cr * primitive_op(lhs, rhs) 68 | A = jnp.zeros((b.num_primitives, b.num_primitives)) 69 | A = A.at[ii, jj].set(aij) 70 | A = A + A.T - jnp.diag(jnp.diag(A)) 71 | index = orbital_index.reshape(1, -1) 72 | return segment_sum(segment_sum(A, index).T, index) 73 | 74 | 75 | def _overlap_primitives(a: Primitive, b: Primitive) -> float: 76 | p = product(a, b) 77 | pa = p.center - a.center 78 | pb = p.center - b.center 79 | out = jnp.power(jnp.pi / p.alpha, 1.5) * p.norm 80 | out *= jnp.prod(overlap_axis(a.lmn, b.lmn, pa, pb, p.alpha)) 81 | return out 82 | 83 | 84 | def _kinetic_primitives(a: Primitive, b: Primitive) -> float: 85 | t0 = b.alpha * (2 * jnp.sum(b.lmn) + 3) * _overlap_primitives(a, b) 86 | 87 | def offset_qn(ax: int, offset: int): 88 | lmn = b.lmn.at[ax].add(offset) 89 | return Primitive(**{**asdict(b), "lmn": lmn}) 90 | 91 | axes = jnp.arange(3) 92 | b1 = vmap(offset_qn, (0, None))(axes, 2) 93 | t1 = jnp.sum(vmap(_overlap_primitives, (None, 0))(a, b1)) 94 | 95 | b2 = vmap(offset_qn, (0, None))(axes, -2) 96 | t2 = jnp.sum(b.lmn * (b.lmn - 1) * vmap(_overlap_primitives, (None, 0))(a, b2)) 97 | return t0 - 2.0 * b.alpha**2 * t1 - 0.5 * t2 98 | 99 | 100 | def kinetic_basis(b: Basis) -> FloatNxN: 101 | return integrate(b, vmap_kinetic_primitives) 102 | 103 | 104 | def build_gindex(): 105 | vals = [ 106 | (i, r, u) 107 | for i in range(LMAX + 1) 108 | for r in range(i // 2 + 1) 109 | for u in range((i - 2 * r) // 2 + 1) 110 | ] 111 | i, r, u = jnp.array(vals).T 112 | return i, r, u 113 | 114 | 115 | def _nuclear_primitives(a: Primitive, b: Primitive, c: Float3): 116 | p = product(a, b) 117 | pa = p.center - a.center 118 | pb = p.center - b.center 119 | pc = p.center - c 120 | epsilon = 1.0 / (4.0 * p.alpha) 121 | 122 | @vmap 123 | def g_term(l1, l2, pa, pb, cp): 124 | i, r, u = build_gindex() 125 | index = i - 2 * r - u 126 | g = ( 127 | jnp.power(-1, i + u) 128 | * jnp.take(binom_factor(l1, l2, pa, pb), i) 129 | * factorial(i) 130 | * jnp.power(cp, index - u) 131 | * jnp.power(epsilon, r + u) 132 | ) / (factorial(r) * factorial(u) * factorial(index - u)) 133 | 134 | g = jnp.where(index <= l1 + l2, g, 0.0) 135 | return jnp.zeros(LMAX + 1).at[index].add(g) 136 | 137 | Gi, Gj, Gk = g_term(a.lmn, b.lmn, pa, pb, pc) 138 | 139 | ijk = jnp.arange(LMAX + 1) 140 | nu = ( 141 | ijk[:, jnp.newaxis, jnp.newaxis] 142 | + ijk[jnp.newaxis, :, jnp.newaxis] 143 | + ijk[jnp.newaxis, jnp.newaxis, :] 144 | ) 145 | 146 | W = ( 147 | Gi[:, jnp.newaxis, jnp.newaxis] 148 | * Gj[jnp.newaxis, :, jnp.newaxis] 149 | * Gk[jnp.newaxis, jnp.newaxis, :] 150 | * gammanu(nu, p.alpha * jnp.inner(pc, pc)) 151 | ) 152 | 153 | return -2.0 * jnp.pi / p.alpha * p.norm * jnp.sum(W) 154 | 155 | 156 | overlap_primitives = jit(_overlap_primitives) 157 | kinetic_primitives = jit(_kinetic_primitives) 158 | nuclear_primitives = jit(_nuclear_primitives) 159 | 160 | vmap_overlap_primitives = jit(vmap(_overlap_primitives)) 161 | vmap_kinetic_primitives = jit(vmap(_kinetic_primitives)) 162 | vmap_nuclear_primitives = jit(vmap(_nuclear_primitives)) 163 | 164 | 165 | @partial(vmap, in_axes=(None, 0, 0)) 166 | def nuclear_basis(b: Basis, c: FloatNx3, z: FloatN) -> FloatNxN: 167 | op = partial(_nuclear_primitives, c=c) 168 | op = vmap(op) 169 | op = jit(op) 170 | return z * integrate(b, op) 171 | 172 | 173 | def build_cindex(): 174 | vals = [ 175 | (i1, i2, r1, r2, u) 176 | for i1 in range(2 * LMAX + 1) 177 | for i2 in range(2 * LMAX + 1) 178 | for r1 in range(i1 // 2 + 1) 179 | for r2 in range(i2 // 2 + 1) 180 | for u in range((i1 + i2) // 2 - r1 - r2 + 1) 181 | ] 182 | i1, i2, r1, r2, u = jnp.array(vals).T 183 | return i1, i2, r1, r2, u 184 | 185 | 186 | def _eri_primitives(a: Primitive, b: Primitive, c: Primitive, d: Primitive) -> float: 187 | p = product(a, b) 188 | q = product(c, d) 189 | pa = p.center - a.center 190 | pb = p.center - b.center 191 | qc = q.center - c.center 192 | qd = q.center - d.center 193 | qp = q.center - p.center 194 | delta = 1 / (4.0 * p.alpha) + 1 / (4.0 * q.alpha) 195 | 196 | def H(l1, l2, a, b, i, r, gamma): 197 | # Note this should match THO Eq 3.5 but that seems to incorrectly show a 198 | # 1/(4 gamma) ^(i- 2r) term which is inconsistent with Eq 2.22. 199 | # Using (4 gamma)^(r - i) matches the reported expressions for H_L 200 | u = factorial(i) * jnp.take(binom_factor(l1, l2, a, b, 2 * LMAX), i) 201 | v = factorial(r) * factorial(i - 2 * r) * (4 * gamma) ** (i - r) 202 | return u / v 203 | 204 | def c_term(la, lb, lc, ld, pa, pb, qc, qd, qp): 205 | # THO Eq 2.22 and 3.4 206 | i1, i2, r1, r2, u = build_cindex() 207 | h = H(la, lb, pa, pb, i1, r1, p.alpha) * H(lc, ld, qc, qd, i2, r2, q.alpha) 208 | index = i1 + i2 - 2 * (r1 + r2) - u 209 | x = (-1) ** (i2 + u) * factorial(index + u) * qp ** (index - u) 210 | y = factorial(u) * factorial(index - u) * delta**index 211 | c = h * x / y 212 | 213 | mask = (i1 <= (la + lb)) & (i2 <= (lc + ld)) 214 | c = jnp.where(mask, c, 0.0) 215 | return segment_sum(c, index, num_segments=4 * LMAX + 1) 216 | 217 | # Manual vmap over cartesian axes (x, y, z) as ran into possible bug. 218 | # See https://github.com/graphcore-research/pyscf-ipu/issues/105 219 | args = [a.lmn, b.lmn, c.lmn, d.lmn, pa, pb, qc, qd, qp] 220 | Ci, Cj, Ck = [c_term(*[v.at[i].get() for v in args]) for i in range(3)] 221 | 222 | ijk = jnp.arange(4 * LMAX + 1) 223 | nu = ( 224 | ijk[:, jnp.newaxis, jnp.newaxis] 225 | + ijk[jnp.newaxis, :, jnp.newaxis] 226 | + ijk[jnp.newaxis, jnp.newaxis, :] 227 | ) 228 | 229 | W = ( 230 | Ci[:, jnp.newaxis, jnp.newaxis] 231 | * Cj[jnp.newaxis, :, jnp.newaxis] 232 | * Ck[jnp.newaxis, jnp.newaxis, :] 233 | * gammanu(nu, jnp.inner(qp, qp) / (4.0 * delta)) 234 | ) 235 | 236 | return ( 237 | 2.0 238 | * jnp.pi**2 239 | / (p.alpha * q.alpha) 240 | * jnp.sqrt(jnp.pi / (p.alpha + q.alpha)) 241 | * p.norm 242 | * q.norm 243 | * jnp.sum(W) 244 | ) 245 | 246 | 247 | eri_primitives = jit(_eri_primitives) 248 | vmap_eri_primitives = jit(vmap(_eri_primitives)) 249 | 250 | 251 | def gen_ijkl(n: int): 252 | """ 253 | adapted from four-index transformations by S Wilson pg 257 254 | """ 255 | for idx in range(n): 256 | for jdx in range(idx + 1): 257 | for kdx in range(idx + 1): 258 | lmax = jdx if idx == kdx else kdx 259 | for ldx in range(lmax + 1): 260 | yield idx, jdx, kdx, ldx 261 | 262 | 263 | def eri_basis_sparse(b: Basis): 264 | indices = [] 265 | batch = [] 266 | offset = np.cumsum([o.num_primitives for o in b.orbitals]) 267 | offset = np.insert(offset, 0, 0) 268 | 269 | for count, idx in enumerate(gen_ijkl(b.num_orbitals)): 270 | mesh = [range(offset[i], offset[i + 1]) for i in idx] 271 | indices += list(cartesian_product(*mesh)) 272 | batch += [count] * (len(indices) - len(batch)) 273 | 274 | indices = jnp.array(indices, dtype=jnp.int32).T 275 | batch = jnp.array(batch, dtype=jnp.int32) 276 | primitives, coefficients, _ = batch_orbitals(b.orbitals) 277 | cijkl = jnp.stack([jnp.take(coefficients, idx) for idx in indices]).prod(axis=0) 278 | pijkl = [ 279 | tree_map(lambda x: jnp.take(x, idx, axis=0), primitives) for idx in indices 280 | ] 281 | eris = cijkl * vmap_eri_primitives(*pijkl) 282 | return segment_sum(eris, batch, num_segments=count + 1) 283 | 284 | 285 | def eri_basis(b: Basis): 286 | unique_eris = eri_basis_sparse(b) 287 | ii, jj, kk, ll = jnp.array(list(gen_ijkl(b.num_orbitals)), dtype=jnp.int32).T 288 | 289 | # Apply 8x permutation symmetry to build dense ERI from sparse ERI. 290 | eri_dense = jnp.empty((b.num_orbitals,) * 4, dtype=jnp.float32) 291 | eri_dense = eri_dense.at[ii, jj, kk, ll].set(unique_eris) 292 | eri_dense = eri_dense.at[ii, jj, ll, kk].set(unique_eris) 293 | eri_dense = eri_dense.at[jj, ii, kk, ll].set(unique_eris) 294 | eri_dense = eri_dense.at[jj, ii, ll, kk].set(unique_eris) 295 | eri_dense = eri_dense.at[kk, ll, ii, jj].set(unique_eris) 296 | eri_dense = eri_dense.at[kk, ll, jj, ii].set(unique_eris) 297 | eri_dense = eri_dense.at[ll, kk, ii, jj].set(unique_eris) 298 | eri_dense = eri_dense.at[ll, kk, jj, ii].set(unique_eris) 299 | return eri_dense 300 | -------------------------------------------------------------------------------- /pyscf_ipu/experimental/interop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | from periodictable import elements 6 | from pyscf import gto 7 | 8 | from .basis import Basis, basisset 9 | from .structure import Structure 10 | 11 | 12 | def to_pyscf( 13 | structure: Structure, basis_name: str = "sto-3g", unit: str = "Bohr" 14 | ) -> "gto.Mole": 15 | mol = gto.Mole(unit=unit, spin=structure.num_electrons % 2, cart=True) 16 | mol.atom = [ 17 | (symbol, pos) 18 | for symbol, pos in zip(structure.atomic_symbol, structure.position) 19 | ] 20 | mol.basis = basis_name 21 | mol.build(unit=unit) 22 | return mol 23 | 24 | 25 | def from_pyscf(mol: "gto.Mole") -> Tuple[Structure, Basis]: 26 | atomic_number = [] 27 | position = [] 28 | 29 | for i in range(mol.natm): 30 | sym, pos = mol.atom[i] 31 | atomic_number.append(elements.symbol(sym).number) 32 | position.append(pos) 33 | 34 | structure = Structure( 35 | atomic_number=np.array(atomic_number), 36 | position=np.array(position), 37 | is_bohr=mol.unit != "Angstom", 38 | ) 39 | 40 | basis = basisset(structure, basis_name=mol.basis) 41 | 42 | return structure, basis 43 | -------------------------------------------------------------------------------- /pyscf_ipu/experimental/mesh.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from typing import Optional, Tuple, Union 3 | 4 | import jax.numpy as jnp 5 | 6 | from .basis import Basis 7 | from .types import FloatN, FloatNx3, FloatNxN 8 | 9 | 10 | def uniform_mesh( 11 | n: Union[int, Tuple] = 50, b: Union[float, Tuple] = 10.0, ndim: int = 3 12 | ): 13 | if isinstance(n, int): 14 | n = (n,) * ndim 15 | 16 | if isinstance(b, float): 17 | b = (b,) * ndim 18 | 19 | if not isinstance(n, (tuple, list)): 20 | raise ValueError("Expected an integer ") 21 | 22 | if len(n) != ndim: 23 | raise ValueError("n must be a tuple with {ndim} elements") 24 | 25 | if len(b) != ndim: 26 | raise ValueError("b must be a tuple with {ndim} elements") 27 | 28 | axes = [jnp.linspace(-bi, bi, ni) for bi, ni in zip(b, n)] 29 | mesh = jnp.stack(jnp.meshgrid(*axes, indexing="ij"), axis=-1) 30 | mesh = mesh.reshape(-1, ndim) 31 | return mesh, axes 32 | 33 | 34 | def electron_density( 35 | basis: Basis, mesh: FloatNx3, C: Optional[FloatNxN] = None 36 | ) -> FloatN: 37 | orbitals = molecular_orbitals(basis, mesh, C) 38 | density = jnp.sum(basis.occupancy * orbitals * orbitals, axis=-1) 39 | return density 40 | 41 | 42 | def molecular_orbitals( 43 | basis: Basis, mesh: FloatNx3, C: Optional[FloatNxN] = None 44 | ) -> FloatN: 45 | C = jnp.eye(basis.num_orbitals) if C is None else C 46 | orbitals = basis(mesh) @ C 47 | return orbitals 48 | -------------------------------------------------------------------------------- /pyscf_ipu/experimental/numerics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from functools import wraps 3 | from typing import Callable 4 | 5 | import jax.numpy as jnp 6 | import numpy as np 7 | from jax.experimental import enable_x64 8 | from jaxtyping import Array 9 | 10 | 11 | def apply_fpcast(v: Array, dtype: np.dtype): 12 | if isinstance(v, jnp.ndarray) and np.issubdtype(v, np.floating): 13 | return v.astype(dtype) 14 | 15 | return v 16 | 17 | 18 | def fpcast(func: Callable, dtype=jnp.float32): 19 | @wraps(func) 20 | def wrapper(*args, **kwargs): 21 | inputs = [apply_fpcast(v, dtype) for v in args] 22 | outputs = func(*inputs, **kwargs) 23 | return outputs 24 | 25 | return wrapper 26 | 27 | 28 | def compare_fp32_to_fp64(func: Callable): 29 | @wraps(func) 30 | def wrapper(*args, **kwargs): 31 | with enable_x64(): 32 | outputs_fp32 = fpcast(func, dtype=jnp.float32)(*args, **kwargs) 33 | outputs_fp64 = fpcast(func, dtype=jnp.float64)(*args, **kwargs) 34 | print_compare(func.__name__, outputs_fp32, outputs_fp64) 35 | return outputs_fp32 36 | 37 | return wrapper 38 | 39 | 40 | def print_compare(name: str, fp32, fp64): 41 | fp32 = [fp32] if isinstance(fp32, jnp.ndarray) else fp32 42 | fp64 = [fp64] if isinstance(fp64, jnp.ndarray) else fp64 43 | 44 | for idx, (low, high) in enumerate(zip(fp32, fp64)): 45 | low = np.asarray(low).astype(np.float64) 46 | high = np.asarray(high) 47 | print(f"{name} output {idx} has max |fp64 - fp32| = {np.abs(high - low).max()}") 48 | -------------------------------------------------------------------------------- /pyscf_ipu/experimental/orbital.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | 3 | from functools import partial 4 | from typing import Tuple 5 | 6 | import chex 7 | import jax.numpy as jnp 8 | from jax import tree_map, vmap 9 | 10 | from .primitive import Primitive, eval_primitive 11 | from .types import FloatN, FloatNx3 12 | 13 | 14 | @chex.dataclass 15 | class Orbital: 16 | primitives: Tuple[Primitive] 17 | coefficients: FloatN 18 | 19 | @property 20 | def num_primitives(self) -> int: 21 | return len(self.primitives) 22 | 23 | def __call__(self, pos: FloatNx3) -> FloatN: 24 | assert pos.shape[-1] == 3, "pos must be have shape [N,3]" 25 | 26 | @partial(vmap, in_axes=(0, 0, None)) 27 | def eval_orbital(p: Primitive, coef: float, pos: FloatNx3): 28 | return coef * eval_primitive(p, pos) 29 | 30 | batch = tree_map(lambda *xs: jnp.stack(xs), *self.primitives) 31 | out = jnp.sum(eval_orbital(batch, self.coefficients, pos), axis=0) 32 | return out 33 | 34 | @staticmethod 35 | def from_bse(center, alphas, lmn, coefficients): 36 | coefficients = coefficients.reshape(-1) 37 | assert len(coefficients) == len(alphas), "Expecting same size vectors!" 38 | p = [Primitive(center=center, alpha=a, lmn=lmn) for a in alphas] 39 | return Orbital(primitives=p, coefficients=coefficients) 40 | 41 | 42 | def batch_orbitals(orbitals: Tuple[Orbital]): 43 | primitives = [p for o in orbitals for p in o.primitives] 44 | primitives = tree_map(lambda *xs: jnp.stack(xs), *primitives) 45 | coefficients = jnp.concatenate([o.coefficients for o in orbitals]) 46 | orbital_index = jnp.concatenate( 47 | [ 48 | i * jnp.ones(o.num_primitives, dtype=jnp.int32) 49 | for i, o in enumerate(orbitals) 50 | ] 51 | ) 52 | return primitives, coefficients, orbital_index 53 | -------------------------------------------------------------------------------- /pyscf_ipu/experimental/plot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import numpy as np 3 | from numpy.typing import NDArray 4 | 5 | from .structure import Structure 6 | from .types import MeshAxes 7 | from .units import to_angstrom 8 | 9 | 10 | def plot_volume(structure: Structure, value: NDArray, axes: MeshAxes): 11 | """plots volumetric data value with molecular structure. 12 | 13 | Args: 14 | structure (Structure): molecular structure 15 | value (NDArray): the volume data to render 16 | axes (MeshAxes): the axes over which the data was sampled. 17 | 18 | Returns: 19 | py3DMol View object 20 | """ 21 | v = structure.view() 22 | v.addVolumetricData(cube_data(value, axes), "cube", build_transferfn(value)) 23 | return v 24 | 25 | 26 | def cube_data(value: NDArray, axes: MeshAxes) -> str: 27 | """Generate the cube file format as a string. See: 28 | 29 | https://paulbourke.net/dataformats/cube/ 30 | 31 | Args: 32 | value (NDArray): the volume data to serialise in the cube format 33 | axes (MeshAxes): the axes over which the data was sampled 34 | 35 | Returns: 36 | str: cube format representation of the volumetric data. 37 | """ 38 | axes = [to_angstrom(ax) for ax in axes] 39 | fmt = "cube format\n\n" 40 | x, y, z = axes 41 | nx, ny, nz = [ax.shape[0] for ax in axes] 42 | fmt += "0 " + " ".join([f"{v:12.6f}" for v in [x[0], y[0], z[0]]]) + "\n" 43 | fmt += f"{nx} " + " ".join([f"{v:12.6f}" for v in [x[1] - x[0], 0.0, 0.0]]) + "\n" 44 | fmt += f"{ny} " + " ".join([f"{v:12.6f}" for v in [0.0, y[1] - y[0], 0.0]]) + "\n" 45 | fmt += f"{nz} " + " ".join([f"{v:12.6f}" for v in [0.0, 0.0, z[1] - z[0]]]) + "\n" 46 | 47 | line = "" 48 | for i in range(len(value)): 49 | line += f"{value[i]:12.6f}" 50 | 51 | if i % 6 == 0: 52 | fmt += line + "\n" 53 | line = "" 54 | 55 | return fmt 56 | 57 | 58 | def build_transferfn(value: NDArray) -> dict: 59 | """Generate the 3dmol.js transferfn argument for a particular value. 60 | 61 | Tries to set isovalues to capture main features of the volume data. 62 | 63 | Args: 64 | value (NDArray): the volume data. 65 | 66 | Returns: 67 | dict: containing transferfn 68 | """ 69 | v = np.percentile(value, [99.9, 75]) 70 | a = [0.02, 0.0005] 71 | return { 72 | "transferfn": [ 73 | {"color": "blue", "opacity": a[0], "value": -v[0]}, 74 | {"color": "blue", "opacity": a[1], "value": -v[1]}, 75 | {"color": "white", "opacity": 0.0, "value": 0.0}, 76 | {"color": "red", "opacity": a[1], "value": v[1]}, 77 | {"color": "red", "opacity": a[0], "value": v[0]}, 78 | ] 79 | } 80 | -------------------------------------------------------------------------------- /pyscf_ipu/experimental/primitive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from typing import Optional 3 | 4 | import chex 5 | import jax.numpy as jnp 6 | import numpy as np 7 | from scipy.special import gammaln 8 | 9 | from .types import Float3, FloatN, FloatNx3, Int3 10 | 11 | 12 | @chex.dataclass 13 | class Primitive: 14 | center: Float3 = np.zeros(3, dtype=np.float32) 15 | alpha: float = 1.0 16 | lmn: Int3 = np.zeros(3, dtype=np.int32) 17 | norm: Optional[float] = None 18 | 19 | def __post_init__(self): 20 | if self.norm is None: 21 | self.norm = normalize(self.lmn, self.alpha) 22 | 23 | @property 24 | def angular_momentum(self) -> int: 25 | return np.sum(self.lmn) 26 | 27 | def __call__(self, pos: FloatNx3) -> FloatN: 28 | return eval_primitive(self, pos) 29 | 30 | 31 | def normalize(lmn: Int3, alpha: float) -> float: 32 | L = np.sum(lmn) 33 | N = ((1 / 2) / alpha) ** (L + 3 / 2) 34 | N *= np.exp(np.sum(gammaln(lmn + 1 / 2))) 35 | return N**-0.5 36 | 37 | 38 | def product(a: Primitive, b: Primitive) -> Primitive: 39 | alpha = a.alpha + b.alpha 40 | center = (a.alpha * a.center + b.alpha * b.center) / alpha 41 | lmn = a.lmn + b.lmn 42 | c = a.norm * b.norm 43 | Rab = a.center - b.center 44 | c *= jnp.exp(-a.alpha * b.alpha / alpha * jnp.inner(Rab, Rab)) 45 | return Primitive(center=center, alpha=alpha, lmn=lmn, norm=c) 46 | 47 | 48 | def eval_primitive(p: Primitive, pos: FloatNx3) -> FloatN: 49 | assert pos.shape[-1] == 3, "pos must be have shape [N,3]" 50 | pos_translated = pos[:, jnp.newaxis] - p.center 51 | v = p.norm * jnp.exp(-p.alpha * jnp.sum(pos_translated**2, axis=-1)) 52 | v *= jnp.prod(pos_translated**p.lmn, axis=-1) 53 | return v 54 | -------------------------------------------------------------------------------- /pyscf_ipu/experimental/special.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from functools import partial 3 | 4 | import jax.numpy as jnp 5 | import numpy as np 6 | from jax import lax 7 | from jax.ops import segment_sum 8 | from jax.scipy.special import betaln, gammainc, gammaln 9 | 10 | from .types import FloatN, IntN 11 | from .units import LMAX 12 | 13 | 14 | def factorial_fori(n: IntN, nmax: int = LMAX) -> IntN: 15 | def body_fun(i, val): 16 | return val * jnp.where(i <= n, i, 1) 17 | 18 | return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n)) 19 | 20 | 21 | def factorial_gamma(n: IntN) -> IntN: 22 | """Appoximate factorial by evaluating the gamma function in log-space. 23 | 24 | This approximation is exact for small integers (n < 10). 25 | """ 26 | approx = jnp.exp(gammaln(n + 1)) 27 | return jnp.rint(approx) 28 | 29 | 30 | def factorial_lookup(n: IntN, nmax: int = LMAX) -> IntN: 31 | N = np.cumprod(np.arange(1, nmax + 1)) 32 | N = np.insert(N, 0, 1) 33 | N = jnp.array(N, dtype=jnp.uint32) 34 | return N.at[n.astype(jnp.uint32)].get() 35 | 36 | 37 | factorial = factorial_gamma 38 | 39 | 40 | def factorial2_fori(n: IntN, nmax: int = 2 * LMAX) -> IntN: 41 | def body_fun(i, val): 42 | return val * jnp.where((i <= n) & (n % 2 == i % 2), i, 1) 43 | 44 | return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n)) 45 | 46 | 47 | def factorial2_lookup(n: IntN, nmax: int = 2 * LMAX) -> IntN: 48 | stop = nmax + 1 if nmax % 2 == 0 else nmax + 2 49 | N = np.arange(1, stop).reshape(-1, 2) 50 | N = np.cumprod(N, axis=0).reshape(-1) 51 | N = np.insert(N, 0, 1) 52 | N = jnp.array(N) 53 | n = jnp.maximum(n, 0) 54 | return N.at[n].get() 55 | 56 | 57 | factorial2 = factorial2_lookup 58 | 59 | 60 | def binom_beta(x: IntN, y: IntN) -> IntN: 61 | approx = 1.0 / ((x + 1) * jnp.exp(betaln(x - y + 1, y + 1))) 62 | return jnp.rint(approx) 63 | 64 | 65 | def binom_fori(x: IntN, y: IntN, nmax: int = LMAX) -> IntN: 66 | bang = partial(factorial_fori, nmax=nmax) 67 | c = x * bang(x - 1) / (bang(y) * bang(x - y)) 68 | return jnp.where(x == y, 1, c) 69 | 70 | 71 | def binom_lookup(x: IntN, y: IntN, nmax: int = LMAX) -> IntN: 72 | bang = partial(factorial_lookup, nmax=nmax) 73 | c = x * bang(x - 1) / (bang(y) * bang(x - y)) 74 | return jnp.where(x == y, 1, c) 75 | 76 | 77 | binom = binom_lookup 78 | 79 | 80 | def gammanu_gamma(nu: IntN, t: FloatN, epsilon: float = 1e-10) -> FloatN: 81 | """ 82 | eq 2.11 from THO but simplified using SymPy and converted to jax 83 | 84 | t, u = symbols("t u", real=True, positive=True) 85 | nu = Symbol("nu", integer=True, nonnegative=True) 86 | 87 | expr = simplify(integrate(u ** (2 * nu) * exp(-t * u**2), (u, 0, 1))) 88 | f = lambdify((nu, t), expr, modules="scipy") 89 | ?f 90 | 91 | We evaulate this in log-space to avoid overflow/nan 92 | """ 93 | t = jnp.maximum(t, epsilon) 94 | x = nu + 0.5 95 | gn = jnp.log(0.5) - x * jnp.log(t) + jnp.log(gammainc(x, t)) + gammaln(x) 96 | return jnp.exp(gn) 97 | 98 | 99 | def gammanu_series(nu: IntN, t: FloatN, num_terms: int = 128) -> FloatN: 100 | """ 101 | eq 2.11 from THO but simplified as derived in equation 19 of gammanu.ipynb 102 | """ 103 | an = nu + 0.5 104 | tn = 1 / an 105 | total = jnp.full_like(nu, tn, dtype=jnp.float32) 106 | 107 | for _ in range(num_terms): 108 | an = an + 1 109 | tn = tn * t / an 110 | total = total + tn 111 | 112 | return jnp.exp(-t) / 2 * total 113 | 114 | 115 | gammanu = gammanu_series 116 | 117 | 118 | def binom_factor(i: int, j: int, a: float, b: float, lmax: int = LMAX) -> FloatN: 119 | """ 120 | Eq. 15 from Augspurger JD, Dykstra CE. General quantum mechanical operators. An 121 | open-ended approach for one-electron integrals with Gaussian bases. Journal of 122 | computational chemistry. 1990 Jan;11(1):105-11. 123 | 124 | """ 125 | s, t = jnp.tril_indices(lmax + 1) 126 | out = binom(i, s - t) * binom(j, t) * a ** (i - (s - t)) * b ** (j - t) 127 | mask = ((s - i) <= t) & (t <= j) 128 | out = jnp.where(mask, out, 0.0) 129 | return segment_sum(out, s, num_segments=lmax + 1) 130 | -------------------------------------------------------------------------------- /pyscf_ipu/experimental/structure.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from typing import List 3 | 4 | import chex 5 | import numpy as np 6 | from periodictable import elements 7 | from py3Dmol import view 8 | 9 | from .types import FloatNx3, IntN 10 | from .units import to_angstrom, to_bohr 11 | 12 | 13 | @chex.dataclass 14 | class Structure: 15 | atomic_number: IntN 16 | position: FloatNx3 17 | is_bohr: bool = True 18 | 19 | def __post_init__(self): 20 | if not self.is_bohr: 21 | self.position = to_bohr(self.position) 22 | 23 | # single atom case 24 | self.position = np.atleast_2d(self.position) 25 | 26 | @property 27 | def num_atoms(self) -> int: 28 | return len(self.atomic_number) 29 | 30 | @property 31 | def atomic_symbol(self) -> List[str]: 32 | return [elements[z].symbol for z in self.atomic_number] 33 | 34 | @property 35 | def num_electrons(self) -> int: 36 | return np.sum(self.atomic_number) 37 | 38 | def to_xyz(self) -> str: 39 | xyz = f"{self.num_atoms}\n\n" 40 | sym = self.atomic_symbol 41 | pos = to_angstrom(self.position) 42 | 43 | for i in range(self.num_atoms): 44 | r = np.array2string(pos[i, :], separator="\t")[1:-1] 45 | xyz += f"{sym[i]}\t{r}\n" 46 | 47 | return xyz 48 | 49 | def view(self) -> "view": 50 | return view(data=self.to_xyz(), style={"stick": {"radius": 0.06}}) 51 | 52 | 53 | def molecule(name: str): 54 | name = name.lower() 55 | 56 | if name == "h2": 57 | return Structure( 58 | atomic_number=np.array([1, 1]), 59 | position=np.array([[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]), 60 | ) 61 | 62 | if name == "water": 63 | r"""Single water molecule 64 | Structure of single water molecule calculated with DFT using B3LYP 65 | functional and 6-31+G** basis set """ 66 | return Structure( 67 | atomic_number=np.array([8, 1, 1]), 68 | position=np.array( 69 | [ 70 | [0.0000, 0.0000, 0.1165], 71 | [0.0000, 0.7694, -0.4661], 72 | [0.0000, -0.7694, -0.4661], 73 | ] 74 | ), 75 | is_bohr=False, 76 | ) 77 | 78 | raise NotImplementedError(f"No structure registered for: {name}") 79 | 80 | 81 | def nuclear_energy(structure: Structure) -> float: 82 | """Nuclear electrostatic interaction energy 83 | 84 | Evaluated by taking sum over all unique pairs of atom centers: 85 | 86 | sum_{j > i} z_i z_j / |r_i - r_j| 87 | 88 | where z_i is the charge of the ith atom (the atomic number). 89 | 90 | Args: 91 | structure (Structure): input structure 92 | 93 | Returns: 94 | float: the total nuclear repulsion energy 95 | """ 96 | idx, jdx = np.triu_indices(structure.num_atoms, 1) 97 | u = structure.atomic_number[idx] * structure.atomic_number[jdx] 98 | rij = structure.position[idx, :] - structure.position[jdx, :] 99 | return np.sum(u / np.linalg.norm(rij, axis=1)) 100 | -------------------------------------------------------------------------------- /pyscf_ipu/experimental/types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from typing import Tuple 3 | 4 | from jaxtyping import Array, Float, Int 5 | 6 | Float3 = Float[Array, "3"] 7 | FloatNx3 = Float[Array, "N 3"] 8 | FloatN = Float[Array, "N"] 9 | FloatNxN = Float[Array, "N N"] 10 | FloatNxM = Float[Array, "N M"] 11 | Int3 = Int[Array, "3"] 12 | IntN = Int[Array, "N"] 13 | 14 | MeshAxes = Tuple[FloatN, FloatN, FloatN] 15 | -------------------------------------------------------------------------------- /pyscf_ipu/experimental/units.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from jaxtyping import Array 3 | 4 | # Maximum value an individual component of the angular momentum lmn can take 5 | # Used for static ahead-of-time compilation of functions involving lmn. 6 | LMAX = 4 7 | 8 | BOHR_PER_ANGSTROM = 0.529177210903 9 | 10 | 11 | def to_angstrom(bohr_value: Array) -> Array: 12 | return bohr_value / BOHR_PER_ANGSTROM 13 | 14 | 15 | def to_bohr(angstrom_value: Array) -> Array: 16 | return angstrom_value * BOHR_PER_ANGSTROM 17 | -------------------------------------------------------------------------------- /pyscf_ipu/nanoDFT/README.md: -------------------------------------------------------------------------------- 1 | :red_circle: :warning: **Experimental and non-official Graphcore product** :warning: :red_circle: 2 | 3 | # nanoDFT 4 | 5 | Hackable DFT implementation made by machine learning researchers for machine learning researchers. nanoDFT tries to be small, clean, intepretable and educational. 6 | 7 | 8 | ### Minimal Working Example 9 | ``` 10 | python nanoDFT.py --structure_optimization True 11 | ``` 12 | 13 | 14 | -------------------------------------------------------------------------------- /pyscf_ipu/nanoDFT/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from .nanoDFT import * 3 | -------------------------------------------------------------------------------- /pyscf_ipu/nanoDFT/compute_eri_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import numpy as np 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | def reconstruct_ERI(ERI, nonzero_idx, N, sym=True): 7 | i, j, k, l = nonzero_idx[:, 0], nonzero_idx[:, 1], nonzero_idx[:, 2], nonzero_idx[:, 3] 8 | rec_ERI = np.zeros((N, N, N, N)) 9 | rec_ERI[i, j, k, l] = ERI[i, j, k, l] 10 | if sym: 11 | rec_ERI[j, i, k, l] = ERI[j, i, k, l] 12 | rec_ERI[i, j, l, k] = ERI[i, j, l, k] 13 | rec_ERI[j, i, l, k] = ERI[j, i, l, k] 14 | rec_ERI[k, l, i, j] = ERI[k, l, i, j] 15 | rec_ERI[k, l, j, i] = ERI[k, l, j, i] 16 | rec_ERI[l, k, i, j] = ERI[l, k, i, j] 17 | rec_ERI[l, k, j, i] = ERI[l, k, j, i] 18 | 19 | return rec_ERI 20 | 21 | def inverse_permutation(a): 22 | b = np.arange(a.shape[0]) 23 | b[a] = b.copy() 24 | return b 25 | 26 | def get_shapes(input_ijkl, bas): 27 | i_sh, j_sh, k_sh, l_sh = input_ijkl[0] 28 | BAS_SLOTS = 8 29 | NPRIM_OF = 2 30 | NCTR_OF = 3 31 | ANG_OF = 1 32 | GSHIFT = 4 33 | 34 | i_prim = bas.reshape(-1)[BAS_SLOTS*i_sh + NPRIM_OF] 35 | j_prim = bas.reshape(-1)[BAS_SLOTS*j_sh + NPRIM_OF] 36 | k_prim = bas.reshape(-1)[BAS_SLOTS*k_sh + NPRIM_OF] 37 | l_prim = bas.reshape(-1)[BAS_SLOTS*l_sh + NPRIM_OF] 38 | 39 | i_ctr = bas.reshape(-1)[BAS_SLOTS * i_sh + NCTR_OF] 40 | j_ctr = bas.reshape(-1)[BAS_SLOTS * j_sh + NCTR_OF] 41 | k_ctr = bas.reshape(-1)[BAS_SLOTS * k_sh + NCTR_OF] 42 | l_ctr = bas.reshape(-1)[BAS_SLOTS * l_sh + NCTR_OF] 43 | 44 | i_l = bas.reshape(-1)[BAS_SLOTS * i_sh + ANG_OF] 45 | j_l = bas.reshape(-1)[BAS_SLOTS * j_sh + ANG_OF] 46 | k_l = bas.reshape(-1)[BAS_SLOTS * k_sh + ANG_OF] 47 | l_l = bas.reshape(-1)[BAS_SLOTS * l_sh + ANG_OF] 48 | 49 | nfi = (i_l+1)*(i_l+2)/2 50 | nfj = (j_l+1)*(j_l+2)/2 51 | nfk = (k_l+1)*(k_l+2)/2 52 | nfl = (l_l+1)*(l_l+2)/2 53 | nf = nfi * nfk * nfl * nfj; 54 | n_comp = 1 55 | 56 | nc = i_ctr * j_ctr * k_ctr * l_ctr; 57 | lenl = nf * nc * n_comp; 58 | lenk = nf * i_ctr * j_ctr * k_ctr * n_comp; 59 | lenj = nf * i_ctr * j_ctr * n_comp; 60 | leni = nf * i_ctr * n_comp; 61 | len0 = nf * n_comp; 62 | 63 | ng = [0, 0, 0, 0, 0, 1, 1, 1]; 64 | 65 | IINC=0 66 | JINC=1 67 | KINC=2 68 | LINC=3 69 | 70 | li_ceil = i_l + ng[IINC] 71 | lj_ceil = j_l + ng[JINC] 72 | lk_ceil = k_l + ng[KINC] 73 | ll_ceil = l_l + ng[LINC] 74 | nrys_roots = (li_ceil + lj_ceil + lk_ceil + ll_ceil)/2 + 1 75 | 76 | 77 | ibase = li_ceil > lj_ceil; 78 | kbase = lk_ceil > ll_ceil; 79 | if (nrys_roots <= 2): 80 | ibase = 0; 81 | kbase = 0; 82 | if (kbase) : 83 | dlk = lk_ceil + ll_ceil + 1; 84 | dll = ll_ceil + 1; 85 | else: 86 | dlk = lk_ceil + 1; 87 | dll = lk_ceil + ll_ceil + 1; 88 | 89 | if (ibase) : 90 | dli = li_ceil + lj_ceil + 1; 91 | dlj = lj_ceil + 1; 92 | else : 93 | dli = li_ceil + 1; 94 | dlj = li_ceil + lj_ceil + 1; 95 | 96 | g_stride_i = nrys_roots; 97 | g_stride_k = nrys_roots * dli; 98 | g_stride_l = nrys_roots * dli * dlk; 99 | g_stride_j = nrys_roots * dli * dlk * dll; 100 | g_size = nrys_roots * dli * dlk * dll * dlj; 101 | gbits = ng[GSHIFT]; 102 | leng = g_size*3*((1< I_max: 146 | I_max = abab 147 | 148 | # collect candidate pairs for s8 149 | considered_indices = [] 150 | tril_idx = np.tril_indices(N) 151 | for a, b in zip(tril_idx[0], tril_idx[1]): 152 | index_ab_s8 = a*(a+1)//2 + b 153 | index_s8 = index_ab_s8*(index_ab_s8+3)//2 154 | abab = np.abs(ERI_s8[index_s8]) 155 | if abab*I_max>=tolerance**2: 156 | considered_indices.append((a, b)) # collect candidate pairs for s8 157 | 158 | screened_indices_s8_4d = np.zeros(((len(considered_indices)*(len(considered_indices)+1)//2), 4), dtype=dtype) 159 | 160 | # generate s8 indices 161 | sid = 0 162 | for index, ab in enumerate(considered_indices): 163 | a, b = ab 164 | for cd in considered_indices[index:]: 165 | c, d = cd 166 | screened_indices_s8_4d[sid, :] = (a, b, c, d) 167 | sid += 1 168 | 169 | return screened_indices_s8_4d 170 | 171 | def remove_ortho(arr, nonzero_pattern, output_size, dtype=jnp.int16): 172 | assert dtype in [jnp.int16, jnp.int32, jnp.uint32] 173 | if dtype == jnp.int16: reinterpret_dtype = jnp.float16 174 | else: reinterpret_dtype = None 175 | 176 | def condition(i, j, k, l): 177 | return ~(nonzero_pattern[i] ^ nonzero_pattern[j]) ^ (nonzero_pattern[k] ^ nonzero_pattern[l]) 178 | 179 | def body_fun(carry, x): 180 | results, counter = carry 181 | x_reinterpret = jax.lax.bitcast_convert_type(x, dtype).astype(jnp.uint32) 182 | i, j, k, l = x_reinterpret 183 | 184 | def update_vals(carry): 185 | res, count, v = carry 186 | res = res.at[count].set(v) 187 | count = count + 1 188 | return res, count 189 | 190 | results, counter = jax.lax.cond(condition(i, j, k, l), (results, counter, x), update_vals, (results, counter), lambda x: x) 191 | return (results, counter), () 192 | 193 | 194 | init_results = jnp.zeros((output_size, arr.shape[1]), dtype=dtype) 195 | init_count = jnp.array(0, dtype=jnp.int32) 196 | 197 | if reinterpret_dtype is not None: 198 | init_results = jax.lax.bitcast_convert_type(init_results, reinterpret_dtype) 199 | arr = jax.lax.bitcast_convert_type(arr, reinterpret_dtype) 200 | 201 | (final_results, _), _ = jax.lax.scan(body_fun, (init_results, init_count), arr) 202 | 203 | final_results = jax.lax.bitcast_convert_type(final_results, dtype) 204 | 205 | return final_results 206 | vmap_remove_ortho = jax.vmap(remove_ortho, in_axes=(0, None, None), out_axes=0) 207 | 208 | def prepare_integrals_2_inputs(mol, itol): 209 | # Shapes/sizes. 210 | atm, bas, env = mol._atm, mol._bas, mol._env 211 | n_atm, n_bas, N = atm.shape[0], bas.shape[0], mol.nao_nr() 212 | ao_loc = np.cumsum(np.concatenate([np.zeros(1), (bas[:,1]*2+1) * bas[:,3] ])).astype(np.int32) 213 | n_ao_loc = np.prod(ao_loc.shape) 214 | shls_slice = (0, n_bas, 0, n_bas, 0, n_bas, 0, n_bas) 215 | shape = [1, N, N, N, N] 216 | 217 | # Initialize tensors for CPU libcint computation. 218 | buf = np.zeros(np.prod(shape)*2) 219 | out = np.zeros(shape) 220 | eri = np.zeros(shape).reshape(-1) 221 | ipu_eri = np.zeros(shape).reshape(-1) 222 | 223 | dtype = np.float32 #hardcoded 224 | buf, out, eri, ipu_eri, env = buf.astype(dtype), out.astype(dtype), eri.astype(dtype), ipu_eri.astype(dtype), env.astype(dtype) 225 | 226 | # The padded shape used to store output from all tiles. 227 | n_buf, n_eri, n_env = 81, 81, np.prod(env.shape) 228 | if mol.basis == "6-31G*": # has known error; please open github issue if you want to use 6-31G* 229 | n_buf = 5**4 230 | n_eri = 5**4 231 | 232 | # Compute how many distinct integrals after 8x symmetry. 233 | num_calls = 0 234 | for i in range(n_bas): 235 | for j in range(i+1): 236 | for k in range(i, n_bas): 237 | for l in range(k+1): 238 | # almost all 8-fold symmetry rules (except horizontal fade) 239 | num_calls+=1 240 | print('num_calls', num_calls) 241 | 242 | # Input/outputs for calling the IPU vertex. 243 | input_ijkl = np.zeros((num_calls, 4), dtype=np.int32) 244 | cpu_output = np.zeros((num_calls, n_eri), dtype=np.float32) 245 | output_sizes = np.zeros((num_calls, 5)) 246 | 247 | USE_TOLERANCE_THRESHOLD = True 248 | screened_indices_s8_4d = [] 249 | 250 | if USE_TOLERANCE_THRESHOLD: 251 | tolerance = itol 252 | print('computing ERI s8 to sample N*(N+1)/2 values... ', end='') 253 | ERI_s8 = mol.intor('int2e_sph', aosym='s8') 254 | print('done') 255 | 256 | # sample symmetry pattern and do safety check 257 | # if N % 2 == 0: 258 | # nonzero_seed = ERI[N-1, N-1, :N//2, 0] != 0 259 | # nonzero_seed = np.concatenate([nonzero_seed, np.flip(nonzero_seed)]) 260 | # else: 261 | # nonzero_seed = ERI[N-1, N-1, :(N+1)//2, 0] != 0 262 | # nonzero_seed = np.concatenate([nonzero_seed, np.flip(nonzero_seed[:-1])]) 263 | # if not np.equal(nonzero_seed, ERI[N-1, N-1, :, 0]!=0).all(): 264 | # print('# -------------------------------------------------------------- #') 265 | # print('# WARNING: Experimental symmetry pattern sample is inconsistent. #') 266 | # print('pred', nonzero_seed) 267 | # print('real', ERI[N-1, N-1, :, 0]!=0) 268 | # print('# -------------------------------------------------------------- #') 269 | 270 | # hardcoded symmetry pattern 271 | sym_pattern = np.array([(i+3)%5!=0 for i in range(N)]) 272 | nonzero_seed = sym_pattern 273 | 274 | if USE_TOLERANCE_THRESHOLD: 275 | # find max value 276 | I_max = 0 277 | tril_idx = np.tril_indices(N) 278 | for a, b in zip(tril_idx[0], tril_idx[1]): 279 | index_ab_s8 = a*(a+1)//2 + b 280 | index_s8 = index_ab_s8*(index_ab_s8+3)//2 281 | abab = np.abs(ERI_s8[index_s8]) 282 | if abab > I_max: 283 | I_max = abab 284 | 285 | # collect candidate pairs for s8 286 | considered_indices = [] 287 | tril_idx = np.tril_indices(N) 288 | for a, b in zip(tril_idx[0], tril_idx[1]): 289 | if USE_TOLERANCE_THRESHOLD: 290 | index_ab_s8 = a*(a+1)//2 + b 291 | index_s8 = index_ab_s8*(index_ab_s8+3)//2 292 | abab = np.abs(ERI_s8[index_s8]) 293 | if abab*I_max>=tolerance**2: 294 | considered_indices.append((a, b)) # collect candidate pairs for s8 295 | else: 296 | considered_indices.append((a, b)) # collect candidate pairs for s8 297 | considered_indices = set(considered_indices) 298 | 299 | print('n_bas', n_bas) 300 | print('ao_loc', ao_loc) 301 | 302 | # Fill input_ijkl and output_sizes with the necessary indices. 303 | n_items = 0 304 | n_all_integrals = 0 305 | n_new_integrals = 0 306 | for i in range(n_bas): 307 | for j in range(i+1): 308 | for k in range(i, n_bas): 309 | for l in range(k+1): 310 | di = ao_loc[i+1] - ao_loc[i] 311 | dj = ao_loc[j+1] - ao_loc[j] 312 | dk = ao_loc[k+1] - ao_loc[k] 313 | dl = ao_loc[l+1] - ao_loc[l] 314 | 315 | ia, ib = ao_loc[i], ao_loc[i+1] 316 | ja, jb = ao_loc[j], ao_loc[j+1] 317 | ka, kb = ao_loc[k], ao_loc[k+1] 318 | la, lb = ao_loc[l], ao_loc[l+1] 319 | 320 | n_all_integrals += di*dj*dk*dl 321 | 322 | found_nonzero = False 323 | # check i,j boxes 324 | for bi in range(ia, ib): 325 | for bj in range(ja, jb): 326 | if (bi, bj) in considered_indices: # if ij box is considered 327 | # check if kl pairs are considered 328 | for bk in range(ka, kb): 329 | if bk>=bi: # apply symmetry - tril fade vertical 330 | mla = la 331 | if bk == bi: 332 | mla = max(bj, la) # apply symmetry - tril fade horizontal 333 | for bl in range(mla, lb): 334 | if (bk, bl) in considered_indices: 335 | # apply grid pattern to find final nonzeros 336 | if ~(nonzero_seed[bi] ^ nonzero_seed[bj]) ^ (nonzero_seed[bk] ^ nonzero_seed[bl]): 337 | found_nonzero = True 338 | break 339 | if found_nonzero: break 340 | if found_nonzero: break 341 | if found_nonzero: break 342 | if not found_nonzero: continue 343 | 344 | n_new_integrals += di*dj*dk*dl 345 | 346 | input_ijkl[n_items] = [i, j, k, l] 347 | 348 | output_sizes[n_items] = [di, dj, dk, dl, di*dj*dk*dl] 349 | 350 | n_items += 1 351 | print('!!! saved', num_calls - n_items, 'calls i.e.', num_calls, '->', n_items) 352 | print('!!! saved', n_all_integrals - n_new_integrals, 'integrals i.e.', n_all_integrals, '->', n_new_integrals) 353 | 354 | num_calls = n_items 355 | input_ijkl = input_ijkl[:num_calls, :] 356 | cpu_output = cpu_output[:num_calls, :] 357 | output_sizes = output_sizes[:num_calls, :] 358 | 359 | # Prepare IPU inputs. 360 | # Merge all int/float inputs in seperate arrays. 361 | input_floats = env.reshape(1, -1) 362 | input_ints = np.zeros((1, 6+n_ao_loc +n_atm*6+n_bas*8), dtype=np.int32) 363 | start, stop = 0, 6 364 | input_ints[:, start:stop] = np.array( [n_eri, n_buf, n_atm, n_bas, n_env, n_ao_loc] ) 365 | start, stop = start+6, stop+n_ao_loc 366 | input_ints[:, start:stop] = ao_loc.reshape(-1) 367 | start, stop = start+n_ao_loc, stop + n_atm*6 368 | input_ints[:, start:stop] = atm.reshape(-1) 369 | start, stop = start+n_atm*6, stop + n_bas*8 370 | input_ints[:, start:stop] = bas.reshape(-1) 371 | 372 | sizes, counts = np.unique(output_sizes[:, -1], return_counts=True) 373 | sizes, counts = sizes.astype(np.int32), counts.astype(np.int32) 374 | 375 | indxs = np.argsort(output_sizes[:, -1]) 376 | sorted_output_sizes = output_sizes[indxs] 377 | input_ijkl = input_ijkl[indxs] 378 | 379 | sizes, counts = np.unique(output_sizes[:, -1], return_counts=True) 380 | sizes, counts = sizes.astype(np.int32), counts.astype(np.int32) 381 | start_index = 0 382 | inputs = [] 383 | shapes = [] 384 | for i, (size, count) in enumerate(zip(sizes, counts)): 385 | a = input_ijkl[start_index: start_index+count] 386 | tuples = tuple(map(tuple, a)) 387 | inputs.append(tuples) 388 | start_index += count 389 | 390 | tuple_ijkl = tuple(inputs) 391 | input_ijkl = inputs 392 | 393 | for i in range(len(sizes)): 394 | shapes.append(get_shapes(input_ijkl[i], bas)) 395 | 396 | return input_floats, input_ints, tuple_ijkl, tuple(shapes), tuple(sizes.tolist()), counts.tolist(), ao_loc, num_calls 397 | -------------------------------------------------------------------------------- /pyscf_ipu/nanoDFT/compute_indices.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "poplar/TileConstants.hpp" 8 | 9 | using namespace poplar; 10 | 11 | #ifdef __IPU__ 12 | // Use the IPU intrinsics 13 | #include 14 | #include 15 | #define NAMESPACE ipu 16 | #else 17 | // Use the std functions 18 | #include 19 | #define NAMESPACE std 20 | #endif 21 | 22 | class IndicesIJKL : public Vertex { 23 | public: 24 | Input> i_; 25 | Input> j_; 26 | Input> k_; 27 | Input> l_; 28 | 29 | Input> sym_; 30 | Input> N_; 31 | 32 | Input> start_; 33 | Input> stop_; 34 | 35 | Output> out_; 36 | 37 | bool compute() { 38 | 39 | const uint32_t& N = N_[0]; 40 | const uint32_t& sym = sym_[0]; 41 | const uint32_t& start = start_[0]; 42 | const uint32_t& stop = stop_[0]; 43 | 44 | for (uint32_t iteration = start; iteration < stop; iteration++){ 45 | 46 | const uint32_t& i = i_[iteration]; 47 | const uint32_t& j = j_[iteration]; 48 | const uint32_t& k = k_[iteration]; 49 | const uint32_t& l = l_[iteration]; 50 | 51 | uint32_t& out = out_[iteration]; 52 | 53 | //_compute_symmetry(ij, kl, N, sym, out[iteration]); 54 | switch (sym) { 55 | case 0: { out = i*N+j; break; } 56 | case 1: { out = j*N+i; break; } 57 | case 2: { out = i*N+j; break; } 58 | case 3: { out = j*N+i; break; } 59 | case 4: { out = k*N+l; break; } 60 | case 5: { out = l*N+k; break; } 61 | case 6: { out = k*N+l; break; } 62 | case 7: { out = l*N+k; break; } 63 | 64 | 65 | //ss_indices_func_J = lambda i,j,k,l,symmetry: jnp.array([k*N+l, k*N+l, l*N+k, l*N+k, i*N+j, i*N+j, j*N+i, j*N+i])[symmetry] 66 | case 8: { out = k*N+l; break; } 67 | case 9: { out = k*N+l; break; } 68 | case 10: { out = l*N+k; break; } 69 | case 11: { out = l*N+k; break; } 70 | case 12: { out = i*N+j; break; } 71 | case 13: { out = i*N+j; break; } 72 | case 14: { out = j*N+i; break; } 73 | case 15: { out = j*N+i; break; } 74 | 75 | 76 | //dm_indices_func_K = lambda i,j,k,l,symmetry: jnp.array([k*N+j, k*N+i, l*N+j, l*N+i, i*N+l, i*N+k, j*N+l, j*N+k])[symmetry] 77 | case 16: { out = k*N+j; break; } 78 | case 17: { out = k*N+i; break; } 79 | case 18: { out = l*N+j; break; } 80 | case 19: { out = l*N+i; break; } 81 | case 20: { out = i*N+l; break; } 82 | case 21: { out = i*N+k; break; } 83 | case 22: { out = j*N+l; break; } 84 | case 23: { out = j*N+k; break; } 85 | 86 | 87 | //ss_indices_func_K = lambda i,j,k,l,symmetry: jnp.array([i*N+l, j*N+l, i*N+k, j*N+k, k*N+j, l*N+j, k*N+i, l*N+i])[symmetry] 88 | case 24: { out = i*N+l; break; } 89 | case 25: { out = j*N+l; break; } 90 | case 26: { out = i*N+k; break; } 91 | case 27: { out = j*N+k; break; } 92 | case 28: { out = k*N+j; break; } 93 | case 29: { out = l*N+j; break; } 94 | case 30: { out = k*N+i; break; } 95 | case 31: { out = l*N+i; break; } 96 | } 97 | 98 | } 99 | return true; 100 | } 101 | }; -------------------------------------------------------------------------------- /pyscf_ipu/nanoDFT/cpu_sparse_symmetric_ERI.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pyscf 3 | import numpy as np 4 | import jax 5 | import jax.numpy as jnp 6 | import os.path as osp 7 | from functools import partial 8 | from icecream import ic 9 | jax.config.update('jax_platform_name', "cpu") 10 | #jax.config.update('jax_enable_x64', True) 11 | HYB_B3LYP = 0.2 12 | 13 | def get_i_j(val): 14 | i = (np.sqrt(1 + 8*val.astype(np.uint64)) - 1)//2 # no need for floor, integer division acts as floor. 15 | j = (((val - i) - (i**2 - val))//2) 16 | return i, j 17 | 18 | def cpu_ijkl(value, symmetry, N, f): 19 | i, j, k, l = value[0].astype(np.uint32), value[1].astype(np.uint32), value[2].astype(np.uint32), value[3].astype(np.uint32) 20 | return f(i,j,k,l,symmetry,N) 21 | cpu_ijkl = jax.vmap(cpu_ijkl, in_axes=(0, None, None, None)) 22 | 23 | from tessellate_ipu import create_ipu_tile_primitive, ipu_cycle_count, tile_map, tile_put_sharded, tile_put_replicated 24 | vertex_filename = osp.join(osp.dirname(__file__), "compute_indices.cpp") 25 | compute_indices= create_ipu_tile_primitive( 26 | "IndicesIJKL" , 27 | "IndicesIJKL" , 28 | inputs=["i_", "j_", "k_", "l_", "sym_", "N_", "start_", "stop_"], 29 | outputs={"out_": 0}, 30 | gp_filename=vertex_filename, 31 | perf_estimate=100, 32 | ) 33 | 34 | @partial(jax.jit, backend="ipu") 35 | def ipu_ijkl(nonzero_indices, symmetry, N): 36 | 37 | size = nonzero_indices.shape[0] 38 | total_threads = (1472-1) * 6 39 | remainder = size % total_threads 40 | 41 | i, j, k, l = [nonzero_indices[:, i].astype(np.uint32) for i in range(4)] 42 | 43 | if remainder != 0: 44 | i = jnp.pad(i, ((0, total_threads-remainder))) 45 | j = jnp.pad(j, ((0, total_threads-remainder))) 46 | k = jnp.pad(k, ((0, total_threads-remainder))) 47 | l = jnp.pad(l, ((0, total_threads-remainder))) 48 | 49 | i = i.reshape(total_threads, -1) 50 | j = j.reshape(total_threads, -1) 51 | k = k.reshape(total_threads, -1) 52 | l = l.reshape(total_threads, -1) 53 | 54 | stop = i.shape[1] 55 | 56 | tiles = tuple((np.arange(0,total_threads) % (1471) + 1).astype(np.uint32).tolist()) 57 | symmetry = tile_put_replicated(jnp.array(symmetry, dtype=jnp.uint32), tiles) 58 | N = tile_put_replicated(jnp.array(N, dtype=jnp.uint32), tiles) 59 | start = tile_put_replicated(jnp.array(0, dtype=jnp.uint32), tiles) 60 | stop = tile_put_replicated(jnp.array(stop, dtype=jnp.uint32), tiles) 61 | 62 | i = tile_put_sharded(i, tiles) 63 | j = tile_put_sharded(j, tiles) 64 | k = tile_put_sharded(k, tiles) 65 | l = tile_put_sharded(l, tiles) 66 | value = tile_map(compute_indices, i, j, k, l, symmetry, N, start, stop) 67 | 68 | return value.array.reshape(-1)[:size] 69 | 70 | def num_repetitions_fast(ij, kl): 71 | i, j = get_i_j(ij) 72 | k, l = get_i_j(kl) 73 | 74 | # compute: repetitions = 2^((i==j) + (k==l) + (k==i and l==j or k==j and l==i)) 75 | repetitions = 2**( 76 | np.equal(i,j).astype(np.uint64) + 77 | np.equal(k,l).astype(np.uint64) + 78 | (1 - ((1 - np.equal(k,i) * np.equal(l,j)) * 79 | (1- np.equal(k,j) * np.equal(l,i))).astype(np.uint64)) 80 | ) 81 | return repetitions 82 | 83 | indices_func = lambda i,j,k,l,symmetry,N: jnp.array([i*N+j, j*N+i, i*N+j, j*N+i, k*N+l, l*N+k, k*N+l, l*N+k, 84 | k*N+l, k*N+l, l*N+k, l*N+k, i*N+j, i*N+j, j*N+i, j*N+i, 85 | k*N+j, k*N+i, l*N+j, l*N+i, i*N+l, i*N+k, j*N+l, j*N+k, 86 | i*N+l, j*N+l, i*N+k, j*N+k, k*N+j, l*N+j, k*N+i, l*N+i])[symmetry] 87 | 88 | def sparse_symmetric_einsum(nonzero_distinct_ERI, nonzero_indices, dm, backend): 89 | dm = dm.reshape(-1) 90 | diff_JK = jnp.zeros(dm.shape) 91 | N = int(np.sqrt(dm.shape[0])) 92 | 93 | dnums = jax.lax.GatherDimensionNumbers( 94 | offset_dims=(), 95 | collapsed_slice_dims=(0,), 96 | start_index_map=(0,)) 97 | scatter_dnums = jax.lax.ScatterDimensionNumbers( 98 | update_window_dims=(), 99 | inserted_window_dims=(0,), 100 | scatter_dims_to_operand_dims=(0,)) 101 | 102 | def iteration(symmetry, vals): 103 | diff_JK = vals 104 | is_K_matrix = (symmetry >= 8) 105 | 106 | def sequentialized_iter(i, vals): 107 | # Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16) 108 | # Trade-off: Using one function leads to smaller always-live memory. 109 | diff_JK = vals 110 | 111 | indices = nonzero_indices[i] 112 | 113 | indices = jax.lax.bitcast_convert_type(indices, np.int16).astype(np.int32) 114 | eris = nonzero_distinct_ERI[i] 115 | 116 | if backend == "cpu": dm_indices = cpu_ijkl(indices, symmetry+is_K_matrix*8, N, indices_func).reshape(-1, 1) 117 | else: dm_indices = ipu_ijkl(indices, symmetry+is_K_matrix*8, N) .reshape(-1, 1) 118 | # dm_values = jnp.take(dm, indices, axis=0) # for our special case the 50 lines of code reduces to the one line below. 119 | print(indices.min(), indices.max(), dm_indices.min(), dm_indices.max(), dm.shape) 120 | dm_values = jax.lax.gather(dm, dm_indices, dimension_numbers=dnums, slice_sizes=(1,), mode=jax.lax.GatherScatterMode.FILL_OR_DROP) 121 | 122 | dm_values = dm_values.at[:].mul( eris ) # this is prod, but re-use variable for inplace update. 123 | 124 | if backend == "cpu": ss_indices = cpu_ijkl(indices, symmetry+8+is_K_matrix*8, N, indices_func) .reshape(-1,1) 125 | else: ss_indices = ipu_ijkl(indices, symmetry+8+is_K_matrix*8, N).astype(np.int32).reshape(-1,1) 126 | # diff_JK = diff_JK + jax.lax.segment_sum( ...) # for our special case the 100 lines of code reduces to the one line below. 127 | diff_JK = diff_JK + jax.lax.scatter_add(jnp.zeros((N**2,)), 128 | ss_indices, dm_values, 129 | scatter_dnums, indices_are_sorted=True, unique_indices=True, mode=jax.lax.GatherScatterMode.FILL_OR_DROP)\ 130 | *(-HYB_B3LYP/2)**is_K_matrix 131 | 132 | return diff_JK 133 | 134 | batches = nonzero_indices.shape[0] # Before pmap, tensor had shape (nipus, batches, -1) so [0]=batches after pmap 135 | diff_JK = jax.lax.fori_loop(0, batches, sequentialized_iter, diff_JK) 136 | #for i in range(batches): 137 | # diff_JK = sequentialized_iter(i, diff_JK) 138 | return diff_JK 139 | 140 | diff_JK = jax.lax.fori_loop(0, 16, iteration, diff_JK) 141 | #for i in range(16): 142 | # diff_JK = iteration(i, diff_JK) 143 | #return jax.lax.psum(diff_JK, axis_name="p") 144 | return diff_JK#[0] 145 | 146 | if __name__ == "__main__": 147 | import time 148 | import argparse 149 | parser = argparse.ArgumentParser(prog='', description='', epilog='') 150 | parser.add_argument('-backend', default="cpu"), 151 | parser.add_argument('-natm', default=3), 152 | parser.add_argument('-test', action="store_true") 153 | parser.add_argument('-prof', action="store_true") 154 | parser.add_argument('-batches', default=5) 155 | parser.add_argument('-nipu', default=16, type=int) 156 | parser.add_argument('-skip', action="store_true") 157 | 158 | args = parser.parse_args() 159 | backend = args.backend 160 | 161 | natm = int(args.natm) 162 | nipu = int(args.nipu) 163 | if backend == "cpu": nipu = 1 164 | 165 | start = time.time() 166 | 167 | mol = pyscf.gto.Mole(atom="".join(f"C 0 {1.54*j} {1.54*i};" for i in range(natm) for j in range(natm))) 168 | #mol = pyscf.gto.Mole(atom="".join(f"C 0 {15.4*j} {15.4*i};" for i in range(1) for j in range(75))) 169 | mol.build() 170 | N = mol.nao_nr() 171 | print("N %i"%mol.nao_nr()) 172 | print("NxN:", (N**2, N**2)) 173 | print("Naive operations: ", N**4*2/10**9, "[Giga]") 174 | if not args.skip: dense_ERI = mol.intor("int2e_sph", aosym="s1") 175 | distinct_ERI = mol.intor("int2e_sph", aosym="s8") 176 | distinct_ERI[np.abs(distinct_ERI)<1e-9] = 0 # zero out stuff 177 | dm = pyscf.scf.hf.init_guess_by_minao(mol) 178 | scale = HYB_B3LYP/2 179 | if not args.skip: 180 | J = np.einsum("ijkl,ji->kl", dense_ERI, dm) 181 | K = np.einsum("ijkl,jk->il", dense_ERI, dm) 182 | truth = J - K / 2 * HYB_B3LYP 183 | 184 | nonzero_indices = np.nonzero(distinct_ERI)[0].astype(np.uint64) 185 | nonzero_distinct_ERI = distinct_ERI[nonzero_indices].astype(np.float32) 186 | print("Nonzero Operations:", nonzero_indices.size*8*2/10**9, "[Giga]") 187 | ij, kl = get_i_j(nonzero_indices) 188 | 189 | 190 | 191 | rep = num_repetitions_fast(ij, kl) 192 | nonzero_distinct_ERI = nonzero_distinct_ERI / rep 193 | dm = dm.reshape(-1) 194 | diff_JK = np.zeros(dm.shape) 195 | 196 | batches = int(args.batches) # perhaps make 10 batches? 197 | remainder = nonzero_indices.shape[0] % (nipu*batches) 198 | 199 | if remainder != 0: 200 | print(nipu*batches-remainder, ij.shape) 201 | ij = np.pad(ij, ((0,nipu*batches-remainder))) 202 | kl = np.pad(kl, ((0,nipu*batches-remainder))) 203 | nonzero_distinct_ERI = np.pad(nonzero_distinct_ERI, (0,nipu*batches-remainder)) 204 | 205 | indxs = np.argsort(kl) 206 | print(kl.shape) 207 | ij = ij[indxs] 208 | kl = kl[indxs] 209 | print(kl.shape) 210 | nonzero_distinct_ERI = nonzero_distinct_ERI[indxs] 211 | 212 | ij = ij.reshape(nipu, batches, -1) 213 | kl = kl.reshape(nipu, batches, -1) 214 | nonzero_distinct_ERI = nonzero_distinct_ERI.reshape(nipu, batches, -1) 215 | 216 | print("---") 217 | for i in range(batches): 218 | print(kl[:, i].min(), kl[:, i].max(), dm.shape) 219 | 220 | 221 | print("---") 222 | 223 | 224 | 225 | i, j = get_i_j(ij.reshape(-1)) 226 | k, l = get_i_j(kl.reshape(-1)) 227 | nonzero_indices = np.vstack([i,j,k,l]).T.reshape(nipu, batches, -1, 4).astype(np.int16) 228 | nonzero_indices = jax.lax.bitcast_convert_type(nonzero_indices, np.float16) 229 | 230 | diff_JK = jax.pmap(sparse_symmetric_einsum, in_axes=(0,0,None,None), static_broadcasted_argnums=(3,), backend=backend, axis_name="p")(nonzero_distinct_ERI, nonzero_indices, dm, args.backend) 231 | #diff_JK = sparse_symmetric_einsum(nonzero_distinct_ERI[0], nonzero_indices[0], dm, args.backend) 232 | 233 | if args.skip: 234 | exit() 235 | if args.nipu > 1: 236 | diff_JK = np.array(diff_JK[0]) 237 | 238 | diff_JK = diff_JK.reshape(N, N) 239 | print(diff_JK.reshape(-1)[::51]) 240 | print(truth.reshape(-1)[::51]) 241 | print(np.max(np.abs(diff_JK.reshape(-1) - truth.reshape(-1)))) 242 | assert np.allclose(diff_JK, truth, atol=1e-6) 243 | print("PASSED!") -------------------------------------------------------------------------------- /pyscf_ipu/nanoDFT/h2o.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/pyscf-ipu/43a9bb343acbd296aec1d1f64c309bb24920d98d/pyscf_ipu/nanoDFT/h2o.gif -------------------------------------------------------------------------------- /pyscf_ipu/nanoDFT/sparse_ERI.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import numpy as np 3 | import jax.numpy as jnp 4 | import os 5 | import pyscf 6 | import jax 7 | jax.config.update('jax_enable_x64', True) 8 | jax.config.update('jax_platform_name', "cpu") 9 | os.environ['OMP_NUM_THREADS'] = "16" 10 | 11 | # Construct molecule we can use for test case. 12 | mol = pyscf.gto.Mole(atom=[["C", (0, 0, i)] for i in range(16)], basis="sto3g") 13 | mol.build() 14 | N = mol.nao_nr() # N: number of atomic orbitals (AO) 15 | density_matrix = pyscf.scf.hf.init_guess_by_minao(mol) # (N, N) 16 | ERI = mol.intor("int2e_sph") # (N, N, N, N) 17 | print(ERI.shape, density_matrix.shape) 18 | 19 | # nanoDFT uses ERI in an "einsum" which is equal to matrix vector multiplication. 20 | truth = jnp.einsum('ijkl,ji->kl', ERI, density_matrix) # <--- einsum 21 | ERI = ERI.reshape(N**2, N**2) 22 | density_matrix = density_matrix.reshape(N**2) 23 | print(ERI.shape, density_matrix.shape) 24 | alternative = ERI @ density_matrix # <--- matrix vector mult 25 | alternative = alternative.reshape(N, N) 26 | 27 | assert np.allclose(truth, alternative) # they're equal! 28 | 29 | # First trick. The matrix is sparse! 30 | print(f"The matrix is {np.around(np.sum(ERI == 0)/ERI.size*100, 2)}% zeros!") 31 | 32 | def sparse_representation(ERI): 33 | rows, cols = np.nonzero(ERI) 34 | values = ERI[rows, cols] 35 | return rows, cols, values 36 | 37 | def sparse_mult(sparse, vector): 38 | rows, cols, values = sparse 39 | in_ = vector.take(cols, axis=0) 40 | prod = in_*values 41 | segment_sum = jax.ops.segment_sum(prod, rows, N**2) 42 | return segment_sum 43 | 44 | sparse_ERI = sparse_representation(ERI) 45 | res = jax.jit(sparse_mult, backend="cpu")(sparse_ERI, density_matrix).reshape(N, N) 46 | 47 | assert np.allclose(truth, res) 48 | 49 | # Problem: 50 | # If I increase molecule size to N=256 it all fits IPU but I get a memory spike. 51 | # I currently believe memory spike is caused by vector.take (and/or) jax.ops.segment_sum. 52 | # 53 | # Instead of doing all of ERI@density_matrix in sparse_mult at one go, we can just 54 | # have a for loop and to k rows at a time. This will require changing the sparse representation to 55 | # take a parameter k. 56 | # -------------------------------------------------------------------------------- /pyscf_ipu/nanoDFT/sparse_symmetric_ERI.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import pyscf 3 | import numpy as np 4 | import jax 5 | import jax.numpy as jnp 6 | import os.path as osp 7 | from functools import partial 8 | from icecream import ic 9 | jax.config.update('jax_platform_name', "cpu") 10 | #jax.config.update('jax_enable_x64', True) 11 | HYB_B3LYP = 0.2 12 | 13 | def get_i_j(val): 14 | i = (np.sqrt(1 + 8*val.astype(np.uint64)) - 1)//2 # no need for floor, integer division acts as floor. 15 | j = (((val - i) - (i**2 - val))//2) 16 | return i, j 17 | 18 | def cpu_ijkl(value, symmetry, N, f): 19 | i, j, k, l = value[0].astype(np.uint32), value[1].astype(np.uint32), value[2].astype(np.uint32), value[3].astype(np.uint32) 20 | return f(i,j,k,l,symmetry,N) 21 | cpu_ijkl = jax.vmap(cpu_ijkl, in_axes=(0, None, None, None)) 22 | 23 | @partial(jax.jit, backend="ipu") 24 | def ipu_ijkl(nonzero_indices, symmetry, N): 25 | from tessellate_ipu import create_ipu_tile_primitive, ipu_cycle_count, tile_map, tile_put_sharded, tile_put_replicated 26 | vertex_filename = osp.join(osp.dirname(__file__), "compute_indices.cpp") 27 | compute_indices= create_ipu_tile_primitive( 28 | "IndicesIJKL" , 29 | "IndicesIJKL" , 30 | inputs=["i_", "j_", "k_", "l_", "sym_", "N_", "start_", "stop_"], 31 | outputs={"out_": 0}, 32 | gp_filename=vertex_filename, 33 | perf_estimate=100, 34 | ) 35 | size = nonzero_indices.shape[0] 36 | total_threads = (1472-1) * 6 37 | remainder = size % total_threads 38 | 39 | i, j, k, l = [nonzero_indices[:, i].astype(np.uint32) for i in range(4)] 40 | 41 | if remainder != 0: 42 | i = jnp.pad(i, ((0, total_threads-remainder))) 43 | j = jnp.pad(j, ((0, total_threads-remainder))) 44 | k = jnp.pad(k, ((0, total_threads-remainder))) 45 | l = jnp.pad(l, ((0, total_threads-remainder))) 46 | 47 | i = i.reshape(total_threads, -1) 48 | j = j.reshape(total_threads, -1) 49 | k = k.reshape(total_threads, -1) 50 | l = l.reshape(total_threads, -1) 51 | 52 | stop = i.shape[1] 53 | 54 | tiles = tuple((np.arange(0,total_threads) % (1471) + 1).astype(np.uint32).tolist()) 55 | symmetry = tile_put_replicated(jnp.array(symmetry, dtype=jnp.uint32), tiles) 56 | N = tile_put_replicated(jnp.array(N, dtype=jnp.uint32), tiles) 57 | start = tile_put_replicated(jnp.array(0, dtype=jnp.uint32), tiles) 58 | stop = tile_put_replicated(jnp.array(stop, dtype=jnp.uint32), tiles) 59 | 60 | i = tile_put_sharded(i, tiles) 61 | j = tile_put_sharded(j, tiles) 62 | k = tile_put_sharded(k, tiles) 63 | l = tile_put_sharded(l, tiles) 64 | value = tile_map(compute_indices, i, j, k, l, symmetry, N, start, stop) 65 | 66 | return value.array.reshape(-1)[:size] 67 | 68 | def num_repetitions_fast(ij, kl): 69 | i, j = get_i_j(ij) 70 | k, l = get_i_j(kl) 71 | 72 | # compute: repetitions = 2^((i==j) + (k==l) + (k==i and l==j or k==j and l==i)) 73 | repetitions = 2**( 74 | np.equal(i,j).astype(np.uint64) + 75 | np.equal(k,l).astype(np.uint64) + 76 | (1 - ((1 - np.equal(k,i) * np.equal(l,j)) * 77 | (1- np.equal(k,j) * np.equal(l,i))).astype(np.uint64)) 78 | ) 79 | return repetitions 80 | 81 | indices_func = lambda i,j,k,l,symmetry,N: jnp.array([i*N+j, j*N+i, i*N+j, j*N+i, k*N+l, l*N+k, k*N+l, l*N+k, 82 | k*N+l, k*N+l, l*N+k, l*N+k, i*N+j, i*N+j, j*N+i, j*N+i, 83 | k*N+j, k*N+i, l*N+j, l*N+i, i*N+l, i*N+k, j*N+l, j*N+k, 84 | i*N+l, j*N+l, i*N+k, j*N+k, k*N+j, l*N+j, k*N+i, l*N+i])[symmetry] 85 | 86 | def sparse_symmetric_einsum(nonzero_distinct_ERI, nonzero_indices, dm, backend): 87 | dm = dm.reshape(-1) 88 | diff_JK = jnp.zeros(dm.shape) 89 | N = int(np.sqrt(dm.shape[0])) 90 | 91 | dnums = jax.lax.GatherDimensionNumbers( 92 | offset_dims=(), 93 | collapsed_slice_dims=(0,), 94 | start_index_map=(0,)) 95 | scatter_dnums = jax.lax.ScatterDimensionNumbers( 96 | update_window_dims=(), 97 | inserted_window_dims=(0,), 98 | scatter_dims_to_operand_dims=(0,)) 99 | 100 | def iteration(symmetry, vals): 101 | diff_JK = vals 102 | is_K_matrix = (symmetry >= 8) 103 | 104 | def sequentialized_iter(i, vals): 105 | # Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16) 106 | # Trade-off: Using one function leads to smaller always-live memory. 107 | diff_JK = vals 108 | 109 | indices = nonzero_indices[i] 110 | print(nonzero_indices.shape, indices.shape) 111 | 112 | indices = jax.lax.bitcast_convert_type(indices, np.int16).astype(np.int32) 113 | eris = nonzero_distinct_ERI[i] 114 | 115 | if backend == "cpu": dm_indices = cpu_ijkl(indices, symmetry+is_K_matrix*8, N, indices_func).reshape(-1, 1) 116 | else: dm_indices = ipu_ijkl(indices, symmetry+is_K_matrix*8, N) .reshape(-1, 1) 117 | # dm_values = jnp.take(dm, indices, axis=0) # for our special case the 50 lines of code reduces to the one line below. 118 | dm_values = jax.lax.gather(dm, dm_indices, dimension_numbers=dnums, slice_sizes=(1,), mode=jax.lax.GatherScatterMode.FILL_OR_DROP) 119 | 120 | dm_values = dm_values.at[:].mul( eris ) # this is prod, but re-use variable for inplace update. 121 | 122 | if backend == "cpu": ss_indices = cpu_ijkl(indices, symmetry+8+is_K_matrix*8, N, indices_func) .reshape(-1,1) 123 | else: ss_indices = ipu_ijkl(indices, symmetry+8+is_K_matrix*8, N).astype(np.int32).reshape(-1,1) 124 | # diff_JK = diff_JK + jax.lax.segment_sum( ...) # for our special case the 100 lines of code reduces to the one line below. 125 | diff_JK = diff_JK + jax.lax.scatter_add(jnp.zeros((N**2,)), 126 | ss_indices, dm_values, 127 | scatter_dnums, indices_are_sorted=True, unique_indices=True, mode=jax.lax.GatherScatterMode.FILL_OR_DROP)\ 128 | *(-HYB_B3LYP/2)**is_K_matrix 129 | 130 | return diff_JK 131 | 132 | batches = nonzero_indices.shape[0] # Before pmap, tensor had shape (nipus, batches, -1) so [0]=batches after pmap 133 | diff_JK = jax.lax.fori_loop(0, batches, sequentialized_iter, diff_JK) 134 | return diff_JK 135 | 136 | diff_JK = jax.lax.fori_loop(0, 16, iteration, diff_JK) 137 | return jax.lax.psum(diff_JK, axis_name="p") 138 | 139 | if __name__ == "__main__": 140 | import time 141 | import argparse 142 | parser = argparse.ArgumentParser(prog='', description='', epilog='') 143 | parser.add_argument('-backend', default="cpu"), 144 | parser.add_argument('-natm', default=3), 145 | parser.add_argument('-test', action="store_true") 146 | parser.add_argument('-prof', action="store_true") 147 | parser.add_argument('-batches', default=5) 148 | parser.add_argument('-nipu', default=16, type=int) 149 | parser.add_argument('-skip', action="store_true") 150 | 151 | args = parser.parse_args() 152 | backend = args.backend 153 | 154 | natm = int(args.natm) 155 | nipu = int(args.nipu) 156 | if backend == "cpu": nipu = 1 157 | 158 | start = time.time() 159 | 160 | mol = pyscf.gto.Mole(atom="".join(f"C 0 {1.54*j} {1.54*i};" for i in range(natm) for j in range(natm))) 161 | #mol = pyscf.gto.Mole(atom="".join(f"C 0 {15.4*j} {15.4*i};" for i in range(1) for j in range(75))) 162 | mol.build() 163 | N = mol.nao_nr() 164 | print("N %i"%mol.nao_nr()) 165 | print("NxN:", (N**2, N**2)) 166 | print("Naive operations: ", N**4*2/10**9, "[Giga]") 167 | if not args.skip: dense_ERI = mol.intor("int2e_sph", aosym="s1") 168 | distinct_ERI = mol.intor("int2e_sph", aosym="s8") 169 | distinct_ERI[np.abs(distinct_ERI)<1e-9] = 0 # zero out stuff 170 | dm = pyscf.scf.hf.init_guess_by_minao(mol) 171 | scale = HYB_B3LYP/2 172 | if not args.skip: 173 | J = np.einsum("ijkl,ji->kl", dense_ERI, dm) 174 | K = np.einsum("ijkl,jk->il", dense_ERI, dm) 175 | truth = J - K / 2 * HYB_B3LYP 176 | 177 | nonzero_indices = np.nonzero(distinct_ERI)[0].astype(np.uint64) 178 | nonzero_distinct_ERI = distinct_ERI[nonzero_indices].astype(np.float32) 179 | print("Nonzero Operations:", nonzero_indices.size*8*2/10**9, "[Giga]") 180 | ij, kl = get_i_j(nonzero_indices) 181 | rep = num_repetitions_fast(ij, kl) 182 | nonzero_distinct_ERI = nonzero_distinct_ERI / rep 183 | dm = dm.reshape(-1) 184 | diff_JK = np.zeros(dm.shape) 185 | 186 | batches = int(args.batches) # perhaps make 10 batches? 187 | remainder = nonzero_indices.shape[0] % (nipu*batches) 188 | 189 | if remainder != 0: 190 | print(nipu*batches-remainder, ij.shape) 191 | ij = np.pad(ij, ((0,nipu*batches-remainder))) 192 | kl = np.pad(kl, ((0,nipu*batches-remainder))) 193 | nonzero_distinct_ERI = np.pad(nonzero_distinct_ERI, (0,nipu*batches-remainder)) 194 | 195 | ij = ij.reshape(nipu, batches, -1) 196 | kl = kl.reshape(nipu, batches, -1) 197 | nonzero_distinct_ERI = nonzero_distinct_ERI.reshape(nipu, batches, -1) 198 | 199 | i, j = get_i_j(ij.reshape(-1)) 200 | k, l = get_i_j(kl.reshape(-1)) 201 | nonzero_indices = np.vstack([i,j,k,l]).T.reshape(nipu, batches, -1, 4).astype(np.int16) 202 | nonzero_indices = jax.lax.bitcast_convert_type(nonzero_indices, np.float16) 203 | 204 | diff_JK = jax.pmap(sparse_symmetric_einsum, in_axes=(0,0,None,None), static_broadcasted_argnums=(3,), backend=backend, axis_name="p")(nonzero_distinct_ERI, nonzero_indices, dm, args.backend) 205 | 206 | if args.skip: 207 | exit() 208 | if args.nipu > 1: 209 | diff_JK = np.array(diff_JK[0]) 210 | 211 | diff_JK = diff_JK.reshape(N, N) 212 | print(diff_JK.reshape(-1)[::51]) 213 | print(truth.reshape(-1)[::51]) 214 | print(np.max(np.abs(diff_JK.reshape(-1) - truth.reshape(-1)))) 215 | assert np.allclose(diff_JK, truth, atol=1e-6) 216 | print("PASSED!") -------------------------------------------------------------------------------- /pyscf_ipu/nanoDFT/symmetric_ERI.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import numpy as np 3 | import jax.numpy as jnp 4 | import os 5 | import pyscf 6 | import jax 7 | jax.config.update('jax_enable_x64', True) 8 | jax.config.update('jax_platform_name', "cpu") 9 | os.environ['OMP_NUM_THREADS'] = "16" 10 | 11 | # Construct molecule we can use for test case. 12 | mol = pyscf.gto.Mole(atom=[["C", (0, 0, i)] for i in range(8)], basis="sto3g") 13 | mol.build() 14 | N = mol.nao_nr() # N: number of atomic orbitals (AO) 15 | density_matrix = pyscf.scf.hf.init_guess_by_minao(mol) # (N, N) 16 | ERI = mol.intor("int2e_sph") # (N, N, N, N) 17 | print(ERI.shape, density_matrix.shape) 18 | 19 | # ERI satisfies the following symmetry where we can interchange (k,l) with (l,k) 20 | # ERI[i,j,k,l]=ERI[i,j,l,k] 21 | i,j,k,l = 5,10,20,25 22 | print(ERI[i,j,k,l], ERI[i,j,l,k]) 23 | assert np.allclose(ERI[i,j,k,l], ERI[i,j,l,k]) 24 | 25 | # In turns out all of the following indices can be interchagned! 26 | # ERI[ijkl]=ERI[ijlk]=ERI[jikl]=ERI[jilk]=ERI[lkij]=ERI[lkji]=ERI[lkij]=ERI[lkji] 27 | print(ERI[i,j,k,l], ERI[i,j,l,k], ERI[j,i,k,l], ERI[j,i,l,k], ERI[l,k,i,j], ERI[l,k,j,i], ERI[l,k,i,j], ERI[l,k,j,i]) 28 | 29 | # Recall sparse_ERI.py uses the following matrix vector multiplication. 30 | ERI = ERI.reshape(N**2, N**2) 31 | density_matrix = density_matrix.reshape(N**2) 32 | truth = ERI @ density_matrix 33 | 34 | # But most of ERI are zeros! We therefore use a sparse matrix multiplicaiton (See below) 35 | print(f"The matrix is {np.around(np.sum(ERI == 0)/ERI.size*100, 2)}% zeros!") 36 | 37 | def sparse_representation(ERI): 38 | rows, cols = np.nonzero(ERI) 39 | values = ERI[rows, cols] 40 | return rows, cols, values 41 | 42 | def sparse_mult(sparse, vector): 43 | rows, cols, values = sparse 44 | in_ = vector.take(cols, axis=0) 45 | prod = in_*values 46 | segment_sum = jax.ops.segment_sum(prod, rows, N**2) 47 | return segment_sum 48 | 49 | sparse_ERI = sparse_representation(ERI) 50 | res = jax.jit(sparse_mult, backend="cpu")(sparse_ERI, density_matrix) 51 | 52 | assert np.allclose(truth, res) 53 | 54 | # Here's the problem. 55 | # Most entries in ERI are repeated 8 times due to ERI[i,j,k,l]=ERI[i,j,l,k]=... 56 | # We can therefore save 8x memory by only storing each element once! 57 | # When we do matrix vector multiplication and need E[i,j,l,k] we then just look at ERI[i,j,k,l] instead. 58 | # This is what the ipu_einsum does in nanoDFT. 59 | # After looking at sparse_ERI.py, I think we may be able to do this using the same take/segment_sum tricks! 60 | # Because of this I think we may get an sufficiently efficient implementation in Jax (we may later choose to move it to poplar). 61 | 62 | # Potentially tricky part: 63 | # My first idea was to make a dictionary that maps ijkl -> ijlk (and so on). 64 | # We should do this, it's an useful exercise. Unfortunately, this will take as much memory as storing ERI, 65 | # so we wont win anything. We thus have to instead compute the ijkl translation on the go. 66 | # A further caveat is that we'll need to sequentialize; if we make the segment sum over the entire thing this will also take too much memory. -------------------------------------------------------------------------------- /pyscf_ipu/pyscf_utils/build_grid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import numpy as np 3 | import numpy 4 | from pyscf import gto 5 | 6 | GROUP_BOX_SIZE = 1.2 7 | GROUP_BOUNDARY_PENALTY = 4.2 8 | def arg_group_grids(mol, coords, box_size=GROUP_BOX_SIZE): 9 | ''' 10 | Parition the entire space into small boxes according to the input box_size. 11 | Group the grids against these boxes. 12 | ''' 13 | import numpy 14 | atom_coords = mol.atom_coords() 15 | boundary = [atom_coords.min(axis=0) - GROUP_BOUNDARY_PENALTY, atom_coords.max(axis=0) + GROUP_BOUNDARY_PENALTY] 16 | # how many boxes inside the boundary 17 | boxes = ((boundary[1] - boundary[0]) * (1./box_size)).round().astype(int) 18 | tot_boxes = numpy.prod(boxes + 2) 19 | #logger.debug(mol, 'tot_boxes %d, boxes in each direction %s', tot_boxes, boxes) 20 | # box_size is the length of each edge of the box 21 | box_size = (boundary[1] - boundary[0]) / boxes 22 | frac_coords = (coords - boundary[0]) * (1./box_size) 23 | box_ids = numpy.floor(frac_coords).astype(int) 24 | box_ids[box_ids<-1] = -1 25 | box_ids[box_ids[:,0] > boxes[0], 0] = boxes[0] 26 | box_ids[box_ids[:,1] > boxes[1], 1] = boxes[1] 27 | box_ids[box_ids[:,2] > boxes[2], 2] = boxes[2] 28 | rev_idx, counts = numpy.unique(box_ids, axis=0, return_inverse=True, return_counts=True)[1:3] 29 | return rev_idx.argsort(kind='stable') 30 | 31 | from pyscf.dft import radi 32 | import numpy 33 | 34 | def original_becke(g): 35 | '''Becke, JCP 88, 2547 (1988); DOI:10.1063/1.454033''' 36 | g = (3 - g**2) * g * .5 37 | g = (3 - g**2) * g * .5 38 | g = (3 - g**2) * g * .5 39 | return g 40 | 41 | 42 | # TODO: refactor this to be jnp 43 | # will be easy to rerwite, main problem will be figuring out how to have it nicely interact with jax.jit 44 | def _get_partition(mol, atom_grids_tab, 45 | radii_adjust=None, atomic_radii=radi.BRAGG_RADII, 46 | becke_scheme=original_becke, concat=True): 47 | '''Generate the mesh grid coordinates and weights for DFT numerical integration. 48 | We can change radii_adjust, becke_scheme functions to generate different meshgrid. 49 | 50 | Kwargs: 51 | concat: bool 52 | Whether to concatenate grids and weights in return 53 | 54 | Returns: 55 | grid_coord and grid_weight arrays. grid_coord array has shape (N,3); 56 | weight 1D array has N elements. 57 | ''' 58 | if callable(radii_adjust) and atomic_radii is not None: 59 | f_radii_adjust = radii_adjust(mol, atomic_radii) 60 | else: 61 | f_radii_adjust = None 62 | atm_coords = numpy.asarray(mol.atom_coords() , order='C') 63 | atm_dist = gto.inter_distance(mol) 64 | 65 | from pyscf import lib 66 | 67 | def gen_grid_partition(coords): 68 | ngrids = coords.shape[0] 69 | #grid_dist = numpy.empty((mol.natm,ngrids)) 70 | grid_dist = numpy.empty((mol.natm,ngrids)) 71 | for ia in range(mol.natm): 72 | dc = coords - atm_coords[ia] 73 | grid_dist[ia] = numpy.sqrt(numpy.einsum('ij,ij->i',dc,dc)) 74 | pbecke = numpy.ones((mol.natm,ngrids)) 75 | for i in range(mol.natm): 76 | for j in range(i): 77 | g = 1/atm_dist[i,j] * (grid_dist[i]-grid_dist[j]) 78 | if f_radii_adjust is not None: 79 | g = f_radii_adjust(i, j, g) 80 | #g = becke_scheme(g)# gets passed the one which returns None 81 | g = original_becke(g) 82 | #print(g) 83 | pbecke[i] *= .5 * (1-g) 84 | pbecke[j] *= .5 * (1+g) 85 | return pbecke 86 | 87 | coords_all = [] 88 | weights_all = [] 89 | for ia in range(mol.natm): 90 | coords, vol = atom_grids_tab[mol.atom_symbol(ia)] 91 | coords = coords + atm_coords[ia] 92 | pbecke = gen_grid_partition(coords) 93 | weights = vol * pbecke[ia] * (1./pbecke.sum(axis=0)) 94 | coords_all.append(coords) 95 | weights_all.append(weights) 96 | 97 | if concat: 98 | coords_all = numpy.vstack(coords_all) 99 | weights_all = numpy.hstack(weights_all) 100 | return coords_all, weights_all 101 | 102 | def get_partition(self, mol, atom_grids_tab=None, 103 | radii_adjust=None, atomic_radii=radi.BRAGG_RADII, 104 | becke_scheme=original_becke, concat=True): 105 | if atom_grids_tab is None: 106 | atom_grids_tab = self.gen_atomic_grids(mol) 107 | return _get_partition(mol, atom_grids_tab, radii_adjust, atomic_radii, becke_scheme, concat=concat) 108 | 109 | def build_grid(self): 110 | with_non0tab=False 111 | sort_grids=True 112 | mol = self.mol 113 | 114 | atom_grids_tab = self.gen_atomic_grids( mol, self.atom_grid, self.radi_method, self.level, self.prune) 115 | self.coords, self.weights = get_partition(self, mol, atom_grids_tab, self.radii_adjust, self.atomic_radii, self.becke_scheme) 116 | idx = arg_group_grids(mol, self.coords) 117 | self.coords = self.coords[idx] 118 | self.weights = self.weights[idx] 119 | 120 | if self.alignment > 1: 121 | def _padding_size(ngrids, alignment): 122 | if alignment <= 1: 123 | return 0 124 | return (ngrids + alignment - 1) // alignment * alignment - ngrids 125 | 126 | padding = _padding_size(self.size, self.alignment) 127 | #logger.debug(self, 'Padding %d grids', padding) 128 | if padding > 0: 129 | self.coords = numpy.vstack( 130 | [self.coords, numpy.repeat([[1e4]*3], padding, axis=0)]) 131 | self.weights = numpy.hstack([self.weights, numpy.zeros(padding)]) 132 | 133 | self.screen_index = self.non0tab = None 134 | 135 | return self -------------------------------------------------------------------------------- /pyscf_ipu/pyscf_utils/build_mol.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import sys 3 | 4 | 5 | def build_mol(self, dump_input=True, parse_arg=True, 6 | verbose=None, output=None, max_memory=None, 7 | atom=None, basis=None, unit=None, nucmod=None, ecp=None, 8 | charge=None, spin=0, symmetry=None, symmetry_subgroup=None, 9 | cart=None, magmom=None): 10 | 11 | if sys.version_info >= (3,): 12 | unicode = str 13 | #print(unicode) 14 | #print("ASD") 15 | #exit() 16 | from pyscf import __config__ 17 | DISABLE_GC = getattr(__config__, 'DISABLE_GC', False) 18 | 19 | if not DISABLE_GC and False: 20 | gc.collect() # To release circular referred objects 21 | pass 22 | 23 | if isinstance(dump_input, (str, unicode)): 24 | sys.stderr.write('Assigning the first argument %s to mol.atom\n' % 25 | dump_input) 26 | dump_input, atom = True, dump_input 27 | 28 | if verbose is not None: self.verbose = verbose 29 | if output is not None: self.output = output 30 | if max_memory is not None: self.max_memory = max_memory 31 | if atom is not None: self.atom = atom 32 | if basis is not None: self.basis = basis 33 | if unit is not None: self.unit = unit 34 | if nucmod is not None: self.nucmod = nucmod 35 | if ecp is not None: self.ecp = ecp 36 | if charge is not None: self.charge = charge 37 | if spin != 0: self.spin = spin 38 | if symmetry is not None: self.symmetry = symmetry 39 | if symmetry_subgroup is not None: self.symmetry_subgroup = symmetry_subgroup 40 | if cart is not None: self.cart = cart 41 | if magmom is not None: self.magmom = magmom 42 | 43 | def _update_from_cmdargs_(mol): 44 | try: 45 | # Detect whether in Ipython shell 46 | __IPYTHON__ # noqa: 47 | return 48 | except Exception: 49 | pass 50 | 51 | if not mol._built: # parse cmdline args only once 52 | opts = cmd_args.cmd_args() 53 | 54 | if opts.verbose: 55 | mol.verbose = opts.verbose 56 | if opts.max_memory: 57 | mol.max_memory = opts.max_memory 58 | 59 | if opts.output: 60 | mol.output = opts.output 61 | 62 | 63 | self._atom = self.format_atom(self.atom, unit=self.unit) 64 | uniq_atoms = set([a[0] for a in self._atom]) 65 | 66 | if isinstance(self.basis, (str, unicode, tuple, list)): 67 | # specify global basis for whole molecule 68 | _basis = dict(((a, self.basis) for a in uniq_atoms)) 69 | elif 'default' in self.basis: 70 | default_basis = self.basis['default'] 71 | _basis = dict(((a, default_basis) for a in uniq_atoms)) 72 | _basis.update(self.basis) 73 | del (_basis['default']) 74 | else: 75 | _basis = self.basis 76 | self._basis = self.format_basis(_basis) 77 | 78 | # TODO: Consider ECP info in point group symmetry initialization 79 | if self.ecp: 80 | # Unless explicitly input, ECP should not be assigned to ghost atoms 81 | if isinstance(self.ecp, (str, unicode)): 82 | _ecp = dict([(a, str(self.ecp)) 83 | for a in uniq_atoms if not is_ghost_atom(a)]) 84 | elif 'default' in self.ecp: 85 | default_ecp = self.ecp['default'] 86 | _ecp = dict(((a, default_ecp) 87 | for a in uniq_atoms if not is_ghost_atom(a))) 88 | _ecp.update(self.ecp) 89 | del (_ecp['default']) 90 | else: 91 | _ecp = self.ecp 92 | self._ecp = self.format_ecp(_ecp) 93 | 94 | PTR_ENV_START = 20 95 | env = self._env[:PTR_ENV_START] 96 | self._atm, self._bas, self._env = \ 97 | self.make_env(self._atom, self._basis, env, self.nucmod, 98 | self.nucprop) 99 | self._atm, self._ecpbas, self._env = \ 100 | self.make_ecp_env(self._atm, self._ecp, self._env) 101 | 102 | if self.spin is None: 103 | self.spin = self.nelectron % 2 104 | else: 105 | # Access self.nelec in which the code checks whether the spin and 106 | # number of electrons are consistent. 107 | self.nelec 108 | 109 | if not self.magmom: 110 | self.magmom = [0.,]*self.natm 111 | import numpy 112 | if self.spin == 0 and abs(numpy.sum(numpy.asarray(self.magmom)) - self.spin) > 1e-6: 113 | #don't check for unrestricted calcs. 114 | raise ValueError("mol.magmom is set incorrectly.") 115 | 116 | if self.symmetry: 117 | self._build_symmetry() 118 | 119 | #if dump_input and not self._built and self.verbose > logger.NOTE: 120 | # self.dump_input() 121 | 122 | '''if self.verbose >= logger.DEBUG3: 123 | logger.debug3(self, 'arg.atm = %s', self._atm) 124 | logger.debug3(self, 'arg.bas = %s', self._bas) 125 | logger.debug3(self, 'arg.env = %s', self._env) 126 | logger.debug3(self, 'ecpbas = %s', self._ecpbas)''' 127 | 128 | self._built = True 129 | return self -------------------------------------------------------------------------------- /pyscf_ipu/pyscf_utils/minao.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from pyscf import gto 3 | import numpy as np 4 | 5 | def minao(mol): 6 | from pyscf.scf import atom_hf 7 | from pyscf.scf import addons 8 | import time 9 | 10 | times = [] 11 | times.append(time.time()) 12 | 13 | def minao_basis(symb, nelec_ecp): 14 | occ = [] 15 | basis_ano = [] 16 | if gto.is_ghost_atom(symb): 17 | return occ, basis_ano 18 | 19 | stdsymb = gto.mole._std_symbol(symb) 20 | basis_add = gto.basis.load('ano', stdsymb) 21 | # coreshl defines the core shells to be removed in the initial guess 22 | coreshl = gto.ecp.core_configuration(nelec_ecp) 23 | #coreshl = (0,0,0,0) # it keeps all core electrons in the initial guess 24 | for l in range(4): 25 | ndocc, frac = atom_hf.frac_occ(stdsymb, l) 26 | assert ndocc >= coreshl[l] 27 | degen = l * 2 + 1 28 | occ_l = [2,]*(ndocc-coreshl[l]) + [frac,] 29 | occ.append(np.repeat(occ_l, degen)) 30 | basis_ano.append([l] + [b[:1] + b[1+coreshl[l]:ndocc+2] 31 | for b in basis_add[l][1:]]) 32 | occ = np.hstack(occ) 33 | 34 | if nelec_ecp > 0: 35 | if symb in mol._basis: 36 | input_basis = mol._basis[symb] 37 | elif stdsymb in mol._basis: 38 | input_basis = mol._basis[stdsymb] 39 | else: 40 | raise KeyError(symb) 41 | 42 | basis4ecp = [[] for i in range(4)] 43 | for bas in input_basis: 44 | l = bas[0] 45 | if l < 4: 46 | basis4ecp[l].append(bas) 47 | 48 | occ4ecp = [] 49 | for l in range(4): 50 | nbas_l = sum((len(bas[1]) - 1) for bas in basis4ecp[l]) 51 | ndocc, frac = atom_hf.frac_occ(stdsymb, l) 52 | ndocc -= coreshl[l] 53 | assert ndocc <= nbas_l 54 | 55 | occ_l = np.zeros(nbas_l) 56 | occ_l[:ndocc] = 2 57 | if frac > 0: 58 | occ_l[ndocc] = frac 59 | occ4ecp.append(np.repeat(occ_l, l * 2 + 1)) 60 | 61 | occ4ecp = np.hstack(occ4ecp) 62 | basis4ecp = lib.flatten(basis4ecp) 63 | 64 | atm1 = gto.Mole() 65 | atm2 = gto.Mole() 66 | atom = [[symb, (0.,0.,0.)]] 67 | atm1._atm, atm1._bas, atm1._env = atm1.make_env(atom, {symb:basis4ecp}, []) 68 | atm2._atm, atm2._bas, atm2._env = atm2.make_env(atom, {symb:basis_ano}, []) 69 | atm1._built = True 70 | atm2._built = True 71 | s12 = gto.intor_cross('int1e_ovlp', atm1, atm2) 72 | if abs(np.linalg.det(s12[occ4ecp>0][:,occ>0])) > .1: 73 | occ, basis_ano = occ4ecp, basis4ecp 74 | else: 75 | logger.debug(mol, 'Density of valence part of ANO basis ' 76 | 'will be used as initial guess for %s', symb) 77 | return occ, basis_ano 78 | 79 | # Issue 548 80 | if any(gto.charge(mol.atom_symbol(ia)) > 96 for ia in range(mol.natm)): 81 | logger.info(mol, 'MINAO initial guess is not available for super-heavy ' 82 | 'elements. "atom" initial guess is used.') 83 | return init_guess_by_atom(mol) 84 | 85 | 86 | times.append(time.time()) 87 | nelec_ecp_dic = dict([(mol.atom_symbol(ia), mol.atom_nelec_core(ia)) 88 | for ia in range(mol.natm)]) 89 | times.append(time.time()) 90 | 91 | basis = {} 92 | occdic = {} 93 | for symb, nelec_ecp in nelec_ecp_dic.items(): 94 | occ_add, basis_add = minao_basis(symb, nelec_ecp) 95 | occdic[symb] = occ_add 96 | basis[symb] = basis_add 97 | 98 | times.append(time.time()) 99 | 100 | occ = [] 101 | new_atom = [] 102 | for ia in range(mol.natm): 103 | #print(ia) 104 | symb = mol.atom_symbol(ia) 105 | if not gto.is_ghost_atom(symb): 106 | occ.append(occdic[symb]) 107 | new_atom.append(mol._atom[ia]) 108 | occ = np.hstack(occ) 109 | 110 | times.append(time.time()) 111 | 112 | pmol = gto.Mole() 113 | times.append(time.time()) 114 | pmol._atm, pmol._bas, pmol._env = pmol.make_env(new_atom, basis, []) 115 | times.append(time.time()) 116 | pmol._built = True 117 | dm = addons.project_dm_nr2nr(pmol, np.diag(occ), mol) 118 | 119 | times.append(time.time()) 120 | 121 | times = np.array(times) 122 | # print("Minao timing:", np.around(times[1:]-times[:-1], 2)) 123 | 124 | return dm -------------------------------------------------------------------------------- /qm1b/README.md: -------------------------------------------------------------------------------- 1 | # QM1B dataset (to be released soon) 2 | 3 | QM1B is a low-resolution DFT dataset generated using [PySCF IPU](https://github.com/graphcore-research/pyscf-ipu). It is composed of one billion training examples containing 9-11 heavy atoms. It was created by taking 1.09M SMILES strings from the [GDB-11 database](https://zenodo.org/record/5172018) and computing molecular properties (e.g. HOMO-LUMO gap) for a set of up to 1000 conformers per molecule. 4 | 5 | ## Dataset schema 6 | 7 | QM1B dataset is stored in the [open-source columnar Apache Parquet format](https://parquet.apache.org/), with the following schema: 8 | * `smile`: The SMILES string taken from GDB11. There are up to 1000 rows (i.e. conformers) with the same SMILES 9 | string. 10 | * `atoms`: String representing the atom symbols of the molecule, e.g. ”COOH”. 11 | * `z`: Integer representation of `atoms` used by SchNet (the atomic numbers). 12 | * `energy`: energy of the molecule computed by PySCF IPU (unit eV). 13 | * `homo`: The energy of the Highest Occupied Molecular Orbital (HOMO) (unit eV). 14 | * `lumo`: The energy of the Lowest occupied Molecular Orbital (LUMO) (unit eV). 15 | * `N`: The number of atomic orbitals for the specific DFT computation (depends on the basis set STO3G). 16 | * `std`: The standard deviation of the energy of the last five iterations of running PySCFIPU, used as 17 | convergence criteria std < 0.01 (unit eV). 18 | * `y`: The HOMO-LUMO Gap (unit eV). 19 | * `pos`: The atom positions (unit Bohr). 20 | 21 | ## Dataset exploration 22 | 23 | Dataset exploration can easily done using Pandas library. For instance, to load the validation set: 24 | ```python 25 | import pandas as pd 26 | 27 | # 20m entries in the validation set. 28 | print(pd.read_parquet("qm1b_val.parquet").head()) 29 | ``` -------------------------------------------------------------------------------- /qm1b/datasheet.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/pyscf-ipu/43a9bb343acbd296aec1d1f64c309bb24920d98d/qm1b/datasheet.pdf -------------------------------------------------------------------------------- /requirements_core.txt: -------------------------------------------------------------------------------- 1 | # Core dependencies for pyscf-ipu 2 | # 3 | # See also: 4 | # requirements_cpu.txt for cpu backend configuration 5 | # requirements_ipu.txt for ipu backend configuration 6 | # requirements_test.txt for test-only dependencies 7 | numpy 8 | matplotlib 9 | pandas 10 | scipy 11 | h5py 12 | pubchempy 13 | pyscf==2.2.1 14 | icecream 15 | seaborn 16 | tqdm 17 | natsort 18 | rdkit 19 | jsonargparse[all] 20 | 21 | mogli 22 | imageio[ffmpeg] 23 | py3Dmol 24 | basis-set-exchange 25 | periodictable 26 | sympy 27 | 28 | # silence warnings about setuptools + numpy 29 | setuptools < 60.0 30 | 31 | jaxtyping==0.2.8 32 | chex 33 | -------------------------------------------------------------------------------- /requirements_cpu.txt: -------------------------------------------------------------------------------- 1 | # Runtime dependencies for pyscf-ipu with CPU backend 2 | # 3 | # See also: 4 | # requirements_core.txt for core configuration 5 | # requirements_ipu.txt for ipu backend configuration 6 | # requirements_test.txt for test-only dependencies 7 | 8 | jax==0.3.16 9 | jaxlib @ https://storage.googleapis.com/jax-releases/nocuda/jaxlib-0.3.15-cp38-none-manylinux2014_x86_64.whl 10 | -------------------------------------------------------------------------------- /requirements_ipu.txt: -------------------------------------------------------------------------------- 1 | # Runtime dependencies for pyscf-ipu with IPU backend 2 | # 3 | # See also: 4 | # requirements_core.txt for core runtime configuration 5 | # requirements_cpu.txt for cpu backend configuration 6 | # requirements_test.txt for test-only dependencies 7 | jax@https://github.com/graphcore-research/jax-experimental/releases/download/jax-v0.3.16-ipu-beta3-sdk3/jax-0.3.16+ipu-py3-none-any.whl 8 | jaxlib@https://github.com/graphcore-research/jax-experimental/releases/download/jax-v0.3.16-ipu-beta3-sdk3/jaxlib-0.3.15+ipu.sdk320-cp38-none-manylinux2014_x86_64.whl 9 | tessellate-ipu@git+https://github.com/graphcore-research/tessellate-ipu.git@main 10 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | # Test dependencies for pyscf-ipu 2 | # 3 | # See also: 4 | # requirements_core.txt for core runtime configuration 5 | # requirements_cpu.txt for cpu backend configuration 6 | # requirements_ipu.txt for ipu backend configuration 7 | black[jupyter] 8 | pytest 9 | nbmake 10 | pre-commit 11 | flake8 12 | flake8-copyright 13 | isort 14 | -------------------------------------------------------------------------------- /schnet_9m/README.md: -------------------------------------------------------------------------------- 1 | # Training SchNet on QM1B (to be released) 2 | 3 | This repository contains the implementation of the SchNet 9M trained on the QM1B 4 | dataset. We show that training a SchNet 9M model to predict HL gap shows improvement as 5 | the number of training samples approaches 500M. 6 | 7 | ![scaling_qm1b](./scaling_qm1b.png) 8 | 9 | ## Requirements 10 | This project requires requires Python 3.8, and Graphcore SDK 3.2. Additional 11 | dependencies can be installed into your environment with 12 | ```bash 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ## Training SchNet 17 | 18 | ```bash 19 | python train.py 20 | ``` 21 | 22 | ## Complete Training Usage 23 | Complete usage is documented below. 24 | 25 | ```bash 26 | python train.py --help 27 | usage: train.py [-h] [--config CONFIG] [--print_config[=flags]] [--seed SEED] [--learning_rate LEARNING_RATE] 28 | [--learning_rate_decay LEARNING_RATE_DECAY] [--update_lr_period UPDATE_LR_PERIOD] 29 | [--val_period VAL_PERIOD] [--pop_config CONFIG] [--pop_config.device_iterations DEVICE_ITERATIONS] 30 | [--pop_config.replication_factor REPLICATION_FACTOR] 31 | [--pop_config.gradient_accumulation GRADIENT_ACCUMULATION] [--pop_config.optimize_popart {true,false}] 32 | [--pop_config.cache_dir CACHE_DIR] [--pop_config.quiet {true,false}] 33 | [--pop_config.offload_optimizer_state {true,false}] [--pop_config.pipeline_splits PIPELINE_SPLITS] 34 | [--pop_config.available_memory_proportion AVAILABLE_MEMORY_PROPORTION] 35 | [--pop_config.use_stochastic_rounding {true,false}] [--data_config CONFIG] [--data_config.dataset DATASET] 36 | [--data_config.num_workers NUM_WORKERS] [--data_config.num_train NUM_TRAIN] 37 | [--data_config.num_test NUM_TEST] [--data_config.max_num_examples MAX_NUM_EXAMPLES] 38 | [--data_config.root_folder ROOT_FOLDER] [--data_config.shuffle {true,false}] [--model_config CONFIG] 39 | [--model_config.num_features NUM_FEATURES] [--model_config.num_filters NUM_FILTERS] 40 | [--model_config.num_interactions NUM_INTERACTIONS] [--model_config.num_gaussians NUM_GAUSSIANS] 41 | [--model_config.k K] [--model_config.cutoff CUTOFF] [--model_config.batch_size BATCH_SIZE] 42 | [--model_config.use_half {true,false}] [--model_config.recomputation_blocks RECOMPUTATION_BLOCKS] 43 | [--debug {true,false}] [--wandb_project WANDB_PROJECT] [--use_wandb {true,false}] 44 | [--only_compile {true,false}] [--warmup_steps WARMUP_STEPS] [--wandb_warmup {true,false}] 45 | 46 | Minimal SchNet GNN training 47 | 48 | optional arguments: 49 | -h, --help Show this help message and exit. 50 | --config CONFIG Path to a configuration file. 51 | --print_config[=flags] 52 | Print the configuration after applying all other arguments and exit. The optional flags customizes 53 | the output and are one or more keywords separated by comma. The supported flags are: comments, 54 | skip_default, skip_null. 55 | --seed SEED the random number seed (type: int, default: 0) 56 | --learning_rate LEARNING_RATE 57 | the learning rate used by the optimizer (type: float, default: 0.0005) 58 | --learning_rate_decay LEARNING_RATE_DECAY 59 | exponential ratio to reduce the learning rate by (type: float, default: 0.96) 60 | --update_lr_period UPDATE_LR_PERIOD 61 | the number of steps between learning rate updates. Default: 32*300*batch_size = 36M so lr is 62 | increasing 0 to 4M warmup, then decreasing every 36M steps (1000/36 ~ 30 times) (type: int, 63 | default: 9600) 64 | --val_period VAL_PERIOD 65 | number of training steps before performing validation (type: int, default: 4000) 66 | --debug {true,false} enables additional logging (with perf overhead) (type: bool, default: False) 67 | --wandb_project WANDB_PROJECT 68 | (type: str, default: schnet-9m) 69 | --use_wandb {true,false} 70 | Use Weights and Biases to log benchmark results. (type: bool, default: True) 71 | --only_compile {true,false} 72 | Compile the and exit (no training) (type: bool, default: False) 73 | --warmup_steps WARMUP_STEPS 74 | set to 0 to turn off. openc uses 2-3 epochs with 2-3M molecules => 4-9M graphs; => steps = 10M / 75 | batchsize ~ --warmup_steps 2500 (type: int, default: 2500) 76 | --wandb_warmup {true,false} 77 | enable wandb logging of warmup steps (type: bool, default: False) 78 | 79 | configuration options for PopTorch: 80 | --pop_config CONFIG Path to a configuration file. 81 | --pop_config.device_iterations DEVICE_ITERATIONS 82 | (type: int, default: 32) 83 | --pop_config.replication_factor REPLICATION_FACTOR 84 | (type: int, default: 16) 85 | --pop_config.gradient_accumulation GRADIENT_ACCUMULATION 86 | (type: int, default: 1) 87 | --pop_config.optimize_popart {true,false} 88 | (type: bool, default: True) 89 | --pop_config.cache_dir CACHE_DIR 90 | (type: str, default: .poptorch_cache) 91 | --pop_config.quiet {true,false} 92 | (type: bool, default: True) 93 | --pop_config.offload_optimizer_state {true,false} 94 | (type: bool, default: False) 95 | --pop_config.pipeline_splits PIPELINE_SPLITS, --pop_config.pipeline_splits+ PIPELINE_SPLITS 96 | (type: Union[List[int], null], default: null) 97 | --pop_config.available_memory_proportion AVAILABLE_MEMORY_PROPORTION 98 | (type: float, default: 0.6) 99 | --pop_config.use_stochastic_rounding {true,false} 100 | (type: bool, default: True) 101 | 102 | options for data subsetting and loading: 103 | --data_config CONFIG Path to a configuration file. 104 | --data_config.dataset DATASET 105 | (type: str, default: qm9) 106 | --data_config.num_workers NUM_WORKERS 107 | (type: int, default: 1) 108 | --data_config.num_train NUM_TRAIN 109 | (type: int, default: 100000) 110 | --data_config.num_test NUM_TEST 111 | (type: int, default: 20000) 112 | --data_config.max_num_examples MAX_NUM_EXAMPLES 113 | (type: int, default: 1000000000) 114 | --data_config.root_folder ROOT_FOLDER 115 | (type: Union[str, null], default: null) 116 | --data_config.shuffle {true,false} 117 | (type: bool, default: True) 118 | 119 | model arcitecture options: 120 | --model_config CONFIG 121 | Path to a configuration file. 122 | --model_config.num_features NUM_FEATURES 123 | (type: int, default: 1024) 124 | --model_config.num_filters NUM_FILTERS 125 | (type: int, default: 256) 126 | --model_config.num_interactions NUM_INTERACTIONS 127 | (type: int, default: 5) 128 | --model_config.num_gaussians NUM_GAUSSIANS 129 | (type: int, default: 200) 130 | --model_config.k K (type: int, default: 28) 131 | --model_config.cutoff CUTOFF 132 | (type: float, default: 15.0) 133 | --model_config.batch_size BATCH_SIZE 134 | (type: int, default: 8) 135 | --model_config.use_half {true,false} 136 | (type: bool, default: False) 137 | --model_config.recomputation_blocks RECOMPUTATION_BLOCKS, --model_config.recomputation_blocks+ RECOMPUTATION_BLOCKS 138 | (type: Union[List[int], null], default: null) 139 | ``` -------------------------------------------------------------------------------- /schnet_9m/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from dataclasses import dataclass 3 | from functools import partial 4 | from pprint import pprint 5 | from typing import Optional, Tuple 6 | 7 | import torch 8 | from qm1b_dataset import create_qm1b_loader 9 | from model import ModelConfig 10 | from poptorch import Options 11 | from poptorch_geometric.dataloader import CustomFixedSizeDataLoader 12 | from torch_geometric.data import Data, Dataset 13 | from torch_geometric.datasets import QM9 14 | 15 | 16 | @dataclass 17 | class DataConfig: 18 | dataset: str = "qm9" 19 | num_workers: int = 1 20 | num_train: int = 100000 21 | num_test: int = 20000 22 | max_num_examples: int = int(1e9) 23 | root_folder: Optional[str] = None 24 | shuffle: bool = True 25 | 26 | 27 | def prep_qm9(data, target=4, use_half: bool = False): 28 | """ 29 | Prepares QM9 molecules for training SchNet for HOMO-LUMO gap prediction 30 | task. Outputs a data object with attributes: 31 | z: the atomic number as a vector of integers with length [num_atoms] 32 | pos: the atomic position as a [num_atoms, 3] tensor of float32 values. 33 | y: the training target value. By default this will be the HOMO-LUMO gap 34 | energy in electronvolts (eV). 35 | """ 36 | dtype = torch.float16 if use_half else torch.float32 37 | return Data( 38 | z=data.z, pos=data.pos.to(dtype), y=data.y[0, target].view(-1).to(dtype) 39 | ) 40 | 41 | 42 | def split(dataset: Dataset, data_config: DataConfig): 43 | # Select test set first to ensure we are sampling the same distribution 44 | split_ids = [ 45 | (0, data_config.num_test), 46 | (data_config.num_test, data_config.num_test + data_config.num_train), 47 | ] 48 | return [dataset[u:v] for u, v in split_ids] 49 | 50 | 51 | def create_pyg_dataset(config: DataConfig, use_half: bool = False): 52 | transform = partial(prep_qm9, use_half=use_half) 53 | if config.dataset == "qm9": 54 | root = "datasets/qm9" if config.root_folder is None else config.root_folder 55 | dataset = QM9(root=root, transform=transform) 56 | else: 57 | raise ValueError(f"Invalid dataset: {config.dataset}") 58 | 59 | if config.shuffle: 60 | dataset = dataset.shuffle() 61 | return split(dataset, config) 62 | 63 | 64 | def create_loader( 65 | data_config: DataConfig, 66 | model_config: ModelConfig, 67 | options: Tuple[Options], 68 | ): 69 | if data_config.dataset == "qm1b": 70 | return create_qm1b_loader(data_config, model_config, options) 71 | 72 | data_splits = create_pyg_dataset(data_config, model_config.use_half) 73 | 74 | loader_args = { 75 | "batch_size": model_config.batch_size, 76 | "num_workers": data_config.num_workers, 77 | "persistent_workers": data_config.num_workers > 0, 78 | "num_nodes": 32 * (model_config.batch_size - 1), 79 | "shuffle": data_config.shuffle, 80 | "collater_args": {"add_masks_to_batch": True}, 81 | } 82 | 83 | return [ 84 | CustomFixedSizeDataLoader(d, options=o, **loader_args) 85 | for d, o in zip(data_splits, options) 86 | ] 87 | 88 | 89 | def loader_info(loader, prefix=None): 90 | info = { 91 | "len": len(loader), 92 | "len_dataset": len(loader.dataset), 93 | "num_workers": loader.num_workers, 94 | "drop_last": loader.drop_last, 95 | "batch_size": loader.batch_size, 96 | "sampler": type(loader.sampler).__name__, 97 | } 98 | 99 | if prefix is not None: 100 | info = {f"{prefix}{k}": v for k, v in info.items()} 101 | 102 | pprint(info) 103 | return info 104 | 105 | 106 | def fake_batch( 107 | model_config: ModelConfig, 108 | options: Options, 109 | training: bool = True, 110 | ): 111 | num_graphs_per_batch = model_config.batch_size - 1 112 | combined_batch_size = ( 113 | options.replication_factor 114 | * options.device_iterations 115 | * options.Training.gradient_accumulation 116 | ) 117 | 118 | num_nodes = 32 * num_graphs_per_batch * combined_batch_size 119 | float_dtype = torch.float16 if model_config.use_half else torch.float32 120 | z = torch.zeros(num_nodes, dtype=torch.long) 121 | pos = torch.zeros(num_nodes, 3, dtype=float_dtype) 122 | batch = torch.zeros(num_nodes, dtype=torch.long) 123 | y = torch.zeros(model_config.batch_size * combined_batch_size, dtype=float_dtype) 124 | out = (z, pos, batch, y) 125 | return out if training else out[:-1] 126 | -------------------------------------------------------------------------------- /schnet_9m/device.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from typing import Optional, List 3 | from torch import Tensor 4 | from torch.nn import Module 5 | from dataclasses import dataclass 6 | from poptorch import ( 7 | BeginBlock, 8 | Options, 9 | OutputMode, 10 | recomputationCheckpoint, 11 | setLogLevel, 12 | TensorLocationSettings, 13 | inferenceModel, 14 | trainingModel, 15 | ) 16 | from poptorch.optim import Optimizer 17 | 18 | 19 | @dataclass 20 | class PopConfig: 21 | device_iterations: int = 32 22 | replication_factor: int = 16 23 | gradient_accumulation: int = 1 24 | optimize_popart: bool = True 25 | cache_dir: str = ".poptorch_cache" 26 | quiet: bool = True 27 | offload_optimizer_state: bool = False 28 | pipeline_splits: Optional[List[int]] = None 29 | available_memory_proportion: float = 0.6 30 | use_stochastic_rounding: bool = True 31 | 32 | 33 | popart_options = { 34 | "defaultBufferingDepth": 4, 35 | "accumulateOuterFragmentSettings.schedule": 2, 36 | "replicatedCollectivesSettings.prepareScheduleForMergingCollectives": True, 37 | "replicatedCollectivesSettings.mergeAllReduceCollectives": True, 38 | } 39 | 40 | 41 | def configure_poptorch( 42 | config: PopConfig, debug: bool, model: Module, optimizer: Optimizer 43 | ) -> Options: 44 | options = Options() 45 | options.outputMode(OutputMode.All) 46 | options.deviceIterations(config.device_iterations) 47 | options.replicationFactor(config.replication_factor) 48 | options.Training.gradientAccumulation(config.gradient_accumulation) 49 | 50 | if config.offload_optimizer_state: 51 | options.TensorLocations.setOptimizerLocation( 52 | TensorLocationSettings().useOnChipStorage(False) 53 | ) 54 | 55 | options.Precision.enableStochasticRounding(config.use_stochastic_rounding) 56 | options.Precision.enableFloatingPointExceptions(debug) 57 | 58 | if not debug: 59 | options.enableExecutableCaching(config.cache_dir) 60 | 61 | if config.optimize_popart: 62 | for k, v in popart_options.items(): 63 | options._Popart.set(k, v) 64 | 65 | if config.quiet and not debug: 66 | setLogLevel("ERR") 67 | 68 | num_ipus = 1 69 | 70 | if config.pipeline_splits is not None: 71 | num_ipus = len(config.pipeline_splits) + 1 72 | 73 | options.setAvailableMemoryProportion( 74 | {f"IPU{i}": config.available_memory_proportion for i in range(num_ipus)} 75 | ) 76 | 77 | if config.pipeline_splits is not None: 78 | for index, block in enumerate(config.pipeline_splits): 79 | model.model.interactions[block] = BeginBlock( 80 | model.model.interactions[block], ipu_id=index + 1 81 | ) 82 | 83 | train_model = trainingModel(model, options, optimizer) 84 | options = options.clone() 85 | options.Training.gradientAccumulation(1) 86 | options.deviceIterations(config.device_iterations * config.gradient_accumulation) 87 | inference_model = inferenceModel(model.eval(), options) 88 | return train_model, inference_model 89 | 90 | 91 | def recomputation_checkpoint(module: Module): 92 | """Annotates the output of a module to be checkpointed instead of 93 | recomputed""" 94 | 95 | def recompute_outputs(module, inputs, outputs): 96 | if isinstance(outputs, Tensor): 97 | return recomputationCheckpoint(outputs) 98 | elif isinstance(outputs, tuple): 99 | return tuple(recomputationCheckpoint(y) for y in outputs) 100 | 101 | return module.register_forward_hook(recompute_outputs) 102 | -------------------------------------------------------------------------------- /schnet_9m/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from typing import Optional, List 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from torch_geometric.nn.models import schnet as pyg_schnet 7 | from torch_geometric.nn import to_fixed_size 8 | from dataclasses import dataclass 9 | from device import recomputation_checkpoint 10 | 11 | 12 | @dataclass 13 | class ModelConfig: 14 | num_features: int = 1024 15 | num_filters: int = 256 16 | num_interactions: int = 5 17 | num_gaussians: int = 200 18 | k: int = 28 19 | cutoff: float = 15.0 20 | batch_size: int = 8 21 | use_half: bool = False 22 | recomputation_blocks: Optional[List[int]] = None 23 | 24 | 25 | class TrainingModule(nn.Module): 26 | def __init__(self, model: nn.Module): 27 | """ 28 | Wrapper that evaluates the forward pass of SchNet followed by MAE loss 29 | """ 30 | super().__init__() 31 | self.model = model 32 | 33 | def forward(self, z, pos, batch, target=None): 34 | prediction = self.model(z, pos, batch).view(-1) 35 | 36 | # slice off the padding molecule 37 | prediction = prediction[0:-1] 38 | 39 | if not self.training: 40 | return prediction 41 | 42 | # Calculate MAE loss after slicing off padding molecule 43 | target = target[0:-1] 44 | return F.l1_loss(prediction, target) 45 | 46 | 47 | class FastShiftedSoftplus(nn.Module): 48 | def __init__(self, needs_cast): 49 | """ 50 | ShiftedSoftplus without the conditional used in native PyTorch softplus 51 | """ 52 | super().__init__() 53 | self.shift = torch.log(torch.tensor(2.0)).item() 54 | self.needs_cast = needs_cast 55 | 56 | def forward(self, x): 57 | x = x.float() if self.needs_cast else x 58 | u = torch.log1p(torch.exp(-x.abs())) 59 | v = torch.clamp_min(x, 0.0) 60 | out = u + v - self.shift 61 | out = out.half() if self.needs_cast else out 62 | return out 63 | 64 | @staticmethod 65 | def replace_activation(module: torch.nn.Module, needs_cast: bool): 66 | """ 67 | recursively find and replace instances of default ShiftedSoftplus 68 | """ 69 | for name, child in module.named_children(): 70 | if isinstance(child, pyg_schnet.ShiftedSoftplus): 71 | setattr(module, name, FastShiftedSoftplus(needs_cast)) 72 | else: 73 | FastShiftedSoftplus.replace_activation(child, needs_cast) 74 | 75 | 76 | class KNNInteractionGraph(torch.nn.Module): 77 | def __init__(self, k: int, cutoff: float = 10.0): 78 | super().__init__() 79 | self.k = k 80 | self.cutoff = cutoff 81 | 82 | def forward(self, pos: torch.Tensor, batch: torch.Tensor): 83 | """ 84 | k-nearest neighbors without dynamic tensor shapes 85 | 86 | :param pos (Tensor): Coordinates of each atom with shape 87 | [num_atoms, 3]. 88 | :param batch (LongTensor): Batch indices assigning each atom to 89 | a separate molecule with shape [num_atoms] 90 | 91 | This method calculates the full num_atoms x num_atoms pairwise distance 92 | matrix. Masking is used to remove: 93 | * self-interaction (the diagonal elements) 94 | * cross-terms (atoms interacting with atoms in different molecules) 95 | * atoms that are beyond the cutoff distance 96 | 97 | Finally topk is used to find the k-nearest neighbors and construct the 98 | edge_index and edge_weight. 99 | """ 100 | pdist = F.pairwise_distance(pos[:, None], pos, eps=0) 101 | rows = arange_like(batch.shape[0], batch).view(-1, 1) 102 | cols = rows.view(1, -1) 103 | diag = rows == cols 104 | cross = batch.view(-1, 1) != batch.view(1, -1) 105 | outer = pdist > self.cutoff 106 | mask = diag | cross | outer 107 | pdist = pdist.masked_fill(mask, self.cutoff) 108 | edge_weight, indices = torch.topk(-pdist, k=self.k) 109 | rows = rows.expand_as(indices) 110 | edge_index = torch.vstack([indices.flatten(), rows.flatten()]) 111 | return edge_index, -edge_weight.flatten() 112 | 113 | 114 | def arange_like(n: int, ref: torch.Tensor) -> torch.Tensor: 115 | return torch.arange(n, device=ref.device, dtype=ref.dtype) 116 | 117 | 118 | def create_model(config: ModelConfig) -> TrainingModule: 119 | model = pyg_schnet.SchNet( 120 | hidden_channels=config.num_features, 121 | num_filters=config.num_filters, 122 | num_interactions=config.num_interactions, 123 | num_gaussians=config.num_gaussians, 124 | cutoff=config.cutoff, 125 | interaction_graph=KNNInteractionGraph(config.k, config.cutoff), 126 | ) 127 | 128 | model = to_fixed_size(model, config.batch_size) 129 | model = TrainingModule(model) 130 | print_model(model, config) 131 | 132 | FastShiftedSoftplus.replace_activation(model, config.use_half) 133 | if config.use_half: 134 | model = model.half() 135 | 136 | if config.recomputation_blocks is not None: 137 | for block_idx in config.recomputation_blocks: 138 | recomputation_checkpoint(model.model.interactions[block_idx]) 139 | 140 | return model 141 | 142 | 143 | def print_model(model: TrainingModule, config: ModelConfig): 144 | from torchinfo import summary 145 | 146 | num_nodes = 32 * (config.batch_size - 1) 147 | z = torch.zeros(num_nodes).long() 148 | pos = torch.zeros(num_nodes, 3) 149 | batch = torch.zeros(num_nodes).long() 150 | y = torch.zeros(config.batch_size) 151 | 152 | summary(model, input_data=[z, pos, batch, y]) 153 | -------------------------------------------------------------------------------- /schnet_9m/qm1b.yaml: -------------------------------------------------------------------------------- 1 | dataset: qm1b 2 | num_workers: 0 3 | num_train: 100000000 4 | num_test: 10000000 5 | max_num_examples: 1000000000 6 | root_folder: /net/group/research/datasets/qm1b 7 | shuffle: true -------------------------------------------------------------------------------- /schnet_9m/qm1b_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import torch 3 | import os.path as osp 4 | from glob import glob 5 | from functools import cached_property 6 | from typing import Iterator, Optional 7 | 8 | import pyarrow as pa 9 | import pyarrow.dataset as pds 10 | import pyarrow.parquet as pq 11 | import numpy as np 12 | from natsort import natsorted 13 | from torch.utils.data import IterableDataset 14 | from tqdm import tqdm 15 | from download import download 16 | 17 | INT_MAX = np.iinfo(np.int32).max 18 | 19 | 20 | class QM1B(IterableDataset): 21 | def __init__( 22 | self, 23 | files, 24 | num_subset: Optional[int] = None, 25 | shuffle: bool = True, 26 | max_open_files: int = 128, 27 | ) -> None: 28 | super().__init__() 29 | self.files = files 30 | self.shuffle = shuffle 31 | self.max_open_files = max_open_files 32 | 33 | num_samples = np.array([pq.read_metadata(f).num_rows for f in self.files]) 34 | offsets = np.pad(num_samples.cumsum(), (1, 0)) 35 | total_len = offsets[-1] 36 | 37 | if num_subset is None: 38 | self.eager = False 39 | self.sample_indices = None 40 | self.len = total_len 41 | return 42 | 43 | # Assuming roughly equal number of molecules / file: 44 | # calculate how many molecules we should sample from each file 45 | num_files = len(self.files) 46 | subsets, rem = divmod(num_subset, num_files) 47 | subsets = np.full(num_files, subsets) 48 | subsets[:rem] += 1 49 | self.sample_indices = [ 50 | np.random.choice(n, s, replace=False) for n, s in zip(num_samples, subsets) 51 | ] 52 | 53 | sample_fraction = num_subset / total_len 54 | self.eager = sample_fraction < 0.5 or len(files) == 1 55 | self.len = num_subset 56 | 57 | def __len__(self) -> int: 58 | return self.len 59 | 60 | def __iter__(self) -> Iterator: 61 | if self.eager: 62 | yield from self.iter_examples(self.subset_table) 63 | return 64 | 65 | for shard in self.iter_shards(): 66 | yield from self.iter_examples(shard) 67 | 68 | @cached_property 69 | def subset_table(self): 70 | return pa.concat_tables([s for s in self.iter_shards()]) 71 | 72 | def iter_shards(self): 73 | splits = np.arange(self.max_open_files, len(self.files), self.max_open_files) 74 | 75 | if self.shuffle: 76 | order = np.random.permutation(len(self.files)) 77 | else: 78 | order = range(len(self.files)) 79 | 80 | for shard in np.array_split(order, splits): 81 | yield self.read_shard(shard) 82 | 83 | def read_shard(self, files_index): 84 | files = [self.files[i] for i in files_index] 85 | ds = pds.dataset(files) 86 | iter_batches = ds.to_batches( 87 | columns=["z", "pos", "y"], 88 | batch_size=INT_MAX, 89 | fragment_readahead=self.max_open_files, 90 | ) 91 | iter_batches = zip(files_index, iter_batches) 92 | batches = [] 93 | 94 | for idx, batch in tqdm(iter_batches, total=len(ds.files)): 95 | if self.sample_indices is not None: 96 | batch = batch.take(self.sample_indices[idx]) 97 | 98 | batches.append(batch) 99 | 100 | return pa.Table.from_batches(batches) 101 | 102 | def iter_examples(self, table): 103 | if self.shuffle: 104 | table = safe_permute(table) 105 | 106 | for batch in table.to_batches(): 107 | z = batch["z"].to_numpy(zero_copy_only=False) 108 | pos = batch["pos"].to_numpy(zero_copy_only=False) 109 | y = batch["y"].to_numpy(zero_copy_only=False) 110 | 111 | for idx in range(batch.num_rows): 112 | yield z[idx], pos[idx].reshape(-1, 3), y[idx], z[idx].shape[0] 113 | 114 | 115 | class QM1BBatch: 116 | def __init__(self, num_fixed_nodes: int) -> None: 117 | self.num_fixed_nodes = num_fixed_nodes 118 | self.z = [] 119 | self.pos = [] 120 | self.y = [] 121 | self.num_nodes = [] 122 | self.graphs_mask = [] 123 | self.batch = [] 124 | 125 | def __call__( 126 | self, 127 | z, 128 | pos, 129 | y, 130 | num_nodes, 131 | ): 132 | # Add the real graphs to our running tally 133 | self.z += z 134 | self.pos += pos 135 | self.y += y 136 | self.num_nodes += num_nodes 137 | 138 | # Insert padding values 139 | num_padding_nodes = self.num_fixed_nodes - sum(num_nodes) 140 | zpad = np.zeros(num_padding_nodes, dtype=z[0].dtype) 141 | self.z.append(zpad) 142 | pospad = np.zeros((num_padding_nodes, 3), dtype=pos[0].dtype) 143 | self.pos.append(pospad) 144 | self.y.append(0.0) 145 | self.num_nodes.append(num_padding_nodes) 146 | 147 | # Append batch assignment vector for this mini-batch 148 | self.batch += [ 149 | np.full(n, i) for i, n in enumerate([*num_nodes, num_padding_nodes]) 150 | ] 151 | 152 | self.graphs_mask += [np.pad(np.ones(len(num_nodes), dtype=bool), (0, 1))] 153 | 154 | def cat(self): 155 | self.z = torch.from_numpy(np.concatenate(self.z, dtype=np.int64)) 156 | self.pos = torch.from_numpy(np.concatenate(self.pos)) 157 | self.y = torch.from_numpy(np.array(self.y)) 158 | self.num_nodes = torch.from_numpy(np.array(self.num_nodes)) 159 | self.graphs_mask = torch.from_numpy(np.concatenate(self.graphs_mask)) 160 | self.batch = torch.from_numpy(np.concatenate(self.batch)) 161 | return self 162 | 163 | 164 | class QM1BCollator: 165 | def __init__(self, num_graphs_per_batch): 166 | self.num_graphs_per_batch = num_graphs_per_batch 167 | self.num_fixed_nodes = 32 * self.num_graphs_per_batch 168 | 169 | def __call__(self, data_list) -> QM1BBatch: 170 | batch = QM1BBatch(self.num_fixed_nodes) 171 | 172 | for idx in range(0, len(data_list), self.num_graphs_per_batch): 173 | mini_batch = data_list[idx : idx + self.num_graphs_per_batch] 174 | batch(*zip(*mini_batch)) 175 | return batch.cat() 176 | 177 | 178 | def safe_permute(T): 179 | # We could just do subset_table.take(perm) but that fails with arrow 11.0.0 180 | # ArrowInvalid: offset overflow while concatenating arrays 181 | perm = np.random.permutation(T.num_rows) 182 | offset = 0 183 | batches = [] 184 | 185 | for batch in T.to_batches(): 186 | mask = (perm >= offset) & (perm < offset + batch.num_rows) 187 | batches.append(batch.take(perm[mask] - offset)) 188 | offset += batch.num_rows 189 | 190 | return pa.Table.from_batches(batches) 191 | 192 | 193 | def combined_batch(options, num_graphs_per_batch): 194 | combined_batch_size = ( 195 | num_graphs_per_batch 196 | * options.replication_factor 197 | * options.device_iterations 198 | * options.Training.gradient_accumulation 199 | ) 200 | 201 | return combined_batch_size 202 | 203 | 204 | def create_qm1b_split_dataset( 205 | root_folder: str, 206 | shuffle: bool = False, 207 | num_test: Optional[int] = None, 208 | num_train: Optional[int] = None, 209 | ): 210 | train_folder = osp.join(root_folder, "shuffled_training") 211 | 212 | if not osp.exists(train_folder): 213 | download(root_folder) 214 | 215 | files = glob(osp.join(train_folder, "*.parquet")) 216 | files = natsorted(files) 217 | 218 | if shuffle: 219 | np.random.shuffle(files) 220 | 221 | val_file = osp.join(root_folder, "validation", "qm1b_validation.parquet") 222 | test = QM1B([val_file], num_subset=num_test, shuffle=shuffle) 223 | train = QM1B(files, num_subset=num_train, shuffle=shuffle) 224 | return test, train 225 | 226 | 227 | def create_qm1b_loader(data_config, model_config, options): 228 | from torch.utils.data import DataLoader 229 | 230 | if model_config.use_half: 231 | raise NotImplementedError() 232 | 233 | if data_config.root_folder is None: 234 | raise ValueError( 235 | "Must provide a data_config.root_folder for the QM1B dataset. " 236 | "This requires approximately 240GB of storage" 237 | ) 238 | 239 | num_graphs_per_batch = model_config.batch_size - 1 240 | val_batch_size, train_batch_size = [ 241 | combined_batch(opt, num_graphs_per_batch) for opt in options 242 | ] 243 | 244 | loader_args = { 245 | "collate_fn": QM1BCollator(num_graphs_per_batch), 246 | "drop_last": True, 247 | "num_workers": data_config.num_workers, 248 | "shuffle": False, 249 | } 250 | 251 | test, train = create_qm1b_split_dataset( 252 | data_config.root_folder, 253 | shuffle=data_config.shuffle, 254 | num_train=data_config.num_train, 255 | num_test=data_config.num_test, 256 | ) 257 | 258 | test_loader = DataLoader(test, batch_size=val_batch_size, **loader_args) 259 | train_loader = DataLoader(train, batch_size=train_batch_size, **loader_args) 260 | return test_loader, train_loader 261 | -------------------------------------------------------------------------------- /schnet_9m/requirements.txt: -------------------------------------------------------------------------------- 1 | awscli 2 | jsonargparse[all] 3 | jupyter 4 | natsort 5 | pandas 6 | periodictable 7 | py3Dmol 8 | pyarrow 9 | rdkit 10 | seaborn 11 | torchinfo 12 | torchmetrics 13 | wandb 14 | -------------------------------------------------------------------------------- /schnet_9m/scaling_qm1b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/pyscf-ipu/43a9bb343acbd296aec1d1f64c309bb24920d98d/schnet_9m/scaling_qm1b.png -------------------------------------------------------------------------------- /schnet_9m/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import os.path as osp 3 | import poptorch 4 | import torch 5 | import wandb 6 | from data import DataConfig, create_loader, loader_info, fake_batch 7 | from device import PopConfig, configure_poptorch 8 | from jsonargparse import CLI 9 | from model import ModelConfig, create_model 10 | from torch.optim.lr_scheduler import ExponentialLR, LinearLR 11 | from torch_geometric import seed_everything 12 | from torchmetrics import MeanAbsoluteError 13 | from tqdm import tqdm 14 | 15 | 16 | def setup_run(args): 17 | if not args["use_wandb"]: 18 | return None 19 | 20 | run = wandb.init( 21 | project=args["wandb_project"], 22 | settings=wandb.Settings(console="wrap"), 23 | config=args, 24 | ) 25 | name = args["data_config"].dataset 26 | name += f"-{tqdm.format_sizeof(args['data_config'].num_train)}" 27 | run.name = name + "-" + run.name 28 | return run 29 | 30 | 31 | def validation(inference_model, loader): 32 | mae = MeanAbsoluteError() 33 | 34 | for data in loader: 35 | yhat = inference_model(data.z, data.pos, data.batch) 36 | mae.update(yhat, data.y[data.graphs_mask]) 37 | 38 | return float(mae.compute()) 39 | 40 | 41 | def save_checkpoint(train_model, optimizer, lr_schedule, id="final"): 42 | path = osp.join(wandb.run.dir, f"checkpoint_{id}.pt") 43 | torch.save( 44 | { 45 | "model": train_model.state_dict(), 46 | "optimizer": optimizer.state_dict(), 47 | "lr_schedule": lr_schedule.state_dict(), 48 | }, 49 | path, 50 | ) 51 | wandb.save(path, policy="now") 52 | 53 | 54 | def train( 55 | seed: int = 0, 56 | learning_rate: float = 0.0005, 57 | learning_rate_decay: float = 0.96, 58 | update_lr_period: int = 32 * 300, 59 | val_period: int = 4000, 60 | pop_config: PopConfig = PopConfig(), 61 | data_config: DataConfig = DataConfig(), 62 | model_config: ModelConfig = ModelConfig(), 63 | debug: bool = False, 64 | wandb_project: str = "schnet-9m", 65 | use_wandb: bool = True, 66 | only_compile: bool = False, 67 | warmup_steps: int = 2500, 68 | wandb_warmup: bool = False, 69 | ): 70 | """ 71 | Minimal SchNet GNN training 72 | 73 | Args: 74 | seed (int): the random number seed 75 | learning_rate (float): the learning rate used by the optimizer 76 | learning_rate_decay (float): exponential ratio to reduce the learning rate by 77 | update_lr_period (int): the number of steps between learning rate updates. 78 | Default: 32*300*batch_size = 36M so lr is increasing 0 79 | to 4M warmup, then decreasing every 36M steps 80 | (1000/36 ~ 30 times) 81 | val_period (int): number of training steps before performing validation 82 | pop_config (PopConfig): configuration options for PopTorch 83 | data_config (DataConfig): options for data subsetting and loading 84 | model_config (ModelConfig): model arcitecture options 85 | debug (bool): enables additional logging (with perf overhead) 86 | use_wandb (bool): Use Weights and Biases to log benchmark results. 87 | only_compile (bool): Compile the and exit (no training) 88 | warmup_steps (int): set to 0 to turn off. openc uses 2-3 epochs with 2-3M 89 | molecules => 4-9M graphs; => steps = 10M / batchsize 90 | ~ --warmup_steps 2500 91 | wandb_warmup (bool): enable wandb logging of warmup steps 92 | """ 93 | run = setup_run(locals()) 94 | seed_everything(seed) 95 | model = create_model(model_config) 96 | optimizer = poptorch.optim.AdamW(model.parameters(), lr=learning_rate) 97 | lr_schedule = ExponentialLR(optimizer=optimizer, gamma=learning_rate_decay) 98 | 99 | warmup_lr_schedule = LinearLR( 100 | optimizer=optimizer, start_factor=0.2, total_iters=warmup_steps 101 | ) 102 | train_model, inference_model = configure_poptorch( 103 | pop_config, debug, model, optimizer 104 | ) 105 | 106 | if only_compile: 107 | _ = train_model(*fake_batch(model_config, train_model.options)) 108 | return 109 | 110 | seed_everything(seed) 111 | test_loader, train_loader = create_loader( 112 | data_config, model_config, (inference_model.options, train_model.options) 113 | ) 114 | 115 | if use_wandb: 116 | run.log(loader_info(test_loader, "test_loader/")) 117 | run.log(loader_info(train_loader, "train_loader/")) 118 | 119 | num_examples = 0 120 | bar = tqdm(total=data_config.max_num_examples, unit_scale=True) 121 | step = 0 122 | results = {} 123 | 124 | val_period = val_period // 10 125 | done = False 126 | 127 | while not done: 128 | for data in train_loader: 129 | if step % val_period == 0 or step < 2 or done: 130 | if step == 0: 131 | results["num_examples"] = 0 132 | if step != 0: 133 | train_model.detachFromDevice() 134 | results["val_mae"] = validation(inference_model, test_loader) 135 | inference_model.detachFromDevice() 136 | 137 | if use_wandb: 138 | run.log(results) 139 | 140 | if done: 141 | return 142 | 143 | if step == 10000: 144 | val_period = val_period * 10 145 | 146 | loss = train_model(data.z, data.pos, data.batch, data.y) 147 | num_batch_examples = int(data.graphs_mask.sum()) 148 | num_examples += num_batch_examples 149 | bar.update(num_batch_examples) 150 | step += 1 151 | 152 | results["train_loss"] = float(loss.mean()) 153 | results["num_examples"] = num_examples 154 | 155 | if step < warmup_steps: 156 | warmup_lr_schedule.step() 157 | train_model.setOptimizer(optimizer) 158 | results["lr"] = warmup_lr_schedule.get_last_lr()[0] 159 | if wandb_warmup: 160 | wandb.log(results) # log the entire warmup stuff! 161 | elif step % update_lr_period == 0: 162 | lr_schedule.step() 163 | train_model.setOptimizer(optimizer) 164 | results["lr"] = lr_schedule.get_last_lr()[0] 165 | 166 | if num_examples > data_config.max_num_examples: 167 | save_checkpoint(optimizer, lr_schedule, train_model) 168 | done = True 169 | 170 | bar.set_postfix(**results) 171 | 172 | 173 | if __name__ == "__main__": 174 | CLI(train) 175 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from pathlib import Path 3 | 4 | from setuptools import setup 5 | 6 | __version__ = "0.0.1" 7 | 8 | 9 | def read_requirements(file): 10 | pwd = Path(__file__).parent.resolve() 11 | txt = (pwd / file).read_text(encoding="utf-8").split("\n") 12 | 13 | def remove_comments(line: str): 14 | return len(line) > 0 and not line.startswith(("#", "-")) 15 | 16 | return list(filter(remove_comments, txt)) 17 | 18 | 19 | install_requires = read_requirements("requirements_core.txt") 20 | cpu_requires = read_requirements("requirements_cpu.txt") 21 | ipu_requires = read_requirements("requirements_ipu.txt") 22 | test_requires = read_requirements("requirements_test.txt") 23 | 24 | setup( 25 | name="pyscf-ipu", 26 | version=__version__, 27 | description="PySCF on IPU", 28 | long_description="file: README.md", 29 | long_description_content_type="text/markdown", 30 | license="Apache License 2.0", 31 | author="Graphcore Research", 32 | author_email="contact@graphcore.ai", 33 | url="https://github.com/graphcore-research/pyscf-ipu", 34 | project_urls={ 35 | "Code": "https://github.com/graphcore-research/pyscf-ipu", 36 | }, 37 | classifiers=[ 38 | "Development Status :: 3 - Alpha", 39 | "Intended Audience :: Developers", 40 | "Topic :: Scientific/Engineering", 41 | "License :: OSI Approved :: Apache Software License", 42 | "Programming Language :: Python :: 3", 43 | ], 44 | install_requires=install_requires, 45 | extras_require={"cpu": cpu_requires, "ipu": ipu_requires, "test": test_requires}, 46 | python_requires=">=3.8", 47 | packages=["pyscf_ipu"], 48 | entry_points={"console_scripts": ["nanoDFT=pyscf_ipu.nanoDFT:main"]}, 49 | ) 50 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) 2022 Graphcore Ltd. All rights reserved. 3 | # Script to be sourced on launch of the Gradient Notebook 4 | 5 | DETECTED_NUMBER_OF_IPUS=$(python .gradient/available_ipus.py) 6 | if [[ "$1" == "test" ]]; then 7 | IPU_ARG="${DETECTED_NUMBER_OF_IPUS}" 8 | else 9 | IPU_ARG=${1:-"${DETECTED_NUMBER_OF_IPUS}"} 10 | fi 11 | 12 | export NUM_AVAILABLE_IPU=${IPU_ARG} 13 | export GRAPHCORE_POD_TYPE="pod${IPU_ARG}" 14 | 15 | export POPLAR_EXECUTABLE_CACHE_DIR="/tmp/exe_cache" 16 | export DATASET_DIR="/tmp/dataset_cache" 17 | export CHECKPOINT_DIR="/tmp/checkpoints" 18 | 19 | # mounted public dataset directory (path in the container) 20 | # in the Paperspace environment this would be ="/datasets" 21 | export PUBLIC_DATASET_DIR="/datasets" 22 | 23 | export POPTORCH_CACHE_DIR="${POPLAR_EXECUTABLE_CACHE_DIR}" 24 | export POPTORCH_LOG_LEVEL=ERR 25 | export RDMAV_FORK_SAFE=1 26 | 27 | export PIP_DISABLE_PIP_VERSION_CHECK=1 CACHE_DIR=/tmp 28 | jupyter lab --allow-root --ip=0.0.0.0 --no-browser --ServerApp.trust_xheaders=True \ 29 | --ServerApp.disable_check_xsrf=False --ServerApp.allow_remote_access=True \ 30 | --ServerApp.allow_origin='*' --ServerApp.allow_credentials=True 31 | -------------------------------------------------------------------------------- /test/test_integrals.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import pytest 5 | from numpy.testing import assert_allclose 6 | 7 | from pyscf_ipu.experimental.basis import basisset 8 | from pyscf_ipu.experimental.integrals import ( 9 | eri_basis, 10 | eri_basis_sparse, 11 | eri_primitives, 12 | kinetic_basis, 13 | kinetic_primitives, 14 | nuclear_basis, 15 | nuclear_primitives, 16 | overlap_basis, 17 | overlap_primitives, 18 | ) 19 | from pyscf_ipu.experimental.interop import to_pyscf 20 | from pyscf_ipu.experimental.primitive import Primitive 21 | from pyscf_ipu.experimental.structure import molecule 22 | 23 | 24 | def test_overlap(): 25 | # Exercise 3.21 of "Modern quantum chemistry: introduction to advanced 26 | # electronic structure theory."" by Szabo and Ostlund 27 | alpha = 0.270950 * 1.24 * 1.24 28 | a = Primitive(alpha=alpha) 29 | b = Primitive(alpha=alpha, center=jnp.array([1.4, 0.0, 0.0])) 30 | assert_allclose(overlap_primitives(a, a), 1.0, atol=1e-5) 31 | assert_allclose(overlap_primitives(b, b), 1.0, atol=1e-5) 32 | assert_allclose(overlap_primitives(b, a), 0.6648, atol=1e-5) 33 | 34 | 35 | @pytest.mark.parametrize("basis_name", ["sto-3g", "6-31+g", "6-31+g*"]) 36 | def test_water_overlap(basis_name): 37 | basis = basisset(molecule("water"), basis_name) 38 | actual_overlap = overlap_basis(basis) 39 | 40 | # Note: PySCF doesn't appear to normalise d basis functions in cartesian basis 41 | scfmol = to_pyscf(molecule("water"), basis_name=basis_name) 42 | expect_overlap = scfmol.intor("int1e_ovlp_cart") 43 | n = 1 / np.sqrt(np.diagonal(expect_overlap)) 44 | expect_overlap = n[:, None] * n[None, :] * expect_overlap 45 | assert_allclose(actual_overlap, expect_overlap, atol=1e-6) 46 | 47 | 48 | def test_kinetic(): 49 | # PyQuante test case for kinetic primitive integral 50 | p = Primitive() 51 | assert_allclose(kinetic_primitives(p, p), 1.5, atol=1e-6) 52 | 53 | # Reproduce the kinetic energy matrix for H2 using STO-3G basis set 54 | # See equation 3.230 of "Modern quantum chemistry: introduction to advanced 55 | # electronic structure theory."" by Szabo and Ostlund 56 | h2 = molecule("h2") 57 | basis = basisset(h2, "sto-3g") 58 | actual = kinetic_basis(basis) 59 | expect = np.array([[0.7600, 0.2365], [0.2365, 0.7600]]) 60 | assert_allclose(actual, expect, atol=1e-4) 61 | 62 | 63 | @pytest.mark.parametrize( 64 | "basis_name", 65 | [ 66 | "sto-3g", 67 | "6-31+g", 68 | pytest.param( 69 | "6-31+g*", marks=pytest.mark.xfail(reason="Cartesian norm problem?") 70 | ), 71 | ], 72 | ) 73 | def test_water_kinetic(basis_name): 74 | basis = basisset(molecule("water"), basis_name) 75 | actual = kinetic_basis(basis) 76 | 77 | expect = to_pyscf(molecule("water"), basis_name=basis_name).intor("int1e_kin_cart") 78 | assert_allclose(actual, expect, atol=1e-4) 79 | 80 | 81 | def test_nuclear(): 82 | # PyQuante test case for nuclear attraction integral 83 | p = Primitive() 84 | c = jnp.zeros(3) 85 | assert_allclose(nuclear_primitives(p, p, c), -1.595769, atol=1e-5) 86 | 87 | # Reproduce the nuclear attraction matrix for H2 using STO-3G basis set 88 | # See equation 3.231 and 3.232 of Szabo and Ostlund 89 | h2 = molecule("h2") 90 | basis = basisset(h2, "sto-3g") 91 | actual = nuclear_basis(basis, h2.position, h2.atomic_number) 92 | expect = np.array( 93 | [ 94 | [[-1.2266, -0.5974], [-0.5974, -0.6538]], 95 | [[-0.6538, -0.5974], [-0.5974, -1.2266]], 96 | ] 97 | ) 98 | 99 | assert_allclose(actual, expect, atol=1e-4) 100 | 101 | 102 | def test_water_nuclear(): 103 | basis_name = "sto-3g" 104 | h2o = molecule("water") 105 | basis = basisset(h2o, basis_name) 106 | actual = nuclear_basis(basis, h2o.position, h2o.atomic_number).sum(axis=0) 107 | expect = to_pyscf(h2o, basis_name=basis_name).intor("int1e_nuc_cart") 108 | assert_allclose(actual, expect, atol=1e-4) 109 | 110 | 111 | def test_eri(): 112 | # PyQuante test cases for ERI 113 | a, b, c, d = [Primitive()] * 4 114 | assert_allclose(eri_primitives(a, b, c, d), 1.128379, atol=1e-5) 115 | 116 | c, d = [Primitive(lmn=jnp.array([1, 0, 0]))] * 2 117 | assert_allclose(eri_primitives(a, b, c, d), 0.940316, atol=1e-5) 118 | 119 | # H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund 120 | h2 = molecule("h2") 121 | basis = basisset(h2, "sto-3g") 122 | 123 | actual = eri_basis(basis) 124 | expect = np.empty((2, 2, 2, 2), dtype=np.float32) 125 | expect[0, 0, 0, 0] = expect[1, 1, 1, 1] = 0.7746 126 | expect[0, 0, 1, 1] = expect[1, 1, 0, 0] = 0.5697 127 | expect[1, 0, 0, 0] = expect[0, 0, 0, 1] = 0.4441 128 | expect[0, 1, 0, 0] = expect[0, 0, 1, 0] = 0.4441 129 | expect[0, 1, 1, 1] = expect[1, 1, 1, 0] = 0.4441 130 | expect[1, 0, 1, 1] = expect[1, 1, 0, 1] = 0.4441 131 | expect[1, 0, 1, 0] = expect[0, 1, 1, 0] = 0.2970 132 | expect[0, 1, 0, 1] = expect[1, 0, 0, 1] = 0.2970 133 | assert_allclose(actual, expect, atol=1e-4) 134 | 135 | 136 | def is_mem_limited(): 137 | # Check if we are running on a limited memory host (e.g. github action) 138 | import psutil 139 | 140 | total_mem_gib = psutil.virtual_memory().total // 1024**3 141 | return total_mem_gib < 10 142 | 143 | 144 | @pytest.mark.parametrize("sparse", [True, False]) 145 | @pytest.mark.skipif(is_mem_limited(), reason="Not enough host memory!") 146 | def test_water_eri(sparse): 147 | basis_name = "sto-3g" 148 | h2o = molecule("water") 149 | basis = basisset(h2o, basis_name) 150 | actual = eri_basis_sparse(basis) if sparse else eri_basis(basis) 151 | aosym = "s8" if sparse else "s1" 152 | expect = to_pyscf(h2o, basis_name=basis_name).intor("int2e_cart", aosym=aosym) 153 | assert_allclose(actual, expect, atol=1e-4) 154 | -------------------------------------------------------------------------------- /test/test_integrals_ipu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import jax.numpy as jnp 3 | import pytest 4 | from numpy.testing import assert_allclose 5 | 6 | from pyscf_ipu.experimental.device import has_ipu, ipu_func 7 | from pyscf_ipu.experimental.integrals import kinetic_primitives, overlap_primitives 8 | from pyscf_ipu.experimental.primitive import Primitive 9 | 10 | 11 | @pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") 12 | def test_overlap(): 13 | from pyscf_ipu.experimental.integrals import _overlap_primitives 14 | 15 | a, b = [Primitive()] * 2 16 | actual = ipu_func(_overlap_primitives)(a, b) 17 | assert_allclose(actual, overlap_primitives(a, b)) 18 | 19 | 20 | @pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") 21 | def test_kinetic(): 22 | from pyscf_ipu.experimental.integrals import _kinetic_primitives 23 | 24 | a, b = [Primitive()] * 2 25 | actual = ipu_func(_kinetic_primitives)(a, b) 26 | assert_allclose(actual, kinetic_primitives(a, b)) 27 | 28 | 29 | @pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") 30 | def test_nuclear(): 31 | from pyscf_ipu.experimental.integrals import _nuclear_primitives 32 | 33 | # PyQuante test case for nuclear attraction integral 34 | a, b = [Primitive()] * 2 35 | c = jnp.zeros(3) 36 | actual = ipu_func(_nuclear_primitives)(a, b, c) 37 | assert_allclose(actual, -1.595769, atol=1e-5) 38 | 39 | 40 | @pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!") 41 | def test_eri(): 42 | from pyscf_ipu.experimental.integrals import _eri_primitives 43 | 44 | # PyQuante test cases for ERI 45 | a, b, c, d = [Primitive()] * 4 46 | actual = ipu_func(_eri_primitives)(a, b, c, d) 47 | assert_allclose(actual, 1.128379, atol=1e-5) 48 | -------------------------------------------------------------------------------- /test/test_interop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import pytest 5 | from numpy.testing import assert_allclose 6 | 7 | from pyscf_ipu.experimental.basis import basisset 8 | from pyscf_ipu.experimental.interop import to_pyscf 9 | from pyscf_ipu.experimental.mesh import electron_density, uniform_mesh 10 | from pyscf_ipu.experimental.structure import molecule, nuclear_energy 11 | 12 | 13 | @pytest.mark.parametrize("basis_name", ["sto-3g", "6-31g**"]) 14 | def test_to_pyscf(basis_name): 15 | mol = molecule("water") 16 | basis = basisset(mol, basis_name) 17 | pyscf_mol = to_pyscf(mol, basis_name) 18 | assert basis.num_orbitals == pyscf_mol.nao 19 | 20 | 21 | def test_gto(): 22 | from pyscf.dft.numint import eval_rho 23 | 24 | # Atomic orbitals 25 | basis_name = "6-31+g" 26 | structure = molecule("water") 27 | basis = basisset(structure, basis_name) 28 | mesh, _ = uniform_mesh() 29 | actual = basis(mesh) 30 | 31 | mol = to_pyscf(structure, basis_name) 32 | expect_ao = mol.eval_gto("GTOval_cart", np.asarray(mesh)) 33 | assert_allclose(actual, expect_ao, atol=1e-6) 34 | 35 | # Molecular orbitals 36 | mf = mol.KS() 37 | mf.kernel() 38 | C = jnp.array(mf.mo_coeff, dtype=jnp.float32) 39 | actual = basis.occupancy * C @ C.T 40 | expect = jnp.array(mf.make_rdm1(), dtype=jnp.float32) 41 | assert_allclose(actual, expect, atol=1e-6) 42 | 43 | # Electron density 44 | actual = electron_density(basis, mesh, C) 45 | expect = eval_rho(mol, expect_ao, mf.make_rdm1(), "lda") 46 | assert_allclose(actual, expect, atol=1e-6) 47 | 48 | 49 | @pytest.mark.parametrize("name", ["water", "h2"]) 50 | def test_nuclear_energy(name): 51 | mol = molecule(name) 52 | actual = nuclear_energy(mol) 53 | expect = to_pyscf(mol).energy_nuc() 54 | assert_allclose(actual, expect) 55 | -------------------------------------------------------------------------------- /test/test_special.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import jax.numpy as jnp 3 | import pytest 4 | from numpy.testing import assert_allclose 5 | 6 | from pyscf_ipu.experimental.special import ( 7 | binom_beta, 8 | binom_fori, 9 | binom_lookup, 10 | factorial2_fori, 11 | factorial2_lookup, 12 | factorial_fori, 13 | factorial_gamma, 14 | factorial_lookup, 15 | ) 16 | 17 | 18 | def test_factorial(): 19 | x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8]) 20 | expect = jnp.array([1, 2, 6, 24, 120, 720, 5040, 40320]) 21 | assert_allclose(factorial_fori(x, x[-1]), expect) 22 | assert_allclose(factorial_lookup(x, x[-1]), expect) 23 | assert_allclose(factorial_gamma(x), expect) 24 | 25 | 26 | def test_factorial2(): 27 | x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8]) 28 | expect = jnp.array([1, 2, 3, 8, 15, 48, 105, 384]) 29 | assert_allclose(factorial2_fori(x), expect) 30 | assert_allclose(factorial2_fori(0), 1) 31 | 32 | assert_allclose(factorial2_lookup(x), expect) 33 | assert_allclose(factorial2_lookup(0), 1) 34 | 35 | 36 | @pytest.mark.parametrize("binom_func", [binom_beta, binom_fori, binom_lookup]) 37 | def test_binom(binom_func): 38 | x = jnp.array([4, 4, 4, 4]) 39 | y = jnp.array([1, 2, 3, 4]) 40 | expect = jnp.array([4, 6, 4, 1]) 41 | assert_allclose(binom_func(x, y), expect) 42 | 43 | zero = jnp.array([0]) 44 | assert_allclose(binom_func(zero, y), jnp.zeros_like(x)) 45 | assert_allclose(binom_func(x, zero), jnp.ones_like(y)) 46 | assert_allclose(binom_func(y, y), jnp.ones_like(y)) 47 | 48 | one = jnp.array([1]) 49 | assert_allclose(binom_func(one, one), one) 50 | assert_allclose(binom_func(zero, -one), zero) 51 | assert_allclose(binom_func(zero, zero), one) 52 | --------------------------------------------------------------------------------