├── .github └── workflows │ ├── pre-commit.yaml │ ├── publish.yaml │ ├── release.yaml │ └── unittest.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── _config.yml ├── _static │ └── style.css ├── _templates │ ├── class.rst │ └── module.rst ├── _toc.yml ├── api.rst ├── batching.ipynb ├── boys_function.ipynb ├── gammanu.ipynb ├── gto_integrals.ipynb ├── intro.md ├── misc.md ├── optim.ipynb ├── prologue.md ├── quadrature.ipynb ├── quirks.ipynb ├── references.bib └── tour.ipynb ├── mess ├── __init__.py ├── autograd_integrals.py ├── basis.py ├── binom_factor_table.py ├── hamiltonian.py ├── integrals.py ├── interop.py ├── mesh.py ├── numerics.py ├── orbital.py ├── orthnorm.py ├── package_utils.py ├── plot.py ├── primitive.py ├── scf.py ├── special.py ├── structure.py ├── types.py ├── units.py ├── xcfunctional.py └── zeropad_integrals.py ├── pyproject.toml └── test ├── conftest.py ├── test_autograd_integrals.py ├── test_benchmark.py ├── test_hamiltonian.py ├── test_integrals.py ├── test_interop.py ├── test_special.py └── test_xcfunctional.py /.github/workflows/pre-commit.yaml: -------------------------------------------------------------------------------- 1 | name: Pre-Commit Checks 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: actions/setup-python@v5 14 | with: 15 | python-version: "3.11" 16 | cache: "pip" 17 | - uses: pre-commit/action@v3.0.1 18 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Publish MESS Book 2 | 3 | on: 4 | # Runs on pushes targeting the default branch 5 | push: 6 | branches: ["main"] 7 | 8 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 9 | permissions: 10 | contents: read 11 | pages: write 12 | id-token: write 13 | 14 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 15 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 16 | concurrency: 17 | group: "pages" 18 | cancel-in-progress: false 19 | 20 | jobs: 21 | # Single deploy job since we're just deploying 22 | deploy: 23 | environment: 24 | name: github-pages 25 | url: ${{ steps.deployment.outputs.page_url }} 26 | runs-on: ubuntu-latest 27 | steps: 28 | - name: Checkout 29 | uses: actions/checkout@v4 30 | 31 | - name: Set up Python 3.11 32 | uses: actions/setup-python@v5 33 | with: 34 | python-version: '3.11' 35 | cache: pip 36 | 37 | - name: Install requirements 38 | run: | 39 | pip install -U pip 40 | pip install .[dev] 41 | 42 | - name: Build book 43 | run: | 44 | jupyter-book build docs 45 | 46 | - name: Setup Pages 47 | uses: actions/configure-pages@v3 48 | 49 | - name: Upload artifact 50 | uses: actions/upload-pages-artifact@v2 51 | with: 52 | path: "docs/_build/html" 53 | 54 | - name: Deploy to GitHub Pages 55 | id: deployment 56 | uses: actions/deploy-pages@v2 57 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: release 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | release-version: 7 | description: "A valid Semver version string" 8 | required: true 9 | 10 | permissions: 11 | contents: write 12 | pull-requests: write 13 | 14 | jobs: 15 | release: 16 | # Do not release if not triggered from the default branch 17 | if: github.ref == format('refs/heads/{0}', github.event.repository.default_branch) 18 | 19 | runs-on: ubuntu-latest 20 | timeout-minutes: 30 21 | 22 | defaults: 23 | run: 24 | shell: bash -l {0} 25 | 26 | steps: 27 | - name: Checkout the code 28 | uses: actions/checkout@v4 29 | 30 | - name: Setup mamba 31 | uses: mamba-org/setup-micromamba@v1 32 | with: 33 | environment-name: my_env 34 | cache-environment: true 35 | cache-downloads: true 36 | create-args: >- 37 | python=3.11 38 | pip 39 | semver 40 | python-build 41 | setuptools_scm 42 | 43 | - name: Check the version is valid semver 44 | run: | 45 | RELEASE_VERSION="${{ inputs.release-version }}" 46 | 47 | { 48 | pysemver check $RELEASE_VERSION 49 | } || { 50 | echo "The version '$RELEASE_VERSION' is not a valid Semver version string." 51 | echo "Please use a valid semver version string. More details at https://semver.org/" 52 | echo "The release process is aborted." 53 | exit 1 54 | } 55 | 56 | - name: Check the version is higher than the latest one 57 | run: | 58 | # Retrieve the git tags first 59 | git fetch --prune --unshallow --tags &> /dev/null 60 | 61 | RELEASE_VERSION="${{ inputs.release-version }}" 62 | LATEST_VERSION=$(git describe --abbrev=0 --tags) 63 | 64 | IS_HIGHER_VERSION=$(pysemver compare $RELEASE_VERSION $LATEST_VERSION) 65 | 66 | if [ "$IS_HIGHER_VERSION" != "1" ]; then 67 | echo "The version '$RELEASE_VERSION' is not higher than the latest version '$LATEST_VERSION'." 68 | echo "The release process is aborted." 69 | exit 1 70 | fi 71 | 72 | - name: Build Changelog 73 | id: github_release 74 | uses: mikepenz/release-changelog-builder-action@v4 75 | env: 76 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 77 | with: 78 | toTag: "main" 79 | 80 | - name: Configure git 81 | run: | 82 | git config --global user.name "${GITHUB_ACTOR}" 83 | git config --global user.email "${GITHUB_ACTOR}@users.noreply.github.com" 84 | 85 | - name: Create and push git tag 86 | env: 87 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 88 | run: | 89 | # Tag the release 90 | git tag -a "${{ inputs.release-version }}" -m "Release version ${{ inputs.release-version }}" 91 | 92 | # Checkout the git tag 93 | git checkout "${{ inputs.release-version }}" 94 | 95 | # Push the modified changelogs 96 | git push origin main 97 | 98 | # Push the tags 99 | git push origin "${{ inputs.release-version }}" 100 | 101 | - name: Install library 102 | run: | 103 | pip install -U pip 104 | python -m pip install --no-deps . 105 | 106 | - name: Build the wheel and sdist 107 | run: python -m build --no-isolation 108 | 109 | - name: Publish package to PyPI 110 | uses: pypa/gh-action-pypi-publish@release/v1 111 | with: 112 | password: ${{ secrets.PYPI_API_TOKEN }} 113 | packages-dir: dist/ 114 | 115 | - name: Create GitHub Release 116 | uses: softprops/action-gh-release@de2c0eb89ae2a093876385947365aca7b0e5f844 117 | with: 118 | tag_name: ${{ inputs.release-version }} 119 | body: ${{steps.github_release.outputs.changelog}} 120 | -------------------------------------------------------------------------------- /.github/workflows/unittest.yaml: -------------------------------------------------------------------------------- 1 | name: unit tests 2 | on: 3 | pull_request: 4 | push: 5 | branches: [main] 6 | 7 | jobs: 8 | pytest-container: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: actions/setup-python@v5 14 | with: 15 | python-version: "3.11" 16 | cache: "pip" 17 | 18 | - name: Install requirements 19 | run: | 20 | pip install -U pip 21 | pip install .[dev] 22 | 23 | - name: Log installed environment 24 | run: | 25 | python3 -m pip freeze 26 | 27 | - name: Run unit tests 28 | run: | 29 | pytest . 30 | 31 | - name: Build MESS book 32 | run: | 33 | jupyter-book build docs 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE 2 | .vscode 3 | 4 | # jupyter-book derived files 5 | docs/_autosummary 6 | docs/_build 7 | .benchmarks/ 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/#use-with-ide 118 | .pdm.toml 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.11 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.6.0 7 | hooks: 8 | - id: end-of-file-fixer 9 | - id: trailing-whitespace 10 | - id: check-yaml 11 | 12 | 13 | - repo: https://github.com/astral-sh/ruff-pre-commit 14 | rev: v0.6.8 15 | hooks: 16 | - id: ruff 17 | types_or: [python, jupyter] 18 | args: [--fix, --exit-non-zero-on-fix, --preview] 19 | 20 | - id: ruff-format 21 | args: [--preview] 22 | 23 | - repo: https://github.com/executablebooks/mdformat 24 | rev: 0.7.17 25 | hooks: 26 | - id: mdformat 27 | # exclusions to keep mdformat from breaking 28 | # - docs/intro.md: grid layout 29 | # - README.md: github admonintion 30 | exclude: "docs/intro.md|README.md" 31 | additional_dependencies: 32 | - mdformat-gfm 33 | - mdformat-admon 34 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MESS 2 | 3 | We are interested in hearing any and all feedback so feel free to raise any questions, 4 | bugs encountered, or enhancement requests as 5 | [Issues](https://github.com/valence-labs/mess/issues). 6 | 7 | ## Setting up a development environment 8 | 9 | The following assumes that you have already set up an install of conda and that the 10 | conda command is available on your system path. Refer to your preferred conda installer: 11 | 12 | - [miniforge installation](https://github.com/conda-forge/miniforge#install) 13 | - [conda installation documentation](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html). 14 | 15 | 1. Create a new conda environment with the minimum python version required: 16 | 17 | ```bash 18 | conda create -n mess python=3.11 19 | ``` 20 | 21 | 1. Install all required packages for developing MESS: 22 | 23 | ```bash 24 | pip install -e .[dev] 25 | ``` 26 | 27 | 1. Install the pre-commit hooks 28 | 29 | ```bash 30 | pre-commit install 31 | ``` 32 | 33 | 1. Create a feature branch, make changes, and when you commit them the pre-commit hooks 34 | will run. 35 | 36 | ```bash 37 | git checkout -b feature 38 | ... 39 | git push --set-upstream origin feature 40 | ``` 41 | 42 | The last command will print a link that you can follow to open a PR. 43 | 44 | ## Testing 45 | 46 | Run all the tests using `pytest` 47 | 48 | ```bash 49 | pytest 50 | ``` 51 | 52 | ## Building Documentation 53 | 54 | From the project root, you can build the documentation with: 55 | 56 | ```bash 57 | jupyter-book build docs 58 | ``` 59 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Graphcore Research 4 | Copyright (c) 2024 Valence Labs 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MESS: Modern Electronic Structure Simulations 2 | 3 | > [!IMPORTANT] 4 | > :hammer: :skull: :warning: :wrench:\ 5 | > This project is a constantly evolving work in progress.\ 6 | > Expect bugs and surprising performance cliffs.\ 7 | > :hammer: :skull: :warning: :wrench: 8 | 9 | [![docs](https://img.shields.io/badge/MESS-docs-blue?logo=bookstack)](https://valence-labs.github.io/mess) 10 | [![arXiv](https://img.shields.io/badge/arXiv-2406.03121-b31b1b.svg)](https://arxiv.org/abs/2406.03121) 11 | [![unit tests](https://github.com/valence-labs/mess/actions/workflows/unittest.yaml/badge.svg)](https://github.com/valence-labs/mess/actions/workflows/unittest.yaml) 12 | [![pre-commit checks](https://github.com/valence-labs/mess/actions/workflows/pre-commit.yaml/badge.svg)](https://github.com/valence-labs/mess/actions/workflows/pre-commit.yaml) 13 | 14 | ## Motivation 15 | 16 | MESS is primarily motivated by the need to demystify the underpinnings of electronic 17 | structure simulations. The target audience is the collective molecular machine learning 18 | community. We identify this community as anyone working towards accelerating our 19 | understanding of atomistic processes by combining physical models (e.g. density 20 | functional theory) with methods that learn from data (e.g. deep neural networks). 21 | 22 | ## Minimal Example 23 | 24 | Calculate the ground state energy of a single water molecule using the 6-31g basis set 25 | and the [local density approximation (LDA)](https://en.wikipedia.org/wiki/Local-density_approximation): 26 | ```python 27 | from mess import Hamiltonian, basisset, minimise, molecule 28 | 29 | mol = molecule("water") 30 | basis = basisset(mol, basis_name="6-31g") 31 | H = Hamiltonian(basis, xc_method="lda") 32 | E, C, sol = minimise(H) 33 | E 34 | ``` 35 | 36 | ## License 37 | 38 | The reader is encouraged to fork, edit, remix, and use the contents however they find 39 | most useful. All content is covered by the permissve [MIT License](./LICENSE) to 40 | encourage this. Our aim is to encourage a truly interdisciplinary approach to accelerate 41 | our understanding of molecular scale processes. 42 | 43 | ## Installing 44 | 45 | We recommend installing directly from the main branch from github and sharing any 46 | feedback as [issues](https://github.com/valence-labs/mess/issues). 47 | 48 | ``` 49 | pip install git+https://github.com/valence-labs/mess.git 50 | ``` 51 | 52 | Requires Python 3.11+ and we recommend [installing JAX](https://jax.readthedocs.io/en/latest/installation.html) for your target system (e.g. CPU, GPU, etc). 53 | 54 | 55 | ## Citation 56 | If you found this library to be useful in academic work, then please cite our 57 | [arXiv paper](https://arxiv.org/abs/2406.03121) 58 | ``` 59 | @misc{helal2024mess, 60 | title={MESS: Modern Electronic Structure Simulations}, 61 | author={Hatem Helal and Andrew Fitzgibbon}, 62 | year={2024}, 63 | eprint={2406.03121}, 64 | archivePrefix={arXiv}, 65 | primaryClass={cs.LG} 66 | } 67 | ``` 68 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | # Book settings 2 | # Learn more at https://jupyterbook.org/customize/config.html 3 | 4 | title: "MESS" 5 | author: "MESS Authors" 6 | copyright: "2024" 7 | 8 | # Force re-execution of notebooks on each build. 9 | # See https://jupyterbook.org/content/execute.html 10 | execute: 11 | execute_notebooks: force 12 | exclude_patterns: 13 | - 'tour.ipynb' 14 | - 'batching.ipynb' 15 | 16 | # Define the name of the latex output file for PDF builds 17 | latex: 18 | latex_engine: xelatex 19 | latex_documents: 20 | targetname: book.tex 21 | 22 | parse: 23 | myst_enable_extensions: 24 | - amsmath 25 | - dollarmath 26 | - colon_fence 27 | 28 | # Add a bibtex file so that we can create citations 29 | bibtex_bibfiles: 30 | - references.bib 31 | 32 | # Information about where the book exists on the web 33 | repository: 34 | url: https://github.com/valence-labs/mess 35 | path_to_book: docs 36 | branch: main 37 | 38 | # Add GitHub buttons to your book 39 | # See https://jupyterbook.org/customize/config.html#add-a-link-to-your-repository 40 | html: 41 | use_issues_button: true 42 | use_repository_button: true 43 | 44 | launch_buttons: 45 | colab_url: "https://colab.research.google.com" 46 | 47 | sphinx: 48 | extra_extensions: 49 | - "sphinx.ext.autodoc" 50 | - "sphinx.ext.autosummary" 51 | - "sphinx.ext.mathjax" 52 | - "sphinx.ext.napoleon" 53 | - "sphinx.ext.viewcode" 54 | - "sphinx_design" 55 | config: 56 | add_module_names: False 57 | autosummary_generate: True 58 | autodoc_typehints: "description" 59 | autodoc_class_signature: "separated" 60 | templates_path: "_templates" 61 | html_theme: sphinx_book_theme 62 | html_theme_options: 63 | navigation_with_keys: False 64 | repository_url: https://github.com/valence-labs/mess 65 | repository_branch: main 66 | path_to_docs: docs 67 | launch_buttons: 68 | colab_url: https://colab.research.google.com 69 | html_show_copyright: False 70 | html_static_path: ["_static"] 71 | html_css_files: ["style.css"] 72 | mathjax_path: https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 73 | -------------------------------------------------------------------------------- /docs/_static/style.css: -------------------------------------------------------------------------------- 1 | table.autosummary.longtable.table.autosummary { 2 | --bs-table-bg: initial; 3 | } 4 | -------------------------------------------------------------------------------- /docs/_templates/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | 8 | {% block methods %} 9 | .. automethod:: __init__ 10 | 11 | {% if methods %} 12 | .. rubric:: {{ _('Methods') }} 13 | 14 | .. autosummary:: 15 | {% for item in methods %} 16 | ~{{ name }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block attributes %} 22 | {% if attributes %} 23 | .. rubric:: {{ _('Attributes') }} 24 | 25 | .. autosummary:: 26 | {% for item in attributes %} 27 | ~{{ name }}.{{ item }} 28 | {%- endfor %} 29 | {% endif %} 30 | {% endblock %} 31 | -------------------------------------------------------------------------------- /docs/_templates/module.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | 5 | {% block attributes %} 6 | {% if attributes %} 7 | .. rubric:: {{ _('Module Attributes') }} 8 | 9 | .. autosummary:: 10 | :toctree: 11 | {% for item in attributes %} 12 | {{ item }} 13 | {%- endfor %} 14 | {% endif %} 15 | {% endblock %} 16 | 17 | {% block functions %} 18 | {% if functions %} 19 | .. rubric:: {{ _('Functions') }} 20 | 21 | .. autosummary:: 22 | :toctree: 23 | {% for item in functions %} 24 | {{ item }} 25 | {%- endfor %} 26 | {% endif %} 27 | {% endblock %} 28 | 29 | {% block classes %} 30 | {% if classes %} 31 | .. rubric:: {{ _('Classes') }} 32 | 33 | .. autosummary:: 34 | :toctree: 35 | :template: class.rst 36 | {% for item in classes %} 37 | {{ item }} 38 | {%- endfor %} 39 | {% endif %} 40 | {% endblock %} 41 | 42 | {% block exceptions %} 43 | {% if exceptions %} 44 | .. rubric:: {{ _('Exceptions') }} 45 | 46 | .. autosummary:: 47 | :toctree: 48 | {% for item in exceptions %} 49 | {{ item }} 50 | {%- endfor %} 51 | {% endif %} 52 | {% endblock %} 53 | 54 | {% block modules %} 55 | {% if modules %} 56 | .. rubric:: Modules 57 | 58 | .. autosummary:: 59 | :toctree: 60 | :template: module.rst 61 | :recursive: 62 | {% for item in modules %} 63 | {{ item }} 64 | {%- endfor %} 65 | {% endif %} 66 | {% endblock %} 67 | -------------------------------------------------------------------------------- /docs/_toc.yml: -------------------------------------------------------------------------------- 1 | # Table of contents 2 | # Learn more at https://jupyterbook.org/customize/toc.html 3 | 4 | format: jb-book 5 | root: intro 6 | parts: 7 | - caption: Getting Started 8 | chapters: 9 | - file: tour 10 | - file: prologue 11 | - file: optim 12 | - file: batching 13 | - file: quirks 14 | - caption: Miscellaneous Things that Might be Useful 15 | chapters: 16 | - file: gto_integrals 17 | - file: gammanu 18 | - caption: API Documentation 19 | chapters: 20 | - file: api 21 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | MESS Package 2 | ================= 3 | .. autosummary:: 4 | :toctree: _autosummary 5 | :template: module.rst 6 | :recursive: 7 | 8 | mess 9 | -------------------------------------------------------------------------------- /docs/boys_function.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import matplotlib.pyplot as plt" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 12, 16 | "metadata": {}, 17 | "outputs": [ 18 | { 19 | "data": { 20 | "text/plain": [ 21 | "[]" 22 | ] 23 | }, 24 | "execution_count": 12, 25 | "metadata": {}, 26 | "output_type": "execute_result" 27 | }, 28 | { 29 | "data": { 30 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjUAAAGdCAYAAADqsoKGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAz7ElEQVR4nO3de3RUZYL3+19VQqpiIAlJJEUwCGJaYAik5RLCeESHvIaWM5oWx0DTA8NkQfcM0EBmFMNBULvnDa0LGxkYM5xp57JeMvAyB2k7h0l3DIrTQ5pLAkdxhLFpNSBUAGkSDOZa+/yBtWM14VKk9t65fD9r1bJ617N3PbXXbvPzuboMwzAEAADQy7mdrgAAAEAkEGoAAECfQKgBAAB9AqEGAAD0CYQaAADQJxBqAABAn0CoAQAAfQKhBgAA9AnRTlfALoFAQGfOnNGgQYPkcrmcrg4AALgFhmHo8uXLSktLk9t947aYfhNqzpw5o/T0dKerAQAAbsOpU6d011133bBMvwk1gwYNknT1psTHxztcGwAAcCsaGxuVnp5u/h2/kX4TaoJdTvHx8YQaAAB6mVsZOsJAYQAA0CcQagAAQJ9wW6Fmy5YtGjFihLxer7Kzs3Xw4MEblt+5c6dGjx4tr9erzMxM7dmzJ+Tz559/XqNHj1ZcXJwGDx6s3NxcHThwIKTMxYsXNW/ePMXHxysxMVGFhYX64osvbqf6AACgDwo71OzYsUNFRUVat26damtrNWHCBOXl5encuXNdlt+/f7/mzp2rwsJCHTlyRPn5+crPz9exY8fMMt/4xje0efNmvf/++/rVr36lESNG6JFHHtH58+fNMvPmzdMHH3ygyspKlZeX691339XixYtv4ycDAIC+yGUYhhHOCdnZ2Zo8ebI2b94s6er6L+np6Vq2bJmeffbZa8oXFBSoqalJ5eXl5rGpU6cqKytLpaWlXX5HY2OjEhIS9NZbb2nGjBn68MMPNXbsWB06dEiTJk2SJFVUVOjRRx/V6dOnlZaWdtN6B6/Z0NDAQGEAAHqJcP5+h9VS09raqpqaGuXm5nZewO1Wbm6uqquruzynuro6pLwk5eXlXbd8a2urtm7dqoSEBE2YMMG8RmJiohloJCk3N1dut/uabqqglpYWNTY2hrwAAEDfFVaouXDhgjo6OpSamhpyPDU1VX6/v8tz/H7/LZUvLy/XwIED5fV69ZOf/ESVlZVKSUkxrzFkyJCQ8tHR0UpKSrru95aUlCghIcF8sfAeAAB9W4+Z/fTwww/r6NGj2r9/v2bOnKmnnnrquuN0bkVxcbEaGhrM16lTpyJYWwAA0NOEFWpSUlIUFRWl+vr6kOP19fXy+XxdnuPz+W6pfFxcnO69915NnTpVP/3pTxUdHa2f/vSn5jV+P+C0t7fr4sWL1/1ej8djLrTHgnsAAPR9YYWamJgYTZw4UVVVVeaxQCCgqqoq5eTkdHlOTk5OSHlJqqysvG75r1+3paXFvMalS5dUU1Njfr53714FAgFlZ2eH8xMAAEAfFfY2CUVFRVqwYIEmTZqkKVOmaOPGjWpqatLChQslSfPnz9ewYcNUUlIiSVq+fLmmT5+uDRs2aNasWdq+fbsOHz6srVu3SpKampr0N3/zN3rsscc0dOhQXbhwQVu2bNFnn32mP/mTP5EkjRkzRjNnztSiRYtUWlqqtrY2LV26VHPmzLmlmU8AAKDvCzvUFBQU6Pz581q7dq38fr+ysrJUUVFhDgauq6sL2Rp82rRpKisr05o1a7R69WplZGRo9+7dGjdunCQpKipKx48f1z//8z/rwoULSk5O1uTJk/Uf//Ef+oM/+APzOtu2bdPSpUs1Y8YMud1uzZ49W5s2beru7wcAAH1E2OvU9FZWrVPzUf1l/evBUxoS79H3p4+K2HUBAICF69TgWmcamvX6f36sN4+ecboqAAD0a4SabvJGX72Fze0dDtcEAID+jVDTTd4BUZKk5lZCDQAATiLUdJMZatoDDtcEAID+jVDTTbHBUNNGSw0AAE4i1HSTd8BXY2raOtRPJpIBANAjEWq6yfNVS03AkFo76IICAMAphJpuCrbUSFJzG6EGAACnEGq6KSbKLbfr6vsWxtUAAOAYQk03uVyuzhlQtNQAAOAYQk0EBEPNl7TUAADgGEJNBJirChNqAABwDKEmArwxrFUDAIDTCDUR4I1mVWEAAJxGqImA4LTuL9n/CQAAxxBqIiA4ULiFnboBAHAMoSYC2P8JAADnEWoigHVqAABwHqEmAjzBMTW01AAA4BhCTQR46X4CAMBxhJoIiKX7CQAAxxFqIiA4pZuWGgAAnEOoiQBz8T1CDQAAjiHURABjagAAcB6hJgI6935iTA0AAE4h1ESAuUs3KwoDAOAYQk0EBLuf2PsJAADnEGoiwBxTwy7dAAA4hlATAcF1aloYKAwAgGMINRHAOjUAADiPUBMB5pgaQg0AAI4h1ERAZ0sNY2oAAHAKoSYCWHwPAADnEWoiIBhqWtoDCgQMh2sDAED/RKiJgGCoka4GGwAAYD9CTQQEVxSW6IICAMAphJoIiI5ya0CUSxJbJQAA4BRCTYR4o9nUEgAAJxFqIsTD/k8AADiKUBMh5lo1dD8BAOAIQk2ExLJWDQAAjiLURIi5Vg1jagAAcAShJkKC3U/s/wQAgDMINRHCVgkAADiLUBMhnaGG7icAAJxAqIkQWmoAAHDWbYWaLVu2aMSIEfJ6vcrOztbBgwdvWH7nzp0aPXq0vF6vMjMztWfPHvOztrY2rVq1SpmZmYqLi1NaWprmz5+vM2fOhFxjxIgRcrlcIa/169ffTvUtEdwqgTE1AAA4I+xQs2PHDhUVFWndunWqra3VhAkTlJeXp3PnznVZfv/+/Zo7d64KCwt15MgR5efnKz8/X8eOHZMkXblyRbW1tXruuedUW1urXbt26cSJE3rssceuudaLL76os2fPmq9ly5aFW33LdM5+ItQAAOAEl2EYRjgnZGdna/Lkydq8ebMkKRAIKD09XcuWLdOzzz57TfmCggI1NTWpvLzcPDZ16lRlZWWptLS0y+84dOiQpkyZok8//VTDhw+XdLWlZsWKFVqxYkU41TU1NjYqISFBDQ0Nio+Pv61r3Mj/3POhtr77Wy1+8B6tfnRMxK8PAEB/FM7f77BaalpbW1VTU6Pc3NzOC7jdys3NVXV1dZfnVFdXh5SXpLy8vOuWl6SGhga5XC4lJiaGHF+/fr2Sk5P1zW9+Uy+//LLa29uve42WlhY1NjaGvKwU7H5iTA0AAM6IDqfwhQsX1NHRodTU1JDjqampOn78eJfn+P3+Lsv7/f4uyzc3N2vVqlWaO3duSCL7wQ9+oPvvv19JSUnav3+/iouLdfbsWb3yyitdXqekpEQvvPBCOD+vW9j7CQAAZ4UVaqzW1tamp556SoZh6LXXXgv5rKioyHw/fvx4xcTE6Hvf+55KSkrk8XiuuVZxcXHIOY2NjUpPT7es7ubsp3amdAMA4ISwQk1KSoqioqJUX18fcry+vl4+n6/Lc3w+3y2VDwaaTz/9VHv37r1pv1l2drba29v1ySef6L777rvmc4/H02XYsQp7PwEA4KywxtTExMRo4sSJqqqqMo8FAgFVVVUpJyeny3NycnJCyktSZWVlSPlgoPnoo4/01ltvKTk5+aZ1OXr0qNxut4YMGRLOT7CMuUs3oQYAAEeE3f1UVFSkBQsWaNKkSZoyZYo2btyopqYmLVy4UJI0f/58DRs2TCUlJZKk5cuXa/r06dqwYYNmzZql7du36/Dhw9q6daukq4HmySefVG1trcrLy9XR0WGOt0lKSlJMTIyqq6t14MABPfzwwxo0aJCqq6u1cuVKffe739XgwYMjdS+6hcX3AABwVtihpqCgQOfPn9fatWvl9/uVlZWliooKczBwXV2d3O7OBqBp06aprKxMa9as0erVq5WRkaHdu3dr3LhxkqTPPvtMb775piQpKysr5LvefvttPfTQQ/J4PNq+fbuef/55tbS0aOTIkVq5cmXImBmndbbUMKYGAAAnhL1OTW9l9To1+09e0Hf+7wPKGDJQlUXTI359AAD6I8vWqcH1dc5+ovsJAAAnEGoixBsdXKeG7icAAJxAqImQ4Jga9n4CAMAZhJoIiY2h+wkAACcRaiIk2P3U1mGovYMuKAAA7EaoiZDgQGGJrRIAAHACoSZCPNGdt5IF+AAAsB+hJkLcbpcZbAg1AADYj1ATQZ1bJdD9BACA3Qg1EcSmlgAAOIdQE0GxbGoJAIBjCDURRPcTAADOIdREkIeWGgAAHEOoiSDvV7OfviTUAABgO0JNBJlbJRBqAACwHaEmgoJbJbCiMAAA9iPURBA7dQMA4BxCTQQFZz992UqoAQDAboSaCDKndLcTagAAsBuhJoJYpwYAAOcQaiKIbRIAAHAOoSaCzDE1hBoAAGxHqImg4N5PLXQ/AQBgO0JNBNH9BACAcwg1EcTsJwAAnEOoiSBPNOvUAADgFEJNBHXu/cSYGgAA7EaoiaDgLt10PwEAYD9CTQR5mf0EAIBjCDURxDo1AAA4h1ATQbHmNgmEGgAA7EaoiaCvr1NjGIbDtQEAoH8h1ESQ56uWmoAhtXUQagAAsBOhJoKCLTUS42oAALAboSaCYqLccruuvm8h1AAAYCtCTQS5XK7OrRKY1g0AgK0INRHG/k8AADiDUBNhwVWF2f8JAAB7EWoizBvDWjUAADiBUBNh3uhg9xNjagAAsBOhJsK+vgAfAACwD6EmwrxslQAAgCMINRHG/k8AADiDUBNhrFMDAIAzCDUR5mFMDQAAjiDURFiwpYa9nwAAsNdthZotW7ZoxIgR8nq9ys7O1sGDB29YfufOnRo9erS8Xq8yMzO1Z88e87O2tjatWrVKmZmZiouLU1pamubPn68zZ86EXOPixYuaN2+e4uPjlZiYqMLCQn3xxRe3U31LxdL9BACAI8IONTt27FBRUZHWrVun2tpaTZgwQXl5eTp37lyX5ffv36+5c+eqsLBQR44cUX5+vvLz83Xs2DFJ0pUrV1RbW6vnnntOtbW12rVrl06cOKHHHnss5Drz5s3TBx98oMrKSpWXl+vdd9/V4sWLb+MnW4sp3QAAOMNlGIYRzgnZ2dmaPHmyNm/eLEkKBAJKT0/XsmXL9Oyzz15TvqCgQE1NTSovLzePTZ06VVlZWSotLe3yOw4dOqQpU6bo008/1fDhw/Xhhx9q7NixOnTokCZNmiRJqqio0KOPPqrTp08rLS3tpvVubGxUQkKCGhoaFB8fH85PDsvfVn2kDZX/rblT0lXyxHjLvgcAgP4gnL/fYbXUtLa2qqamRrm5uZ0XcLuVm5ur6urqLs+prq4OKS9JeXl51y0vSQ0NDXK5XEpMTDSvkZiYaAYaScrNzZXb7daBAwe6vEZLS4saGxtDXnYwx9Sw9xMAALYKK9RcuHBBHR0dSk1NDTmempoqv9/f5Tl+vz+s8s3NzVq1apXmzp1rJjK/368hQ4aElIuOjlZSUtJ1r1NSUqKEhATzlZ6efku/sbs6935iTA0AAHbqUbOf2tra9NRTT8kwDL322mvdulZxcbEaGhrM16lTpyJUyxsL7tLd3E5LDQAAdooOp3BKSoqioqJUX18fcry+vl4+n6/Lc3w+3y2VDwaaTz/9VHv37g3pN/P5fNcMRG5vb9fFixev+70ej0cej+eWf1uksE0CAADOCKulJiYmRhMnTlRVVZV5LBAIqKqqSjk5OV2ek5OTE1JekiorK0PKBwPNRx99pLfeekvJycnXXOPSpUuqqakxj+3du1eBQEDZ2dnh/ATLda5TQ/cTAAB2CqulRpKKioq0YMECTZo0SVOmTNHGjRvV1NSkhQsXSpLmz5+vYcOGqaSkRJK0fPlyTZ8+XRs2bNCsWbO0fft2HT58WFu3bpV0NdA8+eSTqq2tVXl5uTo6OsxxMklJSYqJidGYMWM0c+ZMLVq0SKWlpWpra9PSpUs1Z86cW5r5ZKfgOjUttNQAAGCrsENNQUGBzp8/r7Vr18rv9ysrK0sVFRXmYOC6ujq53Z0NQNOmTVNZWZnWrFmj1atXKyMjQ7t379a4ceMkSZ999pnefPNNSVJWVlbId7399tt66KGHJEnbtm3T0qVLNWPGDLndbs2ePVubNm26nd9sKdapAQDAGWGvU9Nb2bVOzbHPGvR//u2v5Iv36terZ1j2PQAA9AeWrVODmwu21LD3EwAA9iLURBiznwAAcAahJsKCoaalPaB+0rMHAECPQKiJsGCoka4GGwAAYA9CTYQFVxSW2P8JAAA7EWoiLDrKrQFRLklslQAAgJ0INRbwRrOpJQAAdiPUWMDDDCgAAGxHqLEAa9UAAGA/Qo0FYmmpAQDAdoQaC5hr1TCmBgAA2xBqLMCmlgAA2I9QY4FgSw1jagAAsA+hxgKd+z/R/QQAgF0INRZgU0sAAOxHqLFAcKsEVhQGAMA+hBoLmC017P0EAIBtCDUWiI35KtSwSzcAALYh1FjA7H5iTA0AALYh1FiAvZ8AALAfocYCnevU0P0EAIBdCDUWYO8nAADsR6ixANskAABgP0KNBdjQEgAA+xFqLBBsqWHvJwAA7EOosQDbJAAAYD9CjQXMUMM2CQAA2IZQYwFvNLt0AwBgN0KNBczZT+z9BACAbQg1Fujc+4lQAwCAXQg1Fgh2P7V1GOoIGA7XBgCA/oFQY4HgQGGJGVAAANiFUGMBT3TnbWWtGgAA7EGosYDb7TKDDS01AADYg1Bjkc4F+JjWDQCAHQg1FmFTSwAA7EWosQhbJQAAYC9CjUVi6X4CAMBWhBqLeGipAQDAVoQai3iDs59YVRgAAFsQaiwSHFPzJfs/AQBgC0KNRcwxNe2MqQEAwA6EGosEp3S3MKYGAABbEGoswpRuAADsRaixiDmmhlADAIAtCDUWYZsEAADsdVuhZsuWLRoxYoS8Xq+ys7N18ODBG5bfuXOnRo8eLa/Xq8zMTO3Zsyfk8127dumRRx5RcnKyXC6Xjh49es01HnroIblcrpDX97///dupvi3YJgEAAHuFHWp27NihoqIirVu3TrW1tZowYYLy8vJ07ty5Lsvv379fc+fOVWFhoY4cOaL8/Hzl5+fr2LFjZpmmpiY98MAD+vGPf3zD7160aJHOnj1rvl566aVwq28bWmoAALBX2KHmlVde0aJFi7Rw4UKNHTtWpaWluuOOO/T66693Wf7VV1/VzJkz9fTTT2vMmDH64Q9/qPvvv1+bN282y/zpn/6p1q5dq9zc3Bt+9x133CGfz2e+4uPjw62+bczF92ipAQDAFmGFmtbWVtXU1ISED7fbrdzcXFVXV3d5TnV19TVhJS8v77rlb2Tbtm1KSUnRuHHjVFxcrCtXroR9DbvExjD7CQAAO0WHU/jChQvq6OhQampqyPHU1FQdP368y3P8fn+X5f1+f1gV/c53vqO7775baWlpeu+997Rq1SqdOHFCu3bt6rJ8S0uLWlpazP/d2NgY1vd1l9n9xDYJAADYIqxQ46TFixeb7zMzMzV06FDNmDFDJ0+e1KhRo64pX1JSohdeeMHOKobwRDOmBgAAO4XV/ZSSkqKoqCjV19eHHK+vr5fP5+vyHJ/PF1b5W5WdnS1J+s1vftPl58XFxWpoaDBfp06d6tb3hSs4+4m9nwAAsEdYoSYmJkYTJ05UVVWVeSwQCKiqqko5OTldnpOTkxNSXpIqKyuvW/5WBad9Dx06tMvPPR6P4uPjQ152iqX7CQAAW4Xd/VRUVKQFCxZo0qRJmjJlijZu3KimpiYtXLhQkjR//nwNGzZMJSUlkqTly5dr+vTp2rBhg2bNmqXt27fr8OHD2rp1q3nNixcvqq6uTmfOnJEknThxQpLMWU4nT55UWVmZHn30USUnJ+u9997TypUr9eCDD2r8+PHdvglWCI6paaH7CQAAW4QdagoKCnT+/HmtXbtWfr9fWVlZqqioMAcD19XVye3ubACaNm2aysrKtGbNGq1evVoZGRnavXu3xo0bZ5Z58803zVAkSXPmzJEkrVu3Ts8//7xiYmL01ltvmQEqPT1ds2fP1po1a277h1uNvZ8AALCXyzAMw+lK2KGxsVEJCQlqaGiwpSvq08+bNP3ld3RHTJT+68WZln8fAAB9UTh/v9n7ySKxX2up6Se5EQAARxFqLOL5KtQEDKmtg1ADAIDVCDUWCU7plpgBBQCAHQg1FomJcsvluvq+mbVqAACwHKHGIi6X62vjapjWDQCA1Qg1FmL/JwAA7EOosZA3+urtZa0aAACsR6ixULClhv2fAACwHqHGQp3dT4ypAQDAaoQaCwWnddP9BACA9Qg1FmL/JwAA7EOosRChBgAA+xBqLMQ6NQAA2IdQYyEPY2oAALANocZCXlpqAACwDaHGQt7or9apoaUGAADLEWosFBtD9xMAAHYh1Fgo2FLTwt5PAABYjlBjIcbUAABgH0KNhYIrCrP3EwAA1iPUWKhz7ydCDQAAViPUWIgVhQEAsA+hxkKMqQEAwD6EGguxSzcAAPYh1Fgolu4nAABsQ6ixEN1PAADYh1BjIbP7idlPAABYjlBjIU9w7yfWqQEAwHKEGgvFxgS3SQjIMAyHawMAQN9GqLFQcEyNdDXYAAAA6xBqLOSN7ry9zIACAMBahBoLRUe5Fe12SZK+JNQAAGApQo3FYpnWDQCALQg1FvOwAB8AALYg1FiMrRIAALAHocZiwRlQjKkBAMBahBqLBcfUtDCmBgAASxFqLEb3EwAA9iDUWMzc1JL9nwAAsBShxmLmmJpWup8AALASocZiXqZ0AwBgC0KNxYJbJdD9BACAtQg1FvOyojAAALYg1FgsNobuJwAA7ECosZjZ/USoAQDAUoQai7H3EwAA9iDUWIwxNQAA2OO2Qs2WLVs0YsQIeb1eZWdn6+DBgzcsv3PnTo0ePVper1eZmZnas2dPyOe7du3SI488ouTkZLlcLh09evSaazQ3N2vJkiVKTk7WwIEDNXv2bNXX199O9W0Vy95PAADYIuxQs2PHDhUVFWndunWqra3VhAkTlJeXp3PnznVZfv/+/Zo7d64KCwt15MgR5efnKz8/X8eOHTPLNDU16YEHHtCPf/zj637vypUr9fOf/1w7d+7Uvn37dObMGT3xxBPhVt92bJMAAIA9XIZhGOGckJ2drcmTJ2vz5s2SpEAgoPT0dC1btkzPPvvsNeULCgrU1NSk8vJy89jUqVOVlZWl0tLSkLKffPKJRo4cqSNHjigrK8s83tDQoDvvvFNlZWV68sknJUnHjx/XmDFjVF1dralTp9603o2NjUpISFBDQ4Pi4+PD+cndsuf9s/rLbbWaMiJJ//v7ObZ9LwAAfUE4f7/DaqlpbW1VTU2NcnNzOy/gdis3N1fV1dVdnlNdXR1SXpLy8vKuW74rNTU1amtrC7nO6NGjNXz48Otep6WlRY2NjSEvJ5gtNSy+BwCApcIKNRcuXFBHR4dSU1NDjqempsrv93d5jt/vD6v89a4RExOjxMTEW75OSUmJEhISzFd6evotf18kde79RKgBAMBKfXb2U3FxsRoaGszXqVOnHKkHu3QDAGCP6HAKp6SkKCoq6ppZR/X19fL5fF2e4/P5wip/vWu0trbq0qVLIa01N7qOx+ORx+O55e+wijeaKd0AANghrJaamJgYTZw4UVVVVeaxQCCgqqoq5eR0PQg2JycnpLwkVVZWXrd8VyZOnKgBAwaEXOfEiROqq6sL6zpOYPYTAAD2CKulRpKKioq0YMECTZo0SVOmTNHGjRvV1NSkhQsXSpLmz5+vYcOGqaSkRJK0fPlyTZ8+XRs2bNCsWbO0fft2HT58WFu3bjWvefHiRdXV1enMmTOSrgYW6WoLjc/nU0JCggoLC1VUVKSkpCTFx8dr2bJlysnJuaWZT05i7ycAAOwRdqgpKCjQ+fPntXbtWvn9fmVlZamiosIcDFxXVye3u7MBaNq0aSorK9OaNWu0evVqZWRkaPfu3Ro3bpxZ5s033zRDkSTNmTNHkrRu3To9//zzkqSf/OQncrvdmj17tlpaWpSXl6e/+7u/u60fbadg91Nbh6GOgKEot8vhGgEA0DeFvU5Nb+XUOjVftnZozNoKSdIHL+QpzhN2jgQAoN+ybJ0ahM8T3XmL6YICAMA6hBqLud0uM9iw/xMAANYh1NiAnboBALAeocYGTOsGAMB6hBobBFtqWlhVGAAAyxBqbBBr7v9E9xMAAFYh1NjAM4AF+AAAsBqhxgber2Y/saklAADWIdTYgNlPAABYj1BjA3NMDd1PAABYhlBjg+CU7hZCDQAAliHU2MDLQGEAACxHqLEBY2oAALAeocYGXsbUAABgOUKNDdgmAQAA6xFqbED3EwAA1iPU2IDF9wAAsB6hxgaxMV+11LQSagAAsAqhxgZm9xMtNQAAWIZQYwNPNGNqAACwGqHGBsx+AgDAeoQaG7D3EwAA1iPU2CA4pqaF7icAACxDqLEBez8BAGA9Qo0NGFMDAID1CDU2+PqYGsMwHK4NAAB9E6HGBvGxAyRJAUO6dKXN4doAANA3EWps4B0QpbQEryTptxe+cLg2AAD0TYQam9xz50BJ0snzTQ7XBACAvolQY5N77oyTJP2WUAMAgCUINTa5JyUYauh+AgDACoQamwS7n357gZYaAACsQKixSbD76dPPm9TewcrCAABEGqHGJmkJsfIOcKutw9Dp333pdHUAAOhzCDU2cbtdGpH81bgapnUDABBxhBobjQqOq2EGFAAAEUeosVFwXA1r1QAAEHmEGht1rlVD9xMAAJFGqLHRPSlM6wYAwCqEGhsFW2rOX27R5WY2tgQAIJIINTYa5B2gOwd5JDFYGACASCPU2MzcLoFp3QAARBShxmb3MK0bAABLEGpsNordugEAsAShxmada9XQ/QQAQCQRamwWnNb98YUmBQKGw7UBAKDvuK1Qs2XLFo0YMUJer1fZ2dk6ePDgDcvv3LlTo0ePltfrVWZmpvbs2RPyuWEYWrt2rYYOHarY2Fjl5ubqo48+CikzYsQIuVyukNf69etvp/qOumtwrAZEudTSHtBnl9jYEgCASAk71OzYsUNFRUVat26damtrNWHCBOXl5encuXNdlt+/f7/mzp2rwsJCHTlyRPn5+crPz9exY8fMMi+99JI2bdqk0tJSHThwQHFxccrLy1Nzc3PItV588UWdPXvWfC1btizc6jsuOsqtu82NLRlXAwBApIQdal555RUtWrRICxcu1NixY1VaWqo77rhDr7/+epflX331Vc2cOVNPP/20xowZox/+8Ie6//77tXnzZklXW2k2btyoNWvW6PHHH9f48eP1L//yLzpz5ox2794dcq1BgwbJ5/OZr7i4uPB/cQ9gTutmXA0AABETVqhpbW1VTU2NcnNzOy/gdis3N1fV1dVdnlNdXR1SXpLy8vLM8h9//LH8fn9ImYSEBGVnZ19zzfXr1ys5OVnf/OY39fLLL6u9vf26dW1paVFjY2PIq6dgWjcAAJEXHU7hCxcuqKOjQ6mpqSHHU1NTdfz48S7P8fv9XZb3+/3m58Fj1ysjST/4wQ90//33KykpSfv371dxcbHOnj2rV155pcvvLSkp0QsvvBDOz7ONubElC/ABABAxYYUaJxUVFZnvx48fr5iYGH3ve99TSUmJPB7PNeWLi4tDzmlsbFR6erotdb0Z1qoBACDywup+SklJUVRUlOrr60OO19fXy+fzdXmOz+e7YfngP8O5piRlZ2ervb1dn3zySZefezwexcfHh7x6iuC07rMNzbrSev0uNAAAcOvCCjUxMTGaOHGiqqqqzGOBQEBVVVXKycnp8pycnJyQ8pJUWVlplh85cqR8Pl9ImcbGRh04cOC615Sko0ePyu12a8iQIeH8hB5hcFyMBt8xQBKtNQAARErY3U9FRUVasGCBJk2apClTpmjjxo1qamrSwoULJUnz58/XsGHDVFJSIklavny5pk+frg0bNmjWrFnavn27Dh8+rK1bt0qSXC6XVqxYoR/96EfKyMjQyJEj9dxzzyktLU35+fmSrg42PnDggB5++GENGjRI1dXVWrlypb773e9q8ODBEboV9rrnzoGq+fR3+u2FJo0bluB0dQAA6PXCDjUFBQU6f/681q5dK7/fr6ysLFVUVJgDfevq6uR2dzYATZs2TWVlZVqzZo1Wr16tjIwM7d69W+PGjTPLPPPMM2pqatLixYt16dIlPfDAA6qoqJDX65V0tStp+/btev7559XS0qKRI0dq5cqVIWNmept7UuKuhhqmdQMAEBEuwzD6xVr9jY2NSkhIUENDQ48YX/PaOyf144rjemxCmjbN/abT1QEAoEcK5+83ez85hGndAABEFqHGIcFp3R+fb1I/aSwDAMBShBqHDE+KU5TbpabWDtU3tjhdHQAAej1CjUNiot1KHxwriT2gAACIBEKNg4J7QJ1kt24AALqNUOMgdusGACByCDUOGjWE3boBAIgUQo2DzJYapnUDANBthBoHBcfUnP7dl2pu63C4NgAA9G6EGgelDIzRIG+0DEP69PMrTlcHAIBejVDjIJfLZbbWMFgYAIDuIdQ4bJQ5robBwgAAdAehxmHBPaBOnqOlBgCA7iDUOIwF+AAAiAxCjcPM3brPf8HGlgAAdAOhxmEjkuPkckmXm9t14YtWp6sDAECvRahxmHdAlIYlsrElAADdRajpAcxp3YyrAQDgthFqegA2tgQAoPsINT3AKHOwMC01AADcLkJND0D3EwAA3Ueo6QGC07rrLl5Ra3vA4doAANA7EWp6AF+8V3fERKkjYKjuIhtbAgBwOwg1PYDL5dJIBgsDANAthJoegnE1AAB0D6Gmh2BaNwAA3UOo6SHuYVo3AADdQqjpIUbR/QQAQLcQanqI4EDhi02tunSFjS0BAAgXoaaHiPNEmxtb/vKDeodrAwBA70Oo6UEWTLtbkvTjiuNquNLmcG0AAOhdCDU9yMI/HKmMIQP1eVOrXv7lcaerAwBAr0Ko6UEGRLn14uPjJEnbDtTp/dMNDtcIAIDeg1DTw+SMStbjWWkyDGnNz44pEDCcrhIAAL0CoaYH+r8eHaNBnmj9f6cuafuhU05XBwCAXoFQ0wMNifdq5f/4hiTppV8c18UmpngDAHAzhJoean7O3RrtG6RLV9r0UgWDhgEAuBlCTQ8VHeXWj/KvDhrefuiUaut+53CNAADo2Qg1PdikEUl6cuJdkqS1PzumDgYNAwBwXYSaHu7Zb41WvDdaxz5r1LYDnzpdHQAAeixCTQ+XMtCjp/PukyS9/IsTuvBFi8M1AgCgZyLU9ALfyb5b44bF63Jzu0r2MGgYAICuEGp6gSi3Sz98fJxcLun/qT2tQ59cdLpKAAD0OISaXuKbwwdrzuR0SdJzu4+prSPgcI0AAOhZCDW9yDN5o5V4xwAd91/WQy+/o3/4j9+qsZndvAEAkAg1vcrguBht+JMJSoqL0WeXvtSP/t8PNa1kr178+X/p1MUrTlcPAABH3Vao2bJli0aMGCGv16vs7GwdPHjwhuV37typ0aNHy+v1KjMzU3v27An53DAMrV27VkOHDlVsbKxyc3P10UcfhZS5ePGi5s2bp/j4eCUmJqqwsFBffPHF7VS/V5sxJlX7n/0jrX8iUxlDBuqLlna9/p8fa/rLb+sv/leNDn9yUYbBejYAgP4n7FCzY8cOFRUVad26daqtrdWECROUl5enc+fOdVl+//79mjt3rgoLC3XkyBHl5+crPz9fx44dM8u89NJL2rRpk0pLS3XgwAHFxcUpLy9Pzc3NZpl58+bpgw8+UGVlpcrLy/Xuu+9q8eLFt/GTez/vgCjNmTJcv1z5oP75z6fo/8hIUcCQ/v2YX0+WVit/y3/qZ0c/k7+hmV2+AQD9hssI8z/rs7OzNXnyZG3evFmSFAgElJ6ermXLlunZZ5+9pnxBQYGamppUXl5uHps6daqysrJUWloqwzCUlpamv/qrv9Jf//VfS5IaGhqUmpqqf/qnf9KcOXP04YcfauzYsTp06JAmTZokSaqoqNCjjz6q06dPKy0t7ab1bmxsVEJCghoaGhQfHx/OT+4V/rv+sl7/1cfadeQztbZ3DiKOiXIrLdGruwbfobsGx371uvp+2OBYJcbGyBPtltvtcrD2AAB0LZy/39HhXLi1tVU1NTUqLi42j7ndbuXm5qq6urrLc6qrq1VUVBRyLC8vT7t375Ykffzxx/L7/crNzTU/T0hIUHZ2tqqrqzVnzhxVV1crMTHRDDSSlJubK7fbrQMHDujb3/72Nd/b0tKilpbOheoaGxvD+am9zjdSB2n97PH667z7tO3Xddp99DPVXbyi1o6APvn8ij75/MZjbjzRbnkHRCl2QJS8A66+9371fkCUW1Ful6JcLrm/+meUO/hecrtdcrtccklyuSSXXFf/6ZIUfP+1z4JcX8tRX49ULpe9AcvmrwOAPmvUnQP13al3O/b9YYWaCxcuqKOjQ6mpqSHHU1NTdfx414vC+f3+Lsv7/X7z8+CxG5UZMmRIaMWjo5WUlGSW+X0lJSV64YUXbvGX9R0pAz1anpuh5bkZau8IqP5yi05fvKLTv/vyq9dX7y9d0ZlLzeZ+Ui3tAbW0B9TwJbOpAAC358Fv3Nl7Qk1vUlxcHNJC1NjYqPT0dAdrZL/oKLeGJcZqWGKssrv4vCNgqLmtQ1+2dajZfAXM//1l69XP2joMBQxDgYChjuA/A4Y6DHUeMwwFOzKNr94H+zWvvu/8XOr8zCzQ1fHruJUOU+OWrgQAiKQRyXGOfn9YoSYlJUVRUVGqr68POV5fXy+fz9flOT6f74blg/+sr6/X0KFDQ8pkZWWZZX5/IHJ7e7suXrx43e/1eDzyeDy3/uP6oSi3S3GeaMV5+my2BQD0I2HNfoqJidHEiRNVVVVlHgsEAqqqqlJOTk6X5+Tk5ISUl6TKykqz/MiRI+Xz+ULKNDY26sCBA2aZnJwcXbp0STU1NWaZvXv3KhAIKDu7qzYIAADQ34T9n+hFRUVasGCBJk2apClTpmjjxo1qamrSwoULJUnz58/XsGHDVFJSIklavny5pk+frg0bNmjWrFnavn27Dh8+rK1bt0q6Oih0xYoV+tGPfqSMjAyNHDlSzz33nNLS0pSfny9JGjNmjGbOnKlFixaptLRUbW1tWrp0qebMmXNLM58AAEDfF3aoKSgo0Pnz57V27Vr5/X5lZWWpoqLCHOhbV1cnt7uzAWjatGkqKyvTmjVrtHr1amVkZGj37t0aN26cWeaZZ55RU1OTFi9erEuXLumBBx5QRUWFvF6vWWbbtm1aunSpZsyYIbfbrdmzZ2vTpk3d+e0AAKAPCXudmt6qr69TAwBAXxTO32/2fgIAAH0CoQYAAPQJhBoAANAnEGoAAECfQKgBAAB9AqEGAAD0CYQaAADQJxBqAABAn0CoAQAAfUK/2Z45uHByY2OjwzUBAAC3Kvh3+1Y2QOg3oeby5cuSpPT0dIdrAgAAwnX58mUlJCTcsEy/2fspEAjozJkzGjRokFwul9PV6VEaGxuVnp6uU6dOsS/WbeIedg/3r/u4h93D/es+q+6hYRi6fPmy0tLSQjbM7kq/aalxu9266667nK5GjxYfH8//mbuJe9g93L/u4x52D/ev+6y4hzdroQlioDAAAOgTCDUAAKBPINRAHo9H69atk8fjcboqvRb3sHu4f93HPewe7l/39YR72G8GCgMAgL6NlhoAANAnEGoAAECfQKgBAAB9AqEGAAD0CYSafuTdd9/VH//xHystLU0ul0u7d+8O+fzP/uzP5HK5Ql4zZ850prI9UElJiSZPnqxBgwZpyJAhys/P14kTJ0LKNDc3a8mSJUpOTtbAgQM1e/Zs1dfXO1TjnuVW7t9DDz10zTP4/e9/36Ea9zyvvfaaxo8fby5ulpOTo3//9383P+f5u7Gb3T+ev/CtX79eLpdLK1asMI85+RwSavqRpqYmTZgwQVu2bLlumZkzZ+rs2bPm61//9V9trGHPtm/fPi1ZskS//vWvVVlZqba2Nj3yyCNqamoyy6xcuVI///nPtXPnTu3bt09nzpzRE0884WCte45buX+StGjRopBn8KWXXnKoxj3PXXfdpfXr16umpkaHDx/WH/3RH+nxxx/XBx98IInn72Zudv8knr9wHDp0SH//93+v8ePHhxx39Dk00C9JMt54442QYwsWLDAef/xxR+rTG507d86QZOzbt88wDMO4dOmSMWDAAGPnzp1mmQ8//NCQZFRXVztVzR7r9++fYRjG9OnTjeXLlztXqV5o8ODBxj/8wz/w/N2m4P0zDJ6/cFy+fNnIyMgwKisrQ+6b088hLTUI8c4772jIkCG677779Bd/8Rf6/PPPna5Sj9XQ0CBJSkpKkiTV1NSora1Nubm5ZpnRo0dr+PDhqq6udqSOPdnv37+gbdu2KSUlRePGjVNxcbGuXLniRPV6vI6ODm3fvl1NTU3Kycnh+QvT79+/IJ6/W7NkyRLNmjUr5HmTnP/3YL/Z0BI3N3PmTD3xxBMaOXKkTp48qdWrV+tb3/qWqqurFRUV5XT1epRAIKAVK1boD//wDzVu3DhJkt/vV0xMjBITE0PKpqamyu/3O1DLnqur+ydJ3/nOd3T33XcrLS1N7733nlatWqUTJ05o165dDta2Z3n//feVk5Oj5uZmDRw4UG+88YbGjh2ro0eP8vzdguvdP4nn71Zt375dtbW1OnTo0DWfOf3vQUINTHPmzDHfZ2Zmavz48Ro1apTeeecdzZgxw8Ga9TxLlizRsWPH9Ktf/crpqvRK17t/ixcvNt9nZmZq6NChmjFjhk6ePKlRo0bZXc0e6b777tPRo0fV0NCgf/u3f9OCBQu0b98+p6vVa1zv/o0dO5bn7xacOnVKy5cvV2Vlpbxer9PVuQbdT7iue+65RykpKfrNb37jdFV6lKVLl6q8vFxvv/227rrrLvO4z+dTa2urLl26FFK+vr5ePp/P5lr2XNe7f13Jzs6WJJ7Br4mJidG9996riRMnqqSkRBMmTNCrr77K83eLrnf/usLzd62amhqdO3dO999/v6KjoxUdHa19+/Zp06ZNio6OVmpqqqPPIaEG13X69Gl9/vnnGjp0qNNV6REMw9DSpUv1xhtvaO/evRo5cmTI5xMnTtSAAQNUVVVlHjtx4oTq6upC+uz7q5vdv64cPXpUkngGbyAQCKilpYXn7zYF719XeP6uNWPGDL3//vs6evSo+Zo0aZLmzZtnvnfyOaT7qR/54osvQv6L4+OPP9bRo0eVlJSkpKQkvfDCC5o9e7Z8Pp9OnjypZ555Rvfee6/y8vIcrHXPsWTJEpWVlelnP/uZBg0aZPYPJyQkKDY2VgkJCSosLFRRUZGSkpIUHx+vZcuWKScnR1OnTnW49s672f07efKkysrK9Oijjyo5OVnvvfeeVq5cqQcffPCaKaP9VXFxsb71rW9p+PDhunz5ssrKyvTOO+/oF7/4Bc/fLbjR/eP5uzWDBg0KGQcnSXFxcUpOTjaPO/ocWj6/Cj3G22+/bUi65rVgwQLjypUrxiOPPGLceeedxoABA4y7777bWLRokeH3+52udo/R1b2TZPzjP/6jWebLL780/vIv/9IYPHiwcccddxjf/va3jbNnzzpX6R7kZvevrq7OePDBB42kpCTD4/EY9957r/H0008bDQ0Nzla8B/nzP/9z4+677zZiYmKMO++805gxY4bxy1/+0vyc5+/GbnT/eP5u3+9PhXfyOXQZhmFYH50AAACsxZgaAADQJxBqAABAn0CoAQAAfQKhBgAA9AmEGgAA0CcQagAAQJ9AqAEAAH0CoQYAAPQJhBoAANAnEGoAAECfQKgBAAB9AqEGAAD0Cf8/eqWvSW7ZwTMAAAAASUVORK5CYII=", 31 | "text/plain": [ 32 | "
" 33 | ] 34 | }, 35 | "metadata": {}, 36 | "output_type": "display_data" 37 | } 38 | ], 39 | "source": [ 40 | "def factorial2(n):\n", 41 | " from scipy.special import factorial2\n", 42 | "\n", 43 | " return factorial2(np.maximum(n, 0))\n", 44 | "\n", 45 | "\n", 46 | "def F(n, x):\n", 47 | " from scipy.special import gammainc, gamma\n", 48 | "\n", 49 | " xs = np.where(x == 0, 1.0, x)\n", 50 | " Fn = gammainc(n + 0.5, xs) * gamma(n + 0.5) / (2 * xs ** (n + 0.5))\n", 51 | " F0 = 1 / (2 * n + 1)\n", 52 | " return np.where(x == 0, F0, Fn)\n", 53 | "\n", 54 | "\n", 55 | "def Fc(n, x):\n", 56 | " from scipy.special import gammaincc, gamma\n", 57 | "\n", 58 | " xs = np.where(x == 0, 1.0, x)\n", 59 | " Fn = (gamma(n + 0.5) - gamma(n + 0.5) * gammaincc(n + 0.5, xs)) / (\n", 60 | " 2 * xs ** (n + 0.5)\n", 61 | " )\n", 62 | " F0 = 1 / (2 * n + 1)\n", 63 | " return np.where(x == 0, F0, Fn)\n", 64 | "\n", 65 | "\n", 66 | "def Fasymp(n, x):\n", 67 | " return (\n", 68 | " factorial2(np.maximum(2 * n - 1, 1))\n", 69 | " / 2 ** (n + 1)\n", 70 | " * np.sqrt(np.pi / x ** (2 * n + 1))\n", 71 | " )\n", 72 | "\n", 73 | "\n", 74 | "def F0(x):\n", 75 | " from scipy.special import erf\n", 76 | "\n", 77 | " xs = np.where(x == 0, 1.0, x)\n", 78 | " return np.where(x == 0, 1.0, np.sqrt(np.pi / xs) * erf(np.sqrt(xs)) / 2)\n", 79 | "\n", 80 | "\n", 81 | "n = 30\n", 82 | "x = np.linspace(12, 40)\n", 83 | "# plt.plot(x, F(n, x))\n", 84 | "# plt.plot(x, Fc(n, x))\n", 85 | "plt.plot(x, np.abs(Fasymp(n, x) - F(n, x)))\n", 86 | "# plt.plot(x, F0(x))" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [] 95 | } 96 | ], 97 | "metadata": { 98 | "kernelspec": { 99 | "display_name": "mess", 100 | "language": "python", 101 | "name": "python3" 102 | }, 103 | "language_info": { 104 | "codemirror_mode": { 105 | "name": "ipython", 106 | "version": 3 107 | }, 108 | "file_extension": ".py", 109 | "mimetype": "text/x-python", 110 | "name": "python", 111 | "nbconvert_exporter": "python", 112 | "pygments_lexer": "ipython3", 113 | "version": "3.9.18" 114 | } 115 | }, 116 | "nbformat": 4, 117 | "nbformat_minor": 2 118 | } 119 | -------------------------------------------------------------------------------- /docs/gto_integrals.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import sympy as sp\n", 11 | "import IPython.display as ipd\n", 12 | "import py3Dmol\n", 13 | "\n", 14 | "from mess.structure import Structure\n", 15 | "from mess.basis import basisset\n", 16 | "from mess.mesh import uniform_mesh, molecular_orbitals\n", 17 | "from mess.plot import plot_volume, plot_isosurfaces, plot_molecule\n", 18 | "from itertools import product" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "$$\n", 26 | "\\def\\br{\\mathbf{r}}\n", 27 | "\\def\\bA{\\mathbf{A}}\n", 28 | "\\def\\normp{N}\n", 29 | "\\def\\pgto{p_{\\text{GTO}}}\n", 30 | "\\def\\op{\\mathcal{Q}}\n", 31 | "$$\n", 32 | "\n", 33 | "# Integrals over Gaussian Type Orbitals (GTO)\n", 34 | "\n", 35 | "We have a molecule comprising $M$ atoms (Mnemonic: $M$ for $\\text{ato}M$),\n", 36 | "each with atomic number $Z_m$ and position $\\br_m \\in \\mathbb{R}^3$\n", 37 | "\\begin{equation}\n", 38 | "\\mathcal{A} = \\{ (Z_m, \\br_m) \\}_{m=1}^{M}\n", 39 | "\\end{equation}\n", 40 | "Each atom in the molecule contributes both negatively charged electrons that are bound \n", 41 | "to a positively charged nucleus. Both the number of electrons and the charge on the\n", 42 | "nucleus can be found from the periodic table. For a concrete example, a single Oxygen\n", 43 | "atom has atomic number 8 and this implies a charge of $Z = 8$ located at the position\n", 44 | "$\\br$. This charge on the nucleus is balanced by eight negatively charged electrons.\n", 45 | "\n", 46 | "\n", 47 | "We will represent the electron density of the molecule from the eigenstates:\n", 48 | "\\begin{equation}\n", 49 | "\\rho(\\br) = \\sum_{m=1}^M f_m | \\psi_m(\\br) |^2\n", 50 | "\\end{equation}\n", 51 | "Atom $m$ has $I(Z_m)$ basis functions (determined by $Z_m$ and the basis set in use),\n", 52 | "with coefficients $C = [c_{mi}]_{i=1}^{I(Z_m)}$:\n", 53 | "\\begin{equation}\n", 54 | "\\psi_m(\\br) = \\sum_{i=1}^{I(Z_m)} c_{mi}~ \\phi_{Z_m,i}(\\br - \\br_m)\n", 55 | "\\end{equation}\n", 56 | "We can think of $C$ as a jagged array of coefficients, in terms of which the overall system is described by\n", 57 | "\\begin{equation}\n", 58 | "\\psi(\\br) = \\sum_{m=1}^M \\sum_{i=1}^{I(Z_m)} c_{mi} \\phi_{Z_m,i}(\\br - \\br_m),\n", 59 | "\\end{equation}\n", 60 | "and we can think of the task of DFT as being to determine the values $C$ which minimize\n", 61 | "total energy. That is: the task of DFT is to determine $\\psi$, and here $\\psi$ is \n", 62 | "specified by $C$, so the task is to determine $C$, for example by iteratively solving\n", 63 | "Kohn-Sham equations using the self-consistent field (SCF) method.\n", 64 | "\n", 65 | "\n", 67 | "\n", 68 | "Each function $\\phi_{Z,i}$ is defined by the basis set as a fixed linear combination \n", 69 | "of primitive functions $p(\\br; \\nu)$, known as a \"contraction\":\n", 70 | "\\begin{equation}\n", 71 | "\\forall_{i=1}^{I(Z)}: \\phi_{Z,i}(\\br) = \\sum_{k=1}^{K(Z,i)} d_{Z,i,k} ~ p_{f(Z,i,k)}(\\br; \\nu_{Z,i,k})\n", 72 | "\\end{equation}\n", 73 | "The values of $d,\\nu$ and the function type $f$ are read from standard tables of basis sets.\n", 74 | "In general the lengths $K(z,i)$ of the contractions have been pre-optimised alongside the contraction coefficients $d_{z,i,k}$ and the primitive exponents $\\alpha_\\mu$ to best approximate atomic orbitals. A range of Gaussian basis sets covering the periodic table with different compute-time-vs-accuracy tradeoffs are available from the [basis set exchange](https://www.basissetexchange.org/). We use their python API in this project to provide programmatic access to these basis sets.\n", 75 | "As these functions $\\phi$ are often taken as approximations of atomic orbitals, \n", 76 | "this is referred to as the linear combination of atomic orbitals (LCAO) method in the literature.\n", 77 | "\n", 78 | "In this, we will consider only the function type \"GTO\", for Gaussian-type Orbital,\n", 79 | "where the parameter packet $\\nu$ comprises three non-negative integers $(l,m,n)$, \n", 80 | "called \"angular momentum quantum numbers\",\n", 81 | "and a real value $\\alpha$, the \"exponent\".\n", 82 | "The primitive is\n", 83 | "\\begin{equation}\n", 84 | "\\pgto(\\br; \\nu) = \\pgto(\\br; l,m,n,\\alpha) = \\normp(l,m,n,\\alpha) ~ x^l y^m z^n \\exp(-\\alpha\\|\\br\\|^2)\n", 85 | "\\end{equation}\n", 86 | "where the normalizing constant $\\normp(l,m,n,\\alpha)$ is a function of $\\alpha$ and $(l, m, n)$\n", 87 | "and is chosen so that the function integrates to 1, as derived later in this notebook.\n", 88 | "In the following, as we are dealing only with GTO, we will simply write $p$ instead of $\\pgto$.\n", 89 | "\n", 90 | "Noting that a linear combination (LC) of LCs is just another LC, we will often \n", 91 | "contract the two sets of coefficients, writing\n", 92 | "\\begin{equation}\n", 93 | "\\psi_m(\\br) = \\sum_{i=1}^{I(Z_m)} c_{mi} \\sum_{k=1}^{K(z,i)} d_{Z_m,i,k} ~ p(\\br - \\br_m; \\nu_{Z_m,i,k})\n", 94 | "\\end{equation}\n", 95 | "as\n", 96 | "\\begin{equation}\n", 97 | "\\psi_m(\\br) = \\sum_{j=1}^{I_m} a_{mj}~ p(\\br - \\br_m; \\nu_{mj})\n", 98 | "\\end{equation}\n", 99 | "where $I_m$ is just the total number of different $\\nu$ values in the basis set for atom $Z_m$.\n", 100 | "For most basis sets, this will mean $I_m = \\sum_i K(Z_m,i)$.\n", 101 | "\n", 102 | "### This notebook\n", 103 | "This notebook derives several key computations involved in DFT.\n", 104 | "In particular, DFT involves integrals of the form\n", 105 | "\\begin{equation}\n", 106 | "\\langle\\psi|\\op|\\psi\\rangle = \\int \\int \\psi(\\br_1) ~\\op \\psi(\\br_2) ~g(\\br_1, \\br_2) d \\br_1 d \\br_2\n", 107 | "\\end{equation}\n", 108 | "where $\\op\\psi$ is a transformation of function $\\psi$ by an operator $\\op$, such as gradient $\\nabla$, and the function $g$ is some function of a pair of points, e.g. $g(\\mathbf r, \\mathbf s) = \\|\\mathbf r - \\mathbf s\\|^{-1}$.\n", 109 | "Such integrals can be written in terms of the per-atom (or \"per-center\") functions $\\psi_m$\n", 110 | "\\begin{equation}\n", 111 | "\\langle\\psi|\\op|\\psi\\rangle = \\sum_{m_1=1}^M \\sum_{m_2=1}^M \\int \\int \\psi_{m_1}(\\br_1) ~\\op \\psi_{m_2}(\\br_2) ~g(\\br_1, \\br_2) d \\br_1 d \\br_2\n", 112 | "\\end{equation}\n", 113 | "and then in terms of the primitives $p$, where again $\\op p$ is the operator $\\op$ applied to $p(\\br;...)$.\n", 114 | "\\begin{equation}\n", 115 | "\\langle\\psi|Q|\\psi\\rangle = \\sum_{m_1=1}^M \\sum_{m_2=1}^M \n", 116 | " \\sum_{j_1=1}^{I_{m_1}} \\sum_{j_2=1}^{I_{m_2}} \n", 117 | " a_{m_1,j_1} a_{m_2,j_2} \n", 118 | " \\int \\int p(\\br_1 - \\br_{m_1}; \\nu_{m_1,j_1}) ~\\op p(\\br_2 - \\br_{m_2}; \\nu_{m_2,j_2}) ~ g(\\br_1, \\br_2) d \\br_1 d \\br_2\n", 119 | "\\end{equation}\n", 120 | "all of which depends on the ability to compute integrals of the form\n", 121 | "\\begin{equation}\n", 122 | "\\def\\bB{\\mathbf B}\n", 123 | "\\int \\int p(\\br_1 - \\bA; \\nu_A) ~\\op p(\\br_2 - \\bB; \\nu_B) ~g(\\br_1, \\br_2) ~d \\br_1 d \\br_2\n", 124 | "\\end{equation}\n", 125 | "We will describe how to do so in the following, but first let's briefly look at the basis functions for a simple example." 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "## Basis set for a single-atom molecule\n", 133 | "\n", 134 | "Before deriving the integrals, we use the `basisset` function to build the atomic orbitals of a single oxygen atom $M=1$, $Z_1 = 8$, illustraiting typical values for the numbers of orbitals and consequent numbers of primitives.\n", 135 | "The `Basis` object built by `basisset` consists of a list of `Orbital` objects which are defined by a set of `coefficients` and corresponding `Primitive` objects.\n", 136 | "\n", 137 | "A Basis set should define the following quantities:\n", 138 | "\n", 139 | "For each atomic number $Z$, the number of basis functions, or \"orbitals\" is $I(Z)$.\n", 140 | "\n", 141 | "For each basis function $i \\in \\{1..I(Z)\\}$, the number of primitives is $K(Z, i)$." 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "Z = 8\n", 151 | "basisname = \"sto-3g\"\n", 152 | "oxygen = Structure(atomic_number=np.array([Z]), position=np.zeros(3))\n", 153 | "basis = basisset(oxygen, basisname)\n", 154 | "print(f\"The {basisname} basis for {Z=} has I(Z)={basis.num_orbitals}\")" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "The number of atomic orbitals follows from the [Aufbau principle](https://en.wikipedia.org/wiki/Aufbau_principle). Applying this to a single Oxygen atom predicts the electron configuration as $1s^2 2s^2 2p^4$ which keeping in mind that the $2p$ atomic orbitals will consist of three orbitals we arrive at five atomic orbitals for a single oxygen atom.\n", 162 | "\n", 163 | "This rule applies when using a minimal basis set such as `\"sto-3g\"` which uses a\n", 164 | "contraction of 3 Gaussians to approximate each atomic orbital. From this we expect a total of 15 primitives in the basis set for a single oxygen atom:" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "print(f\"{basisname} for {Z=} has {len(basis.orbitals)} basis functions/orbitals\")\n", 174 | "for i, o in enumerate(basis.orbitals):\n", 175 | " terms = (\n", 176 | " f\"{coef:.3f} * r**{p.lmn} * exp(-{p.alpha:.2f} r**2)\"\n", 177 | " for coef, p in zip(o.coefficients, o.primitives)\n", 178 | " )\n", 179 | " print(f\"phi_({Z},{i})=\", \" + \".join(terms))\n", 180 | "\n", 181 | "# Now count unique prims\n", 182 | "prims = set()\n", 183 | "for o in basis.orbitals:\n", 184 | " for p in o.primitives:\n", 185 | " prims = set.union(prims, {(*p.lmn.tolist(), float(p.alpha))})\n", 186 | "print(f\"Found {len(prims)} unique primitives for {Z=} in {basisname}\")" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": {}, 192 | "source": [ 193 | "For a basis set such as `6-31+G` we have a different number of primitives in each basis function: " 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "basisname = \"6-31+G\"\n", 203 | "oxygen = Structure(atomic_number=np.array([Z]), position=np.zeros(3))\n", 204 | "basis = basisset(oxygen, basisname)\n", 205 | "print(\n", 206 | " f\"The {basisname} basis for {Z=} has \"\n", 207 | " f\"I(Z)={basis.num_orbitals} basis functions/orbitals\"\n", 208 | ")\n", 209 | "\n", 210 | "for i, o in enumerate(basis.orbitals):\n", 211 | " terms = (\n", 212 | " f\"{coef:.3f} * r**{p.lmn} * exp(-{p.alpha:.2f} r**2)\"\n", 213 | " for coef, p in zip(o.coefficients, o.primitives)\n", 214 | " )\n", 215 | " print(f\"phi_({Z},{i})=\", \" + \".join(terms))\n", 216 | "\n", 217 | "# Now count unique prims\n", 218 | "prims = set()\n", 219 | "for o in basis.orbitals:\n", 220 | " for p in o.primitives:\n", 221 | " prims = set.union(prims, {(*p.lmn.tolist(), float(p.alpha))})\n", 222 | "print(f\"Found {len(prims)} unique primitives for {Z=} in {basisname}\")" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": {}, 228 | "source": [ 229 | "We can plot these atomic orbitals by first evaluating the basis functions on a regular grid and using py3DMol to render isosurfaces. The $1s$ and $2s$ orbitals have spherical symmetry so we plot the $2px$ orbital below." 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "# Evaluate the molecular orbitals on a mesh -> [num_mesh_points, num_orbitals]\n", 239 | "mesh = uniform_mesh(n=32, b=3.0)\n", 240 | "orbitals = molecular_orbitals(basis, mesh.points)\n", 241 | "orbitals = np.asarray(orbitals)\n", 242 | "\n", 243 | "# Mapping between orbital label -> index in basis set\n", 244 | "# TODO: what is this mapping for 6-31+G?\n", 245 | "# orbital_idx = dict(zip([\"1s\", \"2s\", \"2px\", \"2py\", \"2pz\"], range(5)))\n", 246 | "\n", 247 | "view = py3Dmol.view()\n", 248 | "plot_molecule(view, oxygen)\n", 249 | "plot_volume(view, orbitals[:, 4], mesh.axes)\n", 250 | "view.zoomTo()" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "# Plot again, this time as isosurfaces\n", 260 | "view = py3Dmol.view()\n", 261 | "plot_molecule(view, oxygen)\n", 262 | "plot_isosurfaces(view, orbitals[:, 4], mesh.axes, percentiles=[98, 75])\n", 263 | "view.zoomTo()" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "$$\n", 271 | "\\def\\br{\\mathbf{r}}\n", 272 | "\\def\\bA{\\mathbf{A}}\n", 273 | "\\def\\normp{N}\n", 274 | "\\def\\pgto{p_{\\text{GTO}}}\n", 275 | "\\def\\op{\\mathcal{Q}}\n", 276 | "\\def\\ldp{h} % Read this \"1-D primitive\"\n", 277 | "\\def\\ldpi{H} % read this \"1-D primitive's integral\"\n", 278 | "$$\n", 279 | "\n", 280 | "## Integrals over a single primitive\n", 281 | "\n", 282 | "The GTO primitive defined as\n", 283 | "\\begin{equation}\n", 284 | "p(\\br; l,m,n,\\alpha) = \\normp(l,m,n,\\alpha) ~ x^l y^m z^n e^{-\\alpha\\|\\br\\|^2}\n", 285 | "\\end{equation}\n", 286 | "contains a normalization $\\normp$, which should be chosen so that $p^2$ integrates to 1:\n", 287 | "\\begin{equation}\n", 288 | "\\int p(\\br; l,m,n,\\alpha)^2 d\\br = 1\n", 289 | "\\end{equation}\n", 290 | "That is, \n", 291 | "\\begin{equation}\n", 292 | "\\normp(l,m,n,\\alpha)^2 \\int x^{2l} y^{2m} z^{2n} e^{-2\\alpha\\|\\br\\|^2} d\\br = 1\n", 293 | "\\end{equation}\n", 294 | "We note that $x^{2l} y^{2m} z^{2n} e^{-2\\alpha\\|\\br\\|^2}$ can be written\n", 295 | "\\begin{equation}\n", 296 | "x^{2l} e^{-2\\alpha x^2} \\cdot\n", 297 | "y^{2m} e^{-2\\alpha y^2} \\cdot\n", 298 | "z^{2n} e^{-2\\alpha z^2}\n", 299 | "= \n", 300 | "\\ldp(x;2l,2\\alpha) ~ \\ldp(y;2m,2\\alpha) ~ \\ldp(z;2n,2\\alpha)\n", 301 | "\\end{equation}\n", 302 | "for\n", 303 | "\\begin{equation}\n", 304 | "\\tag{defh}\n", 305 | "\\ldp(t; 2k, \\eta) = t^{2k} e^{-\\eta t^2}\n", 306 | "\\end{equation}\n", 307 | "So the integral\n", 308 | "\\begin{equation}\n", 309 | "\\int x^{2l} y^{2m} z^{2n} \\exp(-2\\alpha\\|\\br\\|^2) d\\br = \\int \\ldp(x;2l,2\\alpha) ~ \\ldp(y;2m,2\\alpha) ~ \\ldp(z;2n,2\\alpha) dx dy dz\n", 310 | "\\end{equation}\n", 311 | "is the product of three independent integrals\n", 312 | "\\begin{equation}\n", 313 | "\\int \\ldp(x;2l,2\\alpha) dx \\cdot \n", 314 | "\\int \\ldp(y;2m,2\\alpha) dy \\cdot\n", 315 | "\\int \\ldp(z;2n,2\\alpha) dz\n", 316 | "\\end{equation}\n", 317 | "Where each is of the form\n", 318 | "\\begin{equation}\n", 319 | "\\tag{defH}\n", 320 | "H(2k, \\eta) = \\int t^{2k} e^{-\\eta t^2} dt\n", 321 | "\\end{equation}\n", 322 | "These one dimensional integrals have a known analytic solution [see equation 42 of gaussian integral](https://mathworld.wolfram.com/GaussianIntegral.html), but we will lean on [SymPy](https://docs.sympy.org/latest/index.html) to help us derive a formula for normalising our primitive Gaussian functions." 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": {}, 328 | "source": [ 329 | "First, let's define the integral:" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": null, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "sp.init_printing(use_unicode=True)\n", 339 | "\n", 340 | "eta = sp.Symbol(\"eta\", positive=True, real=True)\n", 341 | "k = sp.Symbol(\"k\", integer=True, nonnegative=True)\n", 342 | "t = sp.Symbol(\"t\", real=True)\n", 343 | "H = sp.Function(\"H\")\n", 344 | "\n", 345 | "H_na = sp.Integral(t ** (2 * k) * sp.exp(-eta * t**2), (t, -sp.oo, sp.oo))\n", 346 | "ipd.display(sp.Eq(H(2 * k, eta), H_na))" 347 | ] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "metadata": {}, 352 | "source": [ 353 | "Confirm that the equation displayed is the same as that in (defH).\n", 354 | "\n", 355 | "And then solve it:" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "ipd.display(sp.Eq(H(2 * k, eta), H_na.doit().simplify()))" 365 | ] 366 | }, 367 | { 368 | "cell_type": "markdown", 369 | "metadata": {}, 370 | "source": [ 371 | "And, recalling that we are looking for the normalization $\\normp(l,m,n,\\alpha)$, which is given by\n", 372 | "\\begin{equation}\n", 373 | "\\normp(l,m,n,\\alpha) = \\left(\\int x^{2l} y^{2m} z^{2n} \\exp(-2\\alpha\\|\\br\\|^2) d\\br\\right)^{-\\frac12} = \\biggl(H(2l,2\\alpha)~H(2m,2\\alpha)~H(2n,2\\alpha)\\biggr)^{-\\frac12}\n", 374 | "$$" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "N = sp.Function(\"N\")\n", 384 | "a = sp.Symbol(\"alpha\", positive=True, real=True)\n", 385 | "l, m, n = sp.symbols(\"l m n\", integer=True, nonnegative=True)\n", 386 | "\n", 387 | "H_l = H_na.subs({k: l, eta: 2 * a})\n", 388 | "H_m = H_na.subs({k: m, eta: 2 * a})\n", 389 | "H_n = H_na.subs({k: n, eta: 2 * a})\n", 390 | "\n", 391 | "N_val = H_l * H_m * H_n\n", 392 | "ipd.display(sp.Eq(N(l, m, n, a) ** (-2), N_val))\n", 393 | "ipd.display(sp.Eq(N(l, m, n, a) ** (-2), N_val.doit().simplify()))" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": {}, 399 | "source": [ 400 | "We might finally ask - what is $H$ for odd arguments? And we can see: the integrand is odd, so it is zero.\n", 401 | "And just in case, let's check: " 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "metadata": {}, 408 | "outputs": [], 409 | "source": [ 410 | "H_odd = sp.Integral(t ** (2 * k + 1) * sp.exp(-eta * t**2), (t, -sp.oo, sp.oo))\n", 411 | "ipd.display(sp.Eq(H_odd, H_odd.doit()))" 412 | ] 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "metadata": {}, 417 | "source": [ 418 | "We can use SymPy to generate a python function that uses the [gamma](https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.gamma.html) function from scipy" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": null, 424 | "metadata": {}, 425 | "outputs": [], 426 | "source": [ 427 | "f = sp.lambdify((l, m, n, a), N_val.doit().simplify(), modules=\"scipy\")\n", 428 | "help(f)" 429 | ] 430 | }, 431 | { 432 | "cell_type": "markdown", 433 | "metadata": {}, 434 | "source": [ 435 | "$$\n", 436 | "\\def\\br{\\mathbf{r}}\n", 437 | "\\def\\bA{\\mathbf{A}}\n", 438 | "$$\n", 439 | "\n", 440 | "## Integrals over sets of basis functions\n", 441 | "\n", 442 | "The above visualisation uses the formulas defined earlier to numerically evaluate\n", 443 | "each atomic orbital from a linear combintation of primitive Gaussians. Introducing\n", 444 | "the basis set allows us to replace the problem of solving a system partial differential equations (e.g. the Kohn-Sham equations) with an algebraic system of equations that can be solved using standard matrix eigenvalue techniques.\n", 445 | "The individual elements of the matrices in this system are single integrals over pairs of\n", 446 | "primitives:\n", 447 | "\\begin{equation}\n", 448 | "M_{AB} = \\int p(\\br - \\bA; \\nu_A) ~\\op p(\\br - \\bB; \\nu_B) d\\br.\n", 449 | "\\end{equation}\n", 450 | "The operator $\\op$ in its most general definition is\n", 451 | "a function that transforms the $\\psi(\\br)$ function to represent a physical observable.\n", 452 | "That sounds fancy but is just a complicated way of saying we are breaking\n", 453 | "down something we can measure (e.g. the energy, dipole moment, etc) into components\n", 454 | "that we can evaluate through these integrals. \n", 455 | "\n", 456 | "## The overlap integral\n", 457 | "\n", 458 | "The simplest operator $\\op$ is simply an \n", 459 | "identity mapping, $\\op p(\\br) := p(\\br)$, used in conjunction with the trivial $g(\\br_1, \\br_2) = 1$, that we name the _overlap integral_:\n", 460 | "\\begin{equation}\n", 461 | "S_{\\mu\\nu} = \\int p(\\br - \\br_\\mu; l_\\mu, m_\\mu, n_\\mu, \\alpha_\\mu) ~ p(\\br - \\br_\\nu; l_\\nu, m_\\nu, n_\\nu, \\alpha_\\nu) ~ d\\br.\n", 462 | "\\end{equation}\n", 463 | "\n", 464 | "Gaussian basis functions of the form above have a few convenient properties that help\n", 465 | "make it easier to evaluate integrals of the form in equation (5). To show this\n", 466 | "we start by expanding the overlap integral into its full three-dimensional form:\n", 467 | "\\begin{equation}\n", 468 | "\\tilde{S}_{\\mu \\nu} = \\iiint p_\\mu(\\br) p_\\nu(\\br) dx dy dz,\n", 469 | "\\end{equation}\n", 470 | "and observing that a primitive $p_\\mu(\\br) = p(\\br - \\br_\\mu; l_\\mu, m_\\mu, n_\\mu, \\alpha_\\mu)$ can be written as a product of Cartesian components:\n", 471 | "\\begin{equation}\n", 472 | "p_\\mu(\\br) = N_\\mu\n", 473 | "h(x - x_\\mu; l_\\mu, \\alpha_\\mu)\n", 474 | "h(y - y_\\mu; m_\\mu, \\alpha_\\mu)\n", 475 | "h(z - z_\\mu; n_\\mu, \\alpha_\\mu)\n", 476 | "\\end{equation}\n", 477 | "where we are using $h$ from (defh), but not assuming that $l,m,n$ are even.\n", 478 | "We directly see that we can separate the three-dimensional integral in equation (9) as a\n", 479 | "product of three one-dimensional integrals:\n", 480 | "\\begin{equation}\n", 481 | "\\tilde{S}_{\\mu \\nu} = N_\\mu N_\\nu \\tilde{S}_{\\mu \\nu}^{(x)} \\tilde{S}_{\\mu \\nu}^{(y)} \\tilde{S}_{\\mu \\nu}^{(z)}\n", 482 | "\\end{equation}\n", 483 | "where we introduce the one-dimensional overlap as\n", 484 | "\\begin{equation}\n", 485 | "\\def\\twelve{\n", 486 | "\\tilde{S}_{\\mu \\nu}^{(x)} = \\int_{-\\infty}^\\infty \n", 487 | "\\left((t - x_\\mu)^{l_\\mu} e^{-\\alpha_\\mu (t - x_\\mu)^2} \\right)\n", 488 | "\\left((t - x_\\nu)^{l_\\nu} e^{-\\alpha_\\nu (t - x_\\nu)^2} \\right) dt\n", 489 | "}\n", 490 | "\\twelve,\n", 491 | "\\end{equation}\n", 492 | "and use the same definition for the y and z components. We note that if $\\br_\\mu=\\br_\\nu$, we have the normalization integral from (defH), because we can gather the $t - x_\\mu$ terms\n", 493 | "\\begin{equation}\n", 494 | "\\tilde{S}_{\\mu \\nu}^{(x)} = \\int_{-\\infty}^\\infty \n", 495 | "(t - x_\\mu)^{l_\\mu + l\\nu} e^{-(\\alpha_\\mu + \\alpha_\\nu) (t - x_\\mu)^2} dt\n", 496 | "\\end{equation}\n", 497 | "and substitite $u = t - x_\\mu$, $du = dt$ to get\n", 498 | "\\begin{equation}\n", 499 | "\\tilde{S}_{\\mu \\nu}^{(x)} = \\int_{-\\infty}^\\infty \n", 500 | "u^{l_\\mu + l\\nu} e^{-(\\alpha_\\mu + \\alpha_\\nu) u^2} du = H(l_\\mu + l_\\nu, \\alpha_\\mu + \\alpha_\\nu)\n", 501 | "\\end{equation}\n", 502 | "and in particular\n", 503 | "\\begin{equation}\n", 504 | "\\tilde{S}_{\\mu \\mu}^{(x)} = \\int_{-\\infty}^\\infty \n", 505 | "u^{2l_\\mu} e^{-2\\alpha_\\mu u^2} du = H(2l_\\mu, 2\\alpha_\\mu)\n", 506 | "\\end{equation}\n", 507 | "as before." 508 | ] 509 | }, 510 | { 511 | "cell_type": "markdown", 512 | "metadata": {}, 513 | "source": [ 514 | "### The hard case: non-coincident centres\n", 515 | "Let us return our attention back to the general overlap elements in equation (12), \n", 516 | "repeated here: \n", 517 | "$$\n", 518 | "(12) = \\twelve\n", 519 | "$$\n", 520 | "but let's replace $\\nu,\\mu$ subscripts with more visually distinguishable symbols:\n", 521 | "\\begin{equation}\n", 522 | "\\tilde{S}((A,\\alpha,i), (B,\\beta,j)) = \\int_{-\\infty}^\\infty \n", 523 | "\\left((t - A)^i e^{-\\alpha (t - A)^2} \\right)\n", 524 | "\\left((t - B)^j e^{-\\beta (t - B)^2} \\right) dt\n", 525 | "\\end{equation}\n", 526 | "and grouping the polynomial terms and the Gaussian terms we get\n", 527 | "\\begin{equation}\n", 528 | "\\tilde{S}((A,\\alpha,i), (B,\\beta,j)) = \\int_{-\\infty}^\\infty \n", 529 | "\\biggl((t - A)^i (t - B)^j \\biggr)\n", 530 | "\\biggl( e^{-\\alpha (t - A)^2} e^{-\\beta (t - B)^2} \\biggr)\n", 531 | "dt\\\\\n", 532 | "\\end{equation}\n", 533 | "To make progress we start by rewriting the product of two\n", 534 | "Gaussian functions as a single Gaussian:\n", 535 | "\\begin{equation}\n", 536 | "e^{-\\alpha (t - A)^2} e^{-\\beta (t - B)^2} = \n", 537 | "G ~ e^{-\\gamma (t - C)^2}, \\\\\n", 538 | "\\gamma = \\alpha + \\beta \\\\\n", 539 | "C = \\frac{\\alpha A + \\beta B}{\\gamma}\\\\\n", 540 | "G = \\exp\\left(-\\frac{\\alpha \\beta}{\\gamma} (A - B)^2\\right) \\\\\n", 541 | "\\end{equation}\n", 542 | "Next we rewrite the product of the leading polynomial terms in terms of the Gaussian product center $C$:\n", 543 | "\\begin{equation}\n", 544 | "\\def\\CA{C_A}\n", 545 | "\\def\\CB{C_B}\n", 546 | "\\def\\tc{t_C}\n", 547 | "(t - A)^i (t - B)^j = (t - C + C - A)^i (t - C + C - B)^j = (\\tc + \\CA)^i (\\tc + \\CB)^j,\n", 548 | "\\end{equation}\n", 549 | "\n", 550 | "where we introduce\n", 551 | "\\begin{equation}\n", 552 | "\\tc = t - C\\\\\n", 553 | "\\CA = C - A\\\\\n", 554 | "\\CB = C - B.\n", 555 | "\\end{equation}\n", 556 | "There isn't a universal term that covers the product of binomial expansions that we have arrived at in equation (17),\n", 557 | "and using the [binomial theorem](https://en.wikipedia.org/wiki/Binomial_theorem) we have:\n", 558 | "\\begin{equation}\n", 559 | "(\\tc + \\CA)^i = \\sum_{k=0}^i \\binom{i}{k} \\tc^{i-k} \\CA^k \\\\ \n", 560 | "(\\tc + \\CB)^j = \\sum_{l=0}^j \\binom{j}{l} \\tc^{j-l} \\CB^l \n", 561 | "\\end{equation}\n", 562 | "\n", 563 | "\\begin{equation}\n", 564 | "(\\tc + \\CA)^i (\\tc + \\CB)^j = \\sum_{k=0}^i \\sum_{l=0}^j \\binom{i}{k} \\binom{j}{l} \\tc^{i-k} \\tc^{j-l} \\CA^k \\CB^l \n", 565 | "\\end{equation}\n", 566 | "we can use SymPy to help us take the product of these expansions for some representative values of $i$ and $j$" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": null, 572 | "metadata": {}, 573 | "outputs": [], 574 | "source": [ 575 | "x, a, b = sp.symbols(\"t, a, b\", real=True)\n", 576 | "i, j = sp.symbols(\"i, j\", nonnegative=True, integers=True)\n", 577 | "poly_terms = (x + a) ** i * (x + b) ** j\n", 578 | "poly_terms" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": null, 584 | "metadata": {}, 585 | "outputs": [], 586 | "source": [ 587 | "# output monomial terms with ascending orders x^n, n=0, 1, 2, ...\n", 588 | "output = r\"\\begin{align*}\"\n", 589 | "for ival, jval in product(range(3), range(3)):\n", 590 | " p = sp.simplify(sp.expand(poly_terms.subs(i, ival).subs(j, jval)))\n", 591 | " output += f\"(a + t)^{ival} (b + t)^{jval} &= \"\n", 592 | " output += \" + \".join([\n", 593 | " sp.latex(p.coeff(x, n) * x**n) for n in range(ival + jval + 1)\n", 594 | " ])\n", 595 | " output += r\"\\\\\"\n", 596 | "\n", 597 | "output += r\"\\end{align*}\"\n", 598 | "ipd.Latex(output)" 599 | ] 600 | }, 601 | { 602 | "cell_type": "markdown", 603 | "metadata": {}, 604 | "source": [ 605 | "By using the symmetry properties of the binomial expansion we can rewrite the terms in the series of Equation (20) as:\n", 606 | "\\begin{equation}\n", 607 | "\\binom{i}{k} \\binom{j}{l} \\tc^{i-k} \\tc^{j-l} \\CA^k \\CB^l = \\binom{i}{k} \\binom{j}{l} \\tc^k \\tc^l \\CA^{i-k} \\CB^{j-l}.\n", 608 | "\\end{equation}\n", 609 | "Recognising that we want to collect terms of same power $\\tc^s$ we introduce the constraint that $k+l=s$ allows us to\n", 610 | "write the terms as:\n", 611 | "\\begin{equation}\n", 612 | "\\binom{i}{s-l} \\binom{j}{l} \\tc^s \\CA^{i-(s-l)} \\CB^{j-l}.\n", 613 | "\\end{equation}\n", 614 | "Replacing the summation variable $l=0, 1,\\cdots, j$ with $t$ and incorporating the constraint we arrive at:\n", 615 | "\\begin{equation}\n", 616 | "(\\tc + \\CA)^i (\\tc + \\CB)^j = \n", 617 | "\\sum_{s = 0}^{i+j} \\tc^s \n", 618 | "\\sum_{\\substack{t=0 \\\\ s - i \\le t \\le j}}^s \n", 619 | "\\binom{i}{s-t} \\binom{j}{t} \\CA^{i-(s-t)} \\CB^{j-t},\n", 620 | "\\end{equation}\n", 621 | "Finally we can write this series as:\n", 622 | "\\begin{equation}\n", 623 | "(\\tc + \\CA)^i (\\tc + \\CB)^j = \\sum_{s = 0}^{i+j} B(i, j, \\CA, \\CB, s) \\tc^s,\n", 624 | "\\end{equation}\n", 625 | "where we introduce $B(i, j, \\CA, \\CB, s)$ as the coefficient of $x^s$ in the expansion:\n", 626 | "\\begin{equation}\n", 627 | "B(i, j, \\CA, \\CB, s) = \\sum_{\\substack{t=0 \\\\ s - i \\le t \\le j}}^s \\binom{i}{s-t} \\binom{j}{t} \\CA^{i - (s - t)} \\CB^{j-t}\n", 628 | "\\end{equation}\n", 629 | "An evaluation strategy for this variable-sized loop is explored in an accompanying [notebook](./binom_factor_table.ipynb).\n", 630 | "\n", 631 | "Combining this result with the one we derived earlier for the product of two\n", 632 | "Gaussian functions Equation 16 gives us:\n", 633 | "\\begin{equation}\n", 634 | "\\tilde{S}_{\\mu \\nu}^{(x)} = \\int_{-\\infty}^\\infty \n", 635 | "\\left((t - X_\\mu)^{l_\\mu} e^{-\\alpha_\\mu (t - X_\\mu)^2} \\right)\n", 636 | "\\left((t - X_\\nu)^{l_\\nu} e^{-\\alpha_\\nu (t - X_\\nu)^2} \\right) dt = \\\\\n", 637 | "\\int_{-\\infty}^\\infty \n", 638 | " e^{-\\alpha_\\mu \\alpha_\\nu (X_\\mu - X_\\nu)^2 / \\gamma} e^{-\\gamma \\tc^2}\n", 639 | "\\sum_{s=0}^{l_\\mu+ l_\\nu} B(l_\\mu, l_\\nu, C_\\mu, C_\\nu, s) \\tc^s dt\n", 640 | "\\end{equation}\n", 641 | "Swapping the order of integration and the summation we get\n", 642 | "\\begin{equation}\n", 643 | " e^{-\\alpha_\\mu \\alpha_\\nu (X_\\mu - X_\\nu)^2 / \\gamma} \n", 644 | "\\sum_{s=0}^{l_\\mu+ l_\\nu} B(l_\\mu, l_\\nu, C_\\mu, C_\\nu, s) \n", 645 | "\\int_{-\\infty}^\\infty \\tc^s e^{-\\gamma \\tc^2} dt\n", 646 | "\\end{equation}\n", 647 | "and see that we need to evaluate\n", 648 | "integrals of the form that we already defined above at (Hdef)\n", 649 | "\\begin{equation}\n", 650 | "H(s,\\eta) = \\int_{-\\infty}^\\infty t^s e^{-\\eta t^2} dt\n", 651 | "\\end{equation}" 652 | ] 653 | }, 654 | { 655 | "cell_type": "markdown", 656 | "metadata": {}, 657 | "source": [ 658 | "And we recall that for $s$ odd, this is zero, so we evaluate it only for $s$ even." 659 | ] 660 | }, 661 | { 662 | "cell_type": "markdown", 663 | "metadata": {}, 664 | "source": [ 665 | "This result can also be written in terms of the [double factorial](https://en.wikipedia.org/wiki/Double_factorial#Additional_identities) which gives us two possible computation strategies for this integral:\n", 666 | "\\begin{equation}\n", 667 | "H(2s, \\eta) = \\int_{-\\infty}^{\\infty} t^{2s} e^{-\\eta t^2} dt \\\\\n", 668 | "= \\eta^{-s - \\frac{1}{2}} \\Gamma\\left(s + \\frac{1}{2}\\right) \\\\\n", 669 | "= \\frac{(2s-1)!!}{(2\\eta)^s} \\sqrt{\\frac{\\pi}{\\eta}}.\n", 670 | "\\end{equation}\n", 671 | "The last form agrees with Equation (3.15) derived by [Fermann and Valeev](http://arxiv.org/abs/2007.12057).\n", 672 | "\n", 673 | "Using the function $H(2s, \\eta$ allows us to write the one-dimensional overlap integral as:\n", 674 | "\\begin{equation}\n", 675 | "\\tilde{S}_{\\mu \\nu}^{(x)} = \n", 676 | "e^{-\\alpha_\\mu \\alpha_\\nu (X_A - X_B)^2 / \\gamma}\n", 677 | "\\sum_{s=0}^{\\lfloor(l_\\mu + l_\\nu)/2 \\rfloor} B(l_\\mu, l_\\nu, CA_x, CB_x, 2s)\\;H(2s, \\gamma)\n", 678 | "\\end{equation}\n", 679 | "substituting this back into Equation (11) gives us the overlap of two primitive Gaussians:\n", 680 | "\\begin{equation}\n", 681 | "\\tilde{S}_{\\mu \\nu} = \\iiint p_\\mu(\\br) p_\\nu(\\br) dx dy dz \\\\\n", 682 | "= N_\\mu N_\\nu \\exp\\left(-\\frac{\\alpha_\\mu \\alpha_\\nu |\\mathbf{A}-\\mathbf{B}|^2}{\\alpha_\\mu + \\alpha_\\nu}\\right) \\times\\\\\n", 683 | "\\sum_{s=0}^{\\lfloor(l_\\mu + l_\\nu)/2 \\rfloor} B(l_\\mu, l_\\nu, CA_x, CB_x, 2s)\\;G(\\gamma, 2s) \\times \\\\\n", 684 | "\\sum_{s=0}^{\\lfloor(m_\\mu + m_\\nu)/2 \\rfloor} B(m_\\mu, m_\\nu, CA_y, CB_y, 2s)\\;G(\\gamma, 2s) \\times \\\\\n", 685 | "\\sum_{s=0}^{\\lfloor(n_\\mu + n_\\nu)/2 \\rfloor} B(n_\\mu, n_\\nu, CA_z, CB_z, 2s)\\;G(\\gamma, 2s)\n", 686 | "\\end{equation}" 687 | ] 688 | } 689 | ], 690 | "metadata": { 691 | "kernelspec": { 692 | "display_name": "hgei", 693 | "language": "python", 694 | "name": "python3" 695 | }, 696 | "language_info": { 697 | "codemirror_mode": { 698 | "name": "ipython", 699 | "version": 3 700 | }, 701 | "file_extension": ".py", 702 | "mimetype": "text/x-python", 703 | "name": "python", 704 | "nbconvert_exporter": "python", 705 | "pygments_lexer": "ipython3", 706 | "version": "3.9.18" 707 | }, 708 | "orig_nbformat": 4 709 | }, 710 | "nbformat": 4, 711 | "nbformat_minor": 2 712 | } 713 | -------------------------------------------------------------------------------- /docs/intro.md: -------------------------------------------------------------------------------- 1 | # MESS: Modern Electronic Structure Simulations 2 | 3 | :::{note} 4 | This project is a work in progress. 5 | Software features marked with a 🌈 indicate that 6 | this functionality is still in the planning phase. 7 | ::: 8 | 9 | Welcome to MESS, a python framework for exploring the exciting interface 10 | between machine learning, electronic structure, and algorithms. 11 | Our main focus is building a fully hackable implementation of 12 | [Density Functional Theory (DFT)](https://en.wikipedia.org/wiki/Density_functional_theory). 13 | 14 | Target applications include: 15 | * high-throughput DFT simulations for efficient large-scale molecular dataset generation 16 | * exploration of hybrid machine learned/electronic structure simulations 17 | 18 | Within DFT there are many different approximations for handling quantum-mechanical 19 | interactions. These are collectively known as exchange-correlation functionals and 20 | MESS provides a few common implementations: 21 | * LDA, PBE, PBE0, B3LYP 22 | * dispersion corrections 🌈 23 | * machine-learned exchange-correlation functionals 🌈 24 | 25 | This project is built on 26 | [JAX](https://jax.readthedocs.io/en/latest/) to support rapid 27 | prototyping of high-performance simulations. MESS benefits from many features of JAX: 28 | * Hardware Acceleration 29 | * Automatic Differentiation 30 | * Program transformations such as JIT compilation and automatic vectorisation. 31 | * Flexible floating point numeric formats 32 | 33 | 34 | ## Minimal Example 35 | 36 | Calculate the ground state energy of a single water molecule using the 6-31g basis set 37 | and the [local density approximation (LDA)](https://en.wikipedia.org/wiki/Local-density_approximation): 38 | ```python 39 | from mess import Hamiltonian, basisset, minimise, molecule 40 | 41 | mol = molecule("water") 42 | basis = basisset(mol, basis_name="6-31g") 43 | H = Hamiltonian(basis, xc_method="lda") 44 | E, C, sol = minimise(H) 45 | E 46 | ``` 47 | 48 | ## Next Steps 49 | 50 | ::::{grid} 51 | :gutter: 2 52 | 53 | :::{grid-item-card} {material-regular}`map;2em` Learn 54 | :link: tour 55 | :link-type: doc 56 | ::: 57 | 58 | :::{grid-item-card} {material-regular}`construction;2em` Build 59 | :link: api 60 | :link-type: doc 61 | 62 | ::: 63 | 64 | :::: 65 | -------------------------------------------------------------------------------- /docs/misc.md: -------------------------------------------------------------------------------- 1 | # Miscellaneous Mathematics 2 | -------------------------------------------------------------------------------- /docs/prologue.md: -------------------------------------------------------------------------------- 1 | # Prologue 2 | 3 | Regardless of the domain every new project appears to face a choice: 4 | 5 | - build on top of a pre-existing framework 6 | - start fresh and build a new framework 7 | 8 | We think this is actually a false dichotomy but is still useful to demonstrate the 9 | motivation for the MESS project. The curious reader may be wondering: There is no 10 | shortage of software packages covering a broad range of electronic structure 11 | simulations. Arguably there are too many packages that create a fractured ecosystem of 12 | solutions that are not easily interopable with one another. This is not a new problem 13 | and the work of MESS is likely making this problem worse not better! 14 | 15 | The fundamental objective of MESS is to reimagine what a hybrid electronic structure and 16 | machine learning framework might look like on future hardware architectures. This 17 | viewpoint is what we think puts the **modern** in MESS. Beyond this our hope is to begin 18 | to demystify the inner workings of electronic structure methods to help accelerate the 19 | work of the molecular machine learning community. 20 | 21 | To even begin to climb towards these lofty goals we start with a few constraints that we 22 | think might help accelerate the ascent. These constraints are shamelessly borrowed from 23 | what we think are some of the factors that have helped accelerate recent progress in 24 | machine learning across multiple domains. 25 | 26 | - hardware acceleration 27 | - automatic differentiation 28 | - high-level interpreted programming languages 29 | 30 | On the last point, the reader who cut their teeth on electronic structure packages 31 | implemented in Fortran programs that can trace their origins to the 1980s (or even 32 | earlier!) may have the intuition that using an interpreted language would come with an 33 | unacceptable performance penalty. This is a proven approach first introduced by MATLAB 34 | {cite}`moler2020history` to accelerate numerical linear algebra before spreading 35 | throughout computational sciences. NumPy {cite}`harris2020array` developed as an 36 | open-source alternative to this array-centred programming framework and is now 37 | completely ubiqutious throughout scientific computing. 38 | 39 | ```{bibliography} 40 | :style: unsrt 41 | ``` 42 | -------------------------------------------------------------------------------- /docs/quirks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Quirks and Features to be Aware of \n", 8 | "\n", 9 | "This project uses \n", 10 | "[JAX](https://jax.readthedocs.io/en/latest/) to support both rapid\n", 11 | "prototyping and efficient hardware accelerated simulations. \n", 12 | "\n", 13 | "MESS makes use of a number of JAX ecosystem projects:\n", 14 | "* [equinox](https://docs.kidger.site/equinox/)\n", 15 | "* [optax](https://optax.readthedocs.io/en/latest/)\n", 16 | "* [optimistix](https://docs.kidger.site/optimistix/)\n", 17 | "\n", 18 | "These projects were initially designed with neural networks in mind but we have found\n", 19 | "their abstractions are easily adapted and reused in quantum chemistry simulations. \n", 20 | "\n", 21 | "## JIT Compilation Cache\n", 22 | "\n", 23 | "We use\n", 24 | "[Just In Time compilation](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)\n", 25 | "from JAX throughout MESS. JAX supports a persistent compilation cache which can allow\n", 26 | "MESS to reuse previously compiled programs to save a bit of simulation startup time.\n", 27 | "By default the cache will be stored in `~/.cache/mess` and can be customised\n", 28 | "by setting the `MESS_CACHE_DIR` environment variable.\n", 29 | "\n", 30 | "\n", 31 | "## Floating Point Precision\n", 32 | "\n", 33 | "Standalone [JAX uses single-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision) floating point numbers by default as opposed to the widely\n", 34 | "used double-precision format used in scientific simulations. MESS will default to using\n", 35 | "double-precision and this behaviour can be customised in a number of ways: \n", 36 | "\n", 37 | "* setting the environment variable `MESS_ENABLE_FP64=0` before starting the main \n", 38 | "python process\n", 39 | "* using the context manager [jax.experimental.disable_x64](https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.disable_x64.html)\n", 40 | "\n", 41 | "These methods can be used to investigate mixed-precision electronic structure \n", 42 | "simulations." 43 | ] 44 | } 45 | ], 46 | "metadata": { 47 | "kernelspec": { 48 | "display_name": "jax", 49 | "language": "python", 50 | "name": "python3" 51 | }, 52 | "language_info": { 53 | "codemirror_mode": { 54 | "name": "ipython", 55 | "version": 3 56 | }, 57 | "file_extension": ".py", 58 | "mimetype": "text/x-python", 59 | "name": "python", 60 | "nbconvert_exporter": "python", 61 | "pygments_lexer": "ipython3", 62 | "version": "3.10.13" 63 | } 64 | }, 65 | "nbformat": 4, 66 | "nbformat_minor": 2 67 | } 68 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | --- 2 | --- 3 | 4 | 5 | @article{pople1999nobel, 6 | title={{Nobel Lecture: Quantum Chemical Models}}, 7 | author={Pople, John A}, 8 | journal={Reviews of Modern Physics}, 9 | year={1999}, 10 | } 11 | 12 | @article{kohn1999nobel, 13 | title={{Nobel Lecture: Electronic Structure of Matter—Wave Functions and Density Functionals}}, 14 | author={Kohn, Walter}, 15 | journal={Reviews of Modern Physics}, 16 | year={1999}, 17 | } 18 | 19 | @article{nature2023editorial, 20 | title={For chemists, the AI revolution has yet to happen}, 21 | volume={617}, 22 | rights={2023 Springer Nature Limited}, 23 | DOI={10.1038/d41586-023-01612-x}, 24 | number={7961}, 25 | journal={Nature}, 26 | year={2023}, 27 | month=may, 28 | pages={438-438}, 29 | url= {https://doi.org/10.1038/d41586-023-01612-x} 30 | } 31 | 32 | @article{moler2020history, 33 | title={A history of MATLAB}, 34 | author={Moler, Cleve and Little, Jack}, 35 | journal={Proceedings of the ACM on Programming Languages}, 36 | volume={4}, 37 | number={HOPL}, 38 | pages={1--67}, 39 | year={2020}, 40 | publisher={ACM New York, NY, USA}, 41 | url = {https://doi.org/10.1145/3386331} 42 | } 43 | 44 | @article{harris2020array, 45 | title = {Array programming with {NumPy}}, 46 | author = {Charles R. Harris and K. Jarrod Millman and St{\'{e}}fan J. 47 | van der Walt and Ralf Gommers and Pauli Virtanen and David 48 | Cournapeau and Eric Wieser and Julian Taylor and Sebastian 49 | Berg and Nathaniel J. Smith and Robert Kern and Matti Picus 50 | and Stephan Hoyer and Marten H. van Kerkwijk and Matthew 51 | Brett and Allan Haldane and Jaime Fern{\'{a}}ndez del 52 | R{\'{i}}o and Mark Wiebe and Pearu Peterson and Pierre 53 | G{\'{e}}rard-Marchant and Kevin Sheppard and Tyler Reddy and 54 | Warren Weckesser and Hameer Abbasi and Christoph Gohlke and 55 | Travis E. Oliphant}, 56 | year = {2020}, 57 | month = sep, 58 | journal = {Nature}, 59 | volume = {585}, 60 | number = {7825}, 61 | pages = {357--362}, 62 | publisher = {Springer Science and Business Media {LLC}}, 63 | url = {https://doi.org/10.1038/s41586-020-2649-2} 64 | } 65 | -------------------------------------------------------------------------------- /mess/__init__.py: -------------------------------------------------------------------------------- 1 | """MESS: Modern Electronic Structure Simulations 2 | 3 | The mess simulation environment can be customised by setting the following environment 4 | variables: 5 | * ``MESS_ENABLE_FP64``: enables float64 precision. Defaults to ``True``. 6 | * ``MESS_CACHE_DIR``: JIT compilation cache location. Defaults to ``~/.cache/mess`` 7 | """ 8 | 9 | from importlib.metadata import PackageNotFoundError, version 10 | 11 | try: 12 | __version__ = version("mess") 13 | except PackageNotFoundError: # Package is not installed 14 | __version__ = "dev" 15 | 16 | from mess.basis import basisset 17 | from mess.hamiltonian import Hamiltonian, minimise 18 | from mess.structure import molecule 19 | 20 | __all__ = ["molecule", "Hamiltonian", "minimise", "basisset"] 21 | 22 | 23 | def parse_bool(value: str) -> bool: 24 | value = value.lower() 25 | 26 | if value in ("y", "yes", "t", "true", "on", "1"): 27 | return True 28 | elif value in ("n", "no", "f", "false", "off", "0"): 29 | return False 30 | else: 31 | raise ValueError("Failed to parse {value} as a boolean") 32 | 33 | 34 | def _setup_env(): 35 | import os 36 | import os.path as osp 37 | 38 | from jax import config 39 | from jax.experimental.compilation_cache import compilation_cache as cc 40 | 41 | enable_fp64 = parse_bool(os.environ.get("MESS_ENABLE_FP64", "True")) 42 | config.update("jax_enable_x64", enable_fp64) 43 | 44 | cache_dir = str(os.environ.get("MESS_CACHE_DIR", osp.expanduser("~/.cache/mess"))) 45 | cc.set_cache_dir(cache_dir) 46 | 47 | 48 | _setup_env() 49 | -------------------------------------------------------------------------------- /mess/autograd_integrals.py: -------------------------------------------------------------------------------- 1 | """automatic differentiation of atomic orbital integrals""" 2 | 3 | from functools import partial 4 | from typing import Callable 5 | 6 | import equinox as eqx 7 | import jax.numpy as jnp 8 | from jax import grad, jit, tree, vmap 9 | from jax.ops import segment_sum 10 | 11 | from mess.basis import Basis 12 | from mess.integrals import _kinetic_primitives, _nuclear_primitives, _overlap_primitives 13 | from mess.primitive import Primitive 14 | from mess.types import Float3, Float3xNxN 15 | 16 | 17 | def grad_integrate_primitives(a: Primitive, b: Primitive, operator: Callable) -> Float3: 18 | def f(center): 19 | return operator(eqx.combine(center, arest), b) 20 | 21 | acenter, arest = eqx.partition(a, lambda x: id(x) == id(a.center)) 22 | return grad(f)(acenter).center 23 | 24 | 25 | def grad_overlap_primitives(a: Primitive, b: Primitive) -> Float3: 26 | return grad_integrate_primitives(a, b, _overlap_primitives) 27 | 28 | 29 | def grad_kinetic_primitives(a: Primitive, b: Primitive) -> Float3: 30 | return grad_integrate_primitives(a, b, _kinetic_primitives) 31 | 32 | 33 | def grad_nuclear_primitives(a: Primitive, b: Primitive, c: Float3) -> Float3: 34 | def n(lhs, rhs): 35 | return _nuclear_primitives(lhs, rhs, c) 36 | 37 | return grad_integrate_primitives(a, b, n) 38 | 39 | 40 | @partial(jit, static_argnums=1) 41 | def grad_integrate_basis(basis: Basis, operator: Callable) -> Float3xNxN: 42 | def take_primitives(indices): 43 | p = tree.map(lambda x: jnp.take(x, indices, axis=0), basis.primitives) 44 | c = jnp.take(basis.coefficients, indices) 45 | return p, c 46 | 47 | ii, jj = jnp.meshgrid(*[jnp.arange(basis.num_primitives)] * 2, indexing="ij") 48 | lhs, cl = take_primitives(ii.reshape(-1)) 49 | rhs, cr = take_primitives(jj.reshape(-1)) 50 | out = vmap(operator)(lhs, rhs) 51 | 52 | out = cl * cr * out.T 53 | out = out.reshape(3, basis.num_primitives, basis.num_primitives) 54 | out = jnp.rollaxis(out, 1) 55 | out = segment_sum(out, basis.orbital_index, num_segments=basis.num_orbitals) 56 | out = jnp.rollaxis(out, -1) 57 | out = segment_sum(out, basis.orbital_index, num_segments=basis.num_orbitals) 58 | return jnp.rollaxis(out, -1) 59 | 60 | 61 | def grad_overlap_basis(basis: Basis) -> Float3xNxN: 62 | return grad_integrate_basis(basis, grad_overlap_primitives) 63 | 64 | 65 | def grad_kinetic_basis(basis: Basis) -> Float3xNxN: 66 | return grad_integrate_basis(basis, grad_kinetic_primitives) 67 | 68 | 69 | def grad_nuclear_basis(basis: Basis) -> Float3xNxN: 70 | def n(atomic_number, position): 71 | def op(pi, pj): 72 | return atomic_number * grad_nuclear_primitives(pi, pj, position) 73 | 74 | return grad_integrate_basis(basis, op) 75 | 76 | out = vmap(n)(basis.structure.atomic_number, basis.structure.position) 77 | return jnp.sum(out, axis=0) 78 | -------------------------------------------------------------------------------- /mess/basis.py: -------------------------------------------------------------------------------- 1 | """basis sets of Gaussian type orbitals""" 2 | 3 | from typing import Tuple 4 | 5 | import equinox as eqx 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import pandas as pd 9 | from jax import jit 10 | from jax.ops import segment_sum 11 | 12 | from mess.orbital import Orbital, batch_orbitals 13 | from mess.primitive import Primitive 14 | from mess.structure import Structure 15 | from mess.types import FloatN, FloatNx3, FloatNxM, FloatNxN, IntN, default_fptype 16 | 17 | 18 | class Basis(eqx.Module): 19 | orbitals: Tuple[Orbital] 20 | structure: Structure 21 | primitives: Primitive 22 | coefficients: FloatN 23 | orbital_index: IntN 24 | basis_name: str = eqx.field(static=True) 25 | max_L: int = eqx.field(static=True) 26 | 27 | @property 28 | def num_orbitals(self) -> int: 29 | return len(self.orbitals) 30 | 31 | @property 32 | def num_primitives(self) -> int: 33 | return sum(ao.num_primitives for ao in self.orbitals) 34 | 35 | @property 36 | def occupancy(self) -> FloatN: 37 | # Assumes uncharged systems in restricted Kohn-Sham 38 | occ = jnp.full(self.num_orbitals, 2.0) 39 | mask = occ.cumsum() > self.structure.num_electrons 40 | occ = jnp.where(mask, 0.0, occ) 41 | return occ 42 | 43 | def to_dataframe(self) -> pd.DataFrame: 44 | def fixer(x): 45 | # simple workaround for storing 2d array as a pandas column 46 | return [x[i, :] for i in range(x.shape[0])] 47 | 48 | df = pd.DataFrame() 49 | df["orbital"] = self.orbital_index 50 | df["coefficient"] = self.coefficients 51 | df["norm"] = self.primitives.norm 52 | df["center"] = fixer(self.primitives.center) 53 | df["lmn"] = fixer(self.primitives.lmn) 54 | df["alpha"] = self.primitives.alpha 55 | df.index.name = "primitive" 56 | return df 57 | 58 | def density_matrix(self, C: FloatNxN) -> FloatNxN: 59 | """Evaluate the density matrix from the molecular orbital coefficients 60 | 61 | Args: 62 | C (FloatNxN): the molecular orbital coefficients 63 | 64 | Returns: 65 | FloatNxN: the density matrix. 66 | """ 67 | return jnp.einsum("k,ik,jk->ij", self.occupancy, C, C) 68 | 69 | @jit 70 | def __call__(self, pos: FloatNx3) -> FloatNxM: 71 | prim = self.coefficients[jnp.newaxis, :] * self.primitives(pos) 72 | orb = segment_sum(prim.T, self.orbital_index, num_segments=self.num_orbitals) 73 | return orb.T 74 | 75 | def __repr__(self) -> str: 76 | return repr(self.to_dataframe()) 77 | 78 | def _repr_html_(self) -> str | None: 79 | df = self.to_dataframe() 80 | return df._repr_html_() 81 | 82 | def __hash__(self) -> int: 83 | return hash(self.primitives) 84 | 85 | 86 | def basisset(structure: Structure, basis_name: str = "sto-3g") -> Basis: 87 | """Factory function for building a basis set for a structure. 88 | 89 | Args: 90 | structure (Structure): Used to define the basis function parameters. 91 | basis_name (str, optional): Basis set name to look up on the 92 | `basis set exchange `_. 93 | Defaults to ``sto-3g``. 94 | 95 | Returns: 96 | Basis constructed from inputs 97 | """ 98 | from basis_set_exchange import get_basis 99 | from basis_set_exchange.sort import sort_basis 100 | 101 | # fmt: off 102 | LMN_MAP = { 103 | 0: [(0, 0, 0)], 104 | 1: [(1, 0, 0), (0, 1, 0), (0, 0, 1)], 105 | 2: [(2, 0, 0), (1, 1, 0), (1, 0, 1), (0, 2, 0), (0, 1, 1), (0, 0, 2)], 106 | 3: [(3, 0, 0), (2, 1, 0), (2, 0, 1), (1, 2, 0), (1, 1, 1), 107 | (1, 0, 2), (0, 3, 0), (0, 2, 1), (0, 1, 2), (0, 0, 3)], 108 | } 109 | # fmt: on 110 | 111 | bse_basis = get_basis( 112 | basis_name, 113 | elements=structure.atomic_number.tolist(), 114 | uncontract_spdf=True, 115 | uncontract_general=True, 116 | ) 117 | bse_basis = sort_basis(bse_basis)["elements"] 118 | orbitals = [] 119 | 120 | for a in range(structure.num_atoms): 121 | center = structure.position[a, :] 122 | shells = bse_basis[str(structure.atomic_number[a])]["electron_shells"] 123 | 124 | for s in shells: 125 | for lmn in LMN_MAP[s["angular_momentum"][0]]: 126 | ao = Orbital.from_bse( 127 | center=center, 128 | alphas=np.array(s["exponents"], dtype=default_fptype()), 129 | lmn=np.array(lmn, dtype=np.int32), 130 | coefficients=np.array(s["coefficients"], dtype=default_fptype()), 131 | ) 132 | orbitals.append(ao) 133 | 134 | primitives, coefficients, orbital_index = batch_orbitals(orbitals) 135 | 136 | return Basis( 137 | orbitals=orbitals, 138 | structure=structure, 139 | primitives=primitives, 140 | coefficients=coefficients, 141 | orbital_index=orbital_index, 142 | basis_name=basis_name, 143 | max_L=int(np.max(primitives.lmn)), 144 | ) 145 | 146 | 147 | def basis_iter(basis: Basis): 148 | from jax import tree 149 | 150 | from mess.special import triu_indices 151 | 152 | def take_primitives(indices): 153 | p = tree.map(lambda x: jnp.take(x, indices, axis=0), basis.primitives) 154 | c = jnp.take(basis.coefficients, indices) 155 | return p, c 156 | 157 | ii, jj = triu_indices(basis.num_primitives) 158 | lhs, cl = take_primitives(ii.reshape(-1)) 159 | rhs, cr = take_primitives(jj.reshape(-1)) 160 | return (ii, cl, lhs), (jj, cr, rhs) 161 | -------------------------------------------------------------------------------- /mess/binom_factor_table.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED from notebooks/binom_factor_table.ipynb 2 | # fmt: off 3 | # flake8: noqa 4 | # isort: skip_file 5 | from numpy import array 6 | binom_factor_table = ((array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 7 | 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 8 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 9 | 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 10 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 11 | 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12 | 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 13 | 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]), array([1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 14 | 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 1, 1, 1, 1, 1, 2, 2, 15 | 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 16 | 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 17 | 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 18 | 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 19 | 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 20 | 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]), array([0, 0, 1, 0, 1, 2, 0, 1, 2, 3, 0, 0, 1, 1, 0, 1, 1, 2, 2, 0, 1, 1, 21 | 2, 2, 3, 3, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 1, 0, 1, 1, 2, 2, 0, 1, 22 | 1, 2, 2, 2, 3, 3, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 0, 1, 1, 2, 2, 23 | 2, 3, 3, 3, 4, 4, 4, 0, 1, 2, 0, 1, 1, 2, 2, 3, 3, 0, 1, 1, 2, 2, 24 | 2, 3, 3, 3, 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 0, 1, 1, 25 | 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 0, 1, 2, 3, 0, 1, 1, 2, 2, 3, 3, 26 | 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 27 | 3, 3, 4, 4, 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4]), array([ 4, 21, 4, 11, 21, 4, 2, 11, 21, 4, 15, 3, 4, 15, 1, 3, 21, 28 | 4, 15, 17, 1, 11, 3, 21, 4, 15, 19, 2, 17, 1, 11, 3, 21, 4, 29 | 15, 22, 15, 10, 3, 22, 4, 15, 23, 1, 10, 3, 21, 22, 4, 15, 9, 30 | 17, 23, 1, 10, 11, 3, 21, 22, 4, 15, 12, 9, 19, 2, 17, 23, 1, 31 | 10, 11, 3, 21, 22, 13, 22, 15, 14, 10, 13, 3, 22, 4, 15, 5, 14, 32 | 23, 1, 10, 13, 3, 21, 22, 4, 15, 16, 5, 9, 14, 17, 23, 1, 10, 33 | 11, 13, 3, 21, 22, 18, 12, 16, 5, 9, 19, 2, 14, 17, 23, 1, 10, 34 | 11, 13, 7, 13, 22, 15, 24, 7, 14, 10, 13, 3, 22, 4, 15, 20, 5, 35 | 24, 7, 14, 23, 1, 10, 13, 3, 21, 22, 6, 16, 20, 5, 9, 24, 7, 36 | 14, 17, 23, 1, 10, 11, 13, 8, 6, 18, 12, 16, 20, 5, 9, 19, 24, 37 | 2, 7, 14, 17, 23])), array([ 1., 1., 2., 1., 3., 3., 1., 4., 6., 4., 1., 1., 1., 38 | 1., 1., 2., 1., 2., 1., 1., 3., 1., 3., 3., 3., 1., 39 | 1., 1., 4., 6., 4., 4., 6., 4., 1., 1., 2., 1., 2., 40 | 1., 1., 2., 1., 2., 2., 4., 1., 1., 2., 2., 1., 2., 41 | 3., 6., 3., 1., 6., 3., 1., 3., 2., 1., 4., 2., 1., 42 | 8., 6., 12., 4., 4., 8., 6., 1., 1., 3., 3., 1., 3., 43 | 1., 3., 3., 1., 3., 1., 2., 3., 3., 6., 1., 6., 1., 44 | 3., 2., 3., 1., 3., 3., 3., 3., 9., 9., 9., 1., 1., 45 | 9., 3., 3., 1., 3., 4., 6., 12., 3., 1., 4., 12., 18., 46 | 18., 12., 4., 1., 1., 4., 6., 4., 1., 1., 4., 6., 4., 47 | 4., 6., 1., 4., 1., 4., 2., 1., 8., 6., 4., 12., 4., 48 | 8., 1., 6., 1., 4., 3., 12., 6., 3., 1., 12., 4., 18., 49 | 12., 18., 1., 4., 1., 4., 4., 6., 16., 6., 24., 24., 4., 50 | 4., 1., 1., 16., 16., 36.])) 51 | -------------------------------------------------------------------------------- /mess/hamiltonian.py: -------------------------------------------------------------------------------- 1 | """Many electron Hamiltonian with Density Functional Theory or Hartree-Fock.""" 2 | 3 | from typing import Literal, Optional, Tuple, get_args 4 | from functools import partial 5 | 6 | import equinox as eqx 7 | import jax 8 | import jax.numpy as jnp 9 | import jax.numpy.linalg as jnl 10 | import optimistix as optx 11 | from jaxtyping import Array, ScalarLike 12 | 13 | from mess.basis import Basis 14 | from mess.integrals import eri_basis, kinetic_basis, nuclear_basis, overlap_basis 15 | from mess.interop import to_pyscf 16 | from mess.mesh import Mesh, density, density_and_grad, xcmesh_from_pyscf 17 | from mess.orthnorm import symmetric 18 | from mess.structure import nuclear_energy 19 | from mess.types import FloatNxN, OrthNormTransform 20 | from mess.xcfunctional import ( 21 | gga_correlation_lyp, 22 | gga_correlation_pbe, 23 | gga_exchange_b88, 24 | gga_exchange_pbe, 25 | lda_correlation_vwn, 26 | lda_exchange, 27 | ) 28 | 29 | xcstr = Literal["lda", "pbe", "pbe0", "b3lyp", "hfx"] 30 | IntegralBackend = Literal["mess", "pyscf_cart", "pyscf_sph"] 31 | 32 | 33 | class OneElectron(eqx.Module): 34 | overlap: FloatNxN 35 | kinetic: FloatNxN 36 | nuclear: FloatNxN 37 | 38 | def __init__(self, basis: Basis, backend: IntegralBackend = "mess"): 39 | """_summary_ 40 | 41 | Args: 42 | basis (Basis): _description_ 43 | backend (IntegralBackend, optional): _description_. Defaults to "mess". 44 | 45 | Raises: 46 | ValueError: _description_ 47 | ValueError: _description_ 48 | 49 | Returns: 50 | _type_: _description_ 51 | """ 52 | if backend == "mess": 53 | self.overlap = overlap_basis(basis) 54 | self.kinetic = kinetic_basis(basis) 55 | self.nuclear = nuclear_basis(basis).sum(axis=0) 56 | elif backend.startswith("pyscf_"): 57 | mol = to_pyscf(basis.structure, basis.basis_name) 58 | kind = backend.split("_")[1] 59 | S = jnp.array(mol.intor(f"int1e_ovlp_{kind}")) 60 | N = 1 / jnp.sqrt(jnp.diagonal(S)) 61 | self.overlap = N[:, jnp.newaxis] * N[jnp.newaxis, :] * S 62 | self.kinetic = jnp.array(mol.intor(f"int1e_kin_{kind}")) 63 | self.nuclear = jnp.array(mol.intor(f"int1e_nuc_{kind}")) 64 | 65 | 66 | class TwoElectron(eqx.Module): 67 | eri: Array 68 | 69 | def __init__(self, basis: Basis, backend: str = "mess"): 70 | """ 71 | 72 | Args: 73 | basis (Basis): the basis set used to build the electron repulsion integrals 74 | backend (str, optional): Integral backend used. Defaults to "mess". 75 | """ 76 | super().__init__() 77 | if backend == "mess": 78 | self.eri = eri_basis(basis) 79 | elif backend.startswith("pyscf_"): 80 | mol = to_pyscf(basis.structure, basis.basis_name) 81 | kind = backend.split("_")[1] 82 | self.eri = jnp.array(mol.intor(f"int2e_{kind}", aosym="s1")) 83 | 84 | def coloumb(self, P: FloatNxN) -> FloatNxN: 85 | """Build the Coloumb matrix (classical electrostatic) from the density matrix. 86 | 87 | Args: 88 | P (FloatNxN): the density matrix 89 | 90 | Returns: 91 | FloatNxN: Coloumb matrix 92 | """ 93 | return jnp.einsum("kl,ijkl->ij", P, self.eri) 94 | 95 | def exchange(self, P: FloatNxN) -> FloatNxN: 96 | """Build the quantum-mechanical exchange matrix from the density matrix 97 | 98 | Args: 99 | P (FloatNxN): the density matrix 100 | 101 | Returns: 102 | FloatNxN: Exchange matrix 103 | """ 104 | return jnp.einsum("ij,ikjl->kl", P, self.eri) 105 | 106 | 107 | class HartreeFockExchange(eqx.Module): 108 | two_electron: TwoElectron 109 | 110 | def __init__(self, two_electron: TwoElectron): 111 | self.two_electron = two_electron 112 | 113 | def __call__(self, P: FloatNxN) -> ScalarLike: 114 | K = self.two_electron.exchange(P) 115 | return -0.25 * jnp.sum(P * K) 116 | 117 | 118 | class LDA(eqx.Module): 119 | basis: Basis 120 | mesh: Mesh 121 | 122 | def __init__(self, basis: Basis): 123 | self.basis = basis 124 | self.mesh = xcmesh_from_pyscf(basis.structure) 125 | 126 | def __call__(self, P: FloatNxN) -> ScalarLike: 127 | rho = density(self.basis, self.mesh, P) 128 | eps_xc = lda_exchange(rho) + lda_correlation_vwn(rho) 129 | E_xc = jnp.einsum("i,i,i", self.mesh.weights, rho, eps_xc) 130 | return E_xc 131 | 132 | 133 | class PBE(eqx.Module): 134 | basis: Basis 135 | mesh: Mesh 136 | 137 | def __init__(self, basis: Basis): 138 | self.basis = basis 139 | self.mesh = xcmesh_from_pyscf(basis.structure) 140 | 141 | def __call__(self, P: FloatNxN) -> ScalarLike: 142 | rho, grad_rho = density_and_grad(self.basis, self.mesh, P) 143 | eps_xc = gga_exchange_pbe(rho, grad_rho) + gga_correlation_pbe(rho, grad_rho) 144 | E_xc = jnp.einsum("i,i,i", self.mesh.weights, rho, eps_xc) 145 | return E_xc 146 | 147 | 148 | class PBE0(eqx.Module): 149 | basis: Basis 150 | mesh: Mesh 151 | hfx: HartreeFockExchange 152 | 153 | def __init__(self, basis: Basis, two_electron: TwoElectron): 154 | self.basis = basis 155 | self.mesh = xcmesh_from_pyscf(basis.structure) 156 | self.hfx = HartreeFockExchange(two_electron) 157 | 158 | def __call__(self, P: FloatNxN) -> ScalarLike: 159 | rho, grad_rho = density_and_grad(self.basis, self.mesh, P) 160 | e = 0.75 * gga_exchange_pbe(rho, grad_rho) + gga_correlation_pbe(rho, grad_rho) 161 | E_xc = jnp.einsum("i,i,i", self.mesh.weights, rho, e) 162 | return E_xc + 0.25 * self.hfx(P) 163 | 164 | 165 | class B3LYP(eqx.Module): 166 | basis: Basis 167 | mesh: Mesh 168 | hfx: HartreeFockExchange 169 | 170 | def __init__(self, basis: Basis, two_electron: TwoElectron): 171 | self.basis = basis 172 | self.mesh = xcmesh_from_pyscf(basis.structure) 173 | self.hfx = HartreeFockExchange(two_electron) 174 | 175 | def __call__(self, P: FloatNxN) -> ScalarLike: 176 | rho, grad_rho = density_and_grad(self.basis, self.mesh, P) 177 | eps_x = 0.08 * lda_exchange(rho) + 0.72 * gga_exchange_b88(rho, grad_rho) 178 | vwn_c = (1 - 0.81) * lda_correlation_vwn(rho) 179 | lyp_c = 0.81 * gga_correlation_lyp(rho, grad_rho) 180 | b3lyp_xc = eps_x + vwn_c + lyp_c 181 | E_xc = jnp.einsum("i,i,i", self.mesh.weights, rho, b3lyp_xc) 182 | return E_xc + 0.2 * self.hfx(P) 183 | 184 | 185 | def build_xcfunc( 186 | xc_method: xcstr, basis: Basis, two_electron: Optional[TwoElectron] = None 187 | ) -> eqx.Module: 188 | if two_electron is None and xc_method in ("pbe0", "b3lyp"): 189 | raise ValueError( 190 | f"Hybrid functional {xc_method} requires providing TwoElectron integrals" 191 | ) 192 | 193 | match xc_method: 194 | case "lda": 195 | return LDA(basis) 196 | case "pbe": 197 | return PBE(basis) 198 | case "pbe0": 199 | return PBE0(basis, two_electron) 200 | case "b3lyp": 201 | return B3LYP(basis, two_electron) 202 | case "hfx": 203 | return HartreeFockExchange(two_electron) 204 | case _: 205 | methods = get_args(xcstr) 206 | methods = ", ".join(methods) 207 | msg = f"Unsupported exchange-correlation option: {xc_method}." 208 | msg += f"\nMust be one of the following: {methods}" 209 | raise ValueError(msg) 210 | 211 | 212 | class Hamiltonian(eqx.Module): 213 | X: FloatNxN 214 | H_core: FloatNxN 215 | basis: Basis 216 | two_electron: TwoElectron 217 | xcfunc: eqx.Module 218 | 219 | def __init__( 220 | self, 221 | basis: Basis, 222 | ont: OrthNormTransform = symmetric, 223 | xc_method: xcstr = "lda", 224 | backend: IntegralBackend = "pyscf_cart", 225 | ): 226 | super().__init__() 227 | self.basis = basis 228 | one_elec = OneElectron(basis, backend=backend) 229 | S = one_elec.overlap 230 | self.X = ont(S) 231 | self.H_core = one_elec.kinetic + one_elec.nuclear 232 | self.two_electron = TwoElectron(basis, backend=backend) 233 | self.xcfunc = build_xcfunc(xc_method, basis, self.two_electron) 234 | 235 | def __call__(self, P: FloatNxN) -> ScalarLike: 236 | E_core = jnp.sum(self.H_core * P) 237 | E_xc = self.xcfunc(P) 238 | J = self.two_electron.coloumb(P) 239 | E_es = 0.5 * jnp.sum(J * P) 240 | E = E_core + E_xc + E_es 241 | return E 242 | 243 | def orthonormalise(self, Z: FloatNxN) -> FloatNxN: 244 | C = self.X @ jnl.qr(Z).Q 245 | return C 246 | 247 | 248 | @partial(jax.jit, static_argnames=("max_steps", "atol", "rtol")) 249 | def minimise( 250 | H: Hamiltonian, 251 | max_steps: Optional[int] = None, 252 | atol: float = 1e-6, 253 | rtol: float = 1e-5, 254 | ) -> Tuple[ScalarLike, FloatNxN, optx.Solution]: 255 | """Solve for the electronic coefficients that minimise the total energy 256 | 257 | This function takes a Hamiltonian built for a given basis set and molecular 258 | structure, and finds the electronic coefficients that minimise the total energy. 259 | The optimisation is performed using the BFGS algorithm. 260 | 261 | Args: 262 | H (Hamiltonian): The Hamiltonian for a given basis set and molecular structure. 263 | max_steps (Optional[int]): Maximum number of minimizer steps. Defaults to None. 264 | atol (float): Absolute tolerance for convergence. Defaults to 1e-6. 265 | rtol (float): Relative tolerance for convergence. Defaults to 1e-5. 266 | 267 | Returns: 268 | Tuple[ScalarLike, FloatNxN, optimistix.Solution]: A tuple containing: 269 | - total energy in atomic units 270 | - coefficient matrix C that minimizes the Hamiltonian 271 | - the optimistix.Solution object 272 | """ 273 | 274 | def f(Z, _): 275 | C = H.orthonormalise(Z) 276 | P = H.basis.density_matrix(C) 277 | return H(P) 278 | 279 | solver = optx.BFGS(atol=atol, rtol=rtol) 280 | Z = jnp.eye(H.basis.num_orbitals) 281 | sol = optx.minimise(f, solver, Z, max_steps=max_steps) 282 | C = H.orthonormalise(sol.value) 283 | P = H.basis.density_matrix(C) 284 | E_elec = H(P) 285 | E_total = E_elec + nuclear_energy(H.basis.structure) 286 | return E_total, C, sol 287 | -------------------------------------------------------------------------------- /mess/integrals.py: -------------------------------------------------------------------------------- 1 | """ 2 | JAX implementation for integrals over Gaussian basis functions. 3 | 4 | Based upon the closed-form expressions derived in 5 | 6 | Taketa, H., Huzinaga, S., & O-ohata, K. (1966). Gaussian-expansion methods for 7 | molecular integrals. Journal of the physical society of Japan, 21(11), 2313-2324. 8 | 9 | 10 | Hereafter referred to as the "THO paper" 11 | 12 | Related work: 13 | 14 | [1] Augspurger JD, Dykstra CE. General quantum mechanical operators. An 15 | open-ended approach for one-electron integrals with Gaussian bases. Journal of 16 | computational chemistry. 1990 Jan;11(1):105-11. 17 | 18 | """ 19 | 20 | from dataclasses import asdict 21 | from functools import partial 22 | from itertools import product as cartesian_product 23 | from more_itertools import batched 24 | from typing import Callable 25 | 26 | import jax.numpy as jnp 27 | import numpy as np 28 | from jax import jit, tree, vmap 29 | from jax.ops import segment_sum 30 | 31 | from mess.basis import Basis, basis_iter 32 | from mess.primitive import Primitive, product 33 | from mess.special import ( 34 | binom, 35 | binom_factor, 36 | factorial, 37 | factorial2, 38 | gammanu, 39 | allpairs_indices, 40 | ) 41 | from mess.types import Float3, FloatNxN 42 | from mess.units import LMAX 43 | 44 | BinaryPrimitiveOp = Callable[[Primitive, Primitive], float] 45 | 46 | 47 | @partial(jit, static_argnums=(0, 1)) 48 | def integrate_dense(basis: Basis, primitive_op: BinaryPrimitiveOp) -> FloatNxN: 49 | (ii, cl, lhs), (jj, cr, rhs) = basis_iter(basis) 50 | aij = cl * cr * vmap(primitive_op)(lhs, rhs) 51 | A = jnp.zeros((basis.num_primitives, basis.num_primitives)) 52 | A = A.at[ii, jj].set(aij) 53 | A = A + A.T - jnp.diag(jnp.diag(A)) 54 | index = basis.orbital_index.reshape(1, basis.num_primitives) 55 | out = segment_sum(A, index, num_segments=basis.num_orbitals) 56 | out = segment_sum(out.T, index, num_segments=basis.num_orbitals) 57 | return out 58 | 59 | 60 | @partial(jit, static_argnums=(0, 1)) 61 | def integrate_sparse(basis: Basis, primitive_op: BinaryPrimitiveOp) -> FloatNxN: 62 | offset = [0] + [o.num_primitives for o in basis.orbitals] 63 | offset = np.cumsum(offset) 64 | ii, jj = allpairs_indices(basis.num_orbitals) 65 | indices = [] 66 | batch = [] 67 | 68 | for count, idx in enumerate(zip(ii, jj)): 69 | mesh = [range(offset[i], offset[i + 1]) for i in idx] 70 | indices += list(cartesian_product(*mesh)) 71 | batch += [count] * (len(indices) - len(batch)) 72 | 73 | indices = np.array(indices, dtype=np.int32).T 74 | batch = np.array(batch, dtype=np.int32) 75 | cij = jnp.stack([jnp.take(basis.coefficients, idx) for idx in indices]).prod(axis=0) 76 | pij = [ 77 | tree.map(lambda x: jnp.take(x, idx, axis=0), basis.primitives) 78 | for idx in indices 79 | ] 80 | aij = segment_sum(cij * vmap(primitive_op)(*pij), batch, num_segments=count + 1) 81 | 82 | A = jnp.zeros_like(aij, shape=(basis.num_orbitals, basis.num_orbitals)) 83 | A = A.at[ii, jj].set(aij) 84 | A = A + A.T - jnp.diag(jnp.diag(A)) 85 | return A 86 | 87 | 88 | integrate = integrate_dense 89 | 90 | 91 | def _overlap_primitives(a: Primitive, b: Primitive) -> float: 92 | @vmap 93 | def overlap_axis(i: int, j: int, a: float, b: float) -> float: 94 | idx = [(s, t) for s in range(LMAX + 1) for t in range(2 * s + 1)] 95 | s, t = jnp.array(idx, dtype=jnp.uint32).T 96 | out = binom(i, 2 * s - t) * binom(j, t) 97 | out *= a ** jnp.maximum(i - (2 * s - t), 0) * b ** jnp.maximum(j - t, 0) 98 | out *= factorial2(2 * s - 1) / (2 * p.alpha) ** s 99 | 100 | mask = (2 * s - i <= t) & (t <= j) 101 | out = jnp.where(mask, out, 0) 102 | return jnp.sum(out) 103 | 104 | p = product(a, b) 105 | pa = p.center - a.center 106 | pb = p.center - b.center 107 | out = jnp.power(jnp.pi / p.alpha, 1.5) * p.norm 108 | out *= jnp.prod(overlap_axis(a.lmn, b.lmn, pa, pb)) 109 | return out 110 | 111 | 112 | def overlap_basis(basis: Basis) -> FloatNxN: 113 | return integrate(basis, _overlap_primitives) 114 | 115 | 116 | def _kinetic_primitives(a: Primitive, b: Primitive) -> float: 117 | t0 = b.alpha * (2 * jnp.sum(b.lmn) + 3) * _overlap_primitives(a, b) 118 | 119 | def offset_qn(ax: int, offset: int): 120 | lmn = b.lmn.at[ax].add(offset) 121 | return Primitive(**{**asdict(b), "lmn": lmn}) 122 | 123 | axes = jnp.arange(3) 124 | b1 = vmap(offset_qn, (0, None))(axes, 2) 125 | t1 = jnp.sum(vmap(_overlap_primitives, (None, 0))(a, b1)) 126 | 127 | b2 = vmap(offset_qn, (0, None))(axes, -2) 128 | t2 = jnp.sum(b.lmn * (b.lmn - 1) * vmap(_overlap_primitives, (None, 0))(a, b2)) 129 | return t0 - 2.0 * b.alpha**2 * t1 - 0.5 * t2 130 | 131 | 132 | def kinetic_basis(b: Basis) -> FloatNxN: 133 | return integrate(b, _kinetic_primitives) 134 | 135 | 136 | def build_gindex(): 137 | vals = [ 138 | (i, r, u) 139 | for i in range(LMAX + 1) 140 | for r in range(i // 2 + 1) 141 | for u in range((i - 2 * r) // 2 + 1) 142 | ] 143 | i, r, u = jnp.array(vals).T 144 | return i, r, u 145 | 146 | 147 | gindex = build_gindex() 148 | 149 | 150 | def _nuclear_primitives(a: Primitive, b: Primitive, c: Float3): 151 | p = product(a, b) 152 | pa = p.center - a.center 153 | pb = p.center - b.center 154 | pc = p.center - c 155 | epsilon = 1.0 / (4.0 * p.alpha) 156 | 157 | @vmap 158 | def g_term(l1, l2, pa, pb, cp): 159 | i, r, u = gindex 160 | index = i - 2 * r - u 161 | g = ( 162 | jnp.power(-1, i + u) 163 | * jnp.take(binom_factor(l1, l2, pa, pb), i) 164 | * factorial(i) 165 | * jnp.power(cp, index - u) 166 | * jnp.power(epsilon, r + u) 167 | ) / (factorial(r) * factorial(u) * factorial(index - u)) 168 | 169 | g = jnp.where(index <= l1 + l2, g, 0.0) 170 | return segment_sum(g, index, num_segments=LMAX + 1) 171 | 172 | Gi, Gj, Gk = g_term(a.lmn, b.lmn, pa, pb, pc) 173 | ids = jnp.arange(3 * LMAX + 1) 174 | ijk = jnp.arange(LMAX + 1) 175 | nu = ( 176 | ijk[:, jnp.newaxis, jnp.newaxis] 177 | + ijk[jnp.newaxis, :, jnp.newaxis] 178 | + ijk[jnp.newaxis, jnp.newaxis, :] 179 | ) 180 | 181 | W = ( 182 | Gi[:, jnp.newaxis, jnp.newaxis] 183 | * Gj[jnp.newaxis, :, jnp.newaxis] 184 | * Gk[jnp.newaxis, jnp.newaxis, :] 185 | * jnp.take(gammanu(ids, p.alpha * jnp.inner(pc, pc)), nu) 186 | ) 187 | 188 | return -2.0 * jnp.pi / p.alpha * p.norm * jnp.sum(W) 189 | 190 | 191 | overlap_primitives = jit(_overlap_primitives) 192 | kinetic_primitives = jit(_kinetic_primitives) 193 | nuclear_primitives = jit(_nuclear_primitives) 194 | 195 | 196 | @partial(jit, static_argnums=0) 197 | def nuclear_basis(basis: Basis): 198 | def n(atomic_number, position): 199 | def op(pi, pj): 200 | return atomic_number * _nuclear_primitives(pi, pj, position) 201 | 202 | return integrate(basis, op) 203 | 204 | return vmap(n)(basis.structure.atomic_number, basis.structure.position) 205 | 206 | 207 | def build_cindex(): 208 | vals = [ 209 | (i1, i2, r1, r2, u) 210 | for i1 in range(2 * LMAX + 1) 211 | for i2 in range(2 * LMAX + 1) 212 | for r1 in range(i1 // 2 + 1) 213 | for r2 in range(i2 // 2 + 1) 214 | for u in range((i1 + i2) // 2 - r1 - r2 + 1) 215 | ] 216 | i1, i2, r1, r2, u = jnp.array(vals).T 217 | return i1, i2, r1, r2, u 218 | 219 | 220 | cindex = build_cindex() 221 | 222 | 223 | def _eri_primitives(a: Primitive, b: Primitive, c: Primitive, d: Primitive) -> float: 224 | p = product(a, b) 225 | q = product(c, d) 226 | pa = p.center - a.center 227 | pb = p.center - b.center 228 | qc = q.center - c.center 229 | qd = q.center - d.center 230 | qp = q.center - p.center 231 | delta = 1 / (4.0 * p.alpha) + 1 / (4.0 * q.alpha) 232 | 233 | def H(l1, l2, a, b, i, r, gamma): 234 | # Note this should match THO Eq 3.5 but that seems to incorrectly show a 235 | # 1/(4 gamma)^(i- 2r) term which is inconsistent with Eq 2.22. 236 | # Using (4 gamma)^(r - i) matches the reported expressions for H_L 237 | u = factorial(i) * jnp.take(binom_factor(l1, l2, a, b, 2 * LMAX), i) 238 | v = factorial(r) * factorial(i - 2 * r) * (4 * gamma) ** (i - r) 239 | return u / v 240 | 241 | @vmap 242 | def c_term(la, lb, lc, ld, pa, pb, qc, qd, qp): 243 | # THO Eq 2.22 and 3.4 244 | i1, i2, r1, r2, u = cindex 245 | h = H(la, lb, pa, pb, i1, r1, p.alpha) * H(lc, ld, qc, qd, i2, r2, q.alpha) 246 | index = i1 + i2 - 2 * (r1 + r2) - u 247 | x = (-1) ** (i2 + u) * factorial(index + u) * qp ** (index - u) 248 | y = factorial(u) * factorial(index - u) * delta**index 249 | c = h * x / y 250 | 251 | mask = (i1 <= (la + lb)) & (i2 <= (lc + ld)) 252 | c = jnp.where(mask, c, 0.0) 253 | return segment_sum(c, index, num_segments=4 * LMAX + 1) 254 | 255 | Ci, Cj, Ck = c_term(a.lmn, b.lmn, c.lmn, d.lmn, pa, pb, qc, qd, qp) 256 | 257 | ijk = jnp.arange(4 * LMAX + 1) 258 | nu = ( 259 | ijk[:, jnp.newaxis, jnp.newaxis] 260 | + ijk[jnp.newaxis, :, jnp.newaxis] 261 | + ijk[jnp.newaxis, jnp.newaxis, :] 262 | ) 263 | ids = jnp.arange(12 * LMAX + 1) 264 | 265 | W = ( 266 | Ci[:, jnp.newaxis, jnp.newaxis] 267 | * Cj[jnp.newaxis, :, jnp.newaxis] 268 | * Ck[jnp.newaxis, jnp.newaxis, :] 269 | * jnp.take(gammanu(ids, jnp.inner(qp, qp) / (4.0 * delta)), nu) 270 | ) 271 | 272 | return ( 273 | 2.0 274 | * jnp.pi**2 275 | / (p.alpha * q.alpha) 276 | * jnp.sqrt(jnp.pi / (p.alpha + q.alpha)) 277 | * p.norm 278 | * q.norm 279 | * jnp.sum(W) 280 | ) 281 | 282 | 283 | eri_primitives = jit(_eri_primitives) 284 | vmap_eri_primitives = jit(vmap(_eri_primitives)) 285 | 286 | 287 | def gen_ijkl(n: int): 288 | """Adapted from four-index transformations by S Wilson pg 257""" 289 | for idx in range(n): 290 | for jdx in range(idx + 1): 291 | for kdx in range(idx + 1): 292 | lmax = jdx if idx == kdx else kdx 293 | for ldx in range(lmax + 1): 294 | yield idx, jdx, kdx, ldx 295 | 296 | 297 | @partial(jit, static_argnums=0) 298 | def eri_basis_sparse(b: Basis): 299 | indices = [] 300 | batch = [] 301 | offset = np.cumsum([o.num_primitives for o in b.orbitals]) 302 | offset = np.insert(offset, 0, 0) 303 | 304 | for count, idx in enumerate(gen_ijkl(b.num_orbitals)): 305 | mesh = [range(offset[i], offset[i + 1]) for i in idx] 306 | indices += list(cartesian_product(*mesh)) 307 | batch += [count] * (len(indices) - len(batch)) 308 | 309 | indices = jnp.array(indices, dtype=jnp.int32).T 310 | batch = jnp.array(batch, dtype=jnp.int32) 311 | cijkl = jnp.stack([jnp.take(b.coefficients, idx) for idx in indices]).prod(axis=0) 312 | pijkl = [ 313 | tree.map(lambda x: jnp.take(x, idx, axis=0), b.primitives) for idx in indices 314 | ] 315 | eris = cijkl * vmap_eri_primitives(*pijkl) 316 | return segment_sum(eris, batch, num_segments=count + 1) 317 | 318 | 319 | def eri_basis(b: Basis): 320 | unique_eris = eri_basis_sparse_batched(b, batch_size=1024 * 1024) 321 | ii, jj, kk, ll = jnp.array(list(gen_ijkl(b.num_orbitals)), dtype=jnp.int32).T 322 | 323 | # Apply 8x permutation symmetry to build dense ERI from sparse ERI. 324 | eri_dense = jnp.empty_like(unique_eris, shape=(b.num_orbitals,) * 4) 325 | eri_dense = eri_dense.at[ii, jj, kk, ll].set(unique_eris) 326 | eri_dense = eri_dense.at[ii, jj, ll, kk].set(unique_eris) 327 | eri_dense = eri_dense.at[jj, ii, kk, ll].set(unique_eris) 328 | eri_dense = eri_dense.at[jj, ii, ll, kk].set(unique_eris) 329 | eri_dense = eri_dense.at[kk, ll, ii, jj].set(unique_eris) 330 | eri_dense = eri_dense.at[kk, ll, jj, ii].set(unique_eris) 331 | eri_dense = eri_dense.at[ll, kk, ii, jj].set(unique_eris) 332 | eri_dense = eri_dense.at[ll, kk, jj, ii].set(unique_eris) 333 | return eri_dense 334 | 335 | 336 | def eri_basis_sparse_batched(basis: Basis, batch_size: int): 337 | batched_eris = [] 338 | offset = np.cumsum([o.num_primitives for o in basis.orbitals]) 339 | offset = np.insert(offset, 0, 0) 340 | 341 | for ijkl_batch in batched(gen_ijkl(basis.num_orbitals), batch_size): 342 | indices = [] 343 | batch = [] 344 | for count, idx in enumerate(ijkl_batch): 345 | mesh = [range(offset[i], offset[i + 1]) for i in idx] 346 | indices += list(cartesian_product(*mesh)) 347 | batch += [count] * (len(indices) - len(batch)) 348 | 349 | indices = jnp.array(indices, dtype=jnp.int32).T 350 | batch = jnp.array(batch, dtype=jnp.int32) 351 | cijkl = jnp.stack([jnp.take(basis.coefficients, idx) for idx in indices]).prod( 352 | axis=0 353 | ) 354 | pijkl = [ 355 | tree.map(lambda x: jnp.take(x, idx, axis=0), basis.primitives) 356 | for idx in indices 357 | ] 358 | eris = cijkl * vmap_eri_primitives(*pijkl) 359 | batched_eris += [segment_sum(eris, batch, num_segments=count + 1)] 360 | 361 | return jnp.hstack(batched_eris) 362 | -------------------------------------------------------------------------------- /mess/interop.py: -------------------------------------------------------------------------------- 1 | """Interoperation tools for working across MESS, PySCF.""" 2 | 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | from periodictable import elements 7 | from pyscf import gto 8 | 9 | from mess.basis import Basis, basisset 10 | from mess.structure import Structure 11 | from mess.units import to_bohr 12 | from mess.package_utils import requires_package 13 | 14 | 15 | def to_pyscf(structure: Structure, basis_name: str = "sto-3g") -> "gto.Mole": 16 | mol = gto.Mole(unit="Bohr", spin=structure.num_electrons % 2, cart=True) 17 | mol.atom = [ 18 | (symbol, pos) 19 | for symbol, pos in zip(structure.atomic_symbol, structure.position) 20 | ] 21 | mol.basis = basis_name 22 | mol.build(unit="Bohr") 23 | return mol 24 | 25 | 26 | def from_pyscf(mol: "gto.Mole") -> Tuple[Structure, Basis]: 27 | atoms = [(elements.symbol(sym).number, pos) for sym, pos in mol.atom] 28 | atomic_number, position = [np.array(x) for x in zip(*atoms)] 29 | 30 | if mol.unit == "Angstrom": 31 | position = to_bohr(position) 32 | 33 | structure = Structure(atomic_number, position) 34 | 35 | basis = basisset(structure, basis_name=mol.basis) 36 | 37 | return structure, basis 38 | 39 | 40 | @requires_package("pyquante2") 41 | def from_pyquante(name: str) -> Structure: 42 | """Load molecular structure from pyquante2.geo.samples module 43 | 44 | Args: 45 | name (str): Possible names include ch4, c6h6, aspirin, caffeine, hmx, petn, 46 | prozan, rdx, taxol, tylenol, viagara, zoloft 47 | 48 | Returns: 49 | Structure 50 | """ 51 | from pyquante2.geo import samples 52 | 53 | pqmol = getattr(samples, name) 54 | atomic_number, position = zip(*[(a.Z, a.r) for a in pqmol]) 55 | atomic_number, position = [np.asarray(x) for x in (atomic_number, position)] 56 | return Structure(atomic_number, position) 57 | -------------------------------------------------------------------------------- /mess/mesh.py: -------------------------------------------------------------------------------- 1 | """Discretised sampling of orbitals and charge density.""" 2 | 3 | from typing import Optional, Tuple, Union 4 | 5 | import equinox as eqx 6 | import jax.numpy as jnp 7 | from pyscf import dft 8 | from jax import vjp 9 | 10 | from mess.basis import Basis 11 | from mess.interop import to_pyscf 12 | from mess.structure import Structure 13 | from mess.types import FloatN, FloatNx3, FloatNxN, MeshAxes 14 | 15 | 16 | class Mesh(eqx.Module): 17 | points: FloatNx3 18 | weights: Optional[FloatN] = None 19 | axes: Optional[MeshAxes] = None 20 | 21 | 22 | def uniform_mesh( 23 | n: Union[int, Tuple] = 50, b: Union[float, Tuple] = 10.0, ndim: int = 3 24 | ) -> Mesh: 25 | if isinstance(n, int): 26 | n = (n,) * ndim 27 | 28 | if isinstance(b, float): 29 | b = (b,) * ndim 30 | 31 | if not isinstance(n, (tuple, list)): 32 | raise ValueError("Expected an integer ") 33 | 34 | if len(n) != ndim: 35 | raise ValueError("n must be a tuple with {ndim} elements") 36 | 37 | if len(b) != ndim: 38 | raise ValueError("b must be a tuple with {ndim} elements") 39 | 40 | axes = [jnp.linspace(-bi, bi, ni) for bi, ni in zip(b, n)] 41 | points = jnp.stack(jnp.meshgrid(*axes, indexing="ij"), axis=-1) 42 | points = points.reshape(-1, ndim) 43 | return Mesh(points, axes=axes) 44 | 45 | 46 | def density(basis: Basis, mesh: Mesh, P: Optional[FloatNxN] = None) -> FloatN: 47 | P = jnp.diag(basis.occupancy) if P is None else P 48 | orbitals = basis(mesh.points) 49 | return jnp.einsum("ij,pi,pj->p", P, orbitals, orbitals) 50 | 51 | 52 | def density_and_grad( 53 | basis: Basis, mesh: Mesh, P: Optional[FloatNxN] = None 54 | ) -> Tuple[FloatN, FloatNx3]: 55 | def f(points): 56 | return density(basis, eqx.combine(points, rest), P) 57 | 58 | points, rest = eqx.partition(mesh, lambda x: id(x) == id(mesh.points)) 59 | rho, df = vjp(f, points) 60 | grad_rho = df(jnp.ones_like(rho))[0].points 61 | return rho, grad_rho 62 | 63 | 64 | def molecular_orbitals( 65 | basis: Basis, mesh: Mesh, C: Optional[FloatNxN] = None 66 | ) -> FloatN: 67 | C = jnp.eye(basis.num_orbitals) if C is None else C 68 | orbitals = basis(mesh.points) @ C 69 | return orbitals 70 | 71 | 72 | def xcmesh_from_pyscf(structure: Structure, level: int = 3) -> Mesh: 73 | grids = dft.gen_grid.Grids(to_pyscf(structure)) 74 | grids.level = level 75 | grids.build() 76 | return Mesh(points=grids.coords, weights=grids.weights) 77 | -------------------------------------------------------------------------------- /mess/numerics.py: -------------------------------------------------------------------------------- 1 | """Function decorators to automate converting between numeric formats.""" 2 | 3 | from functools import wraps 4 | from typing import Callable 5 | 6 | import jax.numpy as jnp 7 | import numpy as np 8 | from jax.experimental import enable_x64 9 | from jaxtyping import Array 10 | 11 | 12 | def apply_fpcast(v: Array, dtype: np.dtype): 13 | if isinstance(v, jnp.ndarray) and np.issubdtype(v, np.floating): 14 | return v.astype(dtype) 15 | 16 | return v 17 | 18 | 19 | def fpcast(func: Callable, dtype=jnp.float32): 20 | @wraps(func) 21 | def wrapper(*args, **kwargs): 22 | inputs = [apply_fpcast(v, dtype) for v in args] 23 | outputs = func(*inputs, **kwargs) 24 | return outputs 25 | 26 | return wrapper 27 | 28 | 29 | def compare_fp32_to_fp64(func: Callable): 30 | @wraps(func) 31 | def wrapper(*args, **kwargs): 32 | with enable_x64(): 33 | outputs_fp32 = fpcast(func, dtype=jnp.float32)(*args, **kwargs) 34 | outputs_fp64 = fpcast(func, dtype=jnp.float64)(*args, **kwargs) 35 | print_compare(func.__name__, outputs_fp32, outputs_fp64) 36 | return outputs_fp32 37 | 38 | return wrapper 39 | 40 | 41 | def print_compare(name: str, fp32, fp64): 42 | fp32 = [fp32] if isinstance(fp32, jnp.ndarray) else fp32 43 | fp64 = [fp64] if isinstance(fp64, jnp.ndarray) else fp64 44 | 45 | for idx, (low, high) in enumerate(zip(fp32, fp64)): 46 | low = np.asarray(low).astype(np.float64) 47 | high = np.asarray(high) 48 | print(f"{name} output {idx} has max |fp64 - fp32| = {np.abs(high - low).max()}") 49 | -------------------------------------------------------------------------------- /mess/orbital.py: -------------------------------------------------------------------------------- 1 | """Container for a linear combination of Gaussian Primitives (aka contraction).""" 2 | 3 | from functools import partial 4 | from typing import Tuple 5 | 6 | import equinox as eqx 7 | import jax.numpy as jnp 8 | from jax import tree, vmap 9 | 10 | from mess.primitive import Primitive, eval_primitive 11 | from mess.types import FloatN, FloatNx3 12 | 13 | 14 | class Orbital(eqx.Module): 15 | primitives: Tuple[Primitive] 16 | coefficients: FloatN 17 | 18 | @property 19 | def num_primitives(self) -> int: 20 | return len(self.primitives) 21 | 22 | def __call__(self, pos: FloatNx3) -> FloatN: 23 | pos = jnp.atleast_2d(pos) 24 | assert pos.ndim == 2 and pos.shape[1] == 3, "pos must have shape [N,3]" 25 | 26 | @partial(vmap, in_axes=(0, 0, None)) 27 | def eval_orbital(p: Primitive, coef: float, pos: FloatNx3): 28 | return coef * eval_primitive(p, pos) 29 | 30 | batch = tree.map(lambda *xs: jnp.stack(xs), *self.primitives) 31 | out = jnp.sum(eval_orbital(batch, self.coefficients, pos), axis=0) 32 | return out 33 | 34 | @staticmethod 35 | def from_bse(center, alphas, lmn, coefficients): 36 | coefficients = coefficients.reshape(-1) 37 | assert len(coefficients) == len(alphas), "Expecting same size vectors!" 38 | p = [Primitive(center=center, alpha=a, lmn=lmn) for a in alphas] 39 | return Orbital(primitives=p, coefficients=coefficients) 40 | 41 | 42 | def batch_orbitals(orbitals: Tuple[Orbital]): 43 | primitives = [p for o in orbitals for p in o.primitives] 44 | primitives = tree.map(lambda *xs: jnp.stack(xs), *primitives) 45 | coefficients = jnp.concatenate([o.coefficients for o in orbitals]) 46 | orbital_index = jnp.concatenate([ 47 | i * jnp.ones(o.num_primitives, dtype=jnp.int32) for i, o in enumerate(orbitals) 48 | ]) 49 | return primitives, coefficients, orbital_index 50 | -------------------------------------------------------------------------------- /mess/orthnorm.py: -------------------------------------------------------------------------------- 1 | r"""Orthonormal transformation. 2 | 3 | Evaluates the transformation matrix :math:`X` that satisfies 4 | 5 | .. math:: \mathbf{X}^T \mathbf{S} \mathbf{X} = \mathbb{I} 6 | 7 | where :math:`\mathbf{S}` is the overlap matrix of the non-orthonormal basis and 8 | :math:`\mathbb{I}` is the identity matrix. 9 | 10 | This module implements a few commonly used orthonormalisation transforms. 11 | """ 12 | 13 | import jax.numpy as jnp 14 | import jax.numpy.linalg as jnl 15 | 16 | from mess.types import FloatNxN 17 | 18 | 19 | def canonical(S: FloatNxN) -> FloatNxN: 20 | r"""Canonical orthonormal transformation 21 | 22 | .. math:: \mathbf{X} = \mathbf{U} \mathbf{s}^{-1/2} 23 | 24 | where :math:`\mathbf{U}` and :math:`\mathbf{s}` are the eigenvectors and 25 | eigenvalues of the overlap matrix :math:`\mathbf{S}`. 26 | 27 | Args: 28 | S (FloatNxN): overlap matrix for the non-orthonormal basis. 29 | 30 | Returns: 31 | FloatNxN: canonical orthonormal transformation matrix 32 | """ 33 | s, U = jnl.eigh(S) 34 | s = jnp.diag(jnp.power(s, -0.5)) 35 | return U @ s 36 | 37 | 38 | def symmetric(S: FloatNxN) -> FloatNxN: 39 | r"""Symmetric orthonormal transformation 40 | 41 | .. math:: \mathbf{X} = \mathbf{U} \mathbf{s}^{-1/2} \mathbf{U}^T 42 | 43 | where :math:`\mathbf{U}` and :math:`\mathbf{s}` are the eigenvectors and 44 | eigenvalues of the overlap matrix :math:`\mathbf{S}`. 45 | 46 | Args: 47 | S (FloatNxN): overlap matrix for the non-orthonormal basis. 48 | 49 | Returns: 50 | FloatNxN: symmetric orthonormal transformation matrix 51 | """ 52 | s, U = jnl.eigh(S) 53 | s = jnp.diag(jnp.power(s, -0.5)) 54 | return U @ s @ U.T 55 | 56 | 57 | def cholesky(S: FloatNxN) -> FloatNxN: 58 | r"""Cholesky orthonormal transformation 59 | 60 | .. math:: \mathbf{X} = (\mathbf{L}^{-1})^T 61 | 62 | where :math:`\mathbf{L}` is the lower triangular matrix that satisfies the Cholesky 63 | decomposition of the overlap matrix :math:`\mathbf{S}`. 64 | 65 | Args: 66 | S (FloatNxN): overlap matrix for the non-orthonormal basis. 67 | 68 | Returns: 69 | FloatNxN: cholesky orthonormal transformation matrix 70 | """ 71 | L = jnl.cholesky(S) 72 | return jnl.inv(L).T 73 | -------------------------------------------------------------------------------- /mess/package_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from functools import wraps 3 | from typing import Any, Callable, TypeVar 4 | 5 | F = TypeVar("F", bound=Callable[..., Any]) 6 | 7 | 8 | class MissingOptionalDependencyError(BaseException): 9 | """ 10 | An exception raised when an optional dependency is required 11 | but cannot be found. 12 | 13 | Attributes 14 | ---------- 15 | library_name 16 | The name of the missing library. 17 | """ 18 | 19 | def __init__(self, library_name: str): 20 | """ 21 | 22 | Parameters 23 | ---------- 24 | library_name 25 | The name of the missing library. 26 | license_issue 27 | Whether the library was importable but was unusable due 28 | to a missing license. 29 | """ 30 | 31 | message = f"The required {library_name} module could not be imported." 32 | 33 | super(MissingOptionalDependencyError, self).__init__(message) 34 | 35 | self.library_name = library_name 36 | 37 | 38 | def has_package(package_name: str) -> bool: 39 | """ 40 | Helper function to generically check if a Python package is installed. 41 | Intended to be used to check for optional dependencies. 42 | 43 | Parameters 44 | ---------- 45 | package_name : str 46 | The name of the Python package to check the availability of 47 | 48 | Returns 49 | ------- 50 | package_available : bool 51 | Boolean indicator if the package is available or not 52 | 53 | Examples 54 | -------- 55 | >>> has_numpy = has_package('numpy') 56 | >>> has_numpy 57 | True 58 | >>> has_foo = has_package('other_non_installed_package') 59 | >>> has_foo 60 | False 61 | """ 62 | try: 63 | importlib.import_module(package_name) 64 | except ModuleNotFoundError: 65 | return False 66 | return True 67 | 68 | 69 | def requires_package(package_name: str) -> Callable[..., Any]: 70 | """ 71 | Helper function to denote that a funciton requires some optional 72 | dependency. A function decorated with this decorator will raise 73 | `MissingOptionalDependencyError` if the package is not found by 74 | `importlib.import_module()`. 75 | 76 | Parameters 77 | ---------- 78 | package_name : str 79 | The name of the module to be imported. 80 | 81 | Raises 82 | ------ 83 | MissingOptionalDependencyError 84 | 85 | """ 86 | 87 | def inner_decorator(function: F) -> F: 88 | @wraps(function) 89 | def wrapper(*args, **kwargs): 90 | import importlib 91 | 92 | try: 93 | importlib.import_module(package_name) 94 | except ImportError: 95 | raise MissingOptionalDependencyError(library_name=package_name) 96 | except Exception as e: 97 | raise e 98 | 99 | return function(*args, **kwargs) 100 | 101 | return wrapper 102 | 103 | return inner_decorator 104 | -------------------------------------------------------------------------------- /mess/plot.py: -------------------------------------------------------------------------------- 1 | """Visualisations of molecular structures and volumetric data.""" 2 | 3 | import numpy as np 4 | import py3Dmol 5 | from more_itertools import chunked 6 | from numpy.typing import NDArray 7 | 8 | from mess.structure import Structure 9 | from mess.types import MeshAxes 10 | from mess.units import to_angstrom 11 | 12 | 13 | def plot_molecule(view: py3Dmol.view, structure: Structure) -> py3Dmol.view: 14 | """Plots molecular structure. 15 | 16 | Args: 17 | view (py3Dmol.view): py3DMol View to which to add visualizer 18 | structure (Structure): molecular structure 19 | 20 | Returns: 21 | py3DMol View object 22 | 23 | """ 24 | xyz = f"{structure.num_atoms}\n\n" 25 | sym = structure.atomic_symbol 26 | pos = to_angstrom(structure.position) 27 | 28 | for i in range(structure.num_atoms): 29 | r = np.array2string(pos[i, :], separator="\t")[1:-1] 30 | xyz += f"{sym[i]}\t{r}\n" 31 | 32 | view.addModel(xyz) 33 | style = "stick" if structure.num_atoms > 1 else "sphere" 34 | view.setStyle({style: {"radius": 0.1}}) 35 | return view 36 | 37 | 38 | def plot_volume(view: py3Dmol.view, value: NDArray, axes: MeshAxes): 39 | """Plots volumetric data value with molecular structure. 40 | 41 | Volumetric render using https://3dmol.csb.pitt.edu/doc/VolumetricRendererSpec.html 42 | 43 | Args: 44 | view (py3Dmol.view): py3DMol View to which to add visualizer 45 | value (NDArray): the volume data to render 46 | axes (MeshAxes): the axes over which the data was sampled. 47 | 48 | Returns: 49 | py3DMol View object 50 | 51 | """ 52 | 53 | s = cube_data(value, axes) 54 | view.addVolumetricData(s, "cube", build_transferfn(value)) 55 | return view 56 | 57 | 58 | def plot_isosurfaces( 59 | view: py3Dmol.view, value: NDArray, axes: MeshAxes, percentiles=[95, 75] 60 | ): 61 | """Plots volumetric data value with molecular structure. 62 | 63 | IsoSurface render using https://3dmol.csb.pitt.edu/doc/IsoSurfaceSpec.html 64 | 65 | Args: 66 | view (py3Dmol.view): py3DMol View to which to add visualizer 67 | value (NDArray): the volume data to render 68 | axes (MeshAxes): the axes over which the data was sampled. 69 | percentiles (tuple): percentiles at which to draw isosurfaces 70 | 71 | Returns: 72 | py3DMol View object 73 | 74 | Note: 75 | 3Dmol does not currently implement full transparency, so only two 76 | percentiles are accepted, with the inner one being rendered with full opacity. 77 | - https://github.com/3dmol/3Dmol.js/issues/224 78 | """ 79 | assert len(percentiles) == 2 80 | 81 | voldata_str = cube_data(value, axes) 82 | 83 | v = np.percentile(np.abs(value), tuple(reversed(sorted(percentiles)))) 84 | for sign in [-1, 1]: 85 | for isovalind in (0, 1): 86 | isoval = sign * v[isovalind] 87 | tf = { 88 | "isoval": isoval, 89 | "smoothness": 2, 90 | "opacity": 0.9 if isovalind == 1 else 1.0, 91 | "voldata": voldata_str, 92 | "volformat": "cube", 93 | "volscheme": {"gradient": "rwb", "min": -v[0], "max": v[0]}, 94 | } 95 | view.addVolumetricData(voldata_str, "cube", tf) 96 | 97 | return view 98 | 99 | 100 | def cube_data(value: NDArray, axes: MeshAxes) -> str: 101 | """Generate the cube file format as a string. See: 102 | 103 | https://paulbourke.net/dataformats/cube/ 104 | 105 | Args: 106 | value (NDArray): the volume data to serialise in the cube format 107 | axes (MeshAxes): the axes over which the data was sampled 108 | 109 | Returns: 110 | str: cube format representation of the volumetric data. 111 | """ 112 | # The first two lines of the header are comments, they are generally ignored by 113 | # parsing packages or used as two default labels. 114 | fmt = "cube format\n\n" 115 | 116 | axes = [to_angstrom(ax) for ax in axes] 117 | x, y, z = axes 118 | # The third line has the number of atoms included in the file followed by the 119 | # position of the origin of the volumetric data. 120 | fmt += f"0 {cube_format_vec([x[0], y[0], z[0]])}\n" 121 | 122 | # The next three lines give the number of voxels along each axis (x, y, z) 123 | # followed by the axis vector. Note this means the volume need not be aligned 124 | # with the coordinate axis, indeed it also means it may be sheared although most 125 | # volumetric packages won't support that. 126 | # The length of each vector is the length of the side of the voxel thus allowing 127 | # non cubic volumes. 128 | # If the sign of the number of voxels in a dimension is positive then the 129 | # units are Bohr, if negative then Angstroms. 130 | nx, ny, nz = [ax.shape[0] for ax in axes] 131 | dx, dy, dz = [ax[1] - ax[0] for ax in axes] 132 | fmt += f"{nx} {cube_format_vec([dx, 0.0, 0.0])}\n" 133 | fmt += f"{ny} {cube_format_vec([0.0, dy, 0.0])}\n" 134 | fmt += f"{nz} {cube_format_vec([0.0, 0.0, dz])}\n" 135 | 136 | # The last section in the header is one line for each atom consisting of 5 137 | # numbers, the first is the atom number, the second is the charge, and the last 138 | # three are the x,y,z coordinates of the atom center. 139 | pass # Number of atoms = 0 above 140 | 141 | # The volumetric data is straightforward, one floating point number for each 142 | # volumetric element. Traditionally the grid is arranged with the x axis as 143 | # the outer loop and the z axis as the inner loop 144 | for vals in chunked(value, 6): 145 | fmt += f"{cube_format_vec(vals)}\n" 146 | 147 | return fmt 148 | 149 | 150 | def cube_format_vec(vals): 151 | """ 152 | From https://paulbourke.net/dataformats/cube, floats are formatted 12.6 153 | """ 154 | # Benchmarks showed this is 4x faster than numpy.printoptions... 155 | return " ".join([f"{v:12.6f}" for v in vals]) 156 | 157 | 158 | def build_transferfn(value: NDArray) -> dict: 159 | """Generate the 3dmol.js transferfn argument for a particular value. 160 | 161 | Tries to set isovalues to capture main features of the volume data. 162 | 163 | Args: 164 | value (NDArray): the volume data. 165 | 166 | Returns: 167 | dict: containing transferfn 168 | """ 169 | v = np.percentile(value, [99.9, 75]) 170 | a = [0.02, 0.0005] 171 | return { 172 | "transferfn": [ 173 | {"color": "blue", "opacity": a[0], "value": -v[0]}, 174 | {"color": "blue", "opacity": a[1], "value": -v[1]}, 175 | {"color": "white", "opacity": 0.0, "value": 0.0}, 176 | {"color": "red", "opacity": a[1], "value": v[1]}, 177 | {"color": "red", "opacity": a[0], "value": v[0]}, 178 | ] 179 | } 180 | -------------------------------------------------------------------------------- /mess/primitive.py: -------------------------------------------------------------------------------- 1 | """Primitive Gaussian type orbitals""" 2 | 3 | from typing import Optional 4 | 5 | import equinox as eqx 6 | import jax.numpy as jnp 7 | import numpy as np 8 | from jax import jit 9 | from jax.scipy.special import gammaln 10 | 11 | from mess.types import Float3, FloatN, FloatNx3, Int3, asintarray, asfparray 12 | 13 | 14 | class Primitive(eqx.Module): 15 | center: Float3 = eqx.field(converter=asfparray, default=(0.0, 0.0, 0.0)) 16 | alpha: float = eqx.field(converter=asfparray, default=1.0) 17 | lmn: Int3 = eqx.field(converter=asintarray, default=(0, 0, 0)) 18 | norm: Optional[float] = None 19 | 20 | def __post_init__(self): 21 | if self.norm is None: 22 | self.norm = normalize(self.lmn, self.alpha) 23 | 24 | def __check_init__(self): 25 | names = ["center", "alpha", "lmn", "norm"] 26 | shapes = [(3,), (), (3,), ()] 27 | dtypes = [jnp.floating, jnp.floating, jnp.integer, jnp.floating] 28 | 29 | for name, shape, dtype in zip(names, shapes, dtypes): 30 | value = getattr(self, name) 31 | if value.shape != shape or not jnp.issubdtype(value, dtype): 32 | raise ValueError( 33 | f"Invalid value for {name}.\n" 34 | f"Expecting {dtype} array with shape {shape}. " 35 | f"Got {value.dtype} with shape {value.shape}" 36 | ) 37 | 38 | @property 39 | def angular_momentum(self) -> int: 40 | return np.sum(self.lmn) 41 | 42 | def __call__(self, pos: FloatNx3) -> FloatN: 43 | return eval_primitive(self, pos) 44 | 45 | def __hash__(self) -> int: 46 | values = [] 47 | for k, v in vars(self).items(): 48 | if k.startswith("__") or v is None: 49 | continue 50 | 51 | values.append(v.tobytes()) 52 | 53 | return hash(b"".join(values)) 54 | 55 | 56 | @jit 57 | def normalize(lmn: Int3, alpha: float) -> float: 58 | L = jnp.sum(lmn) 59 | N = ((1 / 2) / alpha) ** (L + 3 / 2) 60 | N *= jnp.exp(jnp.sum(gammaln(lmn + 1 / 2))) 61 | return N**-0.5 62 | 63 | 64 | def product(a: Primitive, b: Primitive) -> Primitive: 65 | alpha = a.alpha + b.alpha 66 | center = (a.alpha * a.center + b.alpha * b.center) / alpha 67 | lmn = a.lmn + b.lmn 68 | c = a.norm * b.norm 69 | Rab = a.center - b.center 70 | c *= jnp.exp(-a.alpha * b.alpha / alpha * jnp.inner(Rab, Rab)) 71 | return Primitive(center=center, alpha=alpha, lmn=lmn, norm=c) 72 | 73 | 74 | def eval_primitive(p: Primitive, pos: FloatNx3) -> FloatN: 75 | pos = jnp.atleast_2d(pos) 76 | assert pos.ndim == 2 and pos.shape[1] == 3, "pos must have shape [N,3]" 77 | pos_translated = pos[:, jnp.newaxis] - p.center 78 | v = p.norm * jnp.exp(-p.alpha * jnp.sum(pos_translated**2, axis=-1)) 79 | v *= jnp.prod(pos_translated**p.lmn, axis=-1) 80 | return jnp.squeeze(v) 81 | -------------------------------------------------------------------------------- /mess/scf.py: -------------------------------------------------------------------------------- 1 | """Vanilla self-consistent field solver implementation.""" 2 | 3 | import jax.numpy as jnp 4 | import jax.numpy.linalg as jnl 5 | from jax.lax import while_loop 6 | 7 | from mess.basis import Basis 8 | from mess.integrals import eri_basis, kinetic_basis, nuclear_basis, overlap_basis 9 | from mess.structure import nuclear_energy 10 | from mess.orthnorm import cholesky 11 | from mess.types import OrthNormTransform 12 | 13 | 14 | def scf( 15 | basis: Basis, 16 | otransform: OrthNormTransform = cholesky, 17 | max_iters: int = 32, 18 | tolerance: float = 1e-4, 19 | ): 20 | """ """ 21 | # init 22 | Hcore = kinetic_basis(basis) + nuclear_basis(basis).sum(axis=0) 23 | S = overlap_basis(basis) 24 | eri = eri_basis(basis) 25 | 26 | # initial guess for MO coeffs 27 | X = otransform(S) 28 | C = X @ jnl.eigh(X.T @ Hcore @ X)[1] 29 | 30 | # setup self-consistent iteration as a while loop 31 | counter = 0 32 | E = 0.0 33 | E_prev = 2 * tolerance 34 | scf_args = (counter, E, E_prev, C) 35 | 36 | def while_cond(scf_args): 37 | counter, E, E_prev, _ = scf_args 38 | return (counter < max_iters) & (jnp.abs(E - E_prev) > tolerance) 39 | 40 | def while_body(scf_args): 41 | counter, E, E_prev, C = scf_args 42 | P = basis.occupancy * C @ C.T 43 | J = jnp.einsum("kl,ijkl->ij", P, eri) 44 | K = jnp.einsum("ij,ikjl->kl", P, eri) 45 | G = J - 0.5 * K 46 | H = Hcore + G 47 | C = X @ jnl.eigh(X.T @ H @ X)[1] 48 | E_prev = E 49 | E = 0.5 * jnp.sum(Hcore * P) + 0.5 * jnp.sum(H * P) 50 | return (counter + 1, E, E_prev, C) 51 | 52 | _, E_electronic, _, _ = while_loop(while_cond, while_body, scf_args) 53 | E_nuclear = nuclear_energy(basis.structure) 54 | return E_nuclear + E_electronic 55 | -------------------------------------------------------------------------------- /mess/special.py: -------------------------------------------------------------------------------- 1 | """Special mathematical functions not readily available in JAX.""" 2 | 3 | from functools import partial 4 | from itertools import combinations_with_replacement 5 | 6 | import jax.numpy as jnp 7 | import numpy as np 8 | from jax import lax, vmap 9 | from jax.ops import segment_sum 10 | from jax.scipy.special import betaln, gammainc, gammaln, erfc 11 | 12 | from mess.types import FloatN, IntN 13 | from mess.units import LMAX 14 | 15 | 16 | def factorial_fori(n: IntN, nmax: int = LMAX) -> IntN: 17 | def body_fun(i, val): 18 | return val * jnp.where(i <= n, i, 1) 19 | 20 | return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n)) 21 | 22 | 23 | def factorial_gamma(n: IntN) -> IntN: 24 | """Appoximate factorial by evaluating the gamma function in log-space. 25 | 26 | This approximation is exact for small integers (n < 10). 27 | """ 28 | approx = jnp.exp(gammaln(n + 1)) 29 | return jnp.rint(approx) 30 | 31 | 32 | def factorial_lookup(n: IntN, nmax: int = LMAX) -> IntN: 33 | N = np.cumprod(np.arange(1, nmax + 1)) 34 | N = np.insert(N, 0, 1) 35 | N = jnp.array(N, dtype=jnp.uint32) 36 | return N.at[n.astype(jnp.uint32)].get() 37 | 38 | 39 | factorial = factorial_gamma 40 | 41 | 42 | def factorial2_fori(n: IntN, nmax: int = 2 * LMAX) -> IntN: 43 | def body_fun(i, val): 44 | return val * jnp.where((i <= n) & (n % 2 == i % 2), i, 1) 45 | 46 | return lax.fori_loop(1, nmax + 1, body_fun, jnp.ones_like(n)) 47 | 48 | 49 | def factorial2_lookup(n: IntN, nmax: int = 2 * LMAX) -> IntN: 50 | stop = nmax + 1 if nmax % 2 == 0 else nmax + 2 51 | N = np.arange(1, stop).reshape(-1, 2) 52 | N = np.cumprod(N, axis=0).reshape(-1) 53 | N = np.insert(N, 0, 1) 54 | N = jnp.array(N) 55 | n = jnp.maximum(n, 0) 56 | return N.at[n].get() 57 | 58 | 59 | factorial2 = factorial2_lookup 60 | 61 | 62 | def binom_beta(x: IntN, y: IntN) -> IntN: 63 | approx = 1.0 / ((x + 1) * jnp.exp(betaln(x - y + 1, y + 1))) 64 | return jnp.rint(approx) 65 | 66 | 67 | def binom_fori(x: IntN, y: IntN, nmax: int = LMAX) -> IntN: 68 | bang = partial(factorial_fori, nmax=nmax) 69 | c = x * bang(x - 1) / (bang(y) * bang(x - y)) 70 | return jnp.where(x == y, 1, c) 71 | 72 | 73 | def binom_lookup(x: IntN, y: IntN, nmax: int = LMAX) -> IntN: 74 | bang = partial(factorial_lookup, nmax=nmax) 75 | c = x * bang(x - 1) / (bang(y) * bang(x - y)) 76 | return jnp.where(x == y, 1, c) 77 | 78 | 79 | binom = binom_lookup 80 | 81 | 82 | def gammanu_gamma(nu: IntN, t: FloatN, epsilon: float = 1e-10) -> FloatN: 83 | """Eq 2.11 from THO but simplified using SymPy and converted to jax 84 | 85 | t, u = symbols("t u", real=True, positive=True) 86 | nu = Symbol("nu", integer=True, nonnegative=True) 87 | 88 | expr = simplify(integrate(u ** (2 * nu) * exp(-t * u**2), (u, 0, 1))) 89 | f = lambdify((nu, t), expr, modules="scipy") 90 | ?f 91 | 92 | We evaulate this in log-space to avoid overflow/nan 93 | """ 94 | t = jnp.maximum(t, epsilon) 95 | x = nu + 0.5 96 | gn = jnp.log(0.5) - x * jnp.log(t) + jnp.log(gammainc(x, t)) + gammaln(x) 97 | return jnp.exp(gn) 98 | 99 | 100 | def gammanu_series(nu: IntN, t: FloatN, num_terms: int = 128) -> FloatN: 101 | """Eq 2.11 from THO but simplified as derived in equation 19 of gammanu.ipynb""" 102 | an = nu + 0.5 103 | tn = 1 / an 104 | total = 1 / an 105 | 106 | for _ in range(num_terms): 107 | an = an + 1 108 | tn = tn * t / an 109 | total = total + tn 110 | 111 | return 0.5 * jnp.exp(-t) * total 112 | 113 | 114 | def gammanu_frac_vmap(nu: IntN, t: FloatN, num_terms: int = 128) -> FloatN: 115 | def scalar_fn(nu): 116 | n = jnp.arange(1, num_terms + 1, dtype=t.dtype) 117 | terms = jnp.where(nu >= n, t**n / jnp.cumprod(n - 0.5), 0.0) 118 | q = erfc(jnp.sqrt(t)) + jnp.exp(-t) / jnp.sqrt(jnp.pi * t) * jnp.sum(terms) 119 | lnout = ( 120 | jnp.log(0.5) - (nu + 0.5) * jnp.log(t) + jnp.log(1 - q) + gammaln(nu + 0.5) 121 | ) 122 | return jnp.exp(lnout) 123 | 124 | out_flt = vmap(scalar_fn)(nu.reshape(-1)) 125 | return out_flt.reshape(nu.shape) 126 | 127 | 128 | def gammanu_lax_series(nu: IntN, t: FloatN) -> FloatN: 129 | def cond_fn(vals): 130 | return jnp.any(vals[0]) 131 | 132 | def body_fn(vals): 133 | enabled, an, tn, total = vals 134 | an = an + 1 135 | tn = tn * (t / an) 136 | total = total + tn 137 | enabled = enabled & ((tn / total) > jnp.finfo(t.dtype).eps) 138 | return (enabled, an, tn, total) 139 | 140 | a0 = nu + 0.5 141 | t0 = 1.0 / a0 142 | total = 1.0 / t0 143 | 144 | init_vals = ( 145 | jnp.ones(nu.shape, dtype=bool), 146 | a0, 147 | t0, 148 | total, 149 | ) 150 | 151 | _, _, _, total = lax.while_loop(cond_fn, body_fn, init_vals) 152 | return jnp.exp(jnp.log(0.5) + jnp.log(total) - t) 153 | 154 | 155 | def gammanu_lax_frac(nu: IntN, t: FloatN) -> FloatN: 156 | def cond_fn(vals): 157 | return jnp.any(vals[0]) 158 | 159 | def body_fn(vals): 160 | enabled, term, n, q = vals 161 | term = term * t / (n - 0.5) 162 | enabled = nu >= n 163 | q = jnp.where(enabled, q + term, q) 164 | n = n + 1.0 165 | return (enabled, term, n, q) 166 | 167 | enabled = jnp.ones(nu.shape, dtype=bool) 168 | term = jnp.full(nu.shape, jnp.exp(-t) / jnp.sqrt(jnp.pi * t)) 169 | n = jnp.ones(nu.shape) 170 | q = jnp.full(nu.shape, erfc(jnp.sqrt(t))) 171 | init_vals = (enabled, term, n, q) 172 | _, _, _, q = lax.while_loop(cond_fn, body_fn, init_vals) 173 | lnout = jnp.log(0.5) - (nu + 0.5) * jnp.log(t) + jnp.log(1 - q) + gammaln(nu + 0.5) 174 | return jnp.exp(lnout) 175 | 176 | 177 | def gammanu_select(nu: IntN, t: FloatN, threshold: float = 50.0) -> FloatN: 178 | """Select between different implementation strategies for evaluation of gammanu 179 | 180 | Args: 181 | nu (IntN): 182 | t (FloatN): 183 | 184 | Returns: 185 | FloatN 186 | """ 187 | y0 = 1 / (2 * nu + 1) 188 | y1 = gammanu_series(nu, jnp.minimum(t, threshold)) 189 | y2 = ( 190 | factorial2(2 * nu - 1) 191 | / 2 ** (nu + 1) 192 | * jnp.sqrt(jnp.pi / jnp.where(t >= threshold, t, 1.0) ** (2 * nu + 1)) 193 | ) 194 | 195 | return jnp.select( 196 | (t == 0, t < threshold, t >= threshold), 197 | (y0, y1, y2), 198 | ) 199 | 200 | 201 | gammanu = gammanu_select 202 | 203 | 204 | def triu_indices(n: int): 205 | out = [] 206 | for i in range(n): 207 | for j in range(i, n): 208 | out.append((i, j)) 209 | i, j = np.asarray(out).T 210 | return i, j 211 | 212 | 213 | def tril_indices(n: int): 214 | out = [] 215 | for i in range(n): 216 | for j in range(i + 1): 217 | out.append((i, j)) 218 | i, j = np.asarray(out).T 219 | return i, j 220 | 221 | 222 | def allpairs_indices(n: int): 223 | pairs_gen = combinations_with_replacement(range(n), r=2) 224 | i, j = np.asarray(list(pairs_gen)).T 225 | return i, j 226 | 227 | 228 | def binom_factor(i: int, j: int, a: float, b: float, lmax: int = LMAX) -> FloatN: 229 | """Eq. 15 from Augspurger JD, Dykstra CE. General quantum mechanical operators. An 230 | open-ended approach for one-electron integrals with Gaussian bases. Journal of 231 | computational chemistry. 1990 Jan;11(1):105-11. 232 | 233 | """ 234 | s, t = tril_indices(lmax + 1) 235 | apow = jnp.maximum(i - (s - t), 0) 236 | bpow = jnp.maximum(j - t, 0) 237 | out = binom(i, s - t) * binom(j, t) * a**apow * b**bpow 238 | mask = ((s - i) <= t) & (t <= j) 239 | out = jnp.where(mask, out, 0.0) 240 | return segment_sum(out, s, num_segments=lmax + 1) 241 | -------------------------------------------------------------------------------- /mess/structure.py: -------------------------------------------------------------------------------- 1 | """Container for molecular structures""" 2 | 3 | from typing import List 4 | 5 | import equinox as eqx 6 | import numpy as np 7 | import jax.numpy as jnp 8 | from jax import value_and_grad 9 | from periodictable import elements 10 | 11 | from mess.types import FloatNx3, IntN 12 | from mess.units import to_bohr 13 | 14 | 15 | class Structure(eqx.Module): 16 | atomic_number: IntN 17 | position: FloatNx3 18 | 19 | def __post_init__(self): 20 | # single atom case 21 | self.atomic_number = np.atleast_1d(self.atomic_number) 22 | self.position = np.atleast_2d(self.position) 23 | 24 | @property 25 | def num_atoms(self) -> int: 26 | return len(self.atomic_number) 27 | 28 | @property 29 | def atomic_symbol(self) -> List[str]: 30 | return [elements[z].symbol for z in self.atomic_number] 31 | 32 | @property 33 | def num_electrons(self) -> int: 34 | return np.sum(self.atomic_number) 35 | 36 | def _repr_html_(self): 37 | import py3Dmol 38 | from mess.plot import plot_molecule 39 | 40 | v = py3Dmol.view() 41 | plot_molecule(v, self) 42 | return v._repr_html_() 43 | 44 | 45 | def molecule(name: str) -> Structure: 46 | """Builds a few sample molecules 47 | 48 | Args: 49 | name (str): either "h2" or "water". More to be added. 50 | 51 | Raises: 52 | NotImplementedError: _description_ 53 | 54 | Returns: 55 | Structure: the built molecule as a Structure object 56 | """ 57 | 58 | name = name.lower() 59 | 60 | if name == "h2": 61 | return Structure( 62 | atomic_number=np.array([1, 1]), 63 | position=np.array([[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]), 64 | ) 65 | 66 | if name == "water": 67 | r"""Single water molecule 68 | Structure of single water molecule calculated with DFT using B3LYP 69 | functional and 6-31+G** basis set """ 70 | return Structure( 71 | atomic_number=np.array([8, 1, 1]), 72 | position=to_bohr( 73 | np.array([ 74 | [0.0000, 0.0000, 0.1165], 75 | [0.0000, 0.7694, -0.4661], 76 | [0.0000, -0.7694, -0.4661], 77 | ]) 78 | ), 79 | ) 80 | 81 | raise NotImplementedError(f"No structure registered for: {name}") 82 | 83 | 84 | def nuclear_energy(structure: Structure) -> float: 85 | r"""Nuclear electrostatic interaction energy 86 | 87 | Evaluated by taking sum over all unique pairs of atom centers: 88 | 89 | .. math:: \\sum_{j > i} \\frac{Z_i Z_j}{|\\mathbf{r}_i - \\mathbf{r}_j|} 90 | 91 | where :math:`z_i` is the charge of the ith atom (the atomic number). 92 | 93 | Args: 94 | structure (Structure): input structure 95 | 96 | Returns: 97 | float: the total nuclear repulsion energy 98 | """ 99 | idx, jdx = jnp.triu_indices(structure.num_atoms, 1) 100 | u = structure.atomic_number[idx] * structure.atomic_number[jdx] 101 | rij = structure.position[idx, :] - structure.position[jdx, :] 102 | return jnp.sum(u / jnp.linalg.norm(rij, axis=1)) 103 | 104 | 105 | def nuclear_energy_and_force(structure: Structure): 106 | @value_and_grad 107 | def energy_and_grad(pos, rest): 108 | return nuclear_energy(eqx.combine(pos, rest)) 109 | 110 | pos, rest = eqx.partition(structure, lambda x: id(x) == id(structure.position)) 111 | E, grad = energy_and_grad(pos, rest) 112 | return E, -grad.position 113 | 114 | 115 | def cubic_hydrogen(n: int) -> Structure: 116 | """ 117 | Builds a Structure of hydrogen atoms arranged in a simple cubic lattice. 118 | 119 | Args: 120 | n (int): The number of hydrogen atoms for the cubic cell. For example, n=4 will 121 | build a 4x4x4 cubic lattice. 122 | 123 | Raises: 124 | ValueError: If n is less than 1. 125 | 126 | Returns: 127 | Structure: A Structure object representing the cubic lattice of hydrogen atoms. 128 | """ 129 | if n < 1: 130 | raise ValueError("Expect at least one hydrogen atom in cubic lattice") 131 | 132 | b = 1.4 * np.arange(0, n) 133 | pos = np.stack(np.meshgrid(b, b, b)).reshape(3, -1).T 134 | pos = np.round(pos - np.mean(pos, axis=0), decimals=3) 135 | return Structure(np.ones(pos.shape[0], dtype=np.int64), pos) 136 | -------------------------------------------------------------------------------- /mess/types.py: -------------------------------------------------------------------------------- 1 | """Types used throughout MESS 2 | 3 | Note: 4 | ``N`` represents the number of atomic orbitals. 5 | """ 6 | 7 | from functools import partial 8 | from typing import Tuple, Callable 9 | 10 | import jax.numpy as jnp 11 | from jax import config 12 | from jaxtyping import Array, Float, Int 13 | 14 | Float3 = Float[Array, "3"] 15 | FloatNx3 = Float[Array, "N 3"] 16 | FloatN = Float[Array, "N"] 17 | Float3xNxN = Float[Array, "3 N N"] 18 | FloatNxN = Float[Array, "N N"] 19 | FloatNxM = Float[Array, "N M"] 20 | Int3 = Int[Array, "3"] 21 | IntN = Int[Array, "N"] 22 | 23 | MeshAxes = Tuple[FloatN, FloatN, FloatN] 24 | 25 | asintarray = partial(jnp.asarray, dtype=jnp.int32) 26 | 27 | OrthNormTransform = Callable[[FloatNxN], FloatNxN] 28 | 29 | 30 | def default_fptype(): 31 | return jnp.float64 if config.x64_enabled else jnp.float32 32 | 33 | 34 | asfparray = partial(jnp.asarray, dtype=default_fptype()) 35 | -------------------------------------------------------------------------------- /mess/units.py: -------------------------------------------------------------------------------- 1 | """Conversion between Bohr and Angstrom units 2 | 3 | Note: 4 | MESS uses atomic units internally so these conversions are only necessary when 5 | working with external packages. 6 | """ 7 | 8 | from jaxtyping import Array 9 | 10 | # Maximum value an individual component of the angular momentum lmn can take 11 | # Used for static ahead-of-time compilation of functions involving lmn. 12 | LMAX = 4 13 | 14 | BOHR_PER_ANGSTROM = 1.0 / 0.529177210903 15 | 16 | 17 | def to_angstrom(bohr_value: Array) -> Array: 18 | return bohr_value / BOHR_PER_ANGSTROM 19 | 20 | 21 | def to_bohr(angstrom_value: Array) -> Array: 22 | return angstrom_value * BOHR_PER_ANGSTROM 23 | -------------------------------------------------------------------------------- /mess/xcfunctional.py: -------------------------------------------------------------------------------- 1 | """Core functions for common approximations to the exchange-correlation functional""" 2 | 3 | import jax 4 | import numpy as np 5 | import jax.numpy as jnp 6 | import jax.numpy.linalg as jnl 7 | from mess.types import FloatN, FloatNx3 8 | 9 | 10 | def fzeta(z): 11 | u = (1 + z) ** (4 / 3) + (1 - z) ** (4 / 3) - 2 12 | v = 2 * (2 ** (1 / 3) - 1) 13 | return u / v 14 | 15 | 16 | def d2fdz20(): 17 | f = jax.grad(jax.grad(fzeta)) 18 | return float(f(0.0)) 19 | 20 | 21 | F2 = d2fdz20() 22 | 23 | default_threshold = 1e-15 24 | 25 | 26 | def lda_exchange(rho: FloatN, threshold: float = default_threshold) -> FloatN: 27 | mask = rho > threshold 28 | rho = jnp.where(mask, rho, 0.0) 29 | Cx = (3 / 4) * (3 / np.pi) ** (1 / 3) 30 | eps_x = -Cx * rho ** (1 / 3) 31 | eps_x = jnp.where(mask, eps_x, 0.0) 32 | return eps_x 33 | 34 | 35 | def lda_correlation_vwn( 36 | rho: FloatN, threshold: float = default_threshold, use_rpa: bool = True 37 | ) -> FloatN: 38 | A, x0, b, c = vwn_coefs(use_rpa) 39 | 40 | # avoid divide by zero when rho = 0 by replacing with 1.0 41 | mask = jnp.abs(rho) > threshold 42 | rho = jnp.where(mask, rho, 1.0) 43 | rs = jnp.power(3 / (4 * jnp.pi * rho), 1 / 3).reshape(-1, 1) 44 | x = jnp.sqrt(rs).reshape(-1, 1) 45 | X = rs + b * x + c 46 | X0 = x0**2 + b * x0 + c 47 | Q = np.sqrt(4 * c - b**2) 48 | 49 | u = jnp.log(x**2 / X) + 2 * b / Q * jnp.arctan(Q / (2 * x + b)) 50 | v = jnp.log((x - x0) ** 2 / X) + 2 * (b + 2 * x0) / Q * jnp.arctan(Q / (2 * x + b)) 51 | ec = A * (u - b * x0 / X0 * v) 52 | e0, e1, alpha = ec.T 53 | beta = F2 * (e1 - e0) / alpha - 1 54 | z = jnp.zeros_like(rho) # restricted ks, should be rho_up - rho_down 55 | eps_c = e0 + alpha * fzeta(z) / F2 * (1 + beta * z**4) 56 | eps_c = jnp.where(mask, eps_c, 0.0) 57 | return eps_c 58 | 59 | 60 | def vwn_coefs(use_rpa: bool = True): 61 | # paramagnetic (eps_0) / ferromagnetic (eps_1) / spin stiffness (alpha) 62 | A = np.array([0.0310907, 0.5 * 0.0310907, -1 / (6 * np.pi**2)]) 63 | 64 | if use_rpa: 65 | x0 = np.array([-0.409286, -0.743294, -0.228344]) 66 | b = np.array([13.0720, 20.1231, 1.06835]) 67 | c = np.array([42.7198, 101.578, 11.4813]) 68 | else: 69 | # https://math.nist.gov/DFTdata/atomdata/node5.html 70 | x0 = np.array([-0.10498, -0.32500, -4.75840e-3]) 71 | b = np.array([3.72744, 7.06042, 1.13107]) 72 | c = np.array([12.9352, 18.0578, 13.0045]) 73 | return A, x0, b, c 74 | 75 | 76 | def lda_correlation_pw(rho: FloatN, threshold: float = default_threshold) -> FloatN: 77 | p = np.ones(3) 78 | A = np.array([0.031091, 0.015545, 0.016887]) 79 | a1 = np.array([0.21370, 0.20548, 0.11125]) 80 | b1 = np.array([7.5957, 14.1189, 10.357]) 81 | b2 = np.array([3.5876, 6.1977, 3.6231]) 82 | b3 = np.array([1.6382, 3.3662, 0.88026]) 83 | b4 = np.array([0.49294, 0.62517, 0.49671]) 84 | 85 | # avoid divide by zero when rho = 0 by replacing with 1.0 86 | mask = jnp.abs(rho) > threshold 87 | rho = jnp.where(mask, rho, 1.0) 88 | rs = jnp.power(3 / (4 * jnp.pi * rho), 1 / 3).reshape(-1, 1) 89 | v = 2 * A * (b1 * jnp.sqrt(rs) + b2 * rs + b3 * rs ** (3 / 2) + b4 * rs ** (p + 1)) 90 | G = -2 * A * (1 + a1 * rs) * jnp.log(1 + 1 / v) 91 | e0, e1, alpha = G.T 92 | beta = F2 * (e1 - e0) / alpha - 1 93 | z = jnp.zeros_like(rho) # restricted ks, should be rho_up - rho_down 94 | eps_c = e0 + alpha * fzeta(z) / F2 * (1 + beta * z**4) 95 | eps_c = jnp.where(mask, eps_c, 0.0) 96 | return eps_c 97 | 98 | 99 | def gga_exchange_b88( 100 | rho: FloatN, grad_rho: FloatNx3, threshold: float = default_threshold 101 | ) -> FloatN: 102 | beta = jnp.asarray(0.0042 * 2 ** (1 / 3)) 103 | # avoid divide by zero when rho = 0 by replacing with 1.0 104 | mask = jnp.abs(rho) > threshold 105 | rho = jnp.where(mask, rho, 1.0) 106 | x = jnl.norm(grad_rho, axis=1) / rho ** (4 / 3) 107 | d = 1 + 6 * beta * x * jnp.arcsinh(2 ** (1 / 3) * x) 108 | eps_x = lda_exchange(rho) - beta * rho ** (1 / 3) * x**2 / d 109 | eps_x = jnp.where(mask, eps_x, 0.0) 110 | return eps_x 111 | 112 | 113 | def gga_exchange_pbe( 114 | rho: FloatN, grad_rho: FloatNx3, threshold: float = default_threshold 115 | ) -> FloatN: 116 | beta = np.asarray(0.066725) # Eq 4 117 | mu = beta * np.pi**2 / 3 # Eq 12 118 | kappa = np.asarray(0.8040) # Eq 14 119 | 120 | # avoid divide by zero when rho = 0 by replacing with 1.0 121 | mask = jnp.abs(rho) > threshold 122 | rho = jnp.where(mask, rho, 1.0) 123 | kf = (3 * np.pi**2 * rho) ** (1 / 3) 124 | s = jnl.norm(grad_rho, axis=1) / (2 * kf * rho) 125 | F = 1 + kappa - kappa / (1 + mu * s**2 / kappa) 126 | F = jnp.where(mask, F, 0.0) 127 | return lda_exchange(rho) * F 128 | 129 | 130 | def gga_correlation_pbe( 131 | rho: FloatN, grad_rho: FloatNx3, threshold: float = default_threshold 132 | ) -> FloatN: 133 | beta = np.asarray(0.066725) 134 | gamma = (1 - np.log(2.0)) / np.pi**2 135 | z = jnp.zeros_like(rho) # restricted ks, should be (rho_up - rho_down) / rho 136 | phi = 0.5 * (jnp.power(1 + z, 2 / 3) + jnp.power(1 - z, 2 / 3)) 137 | ec_pw = lda_correlation_pw(rho, threshold) 138 | # avoid divide by zero when rho = 0 by replacing with 1.0 139 | mask = jnp.abs(rho) > threshold 140 | rho = jnp.where(mask, rho, 1.0) 141 | mask_ecpw = jnp.where(mask, ec_pw, 1.0) 142 | A = beta / gamma * (jnp.exp(-mask_ecpw / (gamma * phi**3)) - 1) ** -1 # Eq 8 143 | kf = (3 * np.pi**2 * rho) ** (1 / 3) 144 | ks = jnp.sqrt(4 * kf / np.pi) 145 | t = jnl.norm(grad_rho, axis=1) / (2 * phi * ks * rho) 146 | u = 1 + beta / gamma * t**2 * (1 + A * t**2) / (1 + A * t**2 + A**2 * t**4) 147 | H = gamma * phi**3 * jnp.log(u) # Eq 7 148 | H = jnp.where(mask, H, 0.0) 149 | return ec_pw + H 150 | 151 | 152 | def gga_correlation_lyp( 153 | rho: FloatN, grad_rho: FloatNx3, threshold: float = default_threshold 154 | ) -> FloatN: 155 | a = np.asarray(0.04918) 156 | b = np.asarray(0.132) 157 | c = np.asarray(0.2533) 158 | d = np.asarray(0.349) 159 | CF = 0.3 * (3 * np.pi**2) ** (2 / 3) 160 | 161 | # avoid divide by zero when rho = 0 by replacing with 1.0 162 | mask = jnp.abs(rho) > threshold 163 | rho = jnp.where(mask, rho, 1.0) 164 | v = 1 + d * rho ** (-1 / 3) 165 | omega = jnp.exp(-c * rho ** (-1 / 3)) / v * rho ** (-11 / 3) 166 | delta = c * rho ** (-1 / 3) + d * rho ** (-1 / 3) / v 167 | g = (1 / 24 + 7 * delta / 72) * rho * jnl.norm(grad_rho, axis=1) ** 2 168 | 169 | eps_c = -a / v - a * b * omega * (CF * rho ** (11 / 3) - g) 170 | eps_c = jnp.where(mask, eps_c, 0.0) 171 | return eps_c 172 | -------------------------------------------------------------------------------- /mess/zeropad_integrals.py: -------------------------------------------------------------------------------- 1 | """(experimental) Gaussian orbital integrals without array padding.""" 2 | 3 | from functools import partial 4 | 5 | import jax.numpy as jnp 6 | from jax import jit 7 | 8 | from mess.basis import Basis 9 | from mess.integrals import integrate 10 | from mess.primitive import Primitive, product 11 | from mess.types import FloatNxN 12 | 13 | 14 | @partial(jit, static_argnums=0) 15 | def overlap_basis_zeropad(basis: Basis) -> FloatNxN: 16 | def op(a, b): 17 | return _overlap_primitives_zeropad(a, b, basis.max_L) 18 | 19 | return integrate(basis, op) 20 | 21 | 22 | @partial(jit, static_argnums=2) 23 | def overlap_context(i: int, j: int, max_L: int): 24 | from mess.special import binom, factorial2 25 | 26 | def gen(): 27 | for s in range(max_L + 1): 28 | for t in range(2 * s + 1): 29 | yield s, t 30 | 31 | s, t = jnp.asarray(list(gen())).T 32 | mask = (2 * s - i <= t) & (t <= j) 33 | s = jnp.where(mask, s, 0) 34 | t = jnp.where(mask, t, 0) 35 | w = binom(i, 2 * s - t) * binom(j, t) * factorial2(2 * s - 1) 36 | return s, t, w 37 | 38 | 39 | def _overlap_primitives_zeropad(a: Primitive, b: Primitive, max_L: int) -> float: 40 | def overlap_axis(i: int, j: int, a: float, b: float) -> float: 41 | s, t, w = overlap_context(i, j, max_L) 42 | out = w * a ** (i - (2 * s - t)) * b ** (j - t) / (2 * p.alpha) ** s 43 | return jnp.sum(out) 44 | 45 | p = product(a, b) 46 | pa = p.center - a.center 47 | pb = p.center - b.center 48 | out = jnp.power(jnp.pi / p.alpha, 1.5) * p.norm 49 | 50 | for ax in range(3): 51 | out *= overlap_axis(a.lmn[ax], b.lmn[ax], pa[ax], pb[ax]) 52 | 53 | return out 54 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64", "setuptools_scm>=8"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "mess-jax" 7 | authors = [{ name = "Hatem Helal", email = "hatem@valencelabs.com" }] 8 | description = "MESS: Modern Electronic Structure Simulations" 9 | readme = "README.md" 10 | license = { text = "MIT License" } 11 | requires-python = ">=3.11" 12 | classifiers = [ 13 | "Development Status :: 3 - Alpha", 14 | "Intended Audience :: Developers", 15 | "Topic :: Scientific/Engineering", 16 | "License :: OSI Approved :: MIT License", 17 | "Programming Language :: Python :: 3", 18 | ] 19 | dependencies = [ 20 | "equinox", 21 | "jax[cpu]", 22 | "jaxtyping", 23 | "more-itertools", 24 | "optax", 25 | "optimistix", 26 | "pandas", 27 | "periodictable", 28 | "pyarrow", 29 | "pyscf==2.6.2", 30 | "py3Dmol", 31 | "basis_set_exchange", 32 | "sympy", 33 | "importlib-resources", 34 | ] 35 | dynamic = ["version"] 36 | 37 | [project.optional-dependencies] 38 | dev = [ 39 | "jupyter-book", 40 | "tqdm", 41 | "ipywidgets", 42 | "pytest", 43 | "pytest-benchmark", 44 | "pre-commit", 45 | "ruff", 46 | "mdformat-gfm", 47 | "seaborn", 48 | ] 49 | 50 | [project.urls] 51 | Website = "https://github.com/valence-labs/mess" 52 | "Source Code" = "https://github.com/valence-labs/mess" 53 | "Bug Tracker" = "https://github.com/valence-labs/mess/issues" 54 | Documentation = "https://valence-labs.github.io/mess/" 55 | 56 | [tool.setuptools] 57 | include-package-data = true 58 | 59 | [tool.setuptools_scm] 60 | fallback_version = "dev" 61 | 62 | [tool.pytest.ini_options] 63 | addopts = "-s -v --durations=10" 64 | filterwarnings = [ 65 | "error", 66 | 'ignore:Since PySCF\-2\.3, B3LYP \(and B3P86\) are changed.*:UserWarning', 67 | 'ignore:Function mol\.dumps drops attribute spin.*:UserWarning', 68 | 'ignore:scatter inputs have incompatible types.*:FutureWarning' 69 | ] 70 | 71 | [tool.ruff] 72 | extend-include = ["*.ipynb"] 73 | 74 | [tool.ruff.lint] 75 | select = ["E", "F"] 76 | ignore = ["E741"] 77 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from jax import config 4 | from jax.experimental.compilation_cache import compilation_cache as cc 5 | 6 | 7 | def pytest_sessionstart(session): 8 | cache_dir = osp.expanduser("~/.cache/mess") 9 | print(f"Initializing JAX compilation cache dir: {cache_dir}") 10 | cc.set_cache_dir(cache_dir) 11 | config.update("jax_persistent_cache_min_compile_time_secs", 0.1) 12 | 13 | 14 | def is_mem_limited(): 15 | # Check if we are running on a limited memory host (e.g. github action) 16 | import psutil 17 | 18 | total_mem_gib = psutil.virtual_memory().total // 1024**3 19 | return total_mem_gib < 10 20 | -------------------------------------------------------------------------------- /test/test_autograd_integrals.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import assert_allclose 3 | 4 | from mess.basis import basisset 5 | from mess.interop import to_pyscf 6 | from mess.autograd_integrals import ( 7 | grad_overlap_basis, 8 | grad_kinetic_basis, 9 | grad_nuclear_basis, 10 | ) 11 | from mess.structure import molecule 12 | 13 | 14 | def test_nuclear_gradients(): 15 | basis_name = "sto-3g" 16 | h2 = molecule("h2") 17 | scfmol = to_pyscf(h2, basis_name) 18 | basis = basisset(h2, basis_name) 19 | 20 | actual = grad_overlap_basis(basis) 21 | expect = scfmol.intor("int1e_ipovlp_cart", comp=3) 22 | assert_allclose(actual, expect, atol=1e-6) 23 | 24 | actual = grad_kinetic_basis(basis) 25 | expect = scfmol.intor("int1e_ipkin_cart", comp=3) 26 | assert_allclose(actual, expect, atol=1e-6) 27 | 28 | # TODO: investigate possible inconsistency in libcint outputs? 29 | actual = grad_nuclear_basis(basis) 30 | expect = scfmol.intor("int1e_ipnuc_cart", comp=3) 31 | expect = -np.moveaxis(expect, 1, 2) 32 | assert_allclose(actual, expect, atol=1e-5) 33 | -------------------------------------------------------------------------------- /test/test_benchmark.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import pytest 3 | 4 | from mess.basis import basisset 5 | from mess.hamiltonian import Hamiltonian, minimise 6 | from mess.zeropad_integrals import overlap_basis_zeropad 7 | from mess.integrals import ( 8 | eri_basis_sparse, 9 | kinetic_basis, 10 | nuclear_basis, 11 | overlap_basis, 12 | ) 13 | from mess.structure import molecule 14 | from mess.interop import from_pyquante 15 | from mess.package_utils import has_package 16 | from conftest import is_mem_limited 17 | 18 | 19 | @pytest.mark.parametrize("func", [overlap_basis, overlap_basis_zeropad, kinetic_basis]) 20 | @pytest.mark.skipif( 21 | not has_package("pyquante2"), reason="Missing Optional Dependency: pyquante2" 22 | ) 23 | def test_benzene(func, benchmark): 24 | mol = from_pyquante("c6h6") 25 | basis = basisset(mol, "def2-TZVPPD") 26 | basis = jax.device_put(basis) 27 | 28 | def harness(): 29 | return func(basis).block_until_ready() 30 | 31 | benchmark(harness) 32 | 33 | 34 | @pytest.mark.parametrize("mol_name", ["h2", "water"]) 35 | @pytest.mark.parametrize( 36 | "func", [overlap_basis, kinetic_basis, nuclear_basis, eri_basis_sparse] 37 | ) 38 | def test_integrals(mol_name, func, benchmark): 39 | mol = molecule(mol_name) 40 | basis = basisset(mol) 41 | basis = jax.device_put(basis) 42 | 43 | def harness(): 44 | return func(basis).block_until_ready() 45 | 46 | benchmark(harness) 47 | 48 | 49 | @pytest.mark.parametrize("mol_name", ["h2", "water"]) 50 | @pytest.mark.skipif(is_mem_limited(), reason="Not enough host memory!") 51 | def test_minimise_ks(benchmark, mol_name): 52 | # TODO: investigate test failure with cpu backend and float32 53 | from jax.experimental import enable_x64 54 | 55 | with enable_x64(True): 56 | mol = molecule(mol_name) 57 | basis = basisset(mol, "6-31g") 58 | H = Hamiltonian(basis) 59 | H = jax.device_put(H) 60 | 61 | def harness(): 62 | E, C, _ = minimise(H) 63 | return E.block_until_ready(), C.block_until_ready() 64 | 65 | benchmark(harness) 66 | -------------------------------------------------------------------------------- /test/test_hamiltonian.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from jax.experimental import enable_x64 4 | from numpy.testing import assert_allclose 5 | from pyscf import dft 6 | 7 | from mess.basis import basisset 8 | from mess.hamiltonian import Hamiltonian 9 | from mess.interop import to_pyscf 10 | from mess.structure import Structure 11 | 12 | cases = { 13 | "hfx": "hf,", 14 | "lda": "slater,vwn_rpa", 15 | "pbe": "gga_x_pbe,gga_c_pbe", 16 | "pbe0": "pbe0", 17 | "b3lyp": "b3lyp", 18 | } 19 | 20 | 21 | @pytest.mark.parametrize("inputs", cases.items(), ids=cases.keys()) 22 | def test_energy(inputs): 23 | with enable_x64(True): 24 | xc_method, scfxc = inputs 25 | mol = Structure(np.asarray(2), np.zeros(3)) 26 | basis_name = "6-31g" 27 | basis = basisset(mol, basis_name) 28 | scfmol = to_pyscf(mol, basis_name=basis_name) 29 | s = dft.RKS(scfmol, xc=scfxc) 30 | s.kernel() 31 | P = np.asarray(s.make_rdm1()) 32 | 33 | H = Hamiltonian(basis=basis, xc_method=xc_method) 34 | actual = H(P) 35 | expect = s.energy_tot() 36 | assert_allclose(actual, expect, atol=1e-6) 37 | -------------------------------------------------------------------------------- /test/test_integrals.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | import pytest 4 | from numpy.testing import assert_allclose 5 | 6 | from mess.basis import basisset 7 | from mess.integrals import ( 8 | eri_basis, 9 | eri_basis_sparse, 10 | eri_primitives, 11 | kinetic_basis, 12 | kinetic_primitives, 13 | nuclear_basis, 14 | nuclear_primitives, 15 | overlap_basis, 16 | overlap_primitives, 17 | ) 18 | from mess.interop import to_pyscf 19 | from mess.primitive import Primitive 20 | from mess.structure import molecule 21 | from conftest import is_mem_limited 22 | 23 | 24 | def test_overlap(): 25 | # Exercise 3.21 of "Modern quantum chemistry: introduction to advanced 26 | # electronic structure theory."" by Szabo and Ostlund 27 | alpha = 0.270950 * 1.24 * 1.24 28 | a = Primitive(alpha=alpha) 29 | b = Primitive(alpha=alpha, center=jnp.array([1.4, 0.0, 0.0])) 30 | assert_allclose(overlap_primitives(a, a), 1.0, atol=1e-5) 31 | assert_allclose(overlap_primitives(b, b), 1.0, atol=1e-5) 32 | assert_allclose(overlap_primitives(b, a), 0.6648, atol=1e-5) 33 | 34 | 35 | @pytest.mark.parametrize("basis_name", ["sto-3g", "6-31+g", "6-31+g*"]) 36 | def test_water_overlap(basis_name): 37 | basis = basisset(molecule("water"), basis_name) 38 | actual_overlap = overlap_basis(basis) 39 | 40 | # Note: PySCF doesn't appear to normalise d basis functions in cartesian basis 41 | scfmol = to_pyscf(molecule("water"), basis_name=basis_name) 42 | expect_overlap = scfmol.intor("int1e_ovlp_cart") 43 | n = 1 / np.sqrt(np.diagonal(expect_overlap)) 44 | expect_overlap = n[:, None] * n[None, :] * expect_overlap 45 | assert_allclose(actual_overlap, expect_overlap, atol=1e-5) 46 | 47 | 48 | def test_kinetic(): 49 | # PyQuante test case for kinetic primitive integral 50 | p = Primitive() 51 | assert_allclose(kinetic_primitives(p, p), 1.5, atol=1e-5) 52 | 53 | # Reproduce the kinetic energy matrix for H2 using STO-3G basis set 54 | # See equation 3.230 of "Modern quantum chemistry: introduction to advanced 55 | # electronic structure theory."" by Szabo and Ostlund 56 | h2 = molecule("h2") 57 | basis = basisset(h2, "sto-3g") 58 | actual = kinetic_basis(basis) 59 | expect = np.array([[0.7600, 0.2365], [0.2365, 0.7600]]) 60 | assert_allclose(actual, expect, atol=1e-4) 61 | 62 | 63 | @pytest.mark.parametrize( 64 | "basis_name", 65 | [ 66 | "sto-3g", 67 | "6-31+g", 68 | pytest.param( 69 | "6-31+g*", marks=pytest.mark.xfail(reason="Cartesian norm problem?") 70 | ), 71 | ], 72 | ) 73 | def test_water_kinetic(basis_name): 74 | basis = basisset(molecule("water"), basis_name) 75 | actual = kinetic_basis(basis) 76 | 77 | expect = to_pyscf(molecule("water"), basis_name=basis_name).intor("int1e_kin_cart") 78 | assert_allclose(actual, expect, atol=1e-4) 79 | 80 | 81 | def test_nuclear(): 82 | # PyQuante test case for nuclear attraction integral 83 | p = Primitive() 84 | c = jnp.zeros(3) 85 | assert_allclose(nuclear_primitives(p, p, c), -1.595769, atol=1e-5) 86 | 87 | # Reproduce the nuclear attraction matrix for H2 using STO-3G basis set 88 | # See equation 3.231 and 3.232 of Szabo and Ostlund 89 | h2 = molecule("h2") 90 | basis = basisset(h2, "sto-3g") 91 | actual = nuclear_basis(basis) 92 | expect = np.array([ 93 | [[-1.2266, -0.5974], [-0.5974, -0.6538]], 94 | [[-0.6538, -0.5974], [-0.5974, -1.2266]], 95 | ]) 96 | 97 | assert_allclose(actual, expect, atol=1e-4) 98 | 99 | 100 | def test_water_nuclear(): 101 | basis_name = "sto-3g" 102 | h2o = molecule("water") 103 | basis = basisset(h2o, basis_name) 104 | actual = nuclear_basis(basis).sum(axis=0) 105 | expect = to_pyscf(h2o, basis_name=basis_name).intor("int1e_nuc_cart") 106 | assert_allclose(actual, expect, atol=1e-3) 107 | 108 | 109 | def test_eri(): 110 | # PyQuante test cases for ERI 111 | a, b, c, d = [Primitive()] * 4 112 | assert_allclose(eri_primitives(a, b, c, d), 1.128379, atol=1e-5) 113 | 114 | c, d = [Primitive(lmn=jnp.array([1, 0, 0]))] * 2 115 | assert_allclose(eri_primitives(a, b, c, d), 0.940316, atol=1e-5) 116 | 117 | # H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund 118 | h2 = molecule("h2") 119 | basis = basisset(h2, "sto-3g") 120 | 121 | actual = eri_basis(basis) 122 | expect = np.empty((2, 2, 2, 2), dtype=np.float32) 123 | expect[0, 0, 0, 0] = expect[1, 1, 1, 1] = 0.7746 124 | expect[0, 0, 1, 1] = expect[1, 1, 0, 0] = 0.5697 125 | expect[1, 0, 0, 0] = expect[0, 0, 0, 1] = 0.4441 126 | expect[0, 1, 0, 0] = expect[0, 0, 1, 0] = 0.4441 127 | expect[0, 1, 1, 1] = expect[1, 1, 1, 0] = 0.4441 128 | expect[1, 0, 1, 1] = expect[1, 1, 0, 1] = 0.4441 129 | expect[1, 0, 1, 0] = expect[0, 1, 1, 0] = 0.2970 130 | expect[0, 1, 0, 1] = expect[1, 0, 0, 1] = 0.2970 131 | assert_allclose(actual, expect, atol=1e-4) 132 | 133 | 134 | @pytest.mark.parametrize("sparse", [True, False]) 135 | @pytest.mark.skipif(is_mem_limited(), reason="Not enough host memory!") 136 | def test_water_eri(sparse): 137 | basis_name = "sto-3g" 138 | h2o = molecule("water") 139 | basis = basisset(h2o, basis_name) 140 | actual = eri_basis_sparse(basis) if sparse else eri_basis(basis) 141 | aosym = "s8" if sparse else "s1" 142 | expect = to_pyscf(h2o, basis_name=basis_name).intor("int2e_cart", aosym=aosym) 143 | assert_allclose(actual, expect, atol=1e-4) 144 | -------------------------------------------------------------------------------- /test/test_interop.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | import pytest 4 | from numpy.testing import assert_allclose 5 | 6 | from mess.basis import basisset 7 | from mess.interop import to_pyscf 8 | from mess.mesh import density, density_and_grad, uniform_mesh 9 | from mess.structure import molecule, nuclear_energy 10 | 11 | 12 | @pytest.mark.parametrize("basis_name", ["sto-3g", "6-31g**"]) 13 | def test_to_pyscf(basis_name): 14 | mol = molecule("water") 15 | basis = basisset(mol, basis_name) 16 | pyscf_mol = to_pyscf(mol, basis_name) 17 | assert basis.num_orbitals == pyscf_mol.nao 18 | 19 | 20 | def test_gto(): 21 | from pyscf.dft.numint import eval_rho, eval_ao 22 | from jax.experimental import enable_x64 23 | 24 | with enable_x64(True): 25 | # Run these comparisons to PySCF in fp64 26 | # Atomic orbitals 27 | basis_name = "6-31+g" 28 | structure = molecule("water") 29 | basis = basisset(structure, basis_name) 30 | mesh = uniform_mesh() 31 | actual = basis(mesh.points) 32 | 33 | mol = to_pyscf(structure, basis_name) 34 | expect_ao = eval_ao(mol, np.asarray(mesh.points)) 35 | assert_allclose(actual, expect_ao, atol=1e-7) 36 | 37 | # Density Matrix 38 | mf = mol.RKS() 39 | mf.kernel() 40 | C = jnp.array(mf.mo_coeff) 41 | P = basis.density_matrix(C) 42 | expect = jnp.array(mf.make_rdm1()) 43 | assert_allclose(P, expect) 44 | 45 | # Electron density 46 | actual = density(basis, mesh, P) 47 | expect = eval_rho(mol, expect_ao, mf.make_rdm1(), xctype="lda") 48 | assert_allclose(actual, expect, atol=1e-7) 49 | 50 | # Electron density and gradient 51 | rho, grad_rho = density_and_grad(basis, mesh, P) 52 | ao_and_grad = eval_ao(mol, np.asarray(mesh.points), deriv=1) 53 | expect = eval_rho(mol, ao_and_grad, mf.make_rdm1(), xctype="gga") 54 | expect_rho = expect[0, :] 55 | expect_grad = expect[1:, :].T 56 | assert_allclose(rho, expect_rho, atol=1e-7) 57 | assert_allclose(grad_rho, expect_grad, atol=1e-6) 58 | 59 | 60 | @pytest.mark.parametrize("name", ["water", "h2"]) 61 | def test_nuclear_energy(name): 62 | mol = molecule(name) 63 | actual = nuclear_energy(mol) 64 | expect = to_pyscf(mol).energy_nuc() 65 | assert_allclose(actual, expect) 66 | -------------------------------------------------------------------------------- /test/test_special.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import pytest 3 | from numpy.testing import assert_allclose 4 | 5 | from mess.special import ( 6 | binom_beta, 7 | binom_fori, 8 | binom_lookup, 9 | factorial2_fori, 10 | factorial2_lookup, 11 | factorial_fori, 12 | factorial_gamma, 13 | factorial_lookup, 14 | ) 15 | 16 | 17 | def test_factorial(): 18 | x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8]) 19 | expect = jnp.array([1, 2, 6, 24, 120, 720, 5040, 40320]) 20 | assert_allclose(factorial_fori(x, x[-1]), expect) 21 | assert_allclose(factorial_lookup(x, x[-1]), expect) 22 | assert_allclose(factorial_gamma(x), expect) 23 | 24 | 25 | def test_factorial2(): 26 | x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8]) 27 | expect = jnp.array([1, 2, 3, 8, 15, 48, 105, 384]) 28 | assert_allclose(factorial2_fori(x), expect) 29 | assert_allclose(factorial2_fori(0), 1) 30 | 31 | assert_allclose(factorial2_lookup(x), expect) 32 | assert_allclose(factorial2_lookup(0), 1) 33 | 34 | 35 | @pytest.mark.parametrize("binom_func", [binom_beta, binom_fori, binom_lookup]) 36 | def test_binom(binom_func): 37 | x = jnp.array([4, 4, 4, 4]) 38 | y = jnp.array([1, 2, 3, 4]) 39 | expect = jnp.array([4, 6, 4, 1]) 40 | assert_allclose(binom_func(x, y), expect) 41 | 42 | zero = jnp.array([0]) 43 | assert_allclose(binom_func(zero, y), jnp.zeros_like(x)) 44 | assert_allclose(binom_func(x, zero), jnp.ones_like(y)) 45 | assert_allclose(binom_func(y, y), jnp.ones_like(y)) 46 | 47 | one = jnp.array([1]) 48 | assert_allclose(binom_func(one, one), one) 49 | assert_allclose(binom_func(zero, -one), zero) 50 | assert_allclose(binom_func(zero, zero), one) 51 | -------------------------------------------------------------------------------- /test/test_xcfunctional.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | import pytest 5 | from jax.experimental import enable_x64 6 | from numpy.testing import assert_allclose 7 | from pyscf import dft 8 | 9 | from mess.basis import basisset 10 | from mess.mesh import density_and_grad, xcmesh_from_pyscf 11 | from mess.structure import Structure 12 | from mess.xcfunctional import ( 13 | gga_correlation_lyp, 14 | gga_correlation_pbe, 15 | gga_exchange_b88, 16 | gga_exchange_pbe, 17 | lda_correlation_pw, 18 | lda_correlation_vwn, 19 | lda_exchange, 20 | ) 21 | 22 | # Define test cases with a mapping from test identifier to test arguments 23 | lda_cases = { 24 | "lda_exchange": (lda_exchange, "slater,"), 25 | "lda_correlation_vwn5": (partial(lda_correlation_vwn, use_rpa=False), ",vwn5"), 26 | "lda_correlation_vwn_rpa": (partial(lda_correlation_vwn, use_rpa=True), ",vwn_rpa"), 27 | "lda_correlation_pw": (lda_correlation_pw, ",lda_c_pw"), 28 | } 29 | 30 | gga_cases = { 31 | "gga_exchange_b88": (gga_exchange_b88, "gga_x_b88,"), 32 | "gga_exchange_pbe": (gga_exchange_pbe, "gga_x_pbe,"), 33 | "gga_correlation_pbe": (gga_correlation_pbe, ",gga_c_pbe"), 34 | "gga_correlation_lyp": (gga_correlation_lyp, ",gga_c_lyp"), 35 | } 36 | 37 | 38 | @pytest.fixture 39 | def helium_density(): 40 | with enable_x64(True): 41 | mol = Structure(np.asarray(2), np.zeros(3)) 42 | basis_name = "6-31g" 43 | basis = basisset(mol, basis_name) 44 | mesh = xcmesh_from_pyscf(mol) 45 | rho, grad_rho = [np.asarray(t) for t in density_and_grad(basis, mesh)] 46 | yield rho, grad_rho 47 | 48 | 49 | @pytest.mark.parametrize("xcfunc,scfstr", lda_cases.values(), ids=lda_cases.keys()) 50 | def test_lda(helium_density, xcfunc, scfstr): 51 | rho, _ = helium_density 52 | actual = xcfunc(rho) 53 | expect = dft.libxc.eval_xc(scfstr, rho)[0] 54 | assert_allclose(actual, expect, atol=1e-7) 55 | 56 | 57 | @pytest.mark.parametrize("xcfunc,scfstr", gga_cases.values(), ids=gga_cases.keys()) 58 | def test_gga(helium_density, xcfunc, scfstr): 59 | rho, grad_rho = helium_density 60 | scfin = np.concatenate([rho[:, None], grad_rho], axis=1).T 61 | 62 | actual = xcfunc(rho, grad_rho) 63 | expect = dft.libxc.eval_xc(scfstr, scfin, deriv=1)[0] 64 | assert_allclose(actual, expect, atol=1e-6) 65 | --------------------------------------------------------------------------------