├── .gitattributes ├── quax.png ├── MANIFEST.in ├── quax ├── __init__.py ├── integrals │ ├── __init__.py │ ├── makefile │ ├── utils.h │ ├── tei.py │ └── oei.py ├── methods │ ├── __init__.py │ ├── mp2.py │ ├── energy_utils.py │ ├── basis_utils.py │ ├── ccsd_t.py │ ├── hartree_fock.py │ ├── ccsd.py │ ├── mp2f12.py │ └── ints.py ├── constants.py ├── utils.py └── core.py ├── .github ├── PULL_REQUEST_TEMPLATE.md ├── workflows │ └── continuous_integration.yml └── CONTRIBUTING.md ├── .codecov.yml ├── devtools ├── conda-envs │ └── test_env.yaml └── README.md ├── setup.cfg ├── tests ├── test_energies.py ├── test_gradients.py ├── test_hessians.py └── test_dipoles.py ├── LICENSE ├── pyproject.toml ├── .gitignore ├── CODE_OF_CONDUCT.md └── README.md /.gitattributes: -------------------------------------------------------------------------------- 1 | quax/_version.py export-subst -------------------------------------------------------------------------------- /quax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CCQC/Quax/HEAD/quax.png -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include CODE_OF_CONDUCT.md 2 | include MANIFEST.in 3 | include LICENSE 4 | 5 | graft quax 6 | global-exclude *.py[cod] __pycache__ *.so -------------------------------------------------------------------------------- /quax/__init__.py: -------------------------------------------------------------------------------- 1 | from . import integrals 2 | from . import constants 3 | 4 | from . import methods 5 | from . import core 6 | from . import utils 7 | 8 | -------------------------------------------------------------------------------- /quax/integrals/__init__.py: -------------------------------------------------------------------------------- 1 | from . import tei 2 | from . import oei 3 | from . import libint_interface 4 | 5 | from .tei import TEI 6 | from .oei import OEI 7 | 8 | -------------------------------------------------------------------------------- /quax/methods/__init__.py: -------------------------------------------------------------------------------- 1 | from . import energy_utils 2 | from . import hartree_fock 3 | from . import mp2 4 | from . import mp2f12 5 | from . import ccsd 6 | from . import ccsd_t 7 | from . import ints 8 | 9 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | Provide a brief description of the PR's purpose here. 3 | 4 | ## Todos 5 | Notable points that this PR has either accomplished or will accomplish. 6 | - [ ] TODO 1 7 | 8 | ## Questions 9 | - [ ] Question1 10 | 11 | ## Status 12 | - [ ] Ready to go -------------------------------------------------------------------------------- /.codecov.yml: -------------------------------------------------------------------------------- 1 | # Codecov configuration to make it a bit less noisy 2 | coverage: 3 | status: 4 | patch: false 5 | project: 6 | default: 7 | threshold: 50% 8 | comment: 9 | layout: "header" 10 | require_changes: false 11 | branches: null 12 | behavior: default 13 | flags: null 14 | paths: null -------------------------------------------------------------------------------- /devtools/conda-envs/test_env.yaml: -------------------------------------------------------------------------------- 1 | name: test-quax 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | # Libint dependencies 6 | - libint 7 | - pybind11 8 | - eigen 9 | - boost 10 | - hdf5 11 | # Quax dependencies 12 | - psi4 13 | - jax 14 | - jaxlib 15 | - h5py 16 | # Testing 17 | - pytest 18 | - pytest-cov 19 | - codecov 20 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Helper file to handle all configs 2 | 3 | [coverage:run] 4 | # .coveragerc to control coverage.py and pytest-cov 5 | omit = 6 | # Omit the tests 7 | */tests/* 8 | 9 | [yapf] 10 | # YAPF, in .style.yapf files this shows up as "[style]" header 11 | COLUMN_LIMIT = 119 12 | INDENT_WIDTH = 4 13 | USE_TABS = False 14 | 15 | [flake8] 16 | # Flake8, PyFlakes, etc 17 | max-line-length = 119 18 | 19 | [aliases] 20 | test = pytest -------------------------------------------------------------------------------- /quax/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | 5 | # Get absolute module path 6 | module_path = os.path.dirname(os.path.abspath(__file__)) 7 | 8 | # Check if libint interface is found 9 | libint_imported = False 10 | lib = re.compile("libint_interface\.cpython.+") 11 | for path in os.listdir(module_path + "/integrals"): 12 | if lib.match(path): 13 | from . import integrals 14 | libint_imported = True 15 | 16 | if not libint_imported: 17 | sys.exit("Libint is a required dependency!") 18 | -------------------------------------------------------------------------------- /tests/test_energies.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test energy computations 3 | """ 4 | import quax 5 | import psi4 6 | import pytest 7 | import numpy as np 8 | 9 | molecule = psi4.geometry(""" 10 | 0 1 11 | O -0.000007070942 0.125146536460 0.000000000000 12 | H -1.424097055410 -0.993053750648 0.000000000000 13 | H 1.424209276385 -0.993112599269 0.000000000000 14 | units bohr 15 | """) 16 | basis_name = 'sto-3g' 17 | psi4.set_options({'basis': basis_name, 18 | 'scf_type': 'pk', 19 | 'mp2_type':'conv', 20 | 'e_convergence': 1e-10, 21 | 'd_convergence':1e-10, 22 | 'puream': 0}) 23 | 24 | def test_hartree_fock(method='hf'): 25 | psi_e = psi4.energy(method + '/' + basis_name) 26 | quax_e = quax.core.energy(molecule, basis_name, method) 27 | assert np.allclose(psi_e, quax_e) 28 | 29 | def test_mp2(method='mp2'): 30 | psi_e = psi4.energy(method + '/' + basis_name) 31 | quax_e = quax.core.energy(molecule, basis_name, method) 32 | assert np.allclose(psi_e, quax_e) 33 | 34 | def test_ccsd(method='ccsd'): 35 | psi_e = psi4.energy(method + '/' + basis_name) 36 | quax_e = quax.core.energy(molecule, basis_name, method) 37 | assert np.allclose(psi_e, quax_e) 38 | 39 | def test_ccsd_t(method='ccsd(t)'): 40 | psi_e = psi4.energy(method + '/' + basis_name) 41 | quax_e = quax.core.energy(molecule, basis_name, method) 42 | assert np.allclose(psi_e, quax_e) 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Center for Computational Quantum Chemistry 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "setuptools-git-versioning"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | # Self-descriptive entries which should always be present 6 | # https://packaging.python.org/en/latest/specifications/declaring-project-metadata/ 7 | [project] 8 | name = "quax" 9 | description = "Arbitrary order derivatives of electronic structure computations." 10 | dynamic = ["version"] 11 | readme = "README.md" 12 | authors = [ 13 | { name = "Adam Abbott", email = "adabbott@uga.edu" }, 14 | { name = "Erica Mitchell", email = "emitchell@uga.edu" } 15 | ] 16 | license = { text = "BSD-3C" } 17 | # See https://pypi.org/classifiers/ 18 | classifiers = [ 19 | "License :: OSI Approved :: BSD License", 20 | "Programming Language :: Python :: 3", 21 | "Development Status :: 4 - Beta", 22 | "Intended Audience :: Science/Research" 23 | ] 24 | requires-python = ">=3.9" 25 | # Declare any run-time dependencies that should be installed with the package. 26 | dependencies = [ 27 | "importlib-resources;python_version>'3.8'", 28 | "numpy>=1.23,<2.0", 29 | "jax>=0.4.19", 30 | "jaxlib>=0.4.19", 31 | "h5py>=2.8.0", 32 | "scipy>=1.9" 33 | ] 34 | 35 | # Update the urls once the hosting is set up. 36 | [project.urls] 37 | "Source" = "https://github.com/CCQC/Quax/" 38 | #"Documentation" = "Quax.readthedocs.io/" 39 | 40 | [project.optional-dependencies] 41 | test = [ 42 | "pytest>=6.1.2", 43 | "pytest-cov" 44 | ] 45 | 46 | [tool.setuptools] 47 | zip-safe = false 48 | 49 | [tool.setuptools.packages.find] 50 | namespaces = false 51 | where = ["."] 52 | 53 | [tool.setuptools.package-data] 54 | quax = ["integrals/*.so"] 55 | 56 | [tool.setuptools-git-versioning] 57 | enabled = true 58 | dev_template = "{tag}.{ccount}+git.{sha}" 59 | dirty_template = "{tag}.{ccount}+git.{sha}.dirty" -------------------------------------------------------------------------------- /quax/integrals/makefile: -------------------------------------------------------------------------------- 1 | # NOTE: These paths below need to be edited such that they point to a set of 2 | # Eigen headers, Python headers, Pybind11 headers, Libint API headers libint2.h libint2.hpp, 3 | # the rest of the Libint2 headers, and the library location of libint2.a 4 | CC := g++ 5 | # Options passed to compiler, add "-fopenmp" if intending to use OpenMP 6 | CFLAGS := -O3 -fPIC -fopenmp 7 | # Libint prefix location (where /include, /include/libint2, /lib, /share are located) 8 | LIBINT_PREFIX := $(shell python3-config --prefix) 9 | # Conda prefix location, it is suggested to use conda to install nearly all dependencies 10 | CONDA_PREFIX := $(shell python3-config --prefix) 11 | 12 | I1 := $(LIBINT_PREFIX)/include 13 | I2 := $(LIBINT_PREFIX)/include/libint2 14 | L1 := $(LIBINT_PREFIX)/lib 15 | # Eigen headers location 16 | I3 := $(CONDA_PREFIX)/include/eigen3 17 | # HDF5 headers, static and shared libraries 18 | I6 := $(CONDA_PREFIX)/include 19 | L2 := $(CONDA_PREFIX)/lib 20 | # Edit path in quotes to be same location as L2 definition above 21 | RPATH := -Wl,-rpath,"$(CONDA_PREFIX)/lib" 22 | 23 | # This 'TARGETS' suffix should be set to whatever is returned by the command `python3-config --extension-suffix` entered on command line. 24 | # and it should match the same python version referenced in the above include path for I4 := (3.7 in this case) 25 | TARGETS := libint_interface$(shell python3-config --extension-suffix) 26 | OBJ := libint_interface.o 27 | 28 | # Rest is boilerplate. Do not edit unless you know what you're doing. 29 | .PHONY: all clean 30 | 31 | all: $(TARGETS) 32 | 33 | clean: 34 | rm -f $(OBJ) 35 | 36 | $(OBJ): %.o : %.cc $(DEPS) 37 | $(CC) -c $< -o $@ $(CFLAGS) -I $(I1) -I $(I2) -I $(I3) $(shell python3 -m pybind11 --includes) -I $(I6) -lint2 -L $(L1) -lhdf5 -lhdf5_cpp -L $(L2) $(RPATH) 38 | $(TARGETS): $(OBJ) 39 | $(CC) $^ -o $@ $(CFLAGS) -shared -I $(I1) -I $(I2) -I $(I3) -I $(I4) -I $(I5) -I $(I6) -lint2 -L $(L1) -lhdf5 -lhdf5_cpp -L $(L2) $(RPATH) 40 | 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | .pytest_cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # Pycharm settings 99 | .idea 100 | *.iml 101 | *.iws 102 | *.ipr 103 | 104 | # Ignore devcontainer 105 | /.devcontainer 106 | 107 | # Ignore VSCode settings 108 | /.vscode 109 | 110 | # Ignore Sublime Text settings 111 | *.sublime-workspace 112 | *.sublime-project 113 | 114 | # vim swap 115 | *.swp 116 | 117 | # mkdocs documentation 118 | /site 119 | 120 | # mypy 121 | .mypy_cache/ 122 | 123 | # profraw files from LLVM? Unclear exactly what triggers this 124 | # There are reports this comes from LLVM profiling, but also Xcode 9. 125 | *profraw -------------------------------------------------------------------------------- /tests/test_gradients.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test gradient computations 3 | """ 4 | import quax 5 | import psi4 6 | import pytest 7 | import numpy as np 8 | 9 | molecule = psi4.geometry(""" 10 | 0 1 11 | O -0.000007070942 0.125146536460 0.000000000000 12 | H -1.424097055410 -0.993053750648 0.000000000000 13 | H 1.424209276385 -0.993112599269 0.000000000000 14 | units bohr 15 | """) 16 | basis_name = 'sto-3g' 17 | psi4.set_options({'basis': basis_name, 18 | 'scf_type': 'pk', 19 | 'mp2_type':'conv', 20 | 'e_convergence': 1e-10, 21 | 'd_convergence': 1e-10, 22 | 'puream': 0, 23 | 'points': 5, 24 | 'fd_project': False}) 25 | 26 | options = {'damping': True, 'spectral_shift': False} 27 | 28 | def test_hartree_fock_gradient(method='hf'): 29 | psi_deriv = np.round(np.asarray(psi4.gradient(method + '/' + basis_name)), 10) 30 | quax_deriv = quax.core.geom_deriv(molecule, basis_name, method, deriv_order=1, options=options).reshape(-1,3) 31 | quax_partial0 = quax.core.geom_deriv(molecule, basis_name, method, deriv_order=1, partial=(0,), options=options) 32 | assert np.allclose(psi_deriv, quax_deriv) 33 | assert np.allclose(psi_deriv[0,0], quax_partial0) 34 | 35 | def test_mp2_gradient(method='mp2'): 36 | psi_deriv = np.round(np.asarray(psi4.gradient(method + '/' + basis_name)), 10) 37 | quax_deriv = quax.core.geom_deriv(molecule, basis_name, method, deriv_order=1, options=options).reshape(-1,3) 38 | quax_partial0 = quax.core.geom_deriv(molecule, basis_name, method, deriv_order=1, partial=(0,), options=options) 39 | assert np.allclose(psi_deriv, quax_deriv) 40 | assert np.allclose(psi_deriv[0,0], quax_partial0) 41 | 42 | def test_ccsd_t_gradient(method='ccsd(t)'): 43 | psi_deriv = np.round(np.asarray(psi4.gradient(method + '/' + basis_name)), 10) 44 | quax_deriv = quax.core.geom_deriv(molecule, basis_name, method, deriv_order=1, options=options).reshape(-1,3) 45 | quax_partial0 = quax.core.geom_deriv(molecule, basis_name, method, deriv_order=1, partial=(0,), options=options) 46 | assert np.allclose(psi_deriv, quax_deriv) 47 | assert np.allclose(psi_deriv[0,0], quax_partial0) 48 | 49 | 50 | -------------------------------------------------------------------------------- /devtools/README.md: -------------------------------------------------------------------------------- 1 | # Development, testing, and deployment tools 2 | 3 | This directory contains a collection of tools for running Continuous Integration (CI) tests, 4 | conda installation, and other development tools not directly related to the coding process. 5 | 6 | 7 | ## Manifest 8 | 9 | ### Continuous Integration 10 | 11 | You should test your code, but do not feel compelled to use these specific programs. 12 | 13 | ### Conda Environment: 14 | 15 | This directory contains the files to setup the Conda environment for testing purposes 16 | 17 | * `conda-envs`: directory containing the YAML file(s) which fully describe Conda Environments, their dependencies, and those dependency provenance's 18 | * `test_env.yaml`: Simple test environment file with base dependencies. Channels are not specified here and therefore respect global Conda configuration 19 | 20 | ### Additional Scripts: 21 | 22 | This directory contains OS agnostic helper scripts which don't fall in any of the previous categories 23 | * `scripts` 24 | * `create_conda_env.py`: Helper program for spinning up new conda environments based on a starter file with Python Version and Env. Name command-line options 25 | 26 | 27 | ## How to contribute changes 28 | - Clone the repository if you have write access to the main repo, fork the repository if you are a collaborator. 29 | - Make a new branch with `git checkout -b {your branch name}` 30 | - Make changes and test your code 31 | - Ensure that the test environment dependencies (`conda-envs`) line up with the build and deploy dependencies (`conda-recipe/meta.yaml`) 32 | - Push the branch to the repo (either the main or your fork) with `git push -u origin {your branch name}` 33 | * Note that `origin` is the default name assigned to the remote, yours may be different 34 | - Make a PR on GitHub with your changes 35 | - We'll review the changes and get your code into the repo after lively discussion! 36 | 37 | 38 | ## Checklist for updates 39 | - [ ] Make sure there is an/are issue(s) opened for your specific update 40 | - [ ] Create the PR, referencing the issue 41 | - [ ] Debug the PR as needed until tests pass 42 | - [ ] Tag the final, debugged version 43 | * `git tag -a X.Y.Z [latest pushed commit] && git push --follow-tags` 44 | - [ ] Get the PR merged in 45 | -------------------------------------------------------------------------------- /.github/workflows/continuous_integration.yml: -------------------------------------------------------------------------------- 1 | name: Continuous Integration 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | schedule: 11 | # Weekly tests run on main by default: 12 | # Scheduled workflows run on the latest commit on the default or base branch. 13 | # (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule) 14 | - cron: "0 0 * * 0" 15 | # Scheduled workflows are automatically disabled when no repository activity has occurred in 60 day. 16 | 17 | jobs: 18 | test: 19 | name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }} 20 | runs-on: ${{ matrix.os }} 21 | strategy: 22 | matrix: 23 | os: [ubuntu-latest] 24 | python-version: ["3.10", "3.11", "3.12"] 25 | 26 | steps: 27 | - uses: actions/checkout@v4 28 | 29 | - name: Additional info about the build 30 | shell: bash 31 | run: | 32 | uname -a 33 | df -h 34 | ulimit -a 35 | 36 | - name: Create Environment 37 | uses: mamba-org/setup-micromamba@v1 38 | with: 39 | environment-file: devtools/conda-envs/test_env.yaml 40 | environment-name: test 41 | condarc: | 42 | channels: 43 | - conda-forge 44 | create-args: >- 45 | python=${{ matrix.python-version }} 46 | 47 | - name: Build integrals 48 | shell: bash -l {0} 49 | run: | 50 | cd $GITHUB_WORKSPACE/quax/integrals 51 | make 52 | cd $GITHUB_WORKSPACE 53 | 54 | - name: Install package 55 | # conda setup requires this special shell 56 | shell: bash -l {0} 57 | run: | 58 | python -m pip install . 59 | micromamba list 60 | 61 | - name: Run tests 62 | # conda setup requires this special shell 63 | shell: bash -l {0} 64 | run: | 65 | pytest -v --cov=quax --cov-report=xml --color=yes tests/ 66 | 67 | - name: CodeCov 68 | uses: codecov/codecov-action@v1 69 | with: 70 | file: ./coverage.xml 71 | flags: unittests 72 | name: codecov-${{ matrix.os }}-py${{ matrix.python-version }} 73 | 74 | -------------------------------------------------------------------------------- /tests/test_hessians.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test hessian computations 3 | """ 4 | import quax 5 | import psi4 6 | import pytest 7 | import numpy as np 8 | 9 | molecule = psi4.geometry(""" 10 | 0 1 11 | O -0.000007070942 0.125146536460 0.000000000000 12 | H -1.424097055410 -0.993053750648 0.000000000000 13 | H 1.424209276385 -0.993112599269 0.000000000000 14 | units bohr 15 | """) 16 | basis_name = 'sto-3g' 17 | psi4.set_options({'basis': basis_name, 18 | 'scf_type': 'pk', 19 | 'mp2_type':'conv', 20 | 'e_convergence': 1e-10, 21 | 'd_convergence': 1e-10, 22 | 'puream': 0, 23 | 'points': 5, 24 | 'fd_project': False}) 25 | 26 | options = {'damping': True, 'spectral_shift': False} 27 | 28 | def test_hartree_fock_hessian(method='hf'): 29 | psi_deriv = np.round(np.asarray(psi4.hessian(method + '/' + basis_name)), 10) 30 | n = psi_deriv.shape[0] 31 | quax_deriv = quax.core.geom_deriv(molecule, basis_name, method, deriv_order=2, options=options).reshape(n,n) 32 | quax_partial00 = quax.core.geom_deriv(molecule, basis_name, method, deriv_order=2, partial=(0,0), options=options) 33 | assert np.allclose(psi_deriv, quax_deriv) 34 | assert np.allclose(psi_deriv[0,0], quax_partial00) 35 | 36 | def test_mp2_hessian(method='mp2'): 37 | psi_deriv = np.round(np.asarray(psi4.hessian(method + '/' + basis_name, dertype='gradient')), 10) 38 | n = psi_deriv.shape[0] 39 | quax_deriv = quax.core.geom_deriv(molecule, basis_name, method, deriv_order=2, options=options).reshape(n,n) 40 | quax_partial00 = quax.core.geom_deriv(molecule, basis_name, method, deriv_order=2, partial=(0,0), options=options) 41 | assert np.allclose(psi_deriv, quax_deriv, atol=5e-7) 42 | assert np.allclose(psi_deriv[0,0], quax_partial00) 43 | 44 | def test_ccsd_t_hessian(method='ccsd(t)'): 45 | psi_deriv = np.round(np.asarray(psi4.hessian(method + '/' + basis_name, dertype='energy')), 10) 46 | n = psi_deriv.shape[0] 47 | quax_deriv = quax.core.geom_deriv(molecule, basis_name, method, deriv_order=2, options=options).reshape(n,n) 48 | quax_partial00 = quax.core.geom_deriv(molecule, basis_name, method, deriv_order=2, partial=(0,0), options=options) 49 | assert np.allclose(psi_deriv, quax_deriv, atol=5e-7) 50 | assert np.allclose(psi_deriv[0,0], quax_partial00) 51 | 52 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We welcome contributions from external contributors, and this document 4 | describes how to merge code changes into this Quax. 5 | 6 | ## Getting Started 7 | 8 | * Make sure you have a [GitHub account](https://github.com/signup/free). 9 | * [Fork](https://help.github.com/articles/fork-a-repo/) this repository on GitHub. 10 | * On your local machine, 11 | [clone](https://help.github.com/articles/cloning-a-repository/) your fork of 12 | the repository. 13 | 14 | ## Making Changes 15 | 16 | * Add some really awesome code to your local fork. It's usually a [good 17 | idea](http://blog.jasonmeridth.com/posts/do-not-issue-pull-requests-from-your-master-branch/) 18 | to make changes on a 19 | [branch](https://help.github.com/articles/creating-and-deleting-branches-within-your-repository/) 20 | with the branch name relating to the feature you are going to add. 21 | * When you are ready for others to examine and comment on your new feature, 22 | navigate to your fork of {{cookiecutter.repo_name}} on GitHub and open a [pull 23 | request](https://help.github.com/articles/using-pull-requests/) (PR). Note that 24 | after you launch a PR from one of your fork's branches, all 25 | subsequent commits to that branch will be added to the open pull request 26 | automatically. Each commit added to the PR will be validated for 27 | mergability, compilation and test suite compliance; the results of these tests 28 | will be visible on the PR page. 29 | * If you're providing a new feature, you must add test cases and documentation. 30 | * When the code is ready to go, make sure you run the test suite using pytest. 31 | * When you're ready to be considered for merging, check the "Ready to go" 32 | box on the PR page to let the Quax devs know that the changes are complete. 33 | The code will not be merged until this box is checked, the continuous 34 | integration returns checkmarks, 35 | and multiple core developers give "Approved" reviews. 36 | 37 | # Additional Resources 38 | 39 | * [General GitHub documentation](https://help.github.com/) 40 | * [PR best practices](http://codeinthehole.com/writing/pull-requests-and-other-good-practices-for-teams-using-github/) 41 | * [A guide to contributing to software packages](http://www.contribution-guide.org) 42 | * [Thinkful PR example](http://www.thinkful.com/learn/github-pull-request-tutorial/#Time-to-Submit-Your-First-PR) -------------------------------------------------------------------------------- /quax/methods/mp2.py: -------------------------------------------------------------------------------- 1 | import jax 2 | jax.config.update("jax_enable_x64", True) 3 | import jax.numpy as jnp 4 | from jax.lax import fori_loop 5 | 6 | from .energy_utils import partial_tei_transformation, cartesian_product 7 | from .hartree_fock import restricted_hartree_fock 8 | 9 | def restricted_mp2(*args, options, deriv_order=0, return_aux_data=False): 10 | if options['electric_field'] == 1: 11 | efield, geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path = args 12 | scf_args = efield, geom, basis_set, nelectrons, nuclear_charges, xyz_path 13 | elif options['electric_field'] == 2: 14 | efield_grad, efield, geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path = args 15 | scf_args = efield_grad, efield, geom, basis_set, nelectrons, nuclear_charges, xyz_path 16 | else: 17 | geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path = args 18 | scf_args = (geom, basis_set, nelectrons, nuclear_charges, xyz_path) 19 | 20 | E_scf, C, eps, G = restricted_hartree_fock(*scf_args, options=options, deriv_order=deriv_order, return_aux_data=True) 21 | 22 | # Load keyword options 23 | ndocc = nelectrons // 2 24 | ncore = nfrzn // 2 25 | 26 | print("Running MP2 Computation...") 27 | nvirt = G.shape[0] - ndocc 28 | 29 | G = partial_tei_transformation(G, C[:,ncore:ndocc], C[:,ndocc:], C[:,ncore:ndocc], C[:,ndocc:]) 30 | 31 | # Create tensor dim (occ,vir,occ,vir) of all possible orbital energy denominators 32 | eps_occ, eps_vir = eps[ncore:ndocc], eps[ndocc:] 33 | e_denom = jnp.reciprocal(eps_occ.reshape(-1, 1, 1, 1) - eps_vir.reshape(-1, 1, 1) + eps_occ.reshape(-1, 1) - eps_vir) 34 | 35 | # Loop algo (lower memory, but tei transform is the memory bottleneck) 36 | # Create all combinations of four loop variables to make XLA compilation easier 37 | indices = cartesian_product(jnp.arange(ndocc-ncore), jnp.arange(ndocc-ncore), jnp.arange(nvirt), jnp.arange(nvirt)) 38 | 39 | def loop_mp2(idx, mp2_corr): 40 | i,j,a,b = indices[idx] 41 | mp2_corr += G[i, a, j, b] * (2 * G[i, a, j, b] - G[i, b, j, a]) * e_denom[i, a, j, b] 42 | return mp2_corr 43 | 44 | dE_mp2 = fori_loop(0, indices.shape[0], loop_mp2, 0.0) # MP2 correlation 45 | 46 | if return_aux_data: 47 | #print("MP2 Energy: ", E_scf + dE_mp2) 48 | return E_scf + dE_mp2, C, eps, G 49 | else: 50 | return E_scf + dE_mp2 51 | 52 | -------------------------------------------------------------------------------- /quax/methods/energy_utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | jax.config.update("jax_enable_x64", True) 3 | import jax.numpy as jnp 4 | from functools import partial 5 | 6 | def nuclear_repulsion(geom, nuclear_charges): 7 | """ 8 | Compute the nuclear repulsion energy in a.u. 9 | """ 10 | natom = nuclear_charges.shape[0] 11 | nuc = 0 12 | for i in range(natom): 13 | for j in range(i): 14 | nuc += nuclear_charges[i] * nuclear_charges[j] / jnp.linalg.norm(geom[i] - geom[j]) 15 | return nuc 16 | 17 | def symmetric_orthogonalization(S, cutoff = 1.0e-12): 18 | """ 19 | Compute the symmetric orthogonalization transform U = S^(-1/2) 20 | where S is the overlap matrix 21 | """ 22 | evals, evecs = jnp.linalg.eigh(S) 23 | 24 | sqrtm = jnp.diag(jnp.where(abs(evals) > cutoff, 1 / jnp.sqrt(abs(evals)), 0.0)) 25 | 26 | A = evecs @ sqrtm @ evecs.T 27 | return A 28 | 29 | def cholesky_orthogonalization(S): 30 | """ 31 | Compute the canonical orthogonalization transform U = VL^(-1/2) 32 | where V is the eigenvectors and L diagonal inverse sqrt eigenvalues of the overlap matrix 33 | by way of cholesky decomposition 34 | Scharfenberg, Peter; A New Algorithm for the Symmetric (Lowdin) Orthonormalization; Int J. Quant. Chem. 1977 35 | """ 36 | return jnp.linalg.inv(jnp.linalg.cholesky(S)).T 37 | 38 | def old_tei_transformation(G, C): 39 | """ 40 | Transform TEI's to MO basis. 41 | This algorithm is worse than below, since it creates intermediate arrays in memory. 42 | """ 43 | G = jnp.einsum('pqrs, pP, qQ, rR, sS -> PQRS', G, C, C, C, C, optimize='optimal') 44 | return G 45 | 46 | @jax.jit 47 | def transform(C, G): 48 | return jnp.tensordot(C, G, axes=[(0,),(3,)]) 49 | 50 | def tei_transformation(G, C): 51 | """ 52 | New algo for TEI transform 53 | It's faster than psi4.MintsHelper.mo_transform() for basis sets <~120. 54 | """ 55 | G = transform(C, G) 56 | G = transform(C, G) 57 | G = transform(C, G) 58 | G = transform(C, G) 59 | return G 60 | 61 | def old_partial_tei_transformation(G, Ci, Cj, Ck, Cl): 62 | G = jnp.einsum('pqrs, pP, qQ, rR, sS -> PQRS', G, Ci, Cj, Ck, Cl, optimize='optimal') 63 | return G 64 | 65 | def partial_tei_transformation(G, Ci, Cj, Ck, Cl): 66 | """ 67 | New algo for Partial TEI transform 68 | """ 69 | G = transform(Cl, G) 70 | G = transform(Ck, G) 71 | G = transform(Cj, G) 72 | G = transform(Ci, G) 73 | return G 74 | 75 | def cartesian_product(*arrays): 76 | ''' 77 | JAX-friendly version of cartesian product. 78 | ''' 79 | tmp = jnp.asarray(jnp.meshgrid(*arrays, indexing='ij')).reshape(len(arrays),-1).T 80 | return tmp 81 | -------------------------------------------------------------------------------- /tests/test_dipoles.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test gradient computations 3 | """ 4 | import quax 5 | import psi4 6 | import pytest 7 | import numpy as np 8 | 9 | # Comment out if you have an installation of Libint with Cartesian multipole derivatives 10 | pytestmark = pytest.mark.skip("Requires Libint Cartesian multipole derivatives") 11 | 12 | molecule = psi4.geometry(""" 13 | 0 1 14 | O -0.000007070942 0.125146536460 0.000000000000 15 | H -1.424097055410 -0.993053750648 0.000000000000 16 | H 1.424209276385 -0.993112599269 0.000000000000 17 | units bohr 18 | """) 19 | basis_name = 'sto-3g' 20 | psi4.set_options({'basis': basis_name, 21 | 'scf_type': 'pk', 22 | 'mp2_type':'conv', 23 | 'e_convergence': 1e-10, 24 | 'd_convergence':1e-10, 25 | 'puream': 0}) 26 | 27 | options = {'damping': True, 'spectral_shift': False} 28 | efield = np.zeros((3)) 29 | 30 | def findif_dipole(method, pert): 31 | lambdas = [pert, -pert, 2.0*pert, -2.0*pert] 32 | dip_vec = np.zeros((3)) 33 | 34 | for i in range(3): 35 | pert_vec = [0, 0, 0] 36 | energies = [] 37 | for l in lambdas: 38 | pert_vec[i] = l 39 | psi4.set_options({'perturb_h': True, 40 | 'perturb_with': 'dipole', 41 | 'perturb_dipole': pert_vec}) 42 | energies.append(psi4.energy(method)) 43 | val = (8.0*energies[0] - 8.0*energies[1] - energies[2] + energies[3]) / (12.0*pert) 44 | dip_vec[i] = val 45 | return dip_vec 46 | 47 | def test_hartree_fock_dipole(method='hf'): 48 | psi_deriv = findif_dipole(method, 0.0005) 49 | quax_deriv = quax.core.efield_deriv(molecule, basis_name, method, efield=efield, deriv_order=1, options=options).reshape(-1,3) 50 | quax_partial0 = quax.core.efield_deriv(molecule, basis_name, method, efield=efield, deriv_order=1, partial=(0,), options=options) 51 | assert np.allclose(psi_deriv, quax_deriv) 52 | assert np.allclose(psi_deriv[0], quax_partial0) 53 | 54 | def test_mp2_dipole(method='mp2'): 55 | psi_deriv = findif_dipole(method, 0.0005) 56 | quax_deriv = quax.core.efield_deriv(molecule, basis_name, method, efield=efield, deriv_order=1, options=options).reshape(-1,3) 57 | quax_partial0 = quax.core.efield_deriv(molecule, basis_name, method, efield=efield, deriv_order=1, partial=(0,), options=options) 58 | assert np.allclose(psi_deriv, quax_deriv) 59 | assert np.allclose(psi_deriv[0], quax_partial0) 60 | 61 | def test_ccsd_t_dipole(method='ccsd(t)'): 62 | psi_deriv = findif_dipole(method, 0.0005) 63 | quax_deriv = quax.core.efield_deriv(molecule, basis_name, method, efield=efield, deriv_order=1, options=options).reshape(-1,3) 64 | quax_partial0 = quax.core.efield_deriv(molecule, basis_name, method, efield=efield, deriv_order=1, partial=(0,), options=options) 65 | assert np.allclose(psi_deriv, quax_deriv, atol=1e-7) 66 | assert np.allclose(psi_deriv[0], quax_partial0) 67 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, 8 | body size, disability, ethnicity, gender identity and expression, level of 9 | experience, nationality, personal appearance, race, religion, or sexual 10 | identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment include: 15 | 16 | * Using welcoming and inclusive language 17 | * Being respectful of differing viewpoints and experiences 18 | * Gracefully accepting constructive criticism 19 | * Focusing on what is best for the community 20 | * Showing empathy towards other community members 21 | 22 | Examples of unacceptable behavior by participants include: 23 | 24 | * The use of sexualized language or imagery and unwelcome sexual attention or advances 25 | * Trolling, insulting/derogatory comments, and personal or political attacks 26 | * Public or private harassment 27 | * Publishing others' private information, such as a physical or electronic address, without explicit permission 28 | * Other conduct which could reasonably be considered inappropriate in a professional setting 29 | 30 | ## Our Responsibilities 31 | 32 | Project maintainers are responsible for clarifying the standards of acceptable 33 | behavior and are expected to take appropriate and fair corrective action in 34 | response to any instances of unacceptable behavior. 35 | 36 | Project maintainers have the right and responsibility to remove, edit, or 37 | reject comments, commits, code, wiki edits, issues, and other contributions 38 | that are not aligned to this Code of Conduct, or to ban temporarily or 39 | permanently any contributor for other behaviors that they deem inappropriate, 40 | threatening, offensive, or harmful. 41 | 42 | Moreover, project maintainers will strive to offer feedback and advice to 43 | ensure quality and consistency of contributions to the code. Contributions 44 | from outside the group of project maintainers are strongly welcomed but the 45 | final decision as to whether commits are merged into the codebase rests with 46 | the team of project maintainers. 47 | 48 | ## Scope 49 | 50 | This Code of Conduct applies both within project spaces and in public spaces 51 | when an individual is representing the project or its community. Examples of 52 | representing a project or community include using an official project e-mail 53 | address, posting via an official social media account, or acting as an 54 | appointed representative at an online or offline event. Representation of a 55 | project may be further defined and clarified by project maintainers. 56 | 57 | ## Enforcement 58 | 59 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 60 | reported by contacting the project team at '{{cookiecutter.author_email}}'. The project team will 61 | review and investigate all complaints, and will respond in a way that it deems 62 | appropriate to the circumstances. The project team is obligated to maintain 63 | confidentiality with regard to the reporter of an incident. Further details of 64 | specific enforcement policies may be posted separately. 65 | 66 | Project maintainers who do not follow or enforce the Code of Conduct in good 67 | faith may face temporary or permanent repercussions as determined by other 68 | members of the project's leadership. 69 | 70 | ## Attribution 71 | 72 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 73 | version 1.4, available at 74 | [http://contributor-covenant.org/version/1/4][version] 75 | 76 | [homepage]: http://contributor-covenant.org 77 | [version]: http://contributor-covenant.org/version/1/4/ -------------------------------------------------------------------------------- /quax/methods/basis_utils.py: -------------------------------------------------------------------------------- 1 | import psi4 2 | import jax 3 | import jax.numpy as jnp 4 | from jax.lax import fori_loop 5 | import functools 6 | 7 | from .ints import compute_f12_oeints 8 | from .energy_utils import symmetric_orthogonalization 9 | 10 | def build_RIBS(molecule, basis_set, cabs_name): 11 | """ 12 | Builds basis set for 13 | CABS procedure 14 | """ 15 | 16 | # Libint uses the suffix 'cabs' but Psi4 uses 'optri' 17 | basis_name = basis_set.name() 18 | try: 19 | psi4_name = cabs_name.upper().replace('CABS', 'OPTRI') 20 | except: 21 | raise Exception("Must use a cc-pVXZ-F12 or aug-cc-pVXZ basis set for F12 methods.") 22 | 23 | keys = ["BASIS","CABS_BASIS"] 24 | targets = [basis_name, psi4_name] 25 | roles = ["ORBITAL","F12"] 26 | others = [basis_name, basis_name] 27 | 28 | # Creates combined basis set in Python 29 | ao_union = psi4.driver.qcdb.libmintsbasisset.BasisSet.pyconstruct_combined(molecule.save_string_xyz(), keys, targets, roles, others) 30 | ao_union['name'] = cabs_name 31 | ribs_set = psi4.core.BasisSet.construct_from_pydict(molecule, ao_union, 0) 32 | 33 | print("Basis name: ", cabs_name.upper()) 34 | print("Number of basis functions: ", ribs_set.nbf()) 35 | 36 | return ribs_set 37 | 38 | def build_CABS(geom, basis_set, cabs_set, xyz_path, deriv_order, options): 39 | """ 40 | Builds and returns 41 | CABS transformation matrix 42 | """ 43 | # Make Thread Safe 44 | threads = psi4.get_num_threads() 45 | psi4.set_num_threads(1) 46 | 47 | # Orthogonalize combined basis set 48 | S_ao_ribs_ribs = compute_f12_oeints(geom, cabs_set, cabs_set, xyz_path, deriv_order, options, True) 49 | 50 | if options['spectral_shift']: 51 | convergence = 1e-10 52 | fudge = jnp.asarray(jnp.linspace(0, 1, S_ao_ribs_ribs.shape[0])) * convergence 53 | shift = jnp.diag(fudge) 54 | S_ao_ribs_ribs += shift 55 | 56 | C_ribs = symmetric_orthogonalization(S_ao_ribs_ribs, 1.0e-8) 57 | 58 | # Compute the overlap matrix between OBS and RIBS 59 | S_ao_obs_ribs = compute_f12_oeints(geom, basis_set, cabs_set, xyz_path, deriv_order, options, True) 60 | 61 | _, S, Vt = svd_full(S_ao_obs_ribs @ C_ribs) 62 | 63 | def loop_zero_vals(idx, count): 64 | count += jax.lax.cond(abs(S[idx]) < 1.0e-6, lambda: 1, lambda: 0) 65 | return count 66 | ncabs = fori_loop(0, S.shape[0], loop_zero_vals, S.shape[0]) 67 | 68 | V_N = jnp.transpose(Vt[ncabs:, :]) 69 | 70 | C_cabs = jnp.dot(C_ribs, V_N) 71 | 72 | psi4.set_num_threads(threads) 73 | 74 | return C_cabs 75 | 76 | def F_ij(s, m): 77 | """ 78 | Can be numerically unstable if singular values are degenerate 79 | """ 80 | F_ij = lambda i, j: jax.lax.cond(i == j, lambda: 0., lambda: 1 / (s[j]**2 - s[i]**2)) 81 | F_fun = jax.vmap(jax.vmap(F_ij, (None, 0)), (0, None)) 82 | 83 | indices = jnp.arange(m) 84 | 85 | return F_fun(indices, indices) 86 | 87 | @jax.custom_jvp 88 | def svd_full(A): 89 | return jnp.linalg.svd(A) 90 | 91 | @svd_full.defjvp 92 | def svd_full_jvp(primals, tangents): 93 | A, = primals 94 | dA, = tangents 95 | 96 | m = A.shape[0] 97 | n = A.shape[1] 98 | 99 | U, S, Vt = svd_full(A) 100 | 101 | dP = U.T @ dA @ Vt.T 102 | 103 | dS = jnp.diagonal(dP) 104 | 105 | S1 = jnp.diag(S) 106 | 107 | dP1 = dP[:, :m] 108 | 109 | F = F_ij(S, m) 110 | 111 | dU = U @ (F * (dP1 @ S1 + S1 @ dP1.T)) 112 | 113 | dD1 = F * (S1 @ dP1 + dP1.T @ S1) 114 | 115 | dD2 = jnp.linalg.inv(S1) @ dP[:, m:] # Can be numerically unstable due to inversion 116 | 117 | dD_left = jnp.concatenate((dD1, dD2.T)) 118 | dD_right = jnp.concatenate((-dD2, jnp.zeros((n-m, n-m)))) 119 | 120 | dD = jnp.concatenate((dD_left, dD_right), axis=1) 121 | 122 | dV = Vt.T @ dD 123 | 124 | return (U, S, Vt), (dU, dS, dV.T) 125 | -------------------------------------------------------------------------------- /quax/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | 4 | def atom_to_period(Z): 5 | # Period 1, 2, 3, 4, 5, 6, 7 6 | full_shell_values = [0, 2, 10, 18, 36, 54, 86, 118] 7 | 8 | for p, shell in enumerate(full_shell_values): 9 | if shell > Z: 10 | return p 11 | 12 | def period_to_full_shell(p): 13 | # Period 1, 2, 3, 4, 5, 6, 7 14 | full_shell_values = [0, 2, 10, 18, 36, 54, 86, 118] 15 | 16 | return full_shell_values[p] 17 | 18 | def n_frozen_core(mol, Z_mol): 19 | nfrzn = 0 20 | mol_valence = -1 * Z_mol 21 | largest_shell = 0 22 | 23 | for A in range(mol.natom()): 24 | Z = mol.charge(A) 25 | current_shell = atom_to_period(Z) 26 | delta = period_to_full_shell(current_shell - 1) 27 | 28 | if largest_shell < current_shell: 29 | largest_shell = current_shell 30 | 31 | mol_valence = mol_valence + Z - delta 32 | nfrzn += delta 33 | 34 | if mol_valence <= 0: 35 | nfrzn -= period_to_full_shell(largest_shell - 1) - period_to_full_shell(largest_shell - 2) 36 | 37 | return nfrzn 38 | 39 | def how_many_derivs(k,n): 40 | """How many unique Cartesian derivatives for k atoms at nth order""" 41 | val = 1 42 | fact = 1 43 | for i in range(n): 44 | val *= 3 * k + i 45 | fact *= i + 1 46 | val /= fact 47 | return int(val) 48 | 49 | def get_deriv_vec_idx(deriv_vec): 50 | """ 51 | Used to lookup appropriate slice of disk-saved integral derivative tensor 52 | which corresponds to a particular derivative vector. 53 | Given a derivative vector of shape NCART, 54 | find the flattened generalized upper triangle index of 55 | the cartesian derivative tensor. 56 | """ 57 | dim = deriv_vec.shape[0] 58 | vals = np.arange(dim, dtype=int) 59 | deriv_order = np.sum(deriv_vec) 60 | 61 | deriv_vecs = [] 62 | for c in itertools.combinations_with_replacement(vals, deriv_order): 63 | tmp_deriv_vec = np.zeros_like(deriv_vec, dtype=int) 64 | for i in c: 65 | tmp_deriv_vec[i] += 1 66 | deriv_vecs.append(tmp_deriv_vec) 67 | 68 | deriv_vecs = np.asarray(deriv_vecs) 69 | idx = np.argwhere(np.all(deriv_vecs==deriv_vec,axis=1)).reshape(-1)[0] 70 | return idx 71 | 72 | # Sum over all partitions of the set range(deriv_order) 73 | def partition(collection): 74 | if len(collection) == 1: 75 | yield [collection] 76 | return 77 | first = collection[0] 78 | for smaller in partition(collection[1:]): 79 | # insert `first` in each of the subpartition's subsets 80 | for n, subset in enumerate(smaller): 81 | yield smaller[:n] + [[ first ] + subset] + smaller[n+1:] 82 | # put `first` in its own subset 83 | yield [[first]] + smaller 84 | 85 | def get_required_deriv_vecs(natoms, deriv_order, address): 86 | """ 87 | Simulates the Faa Di Bruno formula, giving a set of partial derivative operators which are required 88 | to find a particular higher order partial derivative operator, as defined by `deriv_order` and `address`. 89 | 90 | The returned partial derivative operators are each represented by vectors of length NCART where NCART is 3 * natom. 91 | The value of each index in these vectors describes how many times to differentiate wrt that particular cartesian coordinate. 92 | For example, a 2 atom system has atoms A,B and all derivative vectors have indices which correspond to the coordinates: [Ax, Ay, Az, Bx, By, Bz]. 93 | A derivative vector [1,0,0,2,0,0] therefore represents the partial derivative (d^3)/(dAx dBx dBx). 94 | 95 | Parameters 96 | ---------- 97 | natoms : int 98 | The number of atoms in the system. The cartesian nuclear derivative tensor for this `natom` system 99 | has a dimension size 3 * natom 100 | deriv_order : int 101 | The order of differentiation. The cartesian nuclear derivative tensor for this `natom` system 102 | has rank `deriv_order` and dimension size 3 * natoms 103 | address : tuple of int 104 | A tuple of integers which describe which cartesian partial derivative 105 | we wish to compute. Each integer in the tuple is in the range [0, NCART-1] 106 | Returns 107 | ------- 108 | partial_derivatives : arr 109 | An array of partial derivatives of dimensions (npartials, NCART) 110 | """ 111 | address = list(address) 112 | deriv_vecs = [] 113 | nparams = natoms * 3 114 | for p in partition(address): 115 | for sub in p: 116 | # List of zeros 117 | deriv_vec = [0] * nparams 118 | for i in sub: 119 | deriv_vec[i] += 1 120 | deriv_vecs.append(deriv_vec) 121 | partial_derivatives = np.unique(np.asarray(deriv_vecs), axis=0) 122 | return partial_derivatives 123 | 124 | 125 | -------------------------------------------------------------------------------- /quax/methods/ccsd_t.py: -------------------------------------------------------------------------------- 1 | import jax 2 | jax.config.update("jax_enable_x64", True) 3 | import jax.numpy as jnp 4 | from jax.lax import while_loop 5 | 6 | from .ccsd import rccsd 7 | 8 | def perturbative_triples(T1, T2, V, fock_Od, fock_Vd): 9 | Voooo, Vooov, Voovv, Vovov, Vovvv, Vvvvv = V 10 | o,v = T1.shape 11 | delta_o = jnp.eye(o) 12 | delta_v = jnp.eye(v) 13 | 14 | def inner_func(i, j, k): 15 | delta_ij = delta_o[i, j] 16 | delta_jk = delta_o[j, k] 17 | W = jnp.einsum('dab,cd', Vovvv[i, :, :, :], T2[k, j, :, :]) 18 | W += jnp.einsum('dac,bd', Vovvv[i, :, :, :], T2[j, k, :, :]) 19 | W += jnp.einsum('dca,bd', Vovvv[k, :, :, :], T2[j, i, :, :]) 20 | W += jnp.einsum('dcb,ad', Vovvv[k, :, :, :], T2[i, j, :, :]) 21 | W += jnp.einsum('dbc,ad', Vovvv[j, :, :, :], T2[i, k, :, :]) 22 | W += jnp.einsum('dba,cd', Vovvv[j, :, :, :], T2[k, i, :, :]) 23 | W -= jnp.einsum('lc,lab', Vooov[:, k, j, :], T2[i, :, :, :]) 24 | W -= jnp.einsum('lb,lac', Vooov[:, j, k, :], T2[i, :, :, :]) 25 | W -= jnp.einsum('lb,lca', Vooov[:, j, i, :], T2[k, :, :, :]) 26 | W -= jnp.einsum('la,lcb', Vooov[:, i, j, :], T2[k, :, :, :]) 27 | W -= jnp.einsum('la,lbc', Vooov[:, i, k, :], T2[j, :, :, :]) 28 | W -= jnp.einsum('lc,lba', Vooov[:, k, i, :], T2[j, :, :, :]) 29 | V = W + jnp.einsum('bc,a', Voovv[j, k, :, :], T1[i, :]) \ 30 | + jnp.einsum('ac,b', Voovv[i, k, :, :], T1[j, :]) \ 31 | + jnp.einsum('ab,c', Voovv[i, j, :, :], T1[k, :]) 32 | 33 | 34 | delta_occ = 2 - delta_ij - delta_jk 35 | Dd_occ = fock_Od[i] + fock_Od[j] + fock_Od[k] 36 | 37 | def loop_a(arr0): 38 | a_0, b_0, c_0, pT_contribution_0 = arr0 39 | b_0 = 0 40 | 41 | def loop_b(arr1): 42 | a_1, b_1, c_1, pT_contribution_1 = arr1 43 | c_1 = 0 44 | delta_vir = 1 + delta_v[a_1, b_1] 45 | 46 | def loop_c(arr2): 47 | a_2, b_2, c_2, delta_vir_2, pT_contribution_2 = arr2 48 | delta_vir_2 = delta_vir_2 + delta_v[b_2,c_2] 49 | Dd = Dd_occ - (fock_Vd[a_2] + fock_Vd[b_2] + fock_Vd[c_2]) 50 | X = W[a_2, b_2, c_2] * V[a_2, b_2, c_2] + W[a_2, c_2, b_2] * V[a_2, c_2, b_2] + W[b_2, a_2, c_2] * V[b_2, a_2, c_2] \ 51 | + W[b_2, c_2, a_2] * V[b_2, c_2, a_2] + W[c_2, a_2, b_2] * V[c_2, a_2, b_2] + W[c_2, b_2, a_2] * V[c_2, b_2, a_2] 52 | Y = (V[a_2, b_2, c_2] + V[b_2, c_2, a_2] + V[c_2, a_2, b_2]) 53 | Z = (V[a_2, c_2, b_2] + V[b_2, a_2, c_2] + V[c_2, b_2, a_2]) 54 | E = (Y - 2 * Z) * (W[a_2, b_2, c_2] + W[b_2, c_2, a_2] + W[c_2, a_2, b_2]) \ 55 | + (Z - 2 * Y) * (W[a_2, c_2, b_2] + W[b_2, a_2, c_2] + W[c_2, b_2, a_2]) + 3 * X 56 | pT_contribution_2 += E * delta_occ / (Dd * delta_vir_2) 57 | c_2 += 1 58 | return (a_2, b_2, c_2, delta_vir_2, pT_contribution_2) 59 | 60 | a_1_, b_1_, c_1_, delta_vir_, pT_contribution_1_ = while_loop(lambda arr2: arr2[2] < arr2[1] + 1, loop_c, (a_1, b_1, c_1, delta_vir, pT_contribution_1)) 61 | b_1_ += 1 62 | return (a_1_, b_1_, c_1_, pT_contribution_1_) 63 | 64 | a_0_, b_0_, c_0_, pT_contribution_0_ = while_loop(lambda arr1: arr1[1] < arr1[0] + 1, loop_b, (a_0, b_0, c_0, pT_contribution_0)) 65 | a_0_ += 1 66 | return (a_0_, b_0_, c_0_, pT_contribution_0_) 67 | 68 | a, b, c, dE_pT = while_loop(lambda arr0: arr0[0] < v, loop_a, (0, 0, 0, 0.0)) # (a, b, c, pT_contribution) 69 | return dE_pT 70 | 71 | def loop_i(arr0): 72 | i_0, j_0, k_0, pT_0 = arr0 73 | j_0 = 0 74 | 75 | def loop_j(arr1): 76 | i_1, j_1, k_1, pT_1 = arr1 77 | k_1 = 0 78 | 79 | def loop_k(arr2): 80 | i_2, j_2, k_2, pT_2 = arr2 81 | pT_2 += inner_func(i_2, j_2, k_2) 82 | k_2 += 1 83 | return (i_2, j_2, k_2, pT_2) 84 | 85 | i_1_, j_1_, k_1_, pT_1_ = while_loop(lambda arr2: arr2[2] < arr2[1] + 1, loop_k, (i_1, j_1, k_1, pT_1)) 86 | j_1_ += 1 87 | return (i_1_, j_1_, k_1_, pT_1_) 88 | 89 | i_0_, j_0_, k_0_, pT_0_ = while_loop(lambda arr1: arr1[1] < arr1[0] + 1, loop_j, (i_0, j_0, k_0, pT_0)) 90 | i_0_ += 1 91 | return (i_0_, j_0_, k_0_, pT_0_) 92 | 93 | i, j, k, pT = while_loop(lambda arr0: arr0[0] < o, loop_i, (0, 0, 0, 0.0)) # (i, j, k, pT) 94 | return pT 95 | 96 | def rccsd_t(*args, options, deriv_order=0): 97 | if options['electric_field']: 98 | efield, geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path = args 99 | ccsd_args = efield, geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path 100 | elif options['electric_field']: 101 | efield_grad, efield, geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path = args 102 | ccsd_args = efield_grad, efield, geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path 103 | else: 104 | geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path = args 105 | ccsd_args = (geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 106 | 107 | E_ccsd, T1, T2, V, fock_Od, fock_Vd = rccsd(*ccsd_args, options=options, deriv_order=deriv_order, return_aux_data=True) 108 | 109 | print("Running (T) Correction...") 110 | pT = perturbative_triples(T1, T2, V, fock_Od, fock_Vd) 111 | #print("(T) energy correction: ", pT) 112 | #print("CCSD(T) total energy: ", E_ccsd + pT) 113 | return E_ccsd + pT 114 | 115 | -------------------------------------------------------------------------------- /quax/methods/hartree_fock.py: -------------------------------------------------------------------------------- 1 | import jax 2 | jax.config.update("jax_enable_x64", True) 3 | import jax.numpy as jnp 4 | 5 | from .ints import compute_integrals, compute_dipole_ints, compute_quadrupole_ints 6 | from .energy_utils import nuclear_repulsion, cholesky_orthogonalization 7 | 8 | def restricted_hartree_fock(*args, options, deriv_order=0, return_aux_data=False): 9 | if options['electric_field'] == 1: 10 | efield, geom, basis_set, nelectrons, nuclear_charges, xyz_path = args 11 | elif options['electric_field'] == 2: 12 | efield_grad, efield, geom, basis_set, nelectrons, nuclear_charges, xyz_path = args 13 | else: 14 | geom, basis_set, nelectrons, nuclear_charges, xyz_path = args 15 | 16 | print("Running Hartree-Fock Computation...") 17 | # Load keyword options 18 | maxit = options['maxit'] 19 | damping = options['damping'] 20 | damp_factor = options['damp_factor'] 21 | spectral_shift = options['spectral_shift'] 22 | ndocc = nelectrons // 2 23 | 24 | # If we are doing MP2 or CCSD after, might as well use jit-compiled JK-build, since HF will not be memory bottleneck 25 | if return_aux_data: 26 | jk_build = jax.jit(jax.vmap(jax.vmap(lambda x,y: jnp.tensordot(x, y, axes=[(0,1), (0,1)]), in_axes=(0, None)), in_axes=(0, None))) 27 | else: 28 | jk_build = jax.vmap(jax.vmap(lambda x,y: jnp.tensordot(x, y, axes=[(0,1), (0,1)]), in_axes=(0, None)), in_axes=(0, None)) 29 | 30 | S, T, V, G = compute_integrals(geom, basis_set, xyz_path, deriv_order, options) 31 | # Canonical orthogonalization via cholesky decomposition 32 | A = cholesky_orthogonalization(S) 33 | 34 | nbf = S.shape[0] 35 | 36 | # For slightly shifting eigenspectrum of transformed Fock for degenerate eigenvalues 37 | # (JAX cannot differentiate degenerate eigenvalue eigh) 38 | def form_shift(): 39 | fudge = jnp.asarray(jnp.linspace(0, 1, nbf)) * 1.e-9 40 | return jnp.diag(fudge) 41 | 42 | shift = jax.lax.cond(spectral_shift, lambda: form_shift(), lambda: jnp.zeros_like(S)) 43 | 44 | # Shifting eigenspectrum requires lower convergence. 45 | convergence = jax.lax.cond(spectral_shift, lambda: 1.0e-9, lambda: 1.0e-10) 46 | 47 | H = T + V 48 | Enuc = nuclear_repulsion(geom.reshape(-1,3), nuclear_charges) 49 | 50 | if options['electric_field'] == 1: 51 | Mu_XYZ = compute_dipole_ints(geom, basis_set, basis_set, xyz_path, deriv_order, options) 52 | H += jnp.einsum('x,xij->ij', efield, Mu_XYZ, optimize = 'optimal') 53 | elif options['electric_field'] == 2: 54 | Mu_Th = compute_quadrupole_ints(geom, basis_set, basis_set, xyz_path, deriv_order, options) 55 | H += jnp.einsum('x,xij->ij', efield, Mu_Th[:3, :, :], optimize = 'optimal') 56 | H += jnp.einsum('x,xij->ij', efield_grad[jnp.triu_indices(3)], Mu_Th[3:, :, :], optimize = 'optimal') 57 | 58 | def rhf_iter(F, D): 59 | E_scf = jnp.einsum('pq,pq->', F + H, D) + Enuc 60 | Fp = A.T @ F @ A 61 | Fp += shift 62 | eps, C2 = jnp.linalg.eigh(Fp) 63 | C = A @ C2 64 | Cocc = C[:, :ndocc] 65 | D = Cocc @ Cocc.T 66 | return E_scf, D, C, eps 67 | 68 | def DIIS_Err(F, D, S): 69 | diis_e = jnp.einsum('ij,jk,kl->il', F, D, S) - jnp.einsum('ij,jk,kl->il', S, D, F) 70 | diis_e = A @ diis_e @ A 71 | return jnp.mean(diis_e ** 2) ** 0.5 72 | 73 | def scf_procedure(carry): 74 | iter, de_, drms_, eps_, C_, D_old, D_, e_old = carry 75 | 76 | D_ = jax.lax.cond(damping and (iter < 10), lambda: D_old * damp_factor + D_ * damp_factor, lambda: D_) 77 | D_old = jnp.copy(D_) 78 | # Build JK matrix: 2 * J - K 79 | JK = 2 * jk_build(G, D_) 80 | JK -= jk_build(G.transpose((0,2,1,3)), D_) 81 | # Build Fock 82 | F = H + JK 83 | # Compute energy, transform Fock and diagonalize, get new density 84 | e_scf, D_, C_, eps_ = rhf_iter(F, D_) 85 | 86 | de_, drms_ = jax.lax.cond(iter + 1 == maxit, lambda: (1.e-15, 1.e-15), lambda: (e_old - e_scf, DIIS_Err(F, D_, S))) 87 | 88 | return (iter + 1, de_, drms_, eps_, C_, D_old, D_, e_scf) 89 | 90 | # Create Guess Density 91 | D = jax.lax.cond(options['guess_core'], lambda: jnp.copy(H), lambda: jnp.zeros_like(H)) 92 | JK = 2 * jk_build(G, D) 93 | JK -= jk_build(G.transpose((0,2,1,3)), D) 94 | F = H + JK 95 | E_init, D_init, C_init, eps_init = rhf_iter(F, D) 96 | 97 | # Perform SCF Procedure 98 | iteration, _, _, eps, C, _, D, E_scf = jax.lax.while_loop(lambda arr: (abs(arr[1]) > convergence) | (arr[2] > convergence), 99 | scf_procedure, (0, 1.0, 1.0, eps_init, C_init, D, D_init, E_init)) 100 | # (iter, dE, dRMS, eps, C, D_old, D, E_scf) 101 | print(iteration, " RHF iterations performed") 102 | 103 | if options['electric_field'] > 0: 104 | E_scf += jnp.einsum('x,q,qx->', efield, nuclear_charges, geom.reshape(-1,3), optimize = 'optimal') 105 | if options['electric_field'] > 1: 106 | E_scf += jnp.einsum('ab,q,qa,qb->', jnp.triu(efield_grad), nuclear_charges, 107 | geom.reshape(-1,3), geom.reshape(-1,3), optimize = 'optimal') 108 | 109 | # If many orbitals are degenerate, warn that higher order derivatives may be unstable 110 | tmp = jnp.round(eps, 6) 111 | ndegen_orbs = tmp.shape[0] - jnp.unique(tmp).shape[0] 112 | if (ndegen_orbs / nbf) > 0.20: 113 | print("Hartree-Fock warning: More than 20% of orbitals have degeneracies. Higher order derivatives may be unstable due to eigendecomposition AD rule") 114 | 115 | if not return_aux_data: 116 | return E_scf 117 | else: 118 | # print("RHF Energy: ", E_scf) 119 | return E_scf, C, eps, G 120 | 121 | -------------------------------------------------------------------------------- /quax/integrals/utils.h: -------------------------------------------------------------------------------- 1 | // Utility functions for libint_interface 2 | 3 | // Creates atom objects from xyz file path 4 | std::vector get_atoms(std::string xyzfilename) 5 | { 6 | std::ifstream input_file(xyzfilename); 7 | std::vector atoms = libint2::read_dotxyz(input_file); 8 | return atoms; 9 | } 10 | 11 | // Creates a combined basis set 12 | libint2::BasisSet make_ao_cabs(std::vector atoms, 13 | std::string obs_name, libint2::BasisSet cabs) 14 | { 15 | // Create OBS 16 | obs_name.erase(obs_name.end() - 5, obs_name.end()); 17 | auto obs = libint2::BasisSet(obs_name, atoms); 18 | obs.set_pure(false); // use cartesian gaussians 19 | 20 | auto obs_idx = obs.atom2shell(atoms); 21 | auto cabs_idx = cabs.atom2shell(atoms); 22 | 23 | std::vector> el_bases(36); // Only consider atoms up to Kr 24 | for (size_t i = 0; i < atoms.size(); i++) { 25 | if (el_bases[atoms[i].atomic_number].empty()) { 26 | std::vector tmp; 27 | 28 | for(long int& idx : obs_idx[i]) { 29 | tmp.push_back(obs[idx]); 30 | } 31 | for(long int& idx : cabs_idx[i]) { 32 | tmp.push_back(cabs[idx]); 33 | } 34 | 35 | stable_sort(tmp.begin(), tmp.end(), [](const auto& a, const auto& b) -> bool 36 | { 37 | return a.contr[0].l < b.contr[0].l; 38 | }); 39 | 40 | el_bases[atoms[i].atomic_number] = tmp; 41 | } 42 | } 43 | 44 | // Create CABS, union of orbital and auxiliary basis AOs 45 | cabs = libint2::BasisSet(atoms, el_bases); 46 | cabs.set_pure(false); 47 | return cabs; 48 | } 49 | 50 | // Used to make contracted Gaussian-type geminal for F12 methods 51 | std::vector> make_cgtg(double exponent) { 52 | // The fitting coefficients and the exponents from MPQC 53 | std::vector> exp_coeff = {}; 54 | std::vector coeffs = {-0.31442480597241274, -0.30369575353387201, -0.16806968430232927, 55 | -0.098115812152857612, -0.060246640234342785, -0.037263541968504843}; 56 | std::vector exps = {0.22085085450735284, 1.0040191632019282, 3.6212173098378728, 57 | 12.162483236221904, 45.855332448029337, 254.23460688554644}; 58 | 59 | for (int i = 0; i < exps.size(); i++){ 60 | auto exp_scaled = (exponent * exponent) * exps[i]; 61 | exp_coeff.push_back(std::make_pair(exp_scaled, coeffs[i])); 62 | } 63 | 64 | return exp_coeff; 65 | } 66 | 67 | // Returns square of cgtg 68 | std::vector> take_square(std::vector> input) { 69 | auto n = input.size(); 70 | std::vector> output; 71 | for (int i = 0; i < n; ++i) { 72 | auto e_i = input[i].first; 73 | auto c_i = input[i].second; 74 | for (int j = i; j < n; ++j) { 75 | auto e_j = input[j].first; 76 | auto c_j = input[j].second; 77 | double scale = i == j ? 1.0 : 2.0; 78 | output.emplace_back(std::make_pair(e_i + e_j, scale * c_i * c_j)); 79 | } 80 | } 81 | return output; 82 | } 83 | 84 | // Cartesian product of arbitrary number of vectors, given a vector of vectors 85 | // Used to find all possible combinations of indices which correspond to desired nuclear derivatives 86 | // For example, if molecule has two atoms, A and B, and we want nuclear derivative d^2/dAz dBz, 87 | // represented by deriv_vec = [0,0,1,0,0,1], and we are looping over 4 shells in ERI's, 88 | // and the four shells are atoms (0,0,1,1), then possible indices 89 | // of the 0-11 shell cartesian component indices are {2,5} for d/dAz and {8,11} for d/dBz. 90 | // So the vector passed to cartesian_product is { {{2,5},{8,11}}, and all combinations of elements 91 | // from first and second subvectors are produced, and the total nuclear derivative of the shell 92 | // is obtained by summing all of these pieces together. 93 | // These resulting indices are converted to flattened Libint buffer indices using the generate_*_lookup functions, 94 | // explained below. 95 | std::vector> cartesian_product (const std::vector>& v) { 96 | std::vector> s = {{}}; 97 | for (const auto& u : v) { 98 | std::vector> r; 99 | for (const auto& x : s) { 100 | for (const auto y : u) { 101 | r.push_back(x); 102 | r.back().push_back(y); 103 | } 104 | } 105 | s = std::move(r); 106 | } 107 | return s; 108 | } 109 | 110 | // Converts a derivative vector (3*Natom array of integers defining which coordinates to 111 | // differentiate wrt and how many times) to a set of atom indices and coordinate indices 0,1,2->x,y,z 112 | void process_deriv_vec(std::vector deriv_vec, 113 | std::vector *desired_atoms, 114 | std::vector *desired_coordinates) 115 | { 116 | for (int i = 0; i < deriv_vec.size(); i++) { 117 | if (deriv_vec[i] > 0) { 118 | for (int j = 0; j < deriv_vec[i]; j++) { 119 | desired_atoms->push_back(i / 3); 120 | desired_coordinates->push_back(i % 3); 121 | } 122 | } 123 | } 124 | } 125 | 126 | // Returns total size of the libint integral derivative buffer, which is how many unique nth order derivatives 127 | // wrt k objects which have 3 differentiable coordinates each 128 | // k: how many centers 129 | // n: order of differentiation 130 | // l: how many atoms (needed for potential integrals only!) 131 | int how_many_derivs(int k, int n, int l = 0) { 132 | int val = 1; 133 | int factorial = 1; 134 | for (int i=0; i < n; i++) { 135 | val *= (3 * (k + l) + i); 136 | factorial *= i + 1; 137 | } 138 | val /= factorial; 139 | return val; 140 | } 141 | 142 | void cwr_recursion(std::vector inp, 143 | std::vector &out, 144 | std::vector> &result, 145 | int k, int i, int n) 146 | { 147 | // base case: if combination size is k, add to result 148 | if (out.size() == k){ 149 | result.push_back(out); 150 | return; 151 | } 152 | for (int j = i; j < n; j++){ 153 | out.push_back(inp[j]); 154 | cwr_recursion(inp, out, result, k, j, n); 155 | // backtrack - remove current element from solution 156 | out.pop_back(); 157 | } 158 | } 159 | 160 | std::vector> generate_multi_index_lookup(int nparams, int deriv_order) { 161 | using namespace std; 162 | // Generate vector of indices 0 through nparams-1 163 | vector inp; 164 | for (int i = 0; i < nparams; i++) { 165 | inp.push_back(i); 166 | } 167 | // Generate all possible combinations with repitition. 168 | // These are upper triangle indices, and the length of them is the total number of derivatives 169 | vector out; 170 | vector> combos; 171 | cwr_recursion(inp, out, combos, deriv_order, 0, nparams); 172 | return combos; 173 | } -------------------------------------------------------------------------------- /quax/methods/ccsd.py: -------------------------------------------------------------------------------- 1 | import jax 2 | jax.config.update("jax_enable_x64", True) 3 | import jax.numpy as jnp 4 | 5 | from .energy_utils import tei_transformation 6 | from .hartree_fock import restricted_hartree_fock 7 | 8 | def rccsd(*args, options, deriv_order=0, return_aux_data=False): 9 | if options['electric_field'] == 1: 10 | efield, geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path = args 11 | scf_args = efield, geom, basis_set, nelectrons, nuclear_charges, xyz_path 12 | elif options['electric_field'] == 2: 13 | efield_grad, efield, geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path = args 14 | scf_args = efield_grad, efield, geom, basis_set, nelectrons, nuclear_charges, xyz_path 15 | else: 16 | geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path = args 17 | scf_args = (geom, basis_set, nelectrons, nuclear_charges, xyz_path) 18 | 19 | # Load keywords 20 | ndocc = nelectrons // 2 21 | ncore = nfrzn // 2 22 | E_scf, C, eps, V = restricted_hartree_fock(*scf_args, options=options, deriv_order=deriv_order, return_aux_data=True) 23 | 24 | print("Running CCSD Computation...") 25 | nbf = V.shape[0] 26 | nvir = nbf - ndocc 27 | 28 | o = slice(ncore, ndocc) 29 | v = slice(ndocc, nbf) 30 | 31 | # Save slices of two-electron repulsion integrals in MO basis 32 | V = tei_transformation(V, C) 33 | V = jnp.swapaxes(V,1,2) 34 | V = (V[o,o,o,o], V[o,o,o,v], V[o,o,v,v], V[o,v,o,v], V[o,v,v,v], V[v,v,v,v]) 35 | 36 | fock_Od = eps[o] 37 | fock_Vd = eps[v] 38 | 39 | # Oribital energy denominators 40 | D = 1.0 / (fock_Od.reshape(-1, 1, 1, 1) + fock_Od.reshape(-1, 1, 1) - fock_Vd.reshape(-1, 1) - fock_Vd) 41 | d = 1.0 / (fock_Od.reshape(-1, 1) - fock_Vd) 42 | 43 | # Initial Amplitudes 44 | 45 | maxit = options['maxit'] 46 | def ccsd_procedure(arr): 47 | iter, de_, T1_,T2_, e_old = arr 48 | 49 | T1_,T2_ = rccsd_iter(T1_,T2_, V, d, D) 50 | e_ccsd = rccsd_energy(T1_,T2_, V[2]) 51 | 52 | de_ = jax.lax.cond(iter + 1 == maxit, lambda: 1.e-12, lambda: e_ccsd - e_old) 53 | 54 | return (iter + 1, de_, T1_, T2_, e_ccsd) 55 | 56 | iteration, _, T1, T2, E_ccsd = jax.lax.while_loop(lambda arr: abs(arr[1]) > 1e-10, ccsd_procedure, 57 | (0, 1.0, jnp.zeros((ndocc - ncore, nvir)), D * V[2], 0.0)) # (iter, dE, T1, T2, E_ccsd) 58 | 59 | print(iteration, " CCSD iterations performed") 60 | if return_aux_data: 61 | #print("CCSD Correlation Energy: ", E_ccsd) 62 | #print("CCSD Total Energy: ", E_ccsd + E_scf) 63 | return E_scf + E_ccsd, T1, T2, V, fock_Od, fock_Vd 64 | else: 65 | return E_scf + E_ccsd 66 | 67 | # Not a lot of memory use here compared to ccsd iterations, safe to jit-compile this. 68 | @jax.jit 69 | def rccsd_energy(T1, T2, Voovv): 70 | E_ccsd = 0.0 71 | E_ccsd -= jnp.tensordot(T1, jnp.tensordot(T1, Voovv, [(0, 1), (1, 2)]), [(0, 1), (0, 1)]) 72 | E_ccsd -= jnp.tensordot(T2, Voovv, [(0, 1, 2, 3), (1, 0, 2, 3)]) 73 | E_ccsd += 2.0*jnp.tensordot(T2, Voovv, [(0, 1, 2, 3),(0, 1, 2, 3)]) 74 | E_ccsd += 2.0*jnp.tensordot(T1, jnp.tensordot(T1, Voovv, [(0, 1), (0, 2)]), [(0, 1), (0, 1)]) 75 | return E_ccsd 76 | 77 | # Jit compiling ccsd is a BAD IDEA. 78 | # TODO consider breaking up function and jit compiling those which do not use more memory than TEI transformation 79 | def rccsd_iter(T1, T2, V, d, D): 80 | Voooo, Vooov, Voovv, Vovov, Vovvv, Vvvvv = V 81 | 82 | newT1 = jnp.zeros(T1.shape) 83 | newT2 = jnp.zeros(T2.shape) 84 | 85 | # T1 equation 86 | newT1 += jnp.tensordot(T1, Voovv, [(0, 1), (0, 2)]) 87 | newT1 += jnp.tensordot(T2, Vovvv, [(1, 2, 3), (0, 3, 2)]) 88 | newT1 -= jnp.tensordot(Vooov, T2, [(0, 1, 3), (0, 1, 3)]) 89 | newT1 -= jnp.einsum('kc, la, lkic -> ia', T1, T1, Vooov, optimize = 'optimal') 90 | newT1 += jnp.einsum('kc, id, kacd -> ia', T1, T1, Vovvv, optimize = 'optimal') 91 | newT1 -= jnp.einsum('kc, ilad, lkcd -> ia', T1, T2, Voovv, optimize = 'optimal') 92 | newT1 -= jnp.einsum('kc, liad, klcd -> ia', T1, T2, Voovv, optimize = 'optimal') 93 | newT1 -= jnp.einsum('ic, lkad, lkcd -> ia', T1, T2, Voovv, optimize = 'optimal') 94 | newT1 -= jnp.einsum('la, ikdc, klcd -> ia', T1, T2, Voovv, optimize = 'optimal') 95 | newT1 -= jnp.einsum('kc, id, la, klcd -> ia', T1, T1, T1, Voovv, optimize = 'optimal') 96 | newT1 += 2.0 * jnp.einsum('kc, ilad, klcd -> ia', T1, T2, Voovv, optimize = 'optimal') 97 | newT1 *= 2.0 98 | 99 | newT1 -= jnp.tensordot(T1, Vovov, [(0, 1), (2, 1)]) 100 | newT1 -= jnp.tensordot(T2, Vovvv, [(0, 2, 3), (0, 3, 2)]) 101 | newT1 += jnp.tensordot(Vooov, T2, [(0, 1, 3), (1, 0, 3)]) 102 | newT1 -= jnp.einsum('kc, id, kadc -> ia', T1, T1, Vovvv, optimize = 'optimal') 103 | newT1 += jnp.einsum('kc, la, klic -> ia', T1, T1, Vooov, optimize = 'optimal') 104 | newT1 += jnp.einsum('kc, liad, lkcd -> ia', T1, T2, Voovv, optimize = 'optimal') 105 | newT1 += jnp.einsum('ic, lkad, klcd -> ia', T1, T2, Voovv, optimize = 'optimal') 106 | newT1 += jnp.einsum('la, ikcd, klcd -> ia', T1, T2, Voovv, optimize = 'optimal') 107 | newT1 += jnp.einsum('kc, id, la, lkcd -> ia', T1, T1, T1, Voovv, optimize = 'optimal') 108 | 109 | # T2 equation 110 | newT2 -= jnp.einsum('ikac, ljbd, klcd -> ijab', T2, T2, Voovv, optimize = 'optimal') 111 | newT2 -= jnp.einsum('lkac, ijdb, klcd -> ijab', T2, T2, Voovv, optimize = 'optimal') 112 | newT2 -= jnp.einsum('ikac, jlbd, lkcd -> ijab', T2, T2, Voovv, optimize = 'optimal') 113 | newT2 -= jnp.einsum('kiac, jlbd, klcd -> ijab', T2, T2, Voovv, optimize = 'optimal') 114 | newT2 -= jnp.einsum('ijac, klbd, klcd -> ijab', T2, T2, Voovv, optimize = 'optimal') 115 | newT2 += 2.0 * jnp.einsum('ikac, jlbd, klcd -> ijab', T2, T2, Voovv, optimize = 'optimal') 116 | newT2 *= 2.0 117 | 118 | # Reducing Vvvvv contractions to tensordot is especially productive. 119 | # TODO try reducing Vovvv as well. Also check if removing jit makes this optimization moot... 120 | newT2 += Voovv 121 | newT2 += jnp.tensordot(T1, jnp.tensordot(T1, Vvvvv, [(1, ), (1, )]), [(1, ), (1, )]) 122 | newT2 += jnp.tensordot(T2, Vvvvv, [(2, 3), (0, 1)]) 123 | newT2 += jnp.einsum('ka, lb, ijkl -> ijab', T1, T1, Voooo, optimize = 'optimal') 124 | newT2 += jnp.tensordot(T2, Voooo, [(0, 1), (2, 3)]).transpose((2, 3, 0, 1)) 125 | newT2 -= jnp.einsum('ic, jd, ka, kbcd -> ijab', T1, T1, T1, Vovvv, optimize = 'optimal') 126 | newT2 -= jnp.einsum('ic, jd, kb, kadc -> ijab', T1, T1, T1, Vovvv, optimize = 'optimal') 127 | newT2 += jnp.einsum('ic, ka, lb, lkjc -> ijab', T1, T1, T1, Vooov, optimize = 'optimal') 128 | newT2 += jnp.einsum('jc, ka, lb, klic -> ijab', T1, T1, T1, Vooov, optimize = 'optimal') 129 | newT2 += jnp.einsum('klac, ijdb, klcd -> ijab', T2, T2, Voovv, optimize = 'optimal') 130 | newT2 += jnp.einsum('kiac, ljdb, lkcd -> ijab', T2, T2, Voovv, optimize = 'optimal') 131 | newT2 += jnp.einsum('ikac, ljbd, lkcd -> ijab', T2, T2, Voovv, optimize = 'optimal') 132 | newT2 += jnp.einsum('kiac, ljbd, klcd -> ijab', T2, T2, Voovv, optimize = 'optimal') 133 | newT2 += jnp.einsum('ijac, lkbd, klcd -> ijab', T2, T2, Voovv, optimize = 'optimal') 134 | newT2 += jnp.einsum('kjac, ildb, lkcd -> ijab', T2, T2, Voovv, optimize = 'optimal') 135 | newT2 += jnp.einsum('ijdc, lkab, klcd -> ijab', T2, T2, Voovv, optimize = 'optimal') 136 | newT2 += jnp.einsum('ic, jd, ka, lb, klcd -> ijab', T1, T1, T1, T1, Voovv, optimize = 'optimal') 137 | newT2 += jnp.einsum('ic, jd, lkab, lkcd -> ijab', T1, T1, T2, Voovv, optimize = 'optimal') 138 | newT2 += jnp.einsum('ka, lb, ijdc, lkcd -> ijab', T1, T1, T2, Voovv, optimize = 'optimal') 139 | 140 | P_OVVO = jnp.tensordot(T2, Voovv, [(1, 3),(0, 2)]).transpose((0, 2, 1, 3)) 141 | P_OVVO -= jnp.einsum('lb, ikac, lkjc -> ijab', T1, T2, Vooov, optimize = 'optimal') 142 | P_OVVO += jnp.einsum('jc, ikad, kbdc -> ijab', T1, T2, Vovvv, optimize = 'optimal') 143 | P_OVVO += jnp.einsum('kc, ijad, kbcd -> ijab', T1, T2, Vovvv, optimize = 'optimal') 144 | P_OVVO -= jnp.einsum('kc, ilab, lkjc -> ijab', T1, T2, Vooov, optimize = 'optimal') 145 | P_OVVO -= jnp.einsum('kc, jd, ilab, klcd -> ijab', T1, T1, T2, Voovv, optimize = 'optimal') 146 | P_OVVO -= jnp.einsum('kc, la, ijdb, klcd -> ijab', T1, T1, T2, Voovv, optimize = 'optimal') 147 | P_OVVO -= jnp.einsum('ic, ka, jlbd, klcd -> ijab', T1, T1, T2, Voovv, optimize = 'optimal') 148 | P_OVVO -= jnp.einsum('ikdc, ljab, klcd -> ijab', T2, T2, Voovv, optimize = 'optimal') 149 | P_OVVO *= 2.0 150 | 151 | P_OVVO -= jnp.tensordot(T1, Vooov, [(0, ), (2, )]).transpose((2, 1, 3, 0)) 152 | P_OVVO += jnp.tensordot(T1, Vovvv, [(1, ), (1, )]).transpose((1, 0, 2, 3)) 153 | P_OVVO -= jnp.tensordot(T2, Voovv, [(0, 3), (0, 2)]).transpose((0, 2, 1, 3)) 154 | P_OVVO -= jnp.einsum('ic, ka, kjcb -> ijab', T1, T1, Voovv, optimize = 'optimal') 155 | P_OVVO -= jnp.einsum('ic, kb, jcka -> ijab', T1, T1, Vovov, optimize = 'optimal') 156 | P_OVVO -= jnp.tensordot(T2, Vovov, [(1, 3), (2, 1)]).transpose((0, 2, 1, 3)) 157 | P_OVVO -= jnp.tensordot(T2, Vovov, [(0, 3), (2, 1)]).transpose((2, 0, 1, 3)) 158 | P_OVVO += jnp.einsum('lb, kiac, lkjc -> ijab', T1, T2, Vooov, optimize = 'optimal') 159 | P_OVVO -= jnp.einsum('jc, ikdb, kacd -> ijab', T1, T2, Vovvv, optimize = 'optimal') 160 | P_OVVO -= jnp.einsum('jc, kiad, kbdc -> ijab', T1, T2, Vovvv, optimize = 'optimal') 161 | P_OVVO -= jnp.einsum('jc, ikad, kbcd -> ijab', T1, T2, Vovvv, optimize = 'optimal') 162 | P_OVVO += jnp.einsum('jc, lkab, lkic -> ijab', T1, T2, Vooov, optimize = 'optimal') 163 | P_OVVO += jnp.einsum('lb, ikac, kljc -> ijab', T1, T2, Vooov, optimize = 'optimal') 164 | P_OVVO -= jnp.einsum('ka, ijdc, kbdc -> ijab', T1, T2, Vovvv, optimize = 'optimal') 165 | P_OVVO += jnp.einsum('ka, ilcb, lkjc -> ijab', T1, T2, Vooov, optimize = 'optimal') 166 | P_OVVO -= jnp.einsum('kc, ijad, kbdc -> ijab', T1, T2, Vovvv, optimize = 'optimal') 167 | P_OVVO += jnp.einsum('kc, ilab, kljc -> ijab', T1, T2, Vooov, optimize = 'optimal') 168 | P_OVVO += jnp.einsum('jkcd, ilab, klcd -> ijab', T2, T2, Voovv, optimize = 'optimal') 169 | P_OVVO += jnp.einsum('kc, jd, ilab, lkcd -> ijab', T1, T1, T2, Voovv, optimize = 'optimal') 170 | P_OVVO += jnp.einsum('kc, la, ijdb, lkcd -> ijab', T1, T1, T2, Voovv, optimize = 'optimal') 171 | P_OVVO += jnp.einsum('ic, ka, ljbd, klcd -> ijab', T1, T1, T2, Voovv, optimize = 'optimal') 172 | P_OVVO += jnp.einsum('ic, ka, ljdb, lkcd -> ijab', T1, T1, T2, Voovv, optimize = 'optimal') 173 | P_OVVO += jnp.einsum('ic, lb, kjad, klcd -> ijab', T1, T1, T2, Voovv, optimize = 'optimal') 174 | 175 | newT2 += P_OVVO 176 | newT2 += P_OVVO.transpose((1, 0, 3, 2)) 177 | 178 | newT1 *= d 179 | newT2 *= D 180 | return newT1, newT2 181 | 182 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Quax: Quantum Chemistry, powered by JAX 2 | ![Continuous Integration](https://github.com/CCQC/Quax/actions/workflows/continuous_integration.yml/badge.svg) 3 | [![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) 4 | 5 | ![Screenshot](quax.png) 6 | 7 | You have found Quax. The paper outlining this work was just [recently published](https://pubs.acs.org/doi/abs/10.1021/acs.jpclett.1c00607). 8 | This library supports a simple and clean API for obtaining higher-order energy derivatives of electronic 9 | structure computations such as Hartree-Fock, second-order Møller-Plesset perturbation theory (MP2), 10 | explicitly correlated MP2 (MP2-F12), and coupled cluster with singles, doubles, and perturbative triples 11 | excitations [CCSD(T)]. 12 | Whereas most codes support only analytic gradient and occasionally Hessian computations, this code can 13 | compute analytic derivatives of arbitrary order for both geometric derivatives and electric field derivatives. 14 | We use [JAX](https://github.com/google/jax) for automatically differentiating electronic structure computations. 15 | The code can be easily extended to support other methods, for example 16 | using the guidance offered by the [Psi4Numpy project](https://github.com/psi4/psi4numpy). 17 | 18 | If you are interested in obtaining electronic energy derivatives with Quax, 19 | but are wary of and/or not familiar with the concept of automatic differentiation, 20 | we recommend [this video](https://www.youtube.com/watch?v=wG_nF1awSSY) for a brief primer. 21 | 22 | We should also note this project is mostly intended as an experimental proof-of-concept. 23 | While it can be used for research applications, users should always take steps to verify the accuracy of the results, 24 | either by checking energies and derivatives against standard electronic structure codes or by using finite differences. 25 | Generally, if the energy and gradient are correct, higher order derivatives are most likely correct to a high degree of numerical precision. 26 | Note however the caveat (described below) that systems with highly degenerate orbitals will likely be numerically unstable at high derivative orders. 27 | 28 | ### Using Quax 29 | The Quax API is very simple. We use Psi4 to handle molecule data like coordinates, charge, multiplicity, and basis set 30 | information. Once a Psi4 Molecule object is defined, energies, derivatives, and partial derivatives can be computed with a single line of code. 31 | In the following example, we perform Hartree-Fock computations with a sto-3g basis set: we compute the energy, gradient, Hessian, and single elements 32 | of the gradient and Hessian: 33 | 34 | ```python 35 | import quax 36 | import psi4 37 | 38 | molecule = psi4.geometry(""" 39 | 0 1 40 | O 0.0 0.0 0.0 41 | H 0.0 0.0 1.0 42 | H 0.0 1.0 0.0 43 | units bohr 44 | """) 45 | 46 | energy = quax.core.energy(molecule, 'sto-3g', 'hf') 47 | print(energy) 48 | gradient = quax.core.geom_deriv(molecule, 'sto-3g', 'hf', deriv_order=1) 49 | print(gradient) 50 | hessian = quax.core.geom_deriv(molecule, 'sto-3g', 'hf', deriv_order=2) 51 | print(hessian) 52 | 53 | dz1 = quax.core.geom_deriv(molecule, 'sto-3g', 'hf', deriv_order=1, partial=(2,)) 54 | print(dz1) 55 | 56 | dz1_dz2 = quax.core.geom_deriv(molecule, 'sto-3g', 'hf', deriv_order=2, partial=(2,5)) 57 | print(dz1_dz2) 58 | 59 | print('Partial gradient matches gradient element: ', dz1 == gradient[2]) 60 | print('Partial hessian matches hessian element: ', dz1_dz2 == hessian[2,5]) 61 | ``` 62 | 63 | Above, in the `quax.core.geom_deriv` function calls, the `partial` arguments describe the address of the element in the _n_th order derivative 64 | tensor you want to compute. The dimensions of a derivative tensor correspond to the row-wise flattened Cartesian coordinates, with 0-based indexing. 65 | For _N_ Cartesian coordinates, gradient is a size _N_ vector, Hessian a _N_ by _N_ matrix, and cubic and quartic derivative tensors are rank-3 and rank-4 tensors with dimension size _N_. 66 | 67 | Speaking of which, the Quax API currently supports up to 4th-order full-derivatives of energy methods, and up to 6th-order partial derivatives. 68 | A full quartic derivative tensor at CCSD(T) can be computed like so: 69 | 70 | ```python 71 | import quax 72 | import psi4 73 | 74 | molecule = psi4.geometry(''' 75 | 0 1 76 | H 0.0 0.0 -0.80000000000 77 | H 0.0 0.0 0.80000000000 78 | units bohr 79 | ''') 80 | 81 | quartic = quax.core.geom_deriv(molecule, '6-31g', 'ccsd(t)', deriv_order=4) 82 | ``` 83 | 84 | Perhaps that's too expensive/slow. You can instead compute quartic partial derivatives: 85 | 86 | ```python 87 | import quax 88 | import psi4 89 | 90 | molecule = psi4.geometry(''' 91 | 0 1 92 | H 0.0 0.0 -0.80000000000 93 | H 0.0 0.0 0.80000000000 94 | units bohr 95 | ''') 96 | 97 | dz1_dz1_dz2_dz2 = quax.core.geom_deriv(molecule, '6-31g', 'ccsd(t)', deriv_order=4, partial=(2,2,5,5)) 98 | ``` 99 | 100 | Similar computations can be split across multiple nodes in an embarassingly parallel fashion, and one can take full advantage of symmetry so that only the unique elements are computed. 101 | The full quartic derivative tensor can then be constructed with the results. 102 | 103 | It's important to note that full derivative tensor computations may easily run into memory issues. 104 | For example, the two-electron integrals fourth derivative tensor used in the above computation 105 | for _n_ basis functions and _N_ cartesian coordinates at derivative order _k_ contains _n_4 * _N_k double precision floating point numbers, which requires a great deal of memory. 106 | Not only that, but the regular two-electron integrals, and the first, second, and third-order derivative tensors are also held in memory. 107 | The above computation therefore, from having 4 basis functions, stores 5 arrays associated with the two-electron integrals at run time: 108 | each of shapes (4,4,4,4), (4,4,4,4,6), (4,4,4,4,6,6), (4,4,4,4,6,6,6), (4,4,4,4,6,6,6,6). 109 | These issues also arise in the simulataneous storage of the old and new T1 and T2 amplitudes during coupled cluster iterations. 110 | Obviously, for large basis sets and molecules, these arrays get very big very fast. 111 | Unless you have impressive computing resources, partial derivatives are recommended for higher order derivatives. 112 | 113 | ### Caveats 114 | The Libint interface is a necessary dependency for Quax. However, compiling Libint for support for very high order 115 | derivatives (5th, 6th) takes a very long time and causes the library size to be very large (sometimes so large it's uncompilable). 116 | We will incrementally roll out improvements which allow user specification for how to handle higher-order integral derivatives. 117 | Contributions and suggestions are welcome. 118 | 119 | Also, we do not recommend computing derivatives of systems with many degenerate orbitals. 120 | The reason for this is because automatically differentiating through eigendecomposition involves denominators of eigenvalue differences, which blow up in the degenerate case. 121 | We cheat our way around this by shifting the eigenspectrum to lift the degeneracy, but this only works for systems with moderate degeneracy. 122 | Workarounds for this are coming soon. 123 | 124 | # Installation Instructions 125 | 126 | ### Anaconda Environment installation instructions 127 | To use Quax, only a few dependencies are needed. We recommend using a clean Anaconda environment: 128 | ``` 129 | conda create -n quax python=3.10 130 | conda activate quax 131 | ``` 132 | 133 | Then install the dependencies into your new environment, all can be installed alongside Psi4: 134 | ``` 135 | conda install psi4 python=3.10 -c conda-forge 136 | ``` 137 | 138 | The Libint interface must be built before installing: 139 | ``` 140 | cd quax/integrals/ 141 | make 142 | cd ../../ 143 | ``` 144 | More can be found below if a custom Libint binary is wanted. 145 | 146 | Finally install Quax! 147 | ``` 148 | python -m pip install . 149 | ``` 150 | 151 | ### Building the Libint Interface 152 | A [Docker image](https://hub.docker.com/r/ericacmitchell/libint_derivs) has been made for Libint with up to 2nd-order derivatives and maximum angular momentum of 5 for standard integrals, Cartesian-multipole integrals, and F12-type integrals. 153 | 154 | Otherwise, for the Libint interface, you nust install those dependencies as well. 155 | ``` 156 | conda install libstdcxx-ng gcc_linux-64 gxx_linux-64 ninja boost eigen3 gmp bzip2 cmake pybind11 157 | ``` 158 | 159 | We note here that the default gcc version (4.8) that comes with `conda install gcc` is not recent enough to successfully compile the Quax-Libint interface. 160 | You must instead use a more modern compiler. To do this in Anaconda, we need to use 161 | `x86_64-conda_cos6-linux-gnu-gcc` as our compiler instead of gcc. 162 | This is available by installing `gcc_linux-64` and `gxx_linux-64`. 163 | Feel free to try other more advanced compilers. gcc >= 7.0 appears to work great. 164 | 165 | ### Building Libint 166 | Libint can be built to support specific maximum angular momentum, different types of integrals, and certain derivative orders. 167 | The following is a build procedure supports up to _d_ functions and 4th-order derivatives. For more details, 168 | see the [Libint](https://github.com/evaleev/libint) repo. 169 | Note this build takes quite a long time! (on the order of hours to a couple days) 170 | In the future we will look into supplying pre-built libint tarballs by some means. 171 | 172 | ``` 173 | git clone https://github.com/evaleev/libint.git 174 | cd libint 175 | ./autogen.sh 176 | mkdir BUILD 177 | cd BUILD 178 | mkdir PREFIX 179 | ../configure --prefix=/path/to/libint/build/PREFIX --with-max-am=2 --with-opt-am=0 --enable-1body=4 --enable-eri=4 --with-multipole-max-order=0 --enable-eri3=no --enable-eri2=no --enable-g12=no --enable-g12dkh=no --with-pic --enable-static --enable-single-evaltype --enable-generic-code --disable-unrolling 180 | 181 | make export 182 | ``` 183 | 184 | The above will produce a file of the form `libint-*.tgz`, containing your custom Libint library that needs to be compiled. 185 | 186 | ### Compiling Libint 187 | Now, given a Libint tarball which supports the desired maximum angular momentum and derivative order, 188 | we need to unpack the library, `cd` into it, and `mkdir PREFIX` where the headers and static library will be stored. 189 | The position independent code flag is required for Libint to play nice with pybind11. 190 | The `-j4` flag instructs how many processors to use in compilation, and can be adjusted according to your system. The `--target check` runs the Libint test suite; it is not required. 191 | The --target check runs test suite, and finally the install command installs the headers and static library into the PREFIX directory. 192 | ``` 193 | tar -xvf libint_*.tgz 194 | cd libint-*/ 195 | mkdir PREFIX 196 | cmake -GNinja . -DCMAKE_INSTALL_PREFIX=/path/to/libint/PREFIX/ -DCMAKE_POSITION_INDEPENDENT_CODE=ON 197 | cmake --build . -- -j4 198 | cmake --build . --target check 199 | cmake --build . --target install 200 | ``` 201 | 202 | Note that the following cmake command may not find various libraries for the dependencies of Libint. 203 | `cmake -GNinja . -DCMAKE_INSTALL_PREFIX=/path/to/libint/PREFIX/ -DCMAKE_POSITION_INDEPENDENT_CODE=ON` 204 | To fix this, you may need to explicitly point to it 205 | `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/libint/dependency/lib/` 206 | and then run the above cmake command. 207 | If using Anaconda, the path is probably in the environment directory `/path/to/envs/quax/lib/`. 208 | 209 | ### Compiling the Libint-Quax interface 210 | Once Libint is installed, the makefile in `quax/integrals/makefile` needs to be edited with your compiler and the proper paths specifying the locations 211 | of headers and libraries for Libint, pybind11, HDF5, and python. 212 | 213 | The `LIBINT_PREFIX` path in the makefile is wherever you installed the headers and the static library `lib/libint2.a`. 214 | All of the required headers and libraries should be discoverable in the Anaconda environment's include and lib paths. 215 | After editing the paths appropriately and setting the CC compiler to `x86_64-conda_cos6-linux-gnu-gcc`, or 216 | if you have a nice modern compiler available, use that. 217 | 218 | Running `make` in the directory `quax/integrals/` to compile the Libint interface. 219 | 220 | ### Citing Quax 221 | If you use Quax in your research, we would appreciate a citation: 222 | ``` 223 | @article{abbott2021, 224 | title={Arbitrary-Order Derivatives of Quantum Chemical Methods via Automatic Differentiation}, 225 | author={Abbott, Adam S and Abbott, Boyi Z and Turney, Justin M and Schaefer III, Henry F}, 226 | journal={The Journal of Physical Chemistry Letters}, 227 | volume={12}, 228 | pages={3232--3239}, 229 | year={2021}, 230 | publisher={ACS Publications} 231 | } 232 | ``` 233 | We also kindly request you give credit to the projects which make up the dependencies of Quax. 234 | -------------------------------------------------------------------------------- /quax/methods/mp2f12.py: -------------------------------------------------------------------------------- 1 | import jax 2 | jax.config.update("jax_enable_x64", True) 3 | import jax.numpy as jnp 4 | from jax.lax import fori_loop, cond 5 | 6 | from .basis_utils import build_CABS 7 | from .ints import compute_f12_oeints, compute_f12_teints, compute_dipole_ints, compute_quadrupole_ints 8 | from .energy_utils import partial_tei_transformation, cartesian_product 9 | from .mp2 import restricted_mp2 10 | 11 | def restricted_mp2_f12(*args, options, deriv_order=0): 12 | if options['electric_field'] == 1: 13 | efield, geom, basis_set, cabs_set, nelectrons, nfrzn, nuclear_charges, xyz_path = args 14 | fields = (efield,) 15 | mp2_args = efield, geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path 16 | elif options['electric_field'] == 2: 17 | efield_grad, efield, geom, basis_set, cabs_set, nelectrons, nfrzn, nuclear_charges, xyz_path = args 18 | fields = (efield_grad, efield) 19 | mp2_args = efield_grad, efield, geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path 20 | else: 21 | geom, basis_set, cabs_set, nelectrons, nfrzn, nuclear_charges, xyz_path = args 22 | fields = None 23 | mp2_args = (geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 24 | 25 | E_mp2, C_obs, eps, G = restricted_mp2(*mp2_args, options=options, deriv_order=deriv_order, return_aux_data=True) 26 | ndocc = nelectrons // 2 27 | ncore = nfrzn // 2 28 | eps_occ, eps_vir = eps[:ndocc], eps[ndocc:] 29 | 30 | print("Running MP2-F12 Computation...") 31 | C_cabs = build_CABS(geom, basis_set, cabs_set, xyz_path, deriv_order, options) 32 | C_mats = (C_obs[:, :ndocc], C_obs, C_cabs) # C_occ, C_obs, C_cabs 33 | 34 | nobs = C_obs.shape[0] 35 | spaces = (ndocc, nobs, C_cabs.shape[0]) # ndocc, nobs, nri 36 | 37 | # Fock 38 | f, fk, k = form_Fock(geom, basis_set, cabs_set, C_mats, spaces, fields, xyz_path, deriv_order, options) 39 | 40 | # V Intermediate 41 | V = form_V(geom, basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options)\ 42 | 43 | # X Intermediate 44 | X = form_X(geom, basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options) 45 | 46 | # C Intermediate 47 | C = form_C(geom, basis_set, cabs_set, f[nobs:, ndocc:nobs], C_mats, spaces, xyz_path, deriv_order, options) 48 | 49 | # B Intermediate 50 | B = form_B(geom, basis_set, cabs_set, f, k, fk[:ndocc, :], C_mats, spaces, xyz_path, deriv_order, options) 51 | 52 | D = -1.0 / (eps_occ.reshape(-1, 1, 1, 1) + eps_occ.reshape(-1, 1, 1) - eps_vir.reshape(-1, 1) - eps_vir) 53 | G = jnp.swapaxes(G, 1, 2) 54 | 55 | indices = jnp.asarray(jnp.triu_indices(ndocc)).reshape(2,-1).T 56 | 57 | def loop_energy(idx, f12_corr): 58 | i, j = indices[idx] 59 | kd = cond(i == j, lambda: 1.0, lambda: 2.0) 60 | 61 | D_ij = D[i, j, :, :] 62 | 63 | GD_ij = jnp.einsum('ab,ab->ab', G[i - ncore, j - ncore, :, :], D_ij, optimize='optimal') 64 | V_ij = V[i, j, :, :] - jnp.einsum('klab,ab->kl', C, GD_ij, optimize='optimal') 65 | 66 | V_s = 0.25 * (t_(i, j, i, j) + t_(i, j, j, i)) * kd * (V_ij[i, j] + V_ij[j, i]) 67 | 68 | V_t = 0.25 * cond(i != j, lambda: (t_(i, j, i, j) - t_(i, j, j, i)) 69 | * kd * (V_ij[i, j] - V_ij[j, i]), lambda: 0.0) 70 | 71 | CD_ij = jnp.einsum('mnab,ab->mnab', C, D_ij, optimize='optimal') 72 | B_ij = B - (X * (f[i, i] + f[j, j])) - jnp.einsum('klab,mnab->klmn', C, CD_ij, optimize='optimal') 73 | 74 | B_s = 0.125 * (t_(i, j, i, j) + t_(i, j, j, i)) * kd \ 75 | * (B_ij[i, j, i, j] + B_ij[j, i, i, j]) \ 76 | * (t_(i, j, i, j) + t_(i, j, j, i)) * kd 77 | 78 | B_t = 0.125 * cond(i != j, lambda: (t_(i, j, i, j) - t_(i, j, j, i)) * kd 79 | * (B_ij[i, j, i, j] - B_ij[j, i, i, j]) 80 | * (t_(i, j, i, j) - t_(i, j, j, i)) * kd, 81 | lambda: 0.0) 82 | 83 | f12_corr += kd * (2.0 * V_s + B_s) # Singlet Pair Energy 84 | f12_corr += 3.0 * kd * (2.0 * V_t + B_t) # Triplet Pair Energy 85 | return f12_corr 86 | 87 | def frzn_core(idx, accu): 88 | accu += ndocc - idx 89 | return accu 90 | 91 | start = fori_loop(0, ncore, frzn_core, 0) 92 | dE_f12 = fori_loop(start, indices.shape[0], loop_energy, 0.0) 93 | 94 | E_s = cabs_singles(f, spaces) 95 | 96 | return E_mp2 + dE_f12 + E_s 97 | 98 | # CABS Singles 99 | def cabs_singles(f, spaces): 100 | ndocc, _, nri = spaces 101 | all_vir = nri - ndocc 102 | 103 | e_ij, C_ij = jnp.linalg.eigh(f[:ndocc, :ndocc]) 104 | e_AB, C_AB = jnp.linalg.eigh(f[ndocc:, ndocc:]) 105 | 106 | f_iA = C_ij.T @ f[:ndocc, ndocc:] @ C_AB 107 | 108 | indices = cartesian_product(jnp.arange(ndocc), jnp.arange(all_vir)) 109 | 110 | def loop_singles(idx, singles): 111 | i, A = indices[idx] 112 | singles += 2 * f_iA[i, A]**2 / (e_ij[i] - e_AB[A]) 113 | return singles 114 | E_s = fori_loop(0, indices.shape[0], loop_singles, 0.0) 115 | 116 | return E_s 117 | 118 | # Fixed Amplitude Ansatz 119 | @jax.jit 120 | def t_(p, q, r, s): 121 | return jnp.select( 122 | [(p == q) & (p == r) & (p == s), (p == r) & (q == s), (p == s) & (q == r)], 123 | [0.5, 0.375, 0.125], 124 | default = jnp.nan 125 | ) 126 | 127 | # One-Electron Integrals 128 | def one_body_mo_computer(geom, bs1, bs2, C1, C2, fields, xyz_path, deriv_order, options): 129 | """ 130 | General one-body MO computer 131 | that computes the AOs and 132 | transforms to MOs 133 | """ 134 | T, V = compute_f12_oeints(geom, bs1, bs2, xyz_path, deriv_order, options, False) 135 | AO = T + V 136 | 137 | if options['electric_field'] == 1: 138 | Mu_XYZ = compute_dipole_ints(geom, bs1, bs2, xyz_path, deriv_order, options) 139 | AO += jnp.einsum('x,xij->ij', fields[0], Mu_XYZ, optimize = 'optimal') 140 | elif options['electric_field'] == 2: 141 | Mu_Th = compute_quadrupole_ints(geom, bs1, bs2, xyz_path, deriv_order, options) 142 | AO += jnp.einsum('x,xij->ij', fields[0], Mu_Th[:3, :, :], optimize = 'optimal') 143 | AO += jnp.einsum('x,xij->ij', fields[1][jnp.triu_indices(3)], Mu_Th[3:, :, :], optimize = 'optimal') 144 | 145 | MO = C1.T @ AO @ C2 146 | return MO 147 | 148 | def form_h(geom, basis_set, cabs_set, C_mats, spaces, fields, xyz_path, deriv_order, options): 149 | _, nobs, nri = spaces 150 | _, C_obs, C_cabs = C_mats 151 | 152 | tv = jnp.zeros((nri, nri)) 153 | 154 | mo1 = one_body_mo_computer(geom, basis_set, basis_set, C_obs, C_obs, fields, xyz_path, deriv_order, options) 155 | tv = tv.at[:nobs, :nobs].set(mo1) # 156 | del mo1 157 | 158 | mo2 = one_body_mo_computer(geom, basis_set, cabs_set, C_obs, C_cabs, fields, xyz_path, deriv_order, options) 159 | tv = tv.at[:nobs, nobs:nri].set(mo2) # 160 | tv = tv.at[nobs:nri, :nobs].set(mo2.T) # 161 | del mo2 162 | 163 | mo3 = one_body_mo_computer(geom, cabs_set, cabs_set, C_cabs, C_cabs, fields, xyz_path, deriv_order, options) 164 | tv = tv.at[nobs:nri, nobs:nri].set(mo3) # 165 | del mo3 166 | 167 | return tv 168 | 169 | # Two-Electron Integrals 170 | def two_body_mo_computer(geom, int_type, bs1, bs2, bs3, bs4, C1, C2, C3, C4, xyz_path, deriv_order, options): 171 | """ 172 | General two-body MO computer 173 | that computes the AOs in chem notation, 174 | then transforms to MOs, 175 | and returns the MOs in phys notation 176 | """ 177 | AO = compute_f12_teints(geom, bs1, bs3, bs2, bs4, int_type, xyz_path, deriv_order, options) 178 | MO = partial_tei_transformation(AO, C1, C3, C2, C4) 179 | MO = jnp.swapaxes(MO, 1, 2) 180 | return MO 181 | 182 | def form_J(geom, basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options): 183 | ndocc, nobs, nri = spaces 184 | C_occ, C_obs, C_cabs = C_mats 185 | 186 | eri = jnp.zeros((nri, ndocc, nri, ndocc)) 187 | 188 | mo1 = two_body_mo_computer(geom, "eri", basis_set, basis_set, basis_set, basis_set,\ 189 | C_obs, C_occ, C_obs, C_occ, xyz_path, deriv_order, options) 190 | eri = eri.at[:nobs, :, :nobs, :].set(mo1) # 191 | del mo1 192 | 193 | mo2 = two_body_mo_computer(geom, "eri", cabs_set, basis_set, basis_set, basis_set,\ 194 | C_cabs, C_occ, C_obs, C_occ, xyz_path, deriv_order, options) 195 | eri = eri.at[nobs:nri, :, :nobs, :].set(mo2) # 196 | eri = eri.at[:nobs, :, nobs:nri, :].set(jnp.transpose(mo2, (2,3,0,1))) # 197 | del mo2 198 | 199 | mo3 = two_body_mo_computer(geom, "eri", cabs_set, basis_set, cabs_set, basis_set,\ 200 | C_cabs, C_occ, C_cabs, C_occ, xyz_path, deriv_order, options) 201 | eri = eri.at[nobs:nri, :, nobs:nri, :].set(mo3) # 202 | del mo3 203 | 204 | return eri 205 | 206 | def form_K(geom, basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options): 207 | ndocc, nobs, nri = spaces 208 | C_occ, C_obs, C_cabs = C_mats 209 | 210 | eri = jnp.empty((nri, ndocc, ndocc, nri)) 211 | 212 | mo1 = two_body_mo_computer(geom, "eri", basis_set, basis_set, basis_set, basis_set,\ 213 | C_obs, C_occ, C_occ, C_obs, xyz_path, deriv_order, options) 214 | eri = eri.at[:nobs, :, :, :nobs].set(mo1) # 215 | del mo1 216 | 217 | mo2 = two_body_mo_computer(geom, "eri", cabs_set, basis_set, basis_set, basis_set,\ 218 | C_cabs, C_occ, C_occ, C_obs, xyz_path, deriv_order, options) 219 | eri = eri.at[nobs:nri, :, :, :nobs].set(mo2) # 220 | eri = eri.at[:nobs, :, :, nobs:nri].set(jnp.transpose(mo2, (3,2,1,0))) # 221 | del mo2 222 | 223 | mo3 = two_body_mo_computer(geom, "eri", cabs_set, basis_set, basis_set, cabs_set,\ 224 | C_cabs, C_occ, C_occ, C_cabs, xyz_path, deriv_order, options) 225 | eri = eri.at[nobs:nri, :, :, nobs:nri].set(mo3) # 226 | del mo3 227 | 228 | return eri 229 | 230 | def form_ooO1(geom, int_type, basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options): 231 | ndocc, nobs, nri = spaces 232 | C_occ, C_obs, C_cabs = C_mats 233 | 234 | eri = jnp.zeros((ndocc, ndocc, nobs, nri)) 235 | 236 | mo1 = two_body_mo_computer(geom, int_type, basis_set, basis_set, basis_set, basis_set,\ 237 | C_occ, C_occ, C_obs, C_obs, xyz_path, deriv_order, options) 238 | eri = eri.at[:, :, :, :nobs].set(mo1) # 239 | del mo1 240 | 241 | mo2 = two_body_mo_computer(geom, int_type, basis_set, basis_set, basis_set, cabs_set,\ 242 | C_occ, C_occ, C_obs, C_cabs, xyz_path, deriv_order, options) 243 | eri = eri.at[:, :, :, nobs:].set(mo2) # 244 | del mo2 245 | 246 | return eri 247 | 248 | def form_F(geom, basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options): 249 | ndocc, nobs, nri = spaces 250 | C_occ, C_obs, C_cabs = C_mats 251 | 252 | f12 = jnp.zeros((ndocc, ndocc, nri, nri)) 253 | 254 | mo1 = two_body_mo_computer(geom, "f12", basis_set, basis_set, basis_set, basis_set,\ 255 | C_occ, C_occ, C_obs, C_obs, xyz_path, deriv_order, options) 256 | f12 = f12.at[:, :, :nobs, :nobs].set(mo1) # 257 | del mo1 258 | 259 | mo2 = two_body_mo_computer(geom, "f12", basis_set, basis_set, basis_set, cabs_set,\ 260 | C_occ, C_occ, C_obs, C_cabs, xyz_path, deriv_order, options) 261 | f12 = f12.at[:, :, :nobs, nobs:].set(mo2) # 262 | f12 = f12.at[:, :, nobs:, :nobs].set(jnp.transpose(mo2, (1,0,3,2))) # 263 | del mo2 264 | 265 | mo3 = two_body_mo_computer(geom, "f12", basis_set, basis_set, cabs_set, cabs_set,\ 266 | C_occ, C_occ, C_cabs, C_cabs, xyz_path, deriv_order, options) 267 | f12 = f12.at[:, :, nobs:, nobs:].set(mo3) # 268 | del mo3 269 | 270 | return f12 271 | 272 | def form_F2(geom, basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options): 273 | ndocc, nobs, nri = spaces 274 | C_occ, C_obs, C_cabs = C_mats 275 | 276 | f12_squared = jnp.zeros((ndocc, ndocc, ndocc, nri)) 277 | 278 | mo1 = two_body_mo_computer(geom, "f12_squared", basis_set, basis_set, basis_set, basis_set,\ 279 | C_occ, C_occ, C_occ, C_obs, xyz_path, deriv_order, options) 280 | f12_squared = f12_squared.at[:, :, :, :nobs].set(mo1) # 281 | del mo1 282 | 283 | mo2 = two_body_mo_computer(geom, "f12_squared", basis_set, basis_set, basis_set, cabs_set,\ 284 | C_occ, C_occ, C_occ, C_cabs, xyz_path, deriv_order, options) 285 | f12_squared = f12_squared.at[:, :, :, nobs:].set(mo2) # 286 | del mo2 287 | 288 | return f12_squared 289 | 290 | # Fock 291 | def form_Fock(geom, basis_set, cabs_set, C_mats, spaces, fields, xyz_path, deriv_order, options): 292 | 293 | fk = form_h(geom, basis_set, cabs_set, C_mats, spaces, fields, xyz_path, deriv_order, options) 294 | J = form_J(geom, basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options) 295 | K = form_K(geom, basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options) 296 | 297 | # Fock Matrix without Exchange 298 | fk += 2.0 * jnp.einsum('piqi->pq', J, optimize='optimal') 299 | 300 | # Exchange 301 | k = jnp.einsum('piiq->pq', K, optimize='optimal') 302 | 303 | f = fk - k 304 | 305 | return f, fk, k 306 | 307 | # F12 Intermediates 308 | def form_V(geom, basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options): 309 | C_occ, _, _ = C_mats 310 | ndocc, nobs, _ = spaces 311 | 312 | FG = two_body_mo_computer(geom, "f12g12", basis_set, basis_set, basis_set, basis_set,\ 313 | C_occ, C_occ, C_occ, C_occ, xyz_path, deriv_order, options) 314 | G = form_ooO1(geom, "eri", basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options) 315 | F = form_ooO1(geom, "f12", basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options) 316 | 317 | ijkl_1 = jnp.einsum('ijmy,klmy->ijkl', G[:, :, :ndocc, nobs:], F[:, :, :ndocc, nobs:], optimize='optimal') 318 | ijkl_2 = jnp.transpose(ijkl_1, (1,0,3,2)) # ijxn,klxn->ijkl 319 | ijkl_3 = jnp.einsum('ijrs,klrs->ijkl', G[:, :, :nobs, :nobs], F[:, :, :nobs, :nobs], optimize='optimal') 320 | 321 | return FG - ijkl_1 - ijkl_2 - ijkl_3 322 | 323 | def form_X(geom, basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options): 324 | C_occ, _, _ = C_mats 325 | ndocc, nobs, _ = spaces 326 | 327 | F2 = two_body_mo_computer(geom, "f12_squared", basis_set, basis_set, basis_set, basis_set,\ 328 | C_occ, C_occ, C_occ, C_occ, xyz_path, deriv_order, options) 329 | F = form_ooO1(geom, "f12", basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options) 330 | 331 | ijkl_1 = jnp.einsum('ijmy,klmy->ijkl', F[:, :, :ndocc, nobs:], F[:, :, :ndocc, nobs:], optimize='optimal') 332 | ijkl_2 = jnp.transpose(ijkl_1, (1,0,3,2)) # ijxn,klxn->ijkl 333 | ijkl_3 = jnp.einsum('ijrs,klrs->ijkl', F[:, :, :nobs, :nobs], F[:, :, :nobs, :nobs], optimize='optimal') 334 | 335 | return F2 - ijkl_1 - ijkl_2 - ijkl_3 336 | 337 | def form_C(geom, basis_set, cabs_set, f_cv, C_mats, spaces, xyz_path, deriv_order, options): 338 | C_occ, C_obs, C_cabs = C_mats 339 | ndocc, nobs, _ = spaces 340 | 341 | F = two_body_mo_computer(geom, "f12", basis_set, basis_set, basis_set, cabs_set,\ 342 | C_occ, C_occ, C_obs, C_cabs, xyz_path, deriv_order, options) 343 | 344 | klab = jnp.einsum('klax,xb->klab', F[:, :, ndocc:nobs, :], f_cv, optimize='optimal') 345 | 346 | return klab + jnp.transpose(klab, (1,0,3,2)) 347 | 348 | def form_B(geom, basis_set, cabs_set, f, k, fk_o1, C_mats, spaces, xyz_path, deriv_order, options): 349 | C_occ, C_obs, C_cabs = C_mats 350 | ndocc, nobs, _ = spaces 351 | 352 | Uf = two_body_mo_computer(geom, "f12_double_commutator", basis_set, basis_set, basis_set, basis_set,\ 353 | C_occ, C_occ, C_occ, C_occ, xyz_path, deriv_order, options) 354 | F2 = form_F2(geom, basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options) 355 | F = form_F(geom, basis_set, cabs_set, C_mats, spaces, xyz_path, deriv_order, options) 356 | 357 | # Term 2 358 | terms = jnp.einsum('nmlP,kP->nmlk', F2, fk_o1) 359 | 360 | # Term 3 361 | terms -= jnp.einsum('nmQP,PR,lkQR->nmlk', F, k, F, optimize='optimal') 362 | 363 | # Term 4 364 | terms -= jnp.einsum('nmjP,PR,lkjR->nmlk', F[:, :, :ndocc, :], f, F[:, :, :ndocc, :], optimize='optimal') 365 | 366 | # Term 5 367 | terms += jnp.einsum('nmyi,ij,lkyj->nmlk', F[:, :, nobs:, :ndocc], f[:ndocc, :ndocc],\ 368 | F[:, :, nobs:, :ndocc], optimize='optimal') 369 | 370 | # Term 6 371 | terms -= jnp.einsum('nmbp,pr,lkbr->nmlk', F[:, :, ndocc:nobs, :nobs], f[:nobs, :nobs],\ 372 | F[:, :, ndocc:nobs, :nobs], optimize='optimal') 373 | 374 | # Term 7 375 | terms -= 2.0 * jnp.einsum('nmyi,iP,lkyP->nmlk', F[:, :, nobs:, :], f[:, :ndocc],\ 376 | F[:, :, nobs:, :ndocc], optimize='optimal') 377 | 378 | # Term 8 379 | terms -= 2.0 * jnp.einsum('nmbx,xq,lkbq->nmlk', F[:, :, ndocc:nobs, :nobs], f[:nobs, nobs:],\ 380 | F[:, :, ndocc:nobs, nobs:], optimize='optimal') 381 | 382 | B_nosymm = Uf + terms + jnp.transpose(terms, (1,0,3,2)) # nmlk->mnkl 383 | 384 | return 0.5 * (B_nosymm + jnp.transpose(B_nosymm, (2,3,0,1))) # mnkl + klmn 385 | -------------------------------------------------------------------------------- /quax/methods/ints.py: -------------------------------------------------------------------------------- 1 | import jax 2 | jax.config.update("jax_enable_x64", True) 3 | import jax.numpy as jnp 4 | from jax import jacfwd 5 | import numpy as np 6 | import h5py 7 | import psi4 8 | import os 9 | 10 | # Check for Libint interface 11 | from ..integrals import TEI 12 | from ..integrals import OEI 13 | from ..integrals import libint_interface 14 | 15 | 16 | def compute_integrals(geom, basis_set, xyz_path, deriv_order, options): 17 | # Load integral algo, decides to compute integrals in memory or use disk 18 | algo = options['integral_algo'] 19 | basis_name = basis_set.name() 20 | libint_interface.initialize(xyz_path, basis_name, basis_name, basis_name, basis_name, options['ints_tolerance']) 21 | 22 | if algo == 'libint_disk': 23 | # Check disk for currently existing integral derivatives 24 | check_oei = check_oei_disk("all", basis_set, basis_set, deriv_order) 25 | check_tei = check_tei_disk("eri", basis_set, basis_set, basis_set, basis_set, deriv_order) 26 | 27 | oei_obj = OEI(basis_set, basis_set, xyz_path, deriv_order, 'disk') 28 | tei_obj = TEI(basis_set, basis_set, basis_set, basis_set, xyz_path, deriv_order, options, 'disk') 29 | # If disk integral derivs are right, nothing to do 30 | if check_oei: 31 | S = oei_obj.overlap(geom) 32 | T = oei_obj.kinetic(geom) 33 | V = oei_obj.potential(geom) 34 | else: 35 | libint_interface.oei_deriv_disk(deriv_order) 36 | S = oei_obj.overlap(geom) 37 | T = oei_obj.kinetic(geom) 38 | V = oei_obj.potential(geom) 39 | 40 | if check_tei: 41 | G = tei_obj.eri(geom) 42 | else: 43 | libint_interface.compute_2e_deriv_disk("eri", 0., deriv_order) 44 | G = tei_obj.eri(geom) 45 | 46 | else: 47 | # Precompute TEI derivatives 48 | oei_obj = OEI(basis_set, basis_set, xyz_path, deriv_order, 'core') 49 | tei_obj = TEI(basis_set, basis_set, basis_set, basis_set, xyz_path, deriv_order, options, 'core') 50 | # Compute integrals 51 | S = oei_obj.overlap(geom) 52 | T = oei_obj.kinetic(geom) 53 | V = oei_obj.potential(geom) 54 | G = tei_obj.eri(geom) 55 | 56 | libint_interface.finalize() 57 | return S, T, V, G 58 | 59 | def compute_dipole_ints(geom, basis1, basis2, xyz_path, deriv_order, options): 60 | # Load integral algo, decides to compute integrals in memory or use disk 61 | algo = options['integral_algo'] 62 | basis1_name = basis1.name() 63 | basis2_name = basis2.name() 64 | libint_interface.initialize(xyz_path, basis1_name, basis2_name, basis1_name, basis2_name, options['ints_tolerance']) 65 | 66 | if algo == 'libint_disk': 67 | # Check disk for currently existing integral derivatives 68 | check_multipole = check_multipole_disk('dipole', basis1, basis2, deriv_order) 69 | 70 | oei_obj = OEI(basis1, basis2, xyz_path, deriv_order, 'disk') 71 | # If disk integral derivs are right, nothing to do 72 | if check_multipole: 73 | Mu_ = oei_obj.dipole(geom) 74 | else: 75 | with open(xyz_path, 'r') as f: 76 | tmp = f.read() 77 | com = psi4.core.Molecule.from_string(tmp, 'xyz+').center_of_mass() 78 | com = list([com[0], com[1], com[2]]) 79 | 80 | libint_interface.compute_dipole_deriv_disk(com, deriv_order) 81 | Mu_ = oei_obj.dipole(geom) 82 | else: 83 | # Precompute TEI derivatives 84 | oei_obj = OEI(basis1, basis2, xyz_path, deriv_order, 'dipole') 85 | # Compute integrals 86 | Mu_ = oei_obj.dipole(geom) 87 | 88 | libint_interface.finalize() 89 | return Mu_ 90 | 91 | def compute_quadrupole_ints(geom, basis1, basis2, xyz_path, deriv_order, options): 92 | # Load integral algo, decides to compute integrals in memory or use disk 93 | algo = options['integral_algo'] 94 | basis1_name = basis1.name() 95 | basis2_name = basis2.name() 96 | libint_interface.initialize(xyz_path, basis1_name, basis2_name, basis1_name, basis2_name, options['ints_tolerance']) 97 | 98 | if algo == 'libint_disk': 99 | # Check disk for currently existing integral derivatives 100 | check_multipole = check_multipole_disk('quadrupole', basis1, basis2, deriv_order) 101 | 102 | oei_obj = OEI(basis1, basis2, xyz_path, deriv_order, 'disk') 103 | # If disk integral derivs are right, nothing to do 104 | if check_multipole: 105 | Mu_Th = oei_obj.quadrupole(geom) 106 | else: 107 | libint_interface.compute_quadrupole_deriv_disk(deriv_order) 108 | Mu_Th = oei_obj.quadrupole(geom) 109 | else: 110 | # Precompute TEI derivatives 111 | oei_obj = OEI(basis1, basis2, xyz_path, deriv_order, 'dipole') 112 | # Compute integrals 113 | Mu_Th = oei_obj.quadrupole(geom) 114 | 115 | libint_interface.finalize() 116 | return Mu_Th 117 | 118 | def compute_f12_oeints(geom, basis1, basis2, xyz_path, deriv_order, options, cabs): 119 | # Load integral algo, decides to compute integrals in memory or use disk 120 | algo = options['integral_algo'] 121 | basis1_name = basis1.name() 122 | basis2_name = basis2.name() 123 | libint_interface.initialize(xyz_path, basis1_name, basis2_name, basis1_name, basis2_name, options['ints_tolerance']) 124 | 125 | if cabs: 126 | if algo == 'libint_disk': 127 | # Check disk for currently existing integral derivatives 128 | check = check_oei_disk("overlap", basis1, basis2, deriv_order) 129 | 130 | oei_obj = OEI(basis1, basis2, xyz_path, deriv_order, 'disk') 131 | # If disk integral derivs are right, nothing to do 132 | if check: 133 | S = oei_obj.overlap(geom) 134 | else: 135 | libint_interface.compute_1e_deriv_disk("overlap", deriv_order) 136 | S = oei_obj.overlap(geom) 137 | 138 | else: 139 | # Precompute OEI derivatives 140 | oei_obj = OEI(basis1, basis2, xyz_path, deriv_order, 'f12') 141 | # Compute integrals 142 | S = oei_obj.overlap(geom) 143 | 144 | libint_interface.finalize() 145 | return S 146 | 147 | else: 148 | if algo == 'libint_disk': 149 | # Check disk for currently existing integral derivatives 150 | check_T = check_oei_disk("kinetic", basis1, basis2, deriv_order) 151 | check_V = check_oei_disk("potential", basis1, basis2, deriv_order) 152 | 153 | oei_obj = OEI(basis1, basis2, xyz_path, deriv_order, 'disk') 154 | # If disk integral derivs are right, nothing to do 155 | if check_T: 156 | T = oei_obj.kinetic(geom) 157 | else: 158 | libint_interface.compute_1e_deriv_disk("kinetic",deriv_order) 159 | T = oei_obj.kinetic(geom) 160 | 161 | if check_V: 162 | V = oei_obj.potential(geom) 163 | else: 164 | libint_interface.compute_1e_deriv_disk("potential", deriv_order) 165 | V = oei_obj.potential(geom) 166 | 167 | else: 168 | # Precompute OEI derivatives 169 | oei_obj = OEI(basis1, basis2, xyz_path, deriv_order, 'f12') 170 | # Compute integrals 171 | T = oei_obj.kinetic(geom) 172 | V = oei_obj.potential(geom) 173 | 174 | libint_interface.finalize() 175 | return T, V 176 | 177 | def compute_f12_teints(geom, basis1, basis2, basis3, basis4, int_type, xyz_path, deriv_order, options): 178 | # Load integral algo, decides to compute integrals in memory or use disk 179 | algo = options['integral_algo'] 180 | beta = options['beta'] 181 | basis1_name = basis1.name() 182 | basis2_name = basis2.name() 183 | basis3_name = basis3.name() 184 | basis4_name = basis4.name() 185 | libint_interface.initialize(xyz_path, basis1_name, basis2_name, basis3_name, basis4_name, options['ints_tolerance']) 186 | 187 | if algo == 'libint_disk': 188 | # Check disk for currently existing integral derivatives 189 | check = check_tei_disk(int_type, basis1, basis2, basis3, basis4, deriv_order) 190 | 191 | tei_obj = TEI(basis1, basis2, basis3, basis4, xyz_path, deriv_order, options, 'disk') 192 | # If disk integral derivs are right, nothing to do 193 | if check: 194 | match int_type: 195 | case "f12": 196 | F = tei_obj.f12(geom, beta) 197 | case "f12_squared": 198 | F = tei_obj.f12_squared(geom, beta) 199 | case "f12g12": 200 | F = tei_obj.f12g12(geom, beta) 201 | case "f12_double_commutator": 202 | F = tei_obj.f12_double_commutator(geom, beta) 203 | case "eri": 204 | F = tei_obj.eri(geom) 205 | else: 206 | match int_type: 207 | case "f12": 208 | libint_interface.compute_2e_deriv_disk(int_type, beta, deriv_order) 209 | F = tei_obj.f12(geom, beta) 210 | case "f12_squared": 211 | libint_interface.compute_2e_deriv_disk(int_type, beta, deriv_order) 212 | F = tei_obj.f12_squared(geom, beta) 213 | case "f12g12": 214 | libint_interface.compute_2e_deriv_disk(int_type, beta, deriv_order) 215 | F = tei_obj.f12g12(geom, beta) 216 | case "f12_double_commutator": 217 | libint_interface.compute_2e_deriv_disk(int_type, beta, deriv_order) 218 | F = tei_obj.f12_double_commutator(geom, beta) 219 | case "eri": 220 | libint_interface.compute_2e_deriv_disk(int_type, 0., deriv_order) 221 | F = tei_obj.eri(geom) 222 | 223 | else: 224 | # Precompute TEI derivatives 225 | tei_obj = TEI(basis1, basis2, basis3, basis4, xyz_path, deriv_order, options, 'f12') 226 | # Compute integrals 227 | match int_type: 228 | case "f12": 229 | F = tei_obj.f12(geom, beta) 230 | case "f12_squared": 231 | F = tei_obj.f12_squared(geom, beta) 232 | case "f12g12": 233 | F = tei_obj.f12g12(geom, beta) 234 | case "f12_double_commutator": 235 | F = tei_obj.f12_double_commutator(geom, beta) 236 | case "eri": 237 | F = tei_obj.eri(geom) 238 | 239 | libint_interface.finalize() 240 | return F 241 | 242 | def check_oei_disk(int_type, basis1, basis2, deriv_order, address=None): 243 | # Check OEI's in compute_integrals 244 | correct_int_derivs = False 245 | correct_nbf1 = correct_nbf2 = correct_deriv_order = False 246 | 247 | if ((os.path.exists("oei_derivs.h5"))): 248 | print("Found currently existing one-electron integral derivatives in your working directory. Trying to use them.") 249 | oeifile = h5py.File('oei_derivs.h5', 'r') 250 | nbf1 = basis1.nbf() 251 | nbf2 = basis2.nbf() 252 | 253 | if int_type == "all": 254 | oei_name = ["overlap_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order),\ 255 | "kinetic_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order),\ 256 | "potential_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order)] 257 | else: 258 | oei_name = int_type + "_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order) 259 | 260 | for name in list(oeifile.keys()): 261 | if name in oei_name: 262 | correct_nbf1 = oeifile[name].shape[0] == nbf1 263 | correct_nbf2 = oeifile[name].shape[1] == nbf2 264 | correct_deriv_order = True 265 | oeifile.close() 266 | 267 | correct_int_derivs = correct_deriv_order and correct_nbf1 and correct_nbf2 268 | 269 | if correct_int_derivs: 270 | print("Integral derivatives appear to be correct. Avoiding recomputation.") 271 | return correct_int_derivs 272 | 273 | """ # TODO flesh out this logic for determining if partials file contains all integrals needed 274 | # for particular address 275 | elif (os.path.exists("oei_partials.h5")): 276 | print("Found currently existing partial oei derivatives in working directory. Assuming they are correct.") 277 | oeifile = h5py.File('oei_partials.h5', 'r') 278 | with open(xyz_path, 'r') as f: 279 | nbf1 = basis1.nbf() 280 | nbf2 = basis2.nbf() 281 | # Check if there are `deriv_order` datasets in the eri file 282 | correct_deriv_order = len(oeifile) == deriv_order 283 | # Check nbf dimension of integral arrays 284 | sample_dataset_name = list(oeifile.keys())[0] 285 | correct_nbf1 = oeifile[sample_dataset_name].shape[0] == nbf1 286 | correct_nbf2 = oeifile[sample_dataset_name].shape[1] == nbf2 287 | oeifile.close() 288 | correct_int_derivs = correct_deriv_order and correct_nbf1 and correct_nbf2 """ 289 | 290 | def check_multipole_disk(int_type, basis1, basis2, deriv_order, address=None): 291 | # Check OEI's in compute_integrals 292 | correct_int_derivs = False 293 | correct_nbf1 = correct_nbf2 = correct_deriv_order = False 294 | 295 | if ((os.path.exists(int_type + "_derivs.h5"))): 296 | print("Found currently existing multipole integral derivatives in your working directory. Trying to use them.") 297 | oeifile = h5py.File(int_type + '_derivs.h5', 'r') 298 | nbf1 = basis1.nbf() 299 | nbf2 = basis2.nbf() 300 | 301 | if int_type == "dipole": 302 | oei_name = ["mu_x_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order), 303 | "mu_y_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order), 304 | "mu_z_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order)] 305 | elif int_type == "quadrupole": 306 | oei_name = ["mu_x_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order), 307 | "mu_y_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order), 308 | "mu_z_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order), 309 | "th_xx_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order), 310 | "th_xy_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order), 311 | "th_xz_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order), 312 | "th_yy_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order), 313 | "th_yz_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order), 314 | "th_zz_" + str(nbf1) + "_" + str(nbf2) + "_deriv" + str(deriv_order)] 315 | else: 316 | raise Exception("Integral type not recognized.") 317 | 318 | for name in list(oeifile.keys()): 319 | if name in oei_name: 320 | correct_nbf1 = oeifile[name].shape[0] == nbf1 321 | correct_nbf2 = oeifile[name].shape[1] == nbf2 322 | correct_deriv_order = True 323 | oeifile.close() 324 | 325 | correct_int_derivs = correct_deriv_order and correct_nbf1 and correct_nbf2 326 | 327 | if correct_int_derivs: 328 | print("Integral derivatives appear to be correct. Avoiding recomputation.") 329 | return correct_int_derivs 330 | 331 | def check_tei_disk(int_type, basis1, basis2, basis3, basis4, deriv_order, address=None): 332 | # Check TEI's in compute_integrals 333 | correct_int_derivs = False 334 | correct_nbf1 = correct_nbf2 = correct_nbf3 = correct_nbf4 = correct_deriv_order = False 335 | 336 | if ((os.path.exists(int_type + "_derivs.h5"))): 337 | print("Found currently existing " + int_type + " integral derivatives in your working directory. Trying to use them.") 338 | erifile = h5py.File(int_type + '_derivs.h5', 'r') 339 | nbf1 = basis1.nbf() 340 | nbf2 = basis2.nbf() 341 | nbf3 = basis3.nbf() 342 | nbf4 = basis4.nbf() 343 | 344 | tei_name = int_type + "_" + str(nbf1) + "_" + str(nbf2)\ 345 | + "_" + str(nbf3) + "_" + str(nbf4) + "_deriv" + str(deriv_order) 346 | 347 | # Check nbf dimension of integral arrays 348 | for name in list(erifile.keys()): 349 | if name in tei_name: 350 | correct_nbf1 = erifile[name].shape[0] == nbf1 351 | correct_nbf2 = erifile[name].shape[1] == nbf2 352 | correct_nbf3 = erifile[name].shape[2] == nbf3 353 | correct_nbf4 = erifile[name].shape[3] == nbf4 354 | correct_deriv_order = True 355 | erifile.close() 356 | correct_int_derivs = correct_deriv_order and correct_nbf1 and correct_nbf2 and correct_nbf3 and correct_nbf4 357 | 358 | if correct_int_derivs: 359 | print("Integral derivatives appear to be correct. Avoiding recomputation.") 360 | return correct_int_derivs 361 | 362 | """ # TODO flesh out this logic for determining if partials file contains all integrals needed 363 | # for particular address 364 | elif ((os.path.exists("eri_partials.h5"))): 365 | print("Found currently existing partial tei derivatives in working directory. Assuming they are correct.") 366 | erifile = h5py.File('eri_partials.h5', 'r') 367 | nbf1 = basis1.nbf() 368 | nbf2 = basis2.nbf() 369 | nbf3 = basis3.nbf() 370 | nbf4 = basis4.nbf() 371 | sample_dataset_name = list(erifile.keys())[0] 372 | correct_nbf1 = erifile[sample_dataset_name].shape[0] == nbf1 373 | correct_nbf2 = erifile[sample_dataset_name].shape[1] == nbf2 374 | correct_nbf3 = erifile[sample_dataset_name].shape[2] == nbf3 375 | correct_nbf4 = erifile[sample_dataset_name].shape[3] == nbf4 376 | erifile.close() 377 | correct_int_derivs = correct_deriv_order and correct_nbf1 and correct_nbf2 and correct_nbf3 and correct_nbf4 378 | if correct_int_derivs: 379 | print("Integral derivatives appear to be correct. Avoiding recomputation.") 380 | return correct_int_derivs 381 | """ 382 | -------------------------------------------------------------------------------- /quax/integrals/tei.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import h5py 5 | import os 6 | import psi4 7 | from . import libint_interface 8 | from ..utils import get_deriv_vec_idx, how_many_derivs 9 | 10 | jax.config.update("jax_enable_x64", True) 11 | 12 | class TEI(object): 13 | 14 | def __init__(self, basis1, basis2, basis3, basis4, xyz_path, max_deriv_order, options, mode): 15 | with open(xyz_path, 'r') as f: 16 | tmp = f.read() 17 | molecule = psi4.core.Molecule.from_string(tmp, 'xyz+') 18 | natoms = molecule.natom() 19 | 20 | nbf1 = basis1.nbf() 21 | nbf2 = basis2.nbf() 22 | nbf3 = basis3.nbf() 23 | nbf4 = basis4.nbf() 24 | 25 | if mode == 'core' and max_deriv_order > 0: 26 | # A list of ERI derivative tensors, containing only unique elements 27 | # corresponding to upper hypertriangle (since derivative tensors are symmetric) 28 | # Length of tuple is maximum deriv order, each array is (upper triangle derivatives,nbf,nbf,nbf,nbf) 29 | # Then when JAX calls JVP, read appropriate slice 30 | self.eri_derivatives = [] 31 | for i in range(max_deriv_order): 32 | n_unique_derivs = how_many_derivs(natoms, i + 1) 33 | eri_deriv = libint_interface.eri_deriv_core(i + 1).reshape(n_unique_derivs, nbf1, nbf2, nbf3, nbf4) 34 | self.eri_derivatives.append(eri_deriv) 35 | 36 | self.mode = mode 37 | self.nbf1 = nbf1 38 | self.nbf2 = nbf2 39 | self.nbf3 = nbf3 40 | self.nbf4 = nbf4 41 | 42 | # Create new JAX primitive for TEI evaluation 43 | self.eri_p = jax.core.Primitive("eri") 44 | self.eri_deriv_p = jax.core.Primitive("eri_deriv") 45 | self.f12_p = jax.core.Primitive("f12") 46 | self.f12_deriv_p = jax.core.Primitive("f12_deriv") 47 | self.f12_squared_p = jax.core.Primitive("f12_squared") 48 | self.f12_squared_deriv_p = jax.core.Primitive("f12_squared_deriv") 49 | self.f12g12_p = jax.core.Primitive("f12g12") 50 | self.f12g12_deriv_p = jax.core.Primitive("f12g12_deriv") 51 | self.f12_double_commutator_p = jax.core.Primitive("f12_double_commutator") 52 | self.f12_double_commutator_deriv_p = jax.core.Primitive("f12_double_commutator_deriv") 53 | 54 | # Register primitive evaluation rules 55 | self.eri_p.def_impl(self.eri_impl) 56 | self.eri_deriv_p.def_impl(self.eri_deriv_impl) 57 | self.f12_p.def_impl(self.f12_impl) 58 | self.f12_deriv_p.def_impl(self.f12_deriv_impl) 59 | self.f12_squared_p.def_impl(self.f12_squared_impl) 60 | self.f12_squared_deriv_p.def_impl(self.f12_squared_deriv_impl) 61 | self.f12g12_p.def_impl(self.f12g12_impl) 62 | self.f12g12_deriv_p.def_impl(self.f12g12_deriv_impl) 63 | self.f12_double_commutator_p.def_impl(self.f12_double_commutator_impl) 64 | self.f12_double_commutator_deriv_p.def_impl(self.f12_double_commutator_deriv_impl) 65 | 66 | # Register the JVP rules with JAX 67 | jax.interpreters.ad.primitive_jvps[self.eri_p] = self.eri_jvp 68 | jax.interpreters.ad.primitive_jvps[self.eri_deriv_p] = self.eri_deriv_jvp 69 | jax.interpreters.ad.primitive_jvps[self.f12_p] = self.f12_jvp 70 | jax.interpreters.ad.primitive_jvps[self.f12_deriv_p] = self.f12_deriv_jvp 71 | jax.interpreters.ad.primitive_jvps[self.f12_squared_p] = self.f12_squared_jvp 72 | jax.interpreters.ad.primitive_jvps[self.f12_squared_deriv_p] = self.f12_squared_deriv_jvp 73 | jax.interpreters.ad.primitive_jvps[self.f12g12_p] = self.f12g12_jvp 74 | jax.interpreters.ad.primitive_jvps[self.f12g12_deriv_p] = self.f12g12_deriv_jvp 75 | jax.interpreters.ad.primitive_jvps[self.f12_double_commutator_p] = self.f12_double_commutator_jvp 76 | jax.interpreters.ad.primitive_jvps[self.f12_double_commutator_deriv_p] = self.f12_double_commutator_deriv_jvp 77 | 78 | # Register tei_deriv batching rule with JAX 79 | jax.interpreters.batching.primitive_batchers[self.eri_deriv_p] = self.eri_deriv_batch 80 | jax.interpreters.batching.primitive_batchers[self.f12_deriv_p] = self.f12_deriv_batch 81 | jax.interpreters.batching.primitive_batchers[self.f12_squared_deriv_p] = self.f12_squared_deriv_batch 82 | jax.interpreters.batching.primitive_batchers[self.f12g12_deriv_p] = self.f12g12_deriv_batch 83 | jax.interpreters.batching.primitive_batchers[self.f12_double_commutator_deriv_p] = self.f12_double_commutator_deriv_batch 84 | 85 | # Create functions to call primitives 86 | def eri(self, geom): 87 | return self.eri_p.bind(geom) 88 | 89 | def eri_deriv(self, geom, deriv_vec): 90 | return self.eri_deriv_p.bind(geom, deriv_vec) 91 | 92 | def f12(self, geom, beta): 93 | return self.f12_p.bind(geom, beta) 94 | 95 | def f12_deriv(self, geom, beta, deriv_vec): 96 | return self.f12_deriv_p.bind(geom, beta, deriv_vec) 97 | 98 | def f12_squared(self, geom, beta): 99 | return self.f12_squared_p.bind(geom, beta) 100 | 101 | def f12_squared_deriv(self, geom, beta, deriv_vec): 102 | return self.f12_squared_deriv_p.bind(geom, beta, deriv_vec) 103 | 104 | def f12g12(self, geom, beta): 105 | return self.f12g12_p.bind(geom, beta) 106 | 107 | def f12g12_deriv(self, geom, beta, deriv_vec): 108 | return self.f12g12_deriv_p.bind(geom, beta, deriv_vec) 109 | 110 | def f12_double_commutator(self, geom, beta): 111 | return self.f12_double_commutator_p.bind(geom, beta) 112 | 113 | def f12_double_commutator_deriv(self, geom, beta, deriv_vec): 114 | return self.f12_double_commutator_deriv_p.bind(geom, beta, deriv_vec) 115 | 116 | # Create primitive evaluation rules 117 | def eri_impl(self, geom): 118 | G = libint_interface.compute_2e_int("eri", 0.) 119 | G = G.reshape(self.nbf1, self.nbf2, self.nbf3, self.nbf4) 120 | return jnp.asarray(G) 121 | 122 | def f12_impl(self, geom, beta): 123 | F = libint_interface.compute_2e_int("f12", beta) 124 | F = F.reshape(self.nbf1, self.nbf2, self.nbf3, self.nbf4) 125 | return jnp.asarray(F) 126 | 127 | def f12_squared_impl(self, geom, beta): 128 | F = libint_interface.compute_2e_int("f12_squared", beta) 129 | F = F.reshape(self.nbf1, self.nbf2, self.nbf3, self.nbf4) 130 | return jnp.asarray(F) 131 | 132 | def f12g12_impl(self, geom, beta): 133 | F = libint_interface.compute_2e_int("f12g12", beta) 134 | F = F.reshape(self.nbf1, self.nbf2, self.nbf3, self.nbf4) 135 | return jnp.asarray(F) 136 | 137 | def f12_double_commutator_impl(self, geom, beta): 138 | F = libint_interface.compute_2e_int("f12_double_commutator", beta) 139 | F = F.reshape(self.nbf1, self.nbf2, self.nbf3, self.nbf4) 140 | return jnp.asarray(F) 141 | 142 | def eri_deriv_impl(self, geom, deriv_vec): 143 | deriv_vec = np.asarray(deriv_vec, int) 144 | deriv_order = np.sum(deriv_vec) 145 | idx = get_deriv_vec_idx(deriv_vec) 146 | 147 | # Use eri derivatives in memory 148 | if self.mode == 'core': 149 | G = self.eri_derivatives[deriv_order-1][idx,:,:,:,:] 150 | return jnp.asarray(G) 151 | 152 | if self.mode == 'f12': 153 | G = libint_interface.compute_2e_deriv("eri", 0., deriv_vec) 154 | return jnp.asarray(G).reshape(self.nbf1, self.nbf2, self.nbf3, self.nbf4) 155 | 156 | # Read from disk 157 | elif self.mode == 'disk': 158 | # By default, look for full derivative tensor file with datasets named (type)_deriv(order) 159 | if os.path.exists("eri_derivs.h5"): 160 | file_name = "eri_derivs.h5" 161 | dataset_name = "eri_" + str(self.nbf1) + "_" + str(self.nbf2)\ 162 | + "_" + str(self.nbf3) + "_" + str(self.nbf4)\ 163 | + "_deriv" + str(deriv_order) 164 | # if not found, look for partial derivative tensor file with datasets named (type)_deriv(order)_(flattened_uppertri_idx) 165 | elif os.path.exists("eri_partials.h5"): 166 | file_name = "eri_partials.h5" 167 | dataset_name = "eri_" + str(self.nbf1) + "_" + str(self.nbf2)\ 168 | + "_" + str(self.nbf3) + "_" + str(self.nbf4)\ 169 | + "_deriv" + str(deriv_order) + "_" + str(idx) 170 | else: 171 | raise Exception("ERI derivatives not found on disk") 172 | 173 | with h5py.File(file_name, 'r') as f: 174 | data_set = f[dataset_name] 175 | if len(data_set.shape) == 5: 176 | G = data_set[:,:,:,:,idx] 177 | elif len(data_set.shape) == 4: 178 | G = data_set[:,:,:,:] 179 | else: 180 | raise Exception("Something went wrong reading integral derivative file") 181 | return jnp.asarray(G) 182 | 183 | def f12_deriv_impl(self, geom, beta, deriv_vec): 184 | deriv_vec = np.asarray(deriv_vec, int) 185 | deriv_order = np.sum(deriv_vec) 186 | idx = get_deriv_vec_idx(deriv_vec) 187 | 188 | # Use f12 derivatives in memory 189 | if self.mode == 'f12': 190 | F = libint_interface.compute_2e_deriv("f12", beta, deriv_vec) 191 | return jnp.asarray(F).reshape(self.nbf1, self.nbf2, self.nbf3, self.nbf4) 192 | 193 | # Read from disk 194 | elif self.mode == 'disk': 195 | # By default, look for full derivative tensor file with datasets named (type)_deriv(order) 196 | if os.path.exists("f12_derivs.h5"): 197 | file_name = "f12_derivs.h5" 198 | dataset_name = "f12_" + str(self.nbf1) + "_" + str(self.nbf2)\ 199 | + "_" + str(self.nbf3) + "_" + str(self.nbf4)\ 200 | + "_deriv" + str(deriv_order) 201 | # if not found, look for partial derivative tensor file with datasets named (type)_deriv(order)_(flattened_uppertri_idx) 202 | elif os.path.exists("f12_partials.h5"): 203 | file_name = "f12_partials.h5" 204 | dataset_name = "f12_" + str(self.nbf1) + "_" + str(self.nbf2)\ 205 | + "_" + str(self.nbf3) + "_" + str(self.nbf4)\ 206 | + "_deriv" + str(deriv_order) + "_" + str(idx) 207 | else: 208 | raise Exception("F12 derivatives not found on disk") 209 | 210 | with h5py.File(file_name, 'r') as f: 211 | data_set = f[dataset_name] 212 | if len(data_set.shape) == 5: 213 | F = data_set[:,:,:,:,idx] 214 | elif len(data_set.shape) == 4: 215 | F = data_set[:,:,:,:] 216 | else: 217 | raise Exception("Something went wrong reading integral derivative file") 218 | return jnp.asarray(F) 219 | 220 | def f12_squared_deriv_impl(self, geom, beta, deriv_vec): 221 | deriv_vec = np.asarray(deriv_vec, int) 222 | deriv_order = np.sum(deriv_vec) 223 | idx = get_deriv_vec_idx(deriv_vec) 224 | 225 | # Use f12 squared derivatives in memory 226 | if self.mode == 'f12': 227 | F = libint_interface.compute_2e_deriv("f12_squared", beta, deriv_vec) 228 | return jnp.asarray(F).reshape(self.nbf1, self.nbf2, self.nbf3, self.nbf4) 229 | 230 | # Read from disk 231 | elif self.mode == 'disk': 232 | # By default, look for full derivative tensor file with datasets named (type)_deriv(order) 233 | if os.path.exists("f12_squared_derivs.h5"): 234 | file_name = "f12_squared_derivs.h5" 235 | dataset_name = "f12_squared_" + str(self.nbf1) + "_" + str(self.nbf2)\ 236 | + "_" + str(self.nbf3) + "_" + str(self.nbf4)\ 237 | + "_deriv" + str(deriv_order) 238 | # if not found, look for partial derivative tensor file with datasets named (type)_deriv(order)_(flattened_uppertri_idx) 239 | elif os.path.exists("f12_squared_partials.h5"): 240 | file_name = "f12_squared_partials.h5" 241 | dataset_name = "f12_squared_" + str(self.nbf1) + "_" + str(self.nbf2)\ 242 | + "_" + str(self.nbf3) + "_" + str(self.nbf4)\ 243 | + "_deriv" + str(deriv_order) + "_" + str(idx) 244 | else: 245 | raise Exception("F12 Squared derivatives not found on disk") 246 | 247 | with h5py.File(file_name, 'r') as f: 248 | data_set = f[dataset_name] 249 | if len(data_set.shape) == 5: 250 | F = data_set[:,:,:,:,idx] 251 | elif len(data_set.shape) == 4: 252 | F = data_set[:,:,:,:] 253 | else: 254 | raise Exception("Something went wrong reading integral derivative file") 255 | return jnp.asarray(F) 256 | 257 | def f12g12_deriv_impl(self, geom, beta, deriv_vec): 258 | deriv_vec = np.asarray(deriv_vec, int) 259 | deriv_order = np.sum(deriv_vec) 260 | idx = get_deriv_vec_idx(deriv_vec) 261 | 262 | # Use f12g12 derivatives in memory 263 | if self.mode == 'f12': 264 | F = libint_interface.compute_2e_deriv("f12g12", beta, deriv_vec) 265 | return jnp.asarray(F).reshape(self.nbf1, self.nbf2, self.nbf3, self.nbf4) 266 | 267 | # Read from disk 268 | elif self.mode == 'disk': 269 | # By default, look for full derivative tensor file with datasets named (type)_deriv(order) 270 | if os.path.exists("f12g12_derivs.h5"): 271 | file_name = "f12g12_derivs.h5" 272 | dataset_name = "f12g12_" + str(self.nbf1) + "_" + str(self.nbf2)\ 273 | + "_" + str(self.nbf3) + "_" + str(self.nbf4)\ 274 | + "_deriv" + str(deriv_order) 275 | # if not found, look for partial derivative tensor file with datasets named (type)_deriv(order)_(flattened_uppertri_idx) 276 | elif os.path.exists("f12g12_partials.h5"): 277 | file_name = "f12g12_partials.h5" 278 | dataset_name = "f12g12_" + str(self.nbf1) + "_" + str(self.nbf2)\ 279 | + "_" + str(self.nbf3) + "_" + str(self.nbf4)\ 280 | + "_deriv" + str(deriv_order) + "_" + str(idx) 281 | else: 282 | raise Exception("F12G12 derivatives not found on disk") 283 | 284 | with h5py.File(file_name, 'r') as f: 285 | data_set = f[dataset_name] 286 | if len(data_set.shape) == 5: 287 | F = data_set[:,:,:,:,idx] 288 | elif len(data_set.shape) == 4: 289 | F = data_set[:,:,:,:] 290 | else: 291 | raise Exception("Something went wrong reading integral derivative file") 292 | return jnp.asarray(F) 293 | 294 | def f12_double_commutator_deriv_impl(self, geom, beta, deriv_vec): 295 | deriv_vec = np.asarray(deriv_vec, int) 296 | deriv_order = np.sum(deriv_vec) 297 | idx = get_deriv_vec_idx(deriv_vec) 298 | 299 | # Use f12 double commutator derivatives in memory 300 | if self.mode == 'f12': 301 | F = libint_interface.compute_2e_deriv("f12_double_commutator", beta, deriv_vec) 302 | return jnp.asarray(F).reshape(self.nbf1, self.nbf2, self.nbf3, self.nbf4) 303 | 304 | # Read from disk 305 | elif self.mode == 'disk': 306 | # By default, look for full derivative tensor file with datasets named (type)_deriv(order) 307 | if os.path.exists("f12_double_commutator_derivs.h5"): 308 | file_name = "f12_double_commutator_derivs.h5" 309 | dataset_name = "f12_double_commutator_" + str(self.nbf1) + "_" + str(self.nbf2)\ 310 | + "_" + str(self.nbf3) + "_" + str(self.nbf4)\ 311 | + "_deriv" + str(deriv_order) 312 | # if not found, look for partial derivative tensor file with datasets named (type)_deriv(order)_(flattened_uppertri_idx) 313 | elif os.path.exists("f12_double_commutator_partials.h5"): 314 | file_name = "f12_double_commutator_partials.h5" 315 | dataset_name = "f12_double_commutator_" + str(self.nbf1) + "_" + str(self.nbf2)\ 316 | + "_" + str(self.nbf3) + "_" + str(self.nbf4)\ 317 | + "_deriv" + str(deriv_order) + "_" + str(idx) 318 | else: 319 | raise Exception("F12 Double Commutator derivatives not found on disk") 320 | 321 | with h5py.File(file_name, 'r') as f: 322 | data_set = f[dataset_name] 323 | if len(data_set.shape) == 5: 324 | F = data_set[:,:,:,:,idx] 325 | elif len(data_set.shape) == 4: 326 | F = data_set[:,:,:,:] 327 | else: 328 | raise Exception("Something went wrong reading integral derivative file") 329 | return jnp.asarray(F) 330 | 331 | # Create Jacobian-vector product rule, which given some input args (primals) 332 | # and a tangent std basis vector (tangent), returns the function evaluated at that point (primals_out) 333 | # and the slice of the Jacobian (tangents_out) 334 | # For high-order differentiation, we add the current value of deriv_vec to the incoming tangent vector 335 | 336 | def eri_jvp(self, primals, tangents): 337 | geom, = primals 338 | primals_out = self.eri(geom) 339 | tangents_out = self.eri_deriv(geom, tangents[0]) 340 | return primals_out, tangents_out 341 | 342 | def eri_deriv_jvp(self, primals, tangents): 343 | geom, deriv_vec = primals 344 | primals_out = self.eri_deriv(geom, deriv_vec) 345 | tangents_out = self.eri_deriv(geom, deriv_vec + tangents[0]) 346 | return primals_out, tangents_out 347 | 348 | def f12_jvp(self, primals, tangents): 349 | geom, beta = primals 350 | primals_out = self.f12(geom, beta) 351 | tangents_out = self.f12_deriv(geom, beta, tangents[0]) 352 | return primals_out, tangents_out 353 | 354 | def f12_deriv_jvp(self, primals, tangents): 355 | geom, beta, deriv_vec = primals 356 | primals_out = self.f12_deriv(geom, beta, deriv_vec) 357 | tangents_out = self.f12_deriv(geom, beta, deriv_vec + tangents[0]) 358 | return primals_out, tangents_out 359 | 360 | def f12_squared_jvp(self, primals, tangents): 361 | geom, beta = primals 362 | primals_out = self.f12_squared(geom, beta) 363 | tangents_out = self.f12_squared_deriv(geom, beta, tangents[0]) 364 | return primals_out, tangents_out 365 | 366 | def f12_squared_deriv_jvp(self, primals, tangents): 367 | geom, beta, deriv_vec = primals 368 | primals_out = self.f12_squared_deriv(geom, beta, deriv_vec) 369 | tangents_out = self.f12_squared_deriv(geom, beta, deriv_vec + tangents[0]) 370 | return primals_out, tangents_out 371 | 372 | def f12g12_jvp(self, primals, tangents): 373 | geom, beta = primals 374 | primals_out = self.f12g12(geom, beta) 375 | tangents_out = self.f12g12_deriv(geom, beta, tangents[0]) 376 | return primals_out, tangents_out 377 | 378 | def f12g12_deriv_jvp(self, primals, tangents): 379 | geom, beta, deriv_vec = primals 380 | primals_out = self.f12g12_deriv(geom, beta, deriv_vec) 381 | tangents_out = self.f12g12_deriv(geom, beta, deriv_vec + tangents[0]) 382 | return primals_out, tangents_out 383 | 384 | def f12_double_commutator_jvp(self, primals, tangents): 385 | geom, beta = primals 386 | primals_out = self.f12_double_commutator(geom, beta) 387 | tangents_out = self.f12_double_commutator_deriv(geom, beta, tangents[0]) 388 | return primals_out, tangents_out 389 | 390 | def f12_double_commutator_deriv_jvp(self, primals, tangents): 391 | geom, beta, deriv_vec = primals 392 | primals_out = self.f12_double_commutator_deriv(geom, beta, deriv_vec) 393 | tangents_out = self.f12_double_commutator_deriv(geom, beta, deriv_vec + tangents[0]) 394 | return primals_out, tangents_out 395 | 396 | # Define Batching rules, this is only needed since jax.jacfwd will call vmap on the JVP of tei 397 | # When the input argument of deriv_batch is batched along the 0'th axis 398 | # we want to evaluate every 4d slice, gather up a (ncart, n,n,n,n) array, 399 | # (expand dims at 0 and concatenate at 0) 400 | # and then return the results, indicating the out batch axis 401 | # is in the 0th position (return results, 0) 402 | 403 | def eri_deriv_batch(self, batched_args, batch_dims): 404 | geom_batch, deriv_batch = batched_args 405 | geom_dim, deriv_dim = batch_dims 406 | results = [] 407 | for i in deriv_batch: 408 | tmp = self.eri_deriv(geom_batch, i) 409 | results.append(jnp.expand_dims(tmp, axis=0)) 410 | results = jnp.concatenate(results, axis=0) 411 | return results, 0 412 | 413 | def f12_deriv_batch(self, batched_args, batch_dims): 414 | geom_batch, beta_batch, deriv_batch = batched_args 415 | geom_dim, beta_dim, deriv_dim = batch_dims 416 | results = [] 417 | for i in deriv_batch: 418 | tmp = self.f12_deriv(geom_batch, beta_batch, i) 419 | results.append(jnp.expand_dims(tmp, axis=0)) 420 | results = jnp.concatenate(results, axis=0) 421 | return results, 0 422 | 423 | def f12_squared_deriv_batch(self, batched_args, batch_dims): 424 | geom_batch, beta_batch, deriv_batch = batched_args 425 | geom_dim, beta_dim, deriv_dim = batch_dims 426 | results = [] 427 | for i in deriv_batch: 428 | tmp = self.f12_squared_deriv(geom_batch, beta_batch, i) 429 | results.append(jnp.expand_dims(tmp, axis=0)) 430 | results = jnp.concatenate(results, axis=0) 431 | return results, 0 432 | 433 | def f12g12_deriv_batch(self, batched_args, batch_dims): 434 | geom_batch, beta_batch, deriv_batch = batched_args 435 | geom_dim, beta_dim, deriv_dim = batch_dims 436 | results = [] 437 | for i in deriv_batch: 438 | tmp = self.f12g12_deriv(geom_batch, beta_batch, i) 439 | results.append(jnp.expand_dims(tmp, axis=0)) 440 | results = jnp.concatenate(results, axis=0) 441 | return results, 0 442 | 443 | def f12_double_commutator_deriv_batch(self, batched_args, batch_dims): 444 | geom_batch, beta_batch, deriv_batch = batched_args 445 | geom_dim, beta_dim, deriv_dim = batch_dims 446 | results = [] 447 | for i in deriv_batch: 448 | tmp = self.f12_double_commutator_deriv(geom_batch, beta_batch, i) 449 | results.append(jnp.expand_dims(tmp, axis=0)) 450 | results = jnp.concatenate(results, axis=0) 451 | return results, 0 452 | -------------------------------------------------------------------------------- /quax/integrals/oei.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import h5py 5 | import os 6 | import psi4 7 | from . import libint_interface 8 | from ..utils import get_deriv_vec_idx, how_many_derivs 9 | 10 | jax.config.update("jax_enable_x64", True) 11 | 12 | class OEI(object): 13 | 14 | def __init__(self, basis1, basis2, xyz_path, max_deriv_order, mode): 15 | with open(xyz_path, 'r') as f: 16 | tmp = f.read() 17 | mol = psi4.core.Molecule.from_string(tmp, 'xyz+') 18 | natoms = mol.natom() 19 | 20 | nbf1 = basis1.nbf() 21 | nbf2 = basis2.nbf() 22 | 23 | if mode == 'core' and max_deriv_order > 0: 24 | # A list of OEI derivative tensors, containing only unique elements 25 | # corresponding to upper hypertriangle (since derivative tensors are symmetric) 26 | # Length of tuple is maximum deriv order, each array is (upper triangle derivatives,nbf,nbf) 27 | # Then when JAX calls JVP, read appropriate slice 28 | self.overlap_derivatives = [] 29 | self.kinetic_derivatives = [] 30 | self.potential_derivatives = [] 31 | for i in range(max_deriv_order): 32 | n_unique_derivs = how_many_derivs(natoms, i + 1) 33 | oei_deriv = libint_interface.oei_deriv_core(i + 1) 34 | self.overlap_derivatives.append(oei_deriv[0].reshape(n_unique_derivs, nbf1, nbf2)) 35 | self.kinetic_derivatives.append(oei_deriv[1].reshape(n_unique_derivs, nbf1, nbf2)) 36 | self.potential_derivatives.append(oei_deriv[2].reshape(n_unique_derivs, nbf1, nbf2)) 37 | 38 | 39 | self.mode = mode 40 | self.nbf1 = nbf1 41 | self.nbf2 = nbf2 42 | 43 | com = mol.center_of_mass() 44 | self.com = list([com[0], com[1], com[2]]) 45 | 46 | # Create new JAX primitives for overlap, kinetic, potential evaluation and their derivatives 47 | self.overlap_p = jax.core.Primitive("overlap") 48 | self.overlap_deriv_p = jax.core.Primitive("overlap_deriv") 49 | self.kinetic_p = jax.core.Primitive("kinetic") 50 | self.kinetic_deriv_p = jax.core.Primitive("kinetic_deriv") 51 | self.potential_p = jax.core.Primitive("potential") 52 | self.potential_deriv_p = jax.core.Primitive("potential_deriv") 53 | self.dipole_p = jax.core.Primitive("dipole") 54 | self.dipole_deriv_p = jax.core.Primitive("dipole_deriv") 55 | self.quadrupole_p = jax.core.Primitive("quadrupole") 56 | self.quadrupole_deriv_p = jax.core.Primitive("quadrupole_deriv") 57 | 58 | # Register primitive evaluation rules 59 | self.overlap_p.def_impl(self.overlap_impl) 60 | self.overlap_deriv_p.def_impl(self.overlap_deriv_impl) 61 | self.kinetic_p.def_impl(self.kinetic_impl) 62 | self.kinetic_deriv_p.def_impl(self.kinetic_deriv_impl) 63 | self.potential_p.def_impl(self.potential_impl) 64 | self.potential_deriv_p.def_impl(self.potential_deriv_impl) 65 | self.dipole_p.def_impl(self.dipole_impl) 66 | self.dipole_deriv_p.def_impl(self.dipole_deriv_impl) 67 | self.quadrupole_p.def_impl(self.quadrupole_impl) 68 | self.quadrupole_deriv_p.def_impl(self.quadrupole_deriv_impl) 69 | 70 | # Register the JVP rules with JAX 71 | jax.interpreters.ad.primitive_jvps[self.overlap_p] = self.overlap_jvp 72 | jax.interpreters.ad.primitive_jvps[self.overlap_deriv_p] = self.overlap_deriv_jvp 73 | jax.interpreters.ad.primitive_jvps[self.kinetic_p] = self.kinetic_jvp 74 | jax.interpreters.ad.primitive_jvps[self.kinetic_deriv_p] = self.kinetic_deriv_jvp 75 | jax.interpreters.ad.primitive_jvps[self.potential_p] = self.potential_jvp 76 | jax.interpreters.ad.primitive_jvps[self.potential_deriv_p] = self.potential_deriv_jvp 77 | jax.interpreters.ad.primitive_jvps[self.dipole_p] = self.dipole_jvp 78 | jax.interpreters.ad.primitive_jvps[self.dipole_deriv_p] = self.dipole_deriv_jvp 79 | jax.interpreters.ad.primitive_jvps[self.quadrupole_p] = self.quadrupole_jvp 80 | jax.interpreters.ad.primitive_jvps[self.quadrupole_deriv_p] = self.quadrupole_deriv_jvp 81 | 82 | # Register the batching rules with JAX 83 | jax.interpreters.batching.primitive_batchers[self.overlap_deriv_p] = self.overlap_deriv_batch 84 | jax.interpreters.batching.primitive_batchers[self.kinetic_deriv_p] = self.kinetic_deriv_batch 85 | jax.interpreters.batching.primitive_batchers[self.potential_deriv_p] = self.potential_deriv_batch 86 | jax.interpreters.batching.primitive_batchers[self.dipole_deriv_p] = self.dipole_deriv_batch 87 | jax.interpreters.batching.primitive_batchers[self.quadrupole_deriv_p] = self.quadrupole_deriv_batch 88 | 89 | # Create functions to call primitives 90 | def overlap(self, geom): 91 | return self.overlap_p.bind(geom) 92 | 93 | def overlap_deriv(self, geom, deriv_vec): 94 | return self.overlap_deriv_p.bind(geom, deriv_vec) 95 | 96 | def kinetic(self, geom): 97 | return self.kinetic_p.bind(geom) 98 | 99 | def kinetic_deriv(self, geom, deriv_vec): 100 | return self.kinetic_deriv_p.bind(geom, deriv_vec) 101 | 102 | def potential(self, geom): 103 | return self.potential_p.bind(geom) 104 | 105 | def potential_deriv(self, geom, deriv_vec): 106 | return self.potential_deriv_p.bind(geom, deriv_vec) 107 | 108 | def dipole(self, geom): 109 | return self.dipole_p.bind(geom) 110 | 111 | def dipole_deriv(self, geom, deriv_vec): 112 | return self.dipole_deriv_p.bind(geom, deriv_vec) 113 | 114 | def quadrupole(self, geom): 115 | return self.quadrupole_p.bind(geom) 116 | 117 | def quadrupole_deriv(self, geom, deriv_vec): 118 | return self.quadrupole_deriv_p.bind(geom, deriv_vec) 119 | 120 | # Create primitive evaluation rules 121 | def overlap_impl(self, geom): 122 | S = libint_interface.compute_1e_int("overlap") 123 | S = S.reshape(self.nbf1, self.nbf2) 124 | return jnp.asarray(S) 125 | 126 | def kinetic_impl(self, geom): 127 | T = libint_interface.compute_1e_int("kinetic") 128 | T = T.reshape(self.nbf1, self.nbf2) 129 | return jnp.asarray(T) 130 | 131 | def potential_impl(self, geom): 132 | V = libint_interface.compute_1e_int("potential") 133 | V = V.reshape(self.nbf1, self.nbf2) 134 | return jnp.asarray(V) 135 | 136 | def dipole_impl(self, geom): 137 | Mu_X, Mu_Y, Mu_Z = libint_interface.compute_dipole_ints(self.com) 138 | Mu_X = Mu_X.reshape(self.nbf1, self.nbf2) 139 | Mu_Y = Mu_Y.reshape(self.nbf1, self.nbf2) 140 | Mu_Z = Mu_Z.reshape(self.nbf1, self.nbf2) 141 | return jnp.stack([Mu_X, Mu_Y, Mu_Z]) 142 | 143 | def quadrupole_impl(self, geom): 144 | Mu_X, Mu_Y, Mu_Z, Th_XX, Th_XY,\ 145 | Th_XZ, Th_YY, Th_YZ, Th_ZZ = libint_interface.compute_quadrupole_ints(self.com) 146 | Mu_X = Mu_X.reshape(self.nbf1, self.nbf2) 147 | Mu_Y = Mu_Y.reshape(self.nbf1, self.nbf2) 148 | Mu_Z = Mu_Z.reshape(self.nbf1, self.nbf2) 149 | Th_XX = Th_XX.reshape(self.nbf1, self.nbf2) 150 | Th_XY = Th_XY.reshape(self.nbf1, self.nbf2) 151 | Th_XZ = Th_XZ.reshape(self.nbf1, self.nbf2) 152 | Th_YY = Th_YY.reshape(self.nbf1, self.nbf2) 153 | Th_YZ = Th_YZ.reshape(self.nbf1, self.nbf2) 154 | Th_ZZ = Th_ZZ.reshape(self.nbf1, self.nbf2) 155 | return jnp.stack([Mu_X, Mu_Y, Mu_Z, Th_XX, Th_XY, Th_XZ, Th_YY, Th_YZ, Th_ZZ]) 156 | 157 | def overlap_deriv_impl(self, geom, deriv_vec): 158 | deriv_vec = np.asarray(deriv_vec, int) 159 | deriv_order = np.sum(deriv_vec) 160 | idx = get_deriv_vec_idx(deriv_vec) 161 | 162 | if self.mode == 'core': 163 | S = self.overlap_derivatives[deriv_order-1][idx,:,:] 164 | return jnp.asarray(S) 165 | if self.mode == 'f12': 166 | S = libint_interface.compute_1e_deriv("overlap", deriv_vec) 167 | return jnp.asarray(S).reshape(self.nbf1,self.nbf2) 168 | elif self.mode == 'disk': 169 | if os.path.exists("oei_derivs.h5"): 170 | file_name = "oei_derivs.h5" 171 | dataset_name = "overlap_" + str(self.nbf1) + "_" + str(self.nbf2)\ 172 | + "_deriv" + str(deriv_order) 173 | elif os.path.exists("oei_partials.h5"): 174 | file_name = "oei_partials.h5" 175 | dataset_name = "overlap_" + str(self.nbf1) + "_" + str(self.nbf2)\ 176 | + "_deriv" + str(deriv_order) + "_" + str(idx) 177 | else: 178 | raise Exception("Something went wrong reading integral derivative file") 179 | with h5py.File(file_name, 'r') as f: 180 | data_set = f[dataset_name] 181 | if len(data_set.shape) == 3: 182 | S = data_set[:,:,idx] 183 | elif len(data_set.shape) == 2: 184 | S = data_set[:,:] 185 | else: 186 | raise Exception("Something went wrong reading integral derivative file") 187 | return jnp.asarray(S) 188 | 189 | def kinetic_deriv_impl(self, geom, deriv_vec): 190 | deriv_vec = np.asarray(deriv_vec, int) 191 | deriv_order = np.sum(deriv_vec) 192 | idx = get_deriv_vec_idx(deriv_vec) 193 | 194 | if self.mode == 'core': 195 | T = self.kinetic_derivatives[deriv_order-1][idx,:,:] 196 | return jnp.asarray(T) 197 | if self.mode == 'f12': 198 | T = libint_interface.compute_1e_deriv("kinetic", deriv_vec) 199 | return jnp.asarray(T).reshape(self.nbf1,self.nbf2) 200 | elif self.mode == 'disk': 201 | if os.path.exists("oei_derivs.h5"): 202 | file_name = "oei_derivs.h5" 203 | dataset_name = "kinetic_" + str(self.nbf1) + "_" + str(self.nbf2)\ 204 | + "_deriv" + str(deriv_order) 205 | elif os.path.exists("oei_partials.h5"): 206 | file_name = "oei_partials.h5" 207 | dataset_name = "kinetic_" + str(self.nbf1) + "_" + str(self.nbf2)\ 208 | + "_deriv" + str(deriv_order) + "_" + str(idx) 209 | else: 210 | raise Exception("Something went wrong reading integral derivative file") 211 | with h5py.File(file_name, 'r') as f: 212 | data_set = f[dataset_name] 213 | if len(data_set.shape) == 3: 214 | T = data_set[:,:,idx] 215 | elif len(data_set.shape) == 2: 216 | T = data_set[:,:] 217 | else: 218 | raise Exception("Something went wrong reading integral derivative file") 219 | return jnp.asarray(T) 220 | 221 | def potential_deriv_impl(self, geom, deriv_vec): 222 | deriv_vec = np.asarray(deriv_vec, int) 223 | deriv_order = np.sum(deriv_vec) 224 | idx = get_deriv_vec_idx(deriv_vec) 225 | 226 | if self.mode == 'core': 227 | V = self.potential_derivatives[deriv_order-1][idx,:,:] 228 | return jnp.asarray(V) 229 | if self.mode == 'f12': 230 | V = libint_interface.compute_1e_deriv("potential", deriv_vec) 231 | return jnp.asarray(V).reshape(self.nbf1,self.nbf2) 232 | elif self.mode == 'disk': 233 | if os.path.exists("oei_derivs.h5"): 234 | file_name = "oei_derivs.h5" 235 | dataset_name = "potential_" + str(self.nbf1) + "_" + str(self.nbf2)\ 236 | + "_deriv" + str(deriv_order) 237 | elif os.path.exists("oei_partials.h5"): 238 | file_name = "oei_partials.h5" 239 | dataset_name = "potential_" + str(self.nbf1) + "_" + str(self.nbf2)\ 240 | + "_deriv" + str(deriv_order) + "_" + str(idx) 241 | else: 242 | raise Exception("Something went wrong reading integral derivative file") 243 | with h5py.File(file_name, 'r') as f: 244 | data_set = f[dataset_name] 245 | if len(data_set.shape) == 3: 246 | V = data_set[:,:,idx] 247 | elif len(data_set.shape) == 2: 248 | V = data_set[:,:] 249 | else: 250 | raise Exception("Something went wrong reading integral derivative file") 251 | return jnp.asarray(V) 252 | 253 | def dipole_deriv_impl(self, geom, deriv_vec): 254 | deriv_vec = np.asarray(deriv_vec, int) 255 | deriv_order = np.sum(deriv_vec) 256 | idx = get_deriv_vec_idx(deriv_vec) 257 | 258 | if self.mode == 'dipole': 259 | Mu_X, Mu_Y, Mu_Z = libint_interface.compute_dipole_derivs(self.com, deriv_vec) 260 | Mu_X = Mu_X.reshape(self.nbf1, self.nbf2) 261 | Mu_Y = Mu_Y.reshape(self.nbf1, self.nbf2) 262 | Mu_Z = Mu_Z.reshape(self.nbf1, self.nbf2) 263 | return jnp.stack([Mu_X, Mu_Y, Mu_Z]) 264 | elif self.mode == 'disk': 265 | if os.path.exists("dipole_derivs.h5"): 266 | file_name = "dipole_derivs.h5" 267 | dataset1_name = "mu_x_" + str(self.nbf1) + "_" + str(self.nbf2)\ 268 | + "_deriv" + str(deriv_order) 269 | dataset2_name = "mu_y_" + str(self.nbf1) + "_" + str(self.nbf2)\ 270 | + "_deriv" + str(deriv_order) 271 | dataset3_name = "mu_z_" + str(self.nbf1) + "_" + str(self.nbf2)\ 272 | + "_deriv" + str(deriv_order) 273 | elif os.path.exists("dipole_partials.h5"): 274 | file_name = "dipole_partials.h5" 275 | dataset1_name = "mu_x_" + str(self.nbf1) + "_" + str(self.nbf2)\ 276 | + "_deriv" + str(deriv_order) + "_" + str(idx) 277 | dataset2_name = "mu_y_" + str(self.nbf1) + "_" + str(self.nbf2)\ 278 | + "_deriv" + str(deriv_order) + "_" + str(idx) 279 | dataset3_name = "mu_z_" + str(self.nbf1) + "_" + str(self.nbf2)\ 280 | + "_deriv" + str(deriv_order) + "_" + str(idx) 281 | else: 282 | raise Exception("Something went wrong reading integral derivative file") 283 | with h5py.File(file_name, 'r') as f: 284 | mu_x_set = f[dataset1_name] 285 | mu_y_set = f[dataset2_name] 286 | mu_z_set = f[dataset3_name] 287 | if len(mu_x_set.shape) == 3: 288 | Mu_X = mu_x_set[:,:,idx] 289 | Mu_Y = mu_y_set[:,:,idx] 290 | Mu_Z = mu_z_set[:,:,idx] 291 | elif len(mu_x_set.shape) == 2: 292 | Mu_X = mu_x_set[:,:] 293 | Mu_Y = mu_y_set[:,:] 294 | Mu_Z = mu_z_set[:,:] 295 | else: 296 | raise Exception("Something went wrong reading integral derivative file") 297 | return jnp.stack([Mu_X, Mu_Y, Mu_Z]) 298 | 299 | def quadrupole_deriv_impl(self, geom, deriv_vec): 300 | deriv_vec = np.asarray(deriv_vec, int) 301 | deriv_order = np.sum(deriv_vec) 302 | idx = get_deriv_vec_idx(deriv_vec) 303 | 304 | if self.mode == 'quadrupole': 305 | Mu_X, Mu_Y, Mu_Z, Th_XX, Th_XY,\ 306 | Th_XZ, Th_YY, Th_YZ, Th_ZZ = libint_interface.compute_quadrupole_derivs() 307 | Mu_X = Mu_X.reshape(self.nbf1, self.nbf2) 308 | Mu_Y = Mu_Y.reshape(self.nbf1, self.nbf2) 309 | Mu_Z = Mu_Z.reshape(self.nbf1, self.nbf2) 310 | Th_XX = Th_XX.reshape(self.nbf1, self.nbf2) 311 | Th_XY = Th_XY.reshape(self.nbf1, self.nbf2) 312 | Th_XZ = Th_XZ.reshape(self.nbf1, self.nbf2) 313 | Th_YY = Th_YY.reshape(self.nbf1, self.nbf2) 314 | Th_YZ = Th_YZ.reshape(self.nbf1, self.nbf2) 315 | Th_ZZ = Th_ZZ.reshape(self.nbf1, self.nbf2) 316 | return jnp.stack([Mu_X, Mu_Y, Mu_Z, Th_XX, Th_XY, Th_XZ, Th_YY, Th_YZ, Th_ZZ]) 317 | elif self.mode == 'disk': 318 | if os.path.exists("quadrupole_derivs.h5"): 319 | file_name = "quadrupole_derivs.h5" 320 | dataset1_name = "mu_x_" + str(self.nbf1) + "_" + str(self.nbf2)\ 321 | + "_deriv" + str(deriv_order) 322 | dataset2_name = "mu_y_" + str(self.nbf1) + "_" + str(self.nbf2)\ 323 | + "_deriv" + str(deriv_order) 324 | dataset3_name = "mu_z_" + str(self.nbf1) + "_" + str(self.nbf2)\ 325 | + "_deriv" + str(deriv_order) 326 | dataset4_name = "th_xx_" + str(self.nbf1) + "_" + str(self.nbf2)\ 327 | + "_deriv" + str(deriv_order) 328 | dataset5_name = "th_xy_" + str(self.nbf1) + "_" + str(self.nbf2)\ 329 | + "_deriv" + str(deriv_order) 330 | dataset6_name = "th_xz_" + str(self.nbf1) + "_" + str(self.nbf2)\ 331 | + "_deriv" + str(deriv_order) 332 | dataset7_name = "th_yy_" + str(self.nbf1) + "_" + str(self.nbf2)\ 333 | + "_deriv" + str(deriv_order) 334 | dataset8_name = "th_yz_" + str(self.nbf1) + "_" + str(self.nbf2)\ 335 | + "_deriv" + str(deriv_order) 336 | dataset9_name = "th_zz_" + str(self.nbf1) + "_" + str(self.nbf2)\ 337 | + "_deriv" + str(deriv_order) 338 | elif os.path.exists("quadrupole_partials.h5"): 339 | file_name = "quadrupole_partials.h5" 340 | dataset1_name = "mu_x_" + str(self.nbf1) + "_" + str(self.nbf2)\ 341 | + "_deriv" + str(deriv_order) + "_" + str(idx) 342 | dataset2_name = "mu_y_" + str(self.nbf1) + "_" + str(self.nbf2)\ 343 | + "_deriv" + str(deriv_order) + "_" + str(idx) 344 | dataset3_name = "mu_z_" + str(self.nbf1) + "_" + str(self.nbf2)\ 345 | + "_deriv" + str(deriv_order) + "_" + str(idx) 346 | dataset4_name = "th_xx_" + str(self.nbf1) + "_" + str(self.nbf2)\ 347 | + "_deriv" + str(deriv_order) + "_" + str(idx) 348 | dataset5_name = "th_xy_" + str(self.nbf1) + "_" + str(self.nbf2)\ 349 | + "_deriv" + str(deriv_order) + "_" + str(idx) 350 | dataset6_name = "th_xz_" + str(self.nbf1) + "_" + str(self.nbf2)\ 351 | + "_deriv" + str(deriv_order) + "_" + str(idx) 352 | dataset7_name = "th_yy_" + str(self.nbf1) + "_" + str(self.nbf2)\ 353 | + "_deriv" + str(deriv_order) + "_" + str(idx) 354 | dataset8_name = "th_yz_" + str(self.nbf1) + "_" + str(self.nbf2)\ 355 | + "_deriv" + str(deriv_order) + "_" + str(idx) 356 | dataset9_name = "th_zz_" + str(self.nbf1) + "_" + str(self.nbf2)\ 357 | + "_deriv" + str(deriv_order) + "_" + str(idx) 358 | else: 359 | raise Exception("Something went wrong reading integral derivative file") 360 | with h5py.File(file_name, 'r') as f: 361 | mu_x_set = f[dataset1_name] 362 | mu_y_set = f[dataset2_name] 363 | mu_z_set = f[dataset3_name] 364 | th_xx_set = f[dataset1_name] 365 | th_xy_set = f[dataset2_name] 366 | th_xz_set = f[dataset3_name] 367 | th_yy_set = f[dataset1_name] 368 | th_yz_set = f[dataset2_name] 369 | th_zz_set = f[dataset3_name] 370 | if len(mu_x_set.shape) == 3: 371 | Mu_X = mu_x_set[:,:,idx] 372 | Mu_Y = mu_y_set[:,:,idx] 373 | Mu_Z = mu_z_set[:,:,idx] 374 | Th_XX = th_xx_set[:,:,idx] 375 | Th_XY = th_xy_set[:,:,idx] 376 | Th_XZ = th_xz_set[:,:,idx] 377 | Th_YY = th_yy_set[:,:,idx] 378 | Th_YZ = th_yz_set[:,:,idx] 379 | Th_ZZ = th_zz_set[:,:,idx] 380 | elif len(mu_x_set.shape) == 2: 381 | Mu_X = mu_x_set[:,:] 382 | Mu_Y = mu_y_set[:,:] 383 | Mu_Z = mu_z_set[:,:] 384 | Th_XX = th_xx_set[:,:] 385 | Th_XY = th_xy_set[:,:] 386 | Th_XZ = th_xz_set[:,:] 387 | Th_YY = th_yy_set[:,:] 388 | Th_YZ = th_yz_set[:,:] 389 | Th_ZZ = th_zz_set[:,:] 390 | else: 391 | raise Exception("Something went wrong reading integral derivative file") 392 | return jnp.stack([Mu_X, Mu_Y, Mu_Z, Th_XX, Th_XY, Th_XZ, Th_YY, Th_YZ, Th_ZZ]) 393 | 394 | def overlap_jvp(self, primals, tangents): 395 | geom, = primals 396 | primals_out = self.overlap(geom) 397 | tangents_out = self.overlap_deriv(geom, tangents[0]) 398 | return primals_out, tangents_out 399 | 400 | def overlap_deriv_jvp(self, primals, tangents): 401 | geom, deriv_vec = primals 402 | primals_out = self.overlap_deriv(geom, deriv_vec) 403 | tangents_out = self.overlap_deriv(geom, deriv_vec + tangents[0]) 404 | return primals_out, tangents_out 405 | 406 | def kinetic_jvp(self, primals, tangents): 407 | geom, = primals 408 | primals_out = self.kinetic(geom) 409 | tangents_out = self.kinetic_deriv(geom, tangents[0]) 410 | return primals_out, tangents_out 411 | 412 | def kinetic_deriv_jvp(self, primals, tangents): 413 | geom, deriv_vec = primals 414 | primals_out = self.kinetic_deriv(geom, deriv_vec) 415 | tangents_out = self.kinetic_deriv(geom, deriv_vec + tangents[0]) 416 | return primals_out, tangents_out 417 | 418 | def potential_jvp(self, primals, tangents): 419 | geom, = primals 420 | primals_out = self.potential(geom) 421 | tangents_out = self.potential_deriv(geom, tangents[0]) 422 | return primals_out, tangents_out 423 | 424 | def potential_deriv_jvp(self, primals, tangents): 425 | geom, deriv_vec = primals 426 | primals_out = self.potential_deriv(geom, deriv_vec) 427 | tangents_out = self.potential_deriv(geom, deriv_vec + tangents[0]) 428 | return primals_out, tangents_out 429 | 430 | def dipole_jvp(self, primals, tangents): 431 | geom, = primals 432 | primals_out = self.dipole(geom) 433 | tangents_out = self.dipole_deriv(geom, tangents[0]) 434 | return primals_out, tangents_out 435 | 436 | def dipole_deriv_jvp(self, primals, tangents): 437 | geom, deriv_vec = primals 438 | primals_out = self.dipole_deriv(geom, deriv_vec) 439 | tangents_out = self.dipole_deriv(geom, deriv_vec + tangents[0]) 440 | return primals_out, tangents_out 441 | 442 | def quadrupole_jvp(self, primals, tangents): 443 | geom, = primals 444 | primals_out = self.quadrupole(geom) 445 | tangents_out = self.quadrupole_deriv(geom, tangents[0]) 446 | return primals_out, tangents_out 447 | 448 | def quadrupole_deriv_jvp(self, primals, tangents): 449 | geom, deriv_vec = primals 450 | primals_out = self.quadrupole_deriv(geom, deriv_vec) 451 | tangents_out = self.quadrupole_deriv(geom, deriv_vec + tangents[0]) 452 | return primals_out, tangents_out 453 | 454 | # Define Batching rules, this is only needed since jax.jacfwd will call vmap on the JVP's 455 | # of each oei function 456 | # When the input argument of deriv_batch is batched along the 0'th axis 457 | # we want to evaluate every 2d slice, gather up a (ncart, n,n) array, 458 | # (expand dims at 0 and concatenate at 0) 459 | # and then return the results, indicating the out batch axis 460 | # is in the 0th position (return results, 0) 461 | 462 | def overlap_deriv_batch(self, batched_args, batch_dims): 463 | geom_batch, deriv_batch = batched_args 464 | geom_dim, deriv_dim = batch_dims 465 | results = [] 466 | for i in deriv_batch: 467 | tmp = self.overlap_deriv(geom_batch, i) 468 | results.append(jnp.expand_dims(tmp, axis=0)) 469 | results = jnp.concatenate(results, axis=0) 470 | return results, 0 471 | 472 | def kinetic_deriv_batch(self, batched_args, batch_dims): 473 | geom_batch, deriv_batch = batched_args 474 | geom_dim, deriv_dim = batch_dims 475 | results = [] 476 | for i in deriv_batch: 477 | tmp = self.kinetic_deriv(geom_batch, i) 478 | results.append(jnp.expand_dims(tmp, axis=0)) 479 | results = jnp.concatenate(results, axis=0) 480 | return results, 0 481 | 482 | def potential_deriv_batch(self, batched_args, batch_dims): 483 | geom_batch, deriv_batch = batched_args 484 | geom_dim, deriv_dim = batch_dims 485 | results = [] 486 | for i in deriv_batch: 487 | tmp = self.potential_deriv(geom_batch, i) 488 | results.append(jnp.expand_dims(tmp, axis=0)) 489 | results = jnp.concatenate(results, axis=0) 490 | return results, 0 491 | 492 | def dipole_deriv_batch(self, batched_args, batch_dims): 493 | geom_batch, deriv_batch = batched_args 494 | geom_dim, deriv_dim = batch_dims 495 | results = [] 496 | for i in deriv_batch: 497 | tmp1, tmp2, tmp3 = self.dipole_deriv(geom_batch, i) 498 | mu_x = jnp.expand_dims(tmp1, axis=0) 499 | mu_y = jnp.expand_dims(tmp2, axis=0) 500 | mu_z = jnp.expand_dims(tmp3, axis=0) 501 | results.append(jnp.stack([mu_x, mu_y, mu_z], axis=1)) 502 | results = jnp.concatenate(results, axis=0) 503 | return results, 0 504 | 505 | def quadrupole_deriv_batch(self, batched_args, batch_dims): 506 | geom_batch, deriv_batch = batched_args 507 | geom_dim, deriv_dim = batch_dims 508 | results = [] 509 | for i in deriv_batch: 510 | tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9 = self.quadrupole_deriv(geom_batch, i) 511 | mu_x = jnp.expand_dims(tmp1, axis=0) 512 | mu_y = jnp.expand_dims(tmp2, axis=0) 513 | mu_z = jnp.expand_dims(tmp3, axis=0) 514 | th_xx = jnp.expand_dims(tmp4, axis=0) 515 | th_xy = jnp.expand_dims(tmp5, axis=0) 516 | th_xz = jnp.expand_dims(tmp6, axis=0) 517 | th_yy = jnp.expand_dims(tmp7, axis=0) 518 | th_yz = jnp.expand_dims(tmp8, axis=0) 519 | th_zz = jnp.expand_dims(tmp9, axis=0) 520 | results.append(jnp.stack([mu_x, mu_y, mu_z, th_xx, th_xy, th_xz, th_yy, th_yz, th_zz], axis=1)) 521 | results = jnp.concatenate(results, axis=0) 522 | return results, 0 523 | 524 | -------------------------------------------------------------------------------- /quax/core.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import jacfwd 3 | jax.config.update("jax_enable_x64", True) 4 | import jax.numpy as jnp 5 | import psi4 6 | import numpy as np 7 | import os 8 | 9 | from .methods.basis_utils import build_RIBS 10 | from .methods.hartree_fock import restricted_hartree_fock 11 | from .methods.mp2 import restricted_mp2 12 | from .methods.mp2f12 import restricted_mp2_f12 13 | from .methods.ccsd import rccsd 14 | from .methods.ccsd_t import rccsd_t 15 | from .utils import n_frozen_core 16 | 17 | psi4.core.be_quiet() 18 | 19 | def check_options(options): 20 | """ 21 | Checks user-supplied keyword options and assigns them 22 | 23 | Parameters 24 | ---------- 25 | options : dict 26 | Dictionary of options controlling electronic structure code parameters 27 | 28 | Returns 29 | ------- 30 | keyword_options : dict 31 | Dictionary of options controlling electronic structure code parameters 32 | """ 33 | # Add all additional keywords to here 34 | keyword_options = {'maxit': 100, 35 | 'damping': False, 36 | 'damp_factor': 0.5, 37 | 'guess_core': False, 38 | 'spectral_shift': True, 39 | 'integral_algo': 'libint_core', 40 | 'ints_tolerance': 1.0e-14, 41 | 'freeze_core': False, 42 | 'beta': 1.0, 43 | 'electric_field': 0 44 | } 45 | 46 | for key in options.keys(): 47 | if key in keyword_options.keys(): 48 | if type(options[key]) == type(keyword_options[key]): 49 | # Override default and assign, else print warning 50 | keyword_options[key] = options[key] 51 | else: 52 | print("Value '{}' for keyword option '{}' not recognized. Ignoring.".format(options[key],key)) 53 | else: 54 | print("{} keyword option not recognized.".format(key)) 55 | return keyword_options 56 | 57 | def compute_standard(method, method_args, deriv_order=0, partial=None, options=None): 58 | """ 59 | General function for computing energies, derivatives, and partial derivatives with respect to one input variable. 60 | """ 61 | # Energy and full derivative tensor evaluations 62 | if not partial: 63 | # Create energy evaluation function 64 | if method == 'scf' or method == 'hf' or method == 'rhf': 65 | def electronic_energy(*args, options=options, deriv_order=deriv_order): 66 | return restricted_hartree_fock(*args, options=options, deriv_order=deriv_order) 67 | elif method =='mp2': 68 | def electronic_energy(*args, options=options, deriv_order=deriv_order): 69 | return restricted_mp2(*args, options=options, deriv_order=deriv_order) 70 | elif method =='mp2-f12': 71 | def electronic_energy(*args, options=options, deriv_order=deriv_order): 72 | return restricted_mp2_f12(*args, options=options, deriv_order=deriv_order) 73 | elif method =='ccsd': 74 | def electronic_energy(*args, options=options, deriv_order=deriv_order): 75 | return rccsd(*args, options=options, deriv_order=deriv_order) 76 | elif method =='ccsd(t)': 77 | def electronic_energy(*args, options=options, deriv_order=deriv_order): 78 | return rccsd_t(*args, options=options, deriv_order=deriv_order) 79 | else: 80 | raise Exception("Error: Method {} not supported.".format(method)) 81 | 82 | # Evaluate energy or derivative 83 | if deriv_order == 0: 84 | energy = electronic_energy(*method_args) 85 | return energy 86 | elif deriv_order == 1: 87 | grad = jacfwd(electronic_energy, 0)(*method_args) 88 | deriv = jnp.round(grad, 10) 89 | elif deriv_order == 2: 90 | hess = jacfwd(jacfwd(electronic_energy, 0))(*method_args) 91 | deriv = jnp.round(hess, 10) 92 | elif deriv_order == 3: 93 | cubic = jacfwd(jacfwd(jacfwd(electronic_energy, 0)))(*method_args) 94 | deriv = jnp.round(cubic, 10) 95 | elif deriv_order == 4: 96 | quartic = jacfwd(jacfwd(jacfwd(jacfwd(electronic_energy, 0))))(*method_args) 97 | deriv = jnp.round(quartic, 10) 98 | else: 99 | raise Exception("Error: Order {} derivatives are not exposed to the API.".format(deriv_order)) 100 | deriv = 0 101 | return np.asarray(deriv) 102 | 103 | # Partial derivatives 104 | else: 105 | if len(partial) != deriv_order: 106 | raise Exception("The length of the index coordinates given by 'partial' argument should be the same as the order of differentiation") 107 | 108 | # For partial derivatives, need to unpack each geometric or electric field coordinate into separate arguments 109 | # to differentiate wrt specific coordinates using JAX AD utilities. 110 | param_list = method_args[0] 111 | 112 | #TODO support internal coordinate wrapper function. 113 | # This will take in internal coordinates, transform them into cartesians, and then compute integrals, energy 114 | # JAX will then collect the internal coordinate partial derivative instead. 115 | if method == 'scf' or method == 'hf' or method == 'rhf': 116 | def partial_wrapper(*args): 117 | param = jnp.asarray(args) 118 | args = (param,) + method_args[1:] 119 | E_scf = restricted_hartree_fock(*args, options=options, deriv_order=deriv_order, return_aux_data=False) 120 | return E_scf 121 | elif method =='mp2': 122 | def partial_wrapper(*args): 123 | param = jnp.asarray(args) 124 | args = (param,) + method_args[1:] 125 | E_mp2 = restricted_mp2(*args, options=options, deriv_order=deriv_order) 126 | return E_mp2 127 | elif method =='mp2-f12': 128 | def partial_wrapper(*args): 129 | param = jnp.asarray(args) 130 | args = (param,) + method_args[1:] 131 | E_mp2f12 = restricted_mp2_f12(*args, options=options, deriv_order=deriv_order) 132 | return E_mp2f12 133 | elif method =='ccsd': 134 | def partial_wrapper(*args): 135 | param = jnp.asarray(args) 136 | args = (param,) + method_args[1:] 137 | E_ccsd = rccsd(*args, options=options, deriv_order=deriv_order) 138 | return E_ccsd 139 | elif method =='ccsd(t)': 140 | def partial_wrapper(*args): 141 | param = jnp.asarray(args) 142 | args = (param,) + method_args[1:] 143 | E_ccsd_t = rccsd_t(*args, options=options, deriv_order=deriv_order) 144 | return E_ccsd_t 145 | else: 146 | raise Exception("Error: Method {} not supported.".format(method)) 147 | 148 | if deriv_order == 1: 149 | i = partial[0] 150 | partial_deriv = jacfwd(partial_wrapper, i)(*param_list) 151 | elif deriv_order == 2: 152 | i,j = partial[0], partial[1] 153 | partial_deriv = jacfwd(jacfwd(partial_wrapper, i), j)(*param_list) 154 | elif deriv_order == 3: 155 | i,j,k = partial[0], partial[1], partial[2] 156 | partial_deriv = jacfwd(jacfwd(jacfwd(partial_wrapper, i), j), k)(*param_list) 157 | elif deriv_order == 4: 158 | i,j,k,l = partial[0], partial[1], partial[2], partial[3] 159 | partial_deriv = jacfwd(jacfwd(jacfwd(jacfwd(partial_wrapper, i), j), k), l)(*param_list) 160 | elif deriv_order == 5: 161 | i,j,k,l,m = partial[0], partial[1], partial[2], partial[3], partial[4] 162 | partial_deriv = jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(partial_wrapper, i), j), k), l), m)(*param_list) 163 | elif deriv_order == 6: 164 | i,j,k,l,m,n = partial[0], partial[1], partial[2], partial[3], partial[4], partial[5] 165 | partial_deriv = jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(partial_wrapper, i), j), k), l), m), n)(*param_list) 166 | else: 167 | raise Exception("Error: Order {} partial derivatives are not exposed to the API.".format(deriv_order)) 168 | partial_deriv = 0 169 | return jnp.round(partial_deriv, 14) 170 | 171 | def compute_mixed(method, method_args, deriv_order_F=1, deriv_order_R=1, partial_F=None, partial_R=None, options=None): 172 | """ 173 | General function for computing energies, derivatives, and partial derivatives with respect to two input variables. 174 | """ 175 | # Number of differentiation calls depends on the total 176 | total_deriv_order = deriv_order_F + deriv_order_R 177 | 178 | # Energy and full derivative tensor evaluations 179 | if not partial_F or not partial_R: 180 | # Creates indices list to decide electric_field or coordinate differentiation 181 | FR_list = np.append(np.zeros(deriv_order_F, int), np.ones(deriv_order_R, int)) 182 | 183 | # Create energy evaluation function 184 | if method == 'scf' or method == 'hf' or method == 'rhf': 185 | def electronic_energy(*args, options=options, deriv_order=deriv_order_R): 186 | return restricted_hartree_fock(*args, options=options, deriv_order=deriv_order) 187 | elif method =='mp2': 188 | def electronic_energy(*args, options=options, deriv_order=deriv_order_R): 189 | return restricted_mp2(*args, options=options, deriv_order=deriv_order) 190 | elif method =='mp2-f12': 191 | def electronic_energy(*args, options=options, deriv_order=deriv_order_R): 192 | return restricted_mp2_f12(*args, options=options, deriv_order=deriv_order) 193 | elif method =='ccsd': 194 | def electronic_energy(*args, options=options, deriv_order=deriv_order_R): 195 | return rccsd(*args, options=options, deriv_order=deriv_order) 196 | elif method =='ccsd(t)': 197 | def electronic_energy(*args, options=options, deriv_order=deriv_order_R): 198 | return rccsd_t(*args, options=options, deriv_order=deriv_order) 199 | else: 200 | print("Desired electronic structure method not understood. Use 'scf' 'hf' 'mp2' 'ccsd' or 'ccsd(t)' ") 201 | 202 | if total_deriv_order == 2: 203 | i,j = FR_list[0], FR_list[1] 204 | deriv = jacfwd(jacfwd(electronic_energy, i), j)(*method_args) 205 | elif total_deriv_order == 3: 206 | i,j,k = FR_list[0], FR_list[1], FR_list[2] 207 | deriv = jacfwd(jacfwd(jacfwd(electronic_energy, i), j), k)(*method_args) 208 | elif total_deriv_order == 4: 209 | i,j,k,l = FR_list[0], FR_list[1], FR_list[2], FR_list[3] 210 | deriv = jacfwd(jacfwd(jacfwd(jacfwd(electronic_energy, i), j), k), l)(*method_args) 211 | elif total_deriv_order == 5: 212 | i,j,k,l,m = FR_list[0], FR_list[1], FR_list[2], FR_list[3], FR_list[4] 213 | deriv = jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(electronic_energy, i), j), k), l), m)(*method_args) 214 | elif total_deriv_order == 6: 215 | i,j,k,l,m,n = FR_list[0], FR_list[1], FR_list[2], FR_list[3], FR_list[4], FR_list[5] 216 | deriv = jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(electronic_energy, i), j), k), l), m), n)(*method_args) 217 | else: 218 | print("Error: Order {},{} mixed derivatives are not exposed to the API.".format(deriv_order_F, deriv_order_R)) 219 | deriv = 0 220 | deriv = jnp.round(deriv, 14) 221 | return np.asarray(deriv) 222 | 223 | # Partial derivatives 224 | else: 225 | if len(partial_F) != deriv_order_F or len(partial_R) != deriv_order_R: 226 | raise Exception("The length of the index coordinates given by 'partial' argument should be the same as the order of differentiation") 227 | 228 | # For partial derivatives, need to unpack each geometric or electric field coordinate into separate arguments 229 | # to differentiate wrt specific coordinates using JAX AD utilities. 230 | param_list = (*method_args[0],) + (*method_args[1],) 231 | 232 | #TODO support internal coordinate wrapper function. 233 | # This will take in internal coordinates, transform them into cartesians, and then compute integrals, energy 234 | # JAX will then collect the internal coordinate partial derivative instead. 235 | if method == 'scf' or method == 'hf' or method == 'rhf': 236 | def partial_wrapper(*args): 237 | param1 = jnp.asarray(args[0:3]) 238 | param2 = jnp.asarray(args[3:]) 239 | args = (param1, param2) + method_args[2:] 240 | E_scf = restricted_hartree_fock(*args, options=options, deriv_order=deriv_order_R, return_aux_data=False) 241 | return E_scf 242 | elif method =='mp2': 243 | def partial_wrapper(*args): 244 | param1 = jnp.asarray(args[0:3]) 245 | param2 = jnp.asarray(args[3:]) 246 | args = (param1, param2) + method_args[2:] 247 | E_mp2 = restricted_mp2(*args, options=options, deriv_order=deriv_order_R) 248 | return E_mp2 249 | elif method =='mp2-f12': 250 | def partial_wrapper(*args): 251 | param1 = jnp.asarray(args[0:3]) 252 | param2 = jnp.asarray(args[3:]) 253 | args = (param1, param2) + method_args[2:] 254 | E_mp2f12 = restricted_mp2_f12(*args, options=options, deriv_order=deriv_order_R) 255 | return E_mp2f12 256 | elif method =='ccsd': 257 | def partial_wrapper(*args): 258 | param1 = jnp.asarray(args[0:3]) 259 | param2 = jnp.asarray(args[3:]) 260 | args = (param1, param2) + method_args[2:] 261 | E_ccsd = rccsd(*args, options=options, deriv_order=deriv_order_R) 262 | return E_ccsd 263 | elif method =='ccsd(t)': 264 | def partial_wrapper(*args): 265 | param1 = jnp.asarray(args[0:3]) 266 | param2 = jnp.asarray(args[3:]) 267 | args = (param1, param2) + method_args[2:] 268 | E_ccsd_t = rccsd_t(*args, options=options, deriv_order=deriv_order_R) 269 | return E_ccsd_t 270 | else: 271 | raise Exception("Error: Method {} not supported.".format(method)) 272 | 273 | # Combine partial tuples into one array 274 | partial = np.append(np.array(partial_F), np.array(partial_R) + 3) 275 | 276 | if total_deriv_order == 2: 277 | i,j = partial[0], partial[1] 278 | partial_deriv = jacfwd(jacfwd(partial_wrapper, i), j)(*param_list) 279 | elif total_deriv_order == 3: 280 | i,j,k = partial[0], partial[1], partial[2] 281 | partial_deriv = jacfwd(jacfwd(jacfwd(partial_wrapper, i), j), k)(*param_list) 282 | elif total_deriv_order == 4: 283 | i,j,k,l = partial[0], partial[1], partial[2], partial[3] 284 | partial_deriv = jacfwd(jacfwd(jacfwd(jacfwd(partial_wrapper, i), j), k), l)(*param_list) 285 | elif total_deriv_order == 5: 286 | i,j,k,l,m = partial[0], partial[1], partial[2], partial[3], partial[4] 287 | partial_deriv = jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(partial_wrapper, i), j), k), l), m)(*param_list) 288 | elif total_deriv_order == 6: 289 | i,j,k,l,m,n = partial[0], partial[1], partial[2], partial[3], partial[4], partial[5] 290 | partial_deriv = jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(partial_wrapper, i), j), k), l), m), n)(*param_list) 291 | elif total_deriv_order == 7: 292 | i,j,k,l,m,n,p = partial[0], partial[1], partial[2], partial[3], partial[4], partial[5], partial[6] 293 | partial_deriv = jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(partial_wrapper, i), j), k), l), m), n), p)(*param_list) 294 | elif total_deriv_order == 8: 295 | i,j,k,l,m,n,p,q = partial[0], partial[1], partial[2], partial[3], partial[4], partial[5], partial[6], partial[7] 296 | partial_deriv = jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(partial_wrapper, i), j), k), l), m), n), p), q)(*param_list) 297 | else: 298 | print("Error: Order {},{} mixed derivatives are not exposed to the API.".format(deriv_order_F, deriv_order_R)) 299 | partial_deriv = 0 300 | return jnp.round(partial_deriv, 14) 301 | 302 | def energy(molecule, basis_name, method, options=None): 303 | """ 304 | """ 305 | # Set keyword options 306 | if options: 307 | options = check_options(options) 308 | options['integral_algo'] = 'libint_core' 309 | else: 310 | options = check_options({'integral_algo': 'libint_core'}) 311 | print("Using integral method: {}".format(options['integral_algo'])) 312 | print("Number of OMP Threads: {}".format(psi4.core.get_num_threads())) 313 | 314 | # Load molecule data 315 | geom2d = np.asarray(molecule.geometry()) 316 | geom_list = geom2d.reshape(-1).tolist() 317 | geom = jnp.asarray(geom2d.flatten()) 318 | xyz_file_name = "geom.xyz" 319 | molecule.save_xyz_file(xyz_file_name, True) 320 | xyz_path = os.path.abspath(os.getcwd()) + "/" + xyz_file_name 321 | mult = molecule.multiplicity() 322 | charge = molecule.molecular_charge() 323 | nuclear_charges = jnp.asarray([molecule.charge(i) for i in range(geom2d.shape[0])]) 324 | nelectrons = int(jnp.sum(nuclear_charges)) - charge 325 | nfrzn = n_frozen_core(molecule, charge) if options['freeze_core'] else 0 326 | 327 | basis_set = psi4.core.BasisSet.build(molecule, 'BASIS', basis_name, puream=0) 328 | nbf = basis_set.nbf() 329 | print("Basis name: ", basis_set.name()) 330 | print("Number of basis functions: ", nbf) 331 | 332 | if method == 'scf' or method == 'hf' or method == 'rhf': 333 | args = (geom, basis_set, nelectrons, nuclear_charges, xyz_path) 334 | elif method =='mp2': 335 | args = (geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 336 | elif method =='mp2-f12': 337 | cabs_set = build_RIBS(molecule, basis_set, basis_name + '-cabs') 338 | args = (geom, basis_set, cabs_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 339 | elif method =='ccsd': 340 | args = (geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 341 | elif method =='ccsd(t)': 342 | args = (geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 343 | else: 344 | print("Desired electronic structure method not understood. Use 'scf' 'hf' 'mp2' 'ccsd' or 'ccsd(t)' ") 345 | 346 | return compute_standard(method, args, deriv_order=0, partial=None, options=options) 347 | 348 | def geom_deriv(molecule, basis_name, method, deriv_order=1, partial=None, options=None): 349 | """ 350 | """ 351 | # Set keyword options 352 | if options: 353 | options = check_options(options) 354 | if deriv_order == 0: 355 | options['integral_algo'] = 'libint_core' 356 | else: 357 | options = check_options({}) 358 | print("Using integral method: {}".format(options['integral_algo'])) 359 | print("Number of OMP Threads: {}".format(psi4.core.get_num_threads())) 360 | 361 | # Load molecule data 362 | geom2d = np.asarray(molecule.geometry()) 363 | geom_list = geom2d.reshape(-1).tolist() 364 | geom = jnp.asarray(geom2d.flatten()) 365 | xyz_file_name = "geom.xyz" 366 | molecule.save_xyz_file(xyz_file_name, True) 367 | xyz_path = os.path.abspath(os.getcwd()) + "/" + xyz_file_name 368 | mult = molecule.multiplicity() 369 | charge = molecule.molecular_charge() 370 | nuclear_charges = jnp.asarray([molecule.charge(i) for i in range(geom2d.shape[0])]) 371 | nelectrons = int(jnp.sum(nuclear_charges)) - charge 372 | nfrzn = n_frozen_core(molecule, charge) if options['freeze_core'] else 0 373 | 374 | basis_set = psi4.core.BasisSet.build(molecule, 'BASIS', basis_name, puream=0) 375 | nbf = basis_set.nbf() 376 | print("Basis name: ", basis_set.name()) 377 | print("Number of basis functions: ", nbf) 378 | 379 | if method == 'scf' or method == 'hf' or method == 'rhf': 380 | args = (geom, basis_set, nelectrons, nuclear_charges, xyz_path) 381 | elif method =='mp2': 382 | args = (geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 383 | elif method =='mp2-f12': 384 | cabs_set = build_RIBS(molecule, basis_set, basis_name + '-cabs') 385 | args = (geom, basis_set, cabs_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 386 | elif method =='ccsd': 387 | args = (geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 388 | elif method =='ccsd(t)': 389 | args = (geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 390 | else: 391 | print("Desired electronic structure method not understood. Use 'scf' 'hf' 'mp2' 'ccsd' or 'ccsd(t)' ") 392 | 393 | return compute_standard(method, args, deriv_order=deriv_order, partial=partial, options=options) 394 | 395 | def efield_deriv(molecule, basis_name, method, efield=None, efield_grad=None, 396 | deriv_order=1, partial=None, options=None): 397 | """ 398 | """ 399 | if type(efield) == type(None) and type(efield_grad) == type(None): 400 | raise Exception("Electric field and its gradient must be given for quadrupole computation.") 401 | elif type(efield) == type(None): 402 | raise Exception("Electric field must be given for dipole computation.") 403 | 404 | try: 405 | options['electric_field'] 406 | except: 407 | if isinstance(efield, np.ndarray) and isinstance(efield_grad, np.ndarray): 408 | options['electric_field'] = 2 409 | elif isinstance(efield, np.ndarray): 410 | options['electric_field'] = 1 411 | else: 412 | raise Exception("Electric field and its gradient must be given as numpy arrays.") 413 | 414 | # Set keyword options 415 | if options: 416 | options = check_options(options) 417 | if deriv_order == 0: 418 | options['integral_algo'] = 'libint_core' 419 | 420 | print("Using integral method: {}".format(options['integral_algo'])) 421 | print("Number of OMP Threads: {}".format(psi4.core.get_num_threads())) 422 | 423 | # Load molecule data 424 | geom2d = np.asarray(molecule.geometry()) 425 | geom_list = geom2d.reshape(-1).tolist() 426 | geom = jnp.asarray(geom2d.flatten()) 427 | xyz_file_name = "geom.xyz" 428 | molecule.save_xyz_file(xyz_file_name, True) 429 | xyz_path = os.path.abspath(os.getcwd()) + "/" + xyz_file_name 430 | mult = molecule.multiplicity() 431 | charge = molecule.molecular_charge() 432 | nuclear_charges = jnp.asarray([molecule.charge(i) for i in range(geom2d.shape[0])]) 433 | nelectrons = int(jnp.sum(nuclear_charges)) - charge 434 | nfrzn = n_frozen_core(molecule, charge) if options['freeze_core'] else 0 435 | 436 | basis_set = psi4.core.BasisSet.build(molecule, 'BASIS', basis_name, puream=0) 437 | nbf = basis_set.nbf() 438 | print("Basis name: ", basis_set.name()) 439 | print("Number of basis functions: ", nbf) 440 | 441 | if options['electric_field'] == 2: 442 | args = (efield_grad, efield) 443 | else: 444 | args = (efield,) 445 | 446 | if method == 'scf' or method == 'hf' or method == 'rhf': 447 | args += (geom, basis_set, nelectrons, nuclear_charges, xyz_path) 448 | elif method =='mp2': 449 | args += (geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 450 | elif method =='mp2-f12': 451 | cabs_set = build_RIBS(molecule, basis_set, basis_name + '-cabs') 452 | args += (geom, basis_set, cabs_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 453 | elif method =='ccsd': 454 | args += (geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 455 | elif method =='ccsd(t)': 456 | args += (geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 457 | else: 458 | print("Desired electronic structure method not understood. Use 'scf' 'hf' 'mp2' 'ccsd' or 'ccsd(t)' ") 459 | 460 | return compute_standard(method, args, deriv_order=deriv_order, partial=partial, options=options) 461 | 462 | def mixed_deriv(molecule, basis_name, method, efield=None, efield_grad=None, 463 | deriv_order_F=1, deriv_order_R=1, partial_F=None, partial_R=None, options=None): 464 | """ 465 | """ 466 | if deriv_order_F == 0 or deriv_order_R == 0: 467 | raise Exception("Error: Order of differentiation cannot equal zero. Use energy or geometry_deriv or electric_field instead.") 468 | 469 | if type(efield) == type(None) and type(efield_grad) == type(None): 470 | raise Exception("Electric field and its gradient must be given for quadrupole computation.") 471 | elif type(efield) == type(None): 472 | raise Exception("Electric field must be given for dipole computation.") 473 | 474 | try: 475 | options['electric_field'] 476 | except: 477 | if isinstance(efield, np.ndarray) and isinstance(efield_grad, np.ndarray): 478 | options['electric_field'] = 2 479 | elif isinstance(efield, np.ndarray): 480 | options['electric_field'] = 1 481 | else: 482 | raise Exception("Electric field and its gradient must be given as numpy arrays.") 483 | 484 | # Set keyword options 485 | if options: 486 | options = check_options(options) 487 | if deriv_order_F == 0 and deriv_order_R == 0: 488 | options['integral_algo'] = 'libint_core' 489 | 490 | print("Using integral method: {}".format(options['integral_algo'])) 491 | print("Number of OMP Threads: {}".format(psi4.core.get_num_threads())) 492 | 493 | # Load molecule data 494 | geom2d = np.asarray(molecule.geometry()) 495 | geom_list = geom2d.reshape(-1).tolist() 496 | geom = jnp.asarray(geom2d.flatten()) 497 | xyz_file_name = "geom.xyz" 498 | molecule.save_xyz_file(xyz_file_name, True) 499 | xyz_path = os.path.abspath(os.getcwd()) + "/" + xyz_file_name 500 | mult = molecule.multiplicity() 501 | charge = molecule.molecular_charge() 502 | nuclear_charges = jnp.asarray([molecule.charge(i) for i in range(geom2d.shape[0])]) 503 | nelectrons = int(jnp.sum(nuclear_charges)) - charge 504 | nfrzn = n_frozen_core(molecule, charge) if options['freeze_core'] else 0 505 | 506 | basis_set = psi4.core.BasisSet.build(molecule, 'BASIS', basis_name, puream=0) 507 | nbf = basis_set.nbf() 508 | print("Basis name: ", basis_set.name()) 509 | print("Number of basis functions: ", nbf) 510 | 511 | if options['electric_field'] == 2: 512 | args = (efield, efield_grad) 513 | else: 514 | args = (efield,) 515 | 516 | if method == 'scf' or method == 'hf' or method == 'rhf': 517 | args += (geom, basis_set, nelectrons, nuclear_charges, xyz_path) 518 | elif method =='mp2': 519 | args += (geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 520 | elif method =='mp2-f12': 521 | cabs_set = build_RIBS(molecule, basis_set, basis_name + '-cabs') 522 | args += (geom, basis_set, cabs_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 523 | elif method =='ccsd': 524 | args += (geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 525 | elif method =='ccsd(t)': 526 | args += (geom, basis_set, nelectrons, nfrzn, nuclear_charges, xyz_path) 527 | else: 528 | print("Desired electronic structure method not understood. Use 'scf' 'hf' 'mp2' 'ccsd' or 'ccsd(t)' ") 529 | 530 | return compute_mixed(method, args, deriv_order_F=deriv_order_F, deriv_order_R=deriv_order_R, 531 | partial_F=partial_F, partial_R=partial_R, options=options) 532 | --------------------------------------------------------------------------------