├── .flake8 ├── .gitattributes ├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE.md ├── PULL_REQUEST_TEMPLATE.md ├── SECURITY.md └── workflows │ ├── pypi.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── .readthedocs.yaml ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── aehmc ├── __init__.py ├── algorithms.py ├── hmc.py ├── integrators.py ├── mass_matrix.py ├── metrics.py ├── nuts.py ├── proposals.py ├── step_size.py ├── termination.py ├── trajectory.py ├── utils.py └── window_adaptation.py ├── conftest.py ├── environment.yml ├── examples ├── LinearRegression.ipynb └── requirements.txt ├── pyproject.toml ├── requirements.txt └── tests ├── __init__.py ├── test_adaptation.py ├── test_algorithms.py ├── test_hmc.py ├── test_integrators.py ├── test_mass_matrix.py ├── test_metrics.py ├── test_step_size.py ├── test_termination.py ├── test_trajectory.py └── test_utils.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | select = C,E,F,W 4 | ignore = E203,E231,E501,E741,W503,W504,C901 5 | per-file-ignores = 6 | **/__init__.py:F401,E402,F403 7 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | aehmc/_version.py export-subst 2 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description of your problem or feature request 2 | 3 | **Please provide a minimal, self-contained, and reproducible example.** 4 | ```python 5 | [Your code here] 6 | ``` 7 | 8 | **Please provide the full traceback of any errors.** 9 | ```python 10 | [The error output here] 11 | ``` 12 | 13 | **Please provide any additional information below.** 14 | 15 | 16 | ## Versions and main components 17 | 18 | * Aesara version: 19 | * Aesara config (`python -c "import aesara; print(aesara.config)"`) 20 | * Python version: 21 | * Operating system: 22 | * How did you install Aesara: (conda/pip) 23 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | **Thank you for opening a PR!** 3 | 4 | Here are a few important guidelines and requirements to check before your PR can be merged: 5 | + [ ] There is an informative high-level description of the changes. 6 | + [ ] The description and/or commit message(s) references the relevant GitHub issue(s). 7 | + [ ] [`pre-commit`](https://pre-commit.com/#installation) is installed and [set up](https://pre-commit.com/#3-install-the-git-hook-scripts). 8 | + [ ] The commit messages follow [these guidelines](https://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html). 9 | + [ ] The commits correspond to [_relevant logical changes_](https://wiki.openstack.org/wiki/GitCommitMessages#Structural_split_of_changes), and there are **no commits that fix changes introduced by other commits in the same branch/BR**. If your commit description starts with "Fix...", then you're probably making this mistake. 10 | + [ ] There are tests covering the changes introduced in the PR. 11 | 12 | Don't worry, your PR doesn't need to be in perfect order to submit it. As development progresses and/or reviewers request changes, you can always [rewrite the history](https://git-scm.com/book/en/v2/Git-Tools-Rewriting-History#_rewriting_history) of your feature/PR branches. 13 | 14 | If your PR is an ongoing effort and you would like to involve us in the process, simply make it a [draft PR](https://docs.github.com/en/free-pro-team@latest/github/collaborating-with-issues-and-pull-requests/about-pull-requests#draft-pull-requests). 15 | -------------------------------------------------------------------------------- /.github/SECURITY.md: -------------------------------------------------------------------------------- 1 | To report a security vulnerability to Aesara, please go to 2 | https://tidelift.com/security and see the instructions there. 3 | 4 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | name: PyPI 2 | on: 3 | push: 4 | branches: 5 | - main 6 | - auto-release 7 | pull_request: 8 | branches: [main] 9 | release: 10 | types: [published] 11 | 12 | # Cancels all previous workflow runs for pull requests that have not completed. 13 | concurrency: 14 | # The concurrency group contains the workflow name and the branch name for pull requests 15 | # or the commit hash for any other events. 16 | group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.head_ref || github.sha }} 17 | cancel-in-progress: true 18 | 19 | jobs: 20 | build: 21 | name: Build source distribution 22 | runs-on: ubuntu-latest 23 | steps: 24 | - uses: actions/checkout@v3 25 | with: 26 | fetch-depth: 0 27 | - uses: actions/setup-python@v4 28 | with: 29 | python-version: "3.8" 30 | - name: Build the sdist and wheel 31 | run: | 32 | python -m pip install --upgrade pip build 33 | python -m build 34 | - name: Check the sdist installs and imports 35 | run: | 36 | python -m venv venv-sdist 37 | # Since the whl distribution is build using sdist, it suffices 38 | # to only test the wheel installation to ensure both function as expected. 39 | venv-sdist/bin/python -m pip install dist/aehmc-*.whl 40 | venv-sdist/bin/python -c "import aehmc;print(aehmc.__version__)" 41 | - uses: actions/upload-artifact@v3 42 | with: 43 | name: artifact 44 | path: dist 45 | if-no-files-found: error 46 | 47 | upload_pypi: 48 | name: Upload to PyPI on release 49 | needs: [build] 50 | runs-on: ubuntu-latest 51 | if: github.event_name == 'release' && github.event.action == 'published' 52 | steps: 53 | - uses: actions/download-artifact@v2 54 | with: 55 | name: artifact 56 | path: dist 57 | - uses: pypa/gh-action-pypi-publish@release/v1 58 | with: 59 | user: __token__ 60 | password: ${{ secrets.pypi_secret }} 61 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - checks 8 | pull_request: 9 | branches: 10 | - main 11 | 12 | # Cancels all previous workflow runs for pull requests that have not completed. 13 | concurrency: 14 | # The concurrency group contains the workflow name and the branch name for pull requests 15 | # or the commit hash for any other events. 16 | group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.head_ref || github.sha }} 17 | cancel-in-progress: true 18 | 19 | jobs: 20 | changes: 21 | name: "Check for changes" 22 | runs-on: ubuntu-latest 23 | outputs: 24 | changes: ${{ steps.changes.outputs.src }} 25 | steps: 26 | - uses: actions/checkout@v3 27 | with: 28 | fetch-depth: 0 29 | - uses: dorny/paths-filter@v2 30 | id: changes 31 | with: 32 | filters: | 33 | python: &python 34 | - 'aehmc/**/*.py' 35 | - 'tests/**/*.py' 36 | - 'aehmc/**/*.pyx' 37 | - 'tests/**/*.pyx' 38 | - '*.py' 39 | src: 40 | - *python 41 | - 'aehmc/**/*.c' 42 | - 'tests/**/*.c' 43 | - 'aehmc/**/*.h' 44 | - 'tests/**/*.h' 45 | - '.github/workflows/*.yml' 46 | - 'setup.cfg' 47 | - 'requirements.txt' 48 | - '.coveragerc' 49 | 50 | 51 | style: 52 | name: Check code style 53 | needs: changes 54 | runs-on: ubuntu-latest 55 | if: ${{ needs.changes.outputs.changes == 'true' }} 56 | steps: 57 | - uses: actions/checkout@v3 58 | - uses: actions/setup-python@v4 59 | - uses: pre-commit/action@v2.0.0 60 | 61 | test: 62 | name: "Test py${{ matrix.python-version }}: ${{ matrix.part }}" 63 | needs: 64 | - changes 65 | - style 66 | runs-on: ubuntu-latest 67 | if: ${{ needs.changes.outputs.changes == 'true' && needs.style.result == 'success' }} 68 | strategy: 69 | fail-fast: true 70 | matrix: 71 | python-version: ["3.8", "3.10"] 72 | fast-compile: [0] 73 | float32: [0] 74 | part: 75 | - "tests" 76 | 77 | steps: 78 | - uses: actions/checkout@v3 79 | with: 80 | fetch-depth: 0 81 | - name: Set up Python ${{ matrix.python-version }} 82 | uses: conda-incubator/setup-miniconda@v2 83 | with: 84 | mamba-version: "*" 85 | channels: conda-forge,defaults 86 | channel-priority: true 87 | python-version: ${{ matrix.python-version }} 88 | auto-update-conda: true 89 | 90 | - name: Create matrix id 91 | id: matrix-id 92 | env: 93 | MATRIX_CONTEXT: ${{ toJson(matrix) }} 94 | run: | 95 | echo $MATRIX_CONTEXT 96 | export MATRIX_ID=`echo $MATRIX_CONTEXT | md5sum | cut -c 1-32` 97 | echo $MATRIX_ID 98 | echo "::set-output name=id::$MATRIX_ID" 99 | 100 | - name: Install dependencies 101 | shell: bash -l {0} 102 | run: | 103 | mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service "aesara>=2.8.11" 104 | pip install -q -r requirements.txt 105 | mamba list && pip freeze 106 | python -c 'import aesara; print(aesara.config.__str__(print_doc=False))' 107 | python -c 'import aesara; assert(aesara.config.blas__ldflags != "")' 108 | env: 109 | PYTHON_VERSION: ${{ matrix.python-version }} 110 | 111 | - name: Run tests 112 | shell: bash -l {0} 113 | run: | 114 | if [[ $FAST_COMPILE == "1" ]]; then export AESARA_FLAGS=$AESARA_FLAGS,mode=FAST_COMPILE; fi 115 | if [[ $FLOAT32 == "1" ]]; then export AESARA_FLAGS=$AESARA_FLAGS,floatX=float32; fi 116 | export AESARA_FLAGS=$AESARA_FLAGS,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe 117 | python -m pytest -x -r A --verbose --cov=aehmc --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART 118 | env: 119 | MATRIX_ID: ${{ steps.matrix-id.outputs.id }} 120 | MKL_THREADING_LAYER: GNU 121 | MKL_NUM_THREADS: 1 122 | OMP_NUM_THREADS: 1 123 | PART: ${{ matrix.part }} 124 | FAST_COMPILE: ${{ matrix.fast-compile }} 125 | FLOAT32: ${{ matrix.float32 }} 126 | 127 | - name: Upload coverage file 128 | uses: actions/upload-artifact@v2 129 | with: 130 | name: coverage 131 | path: coverage/coverage-${{ steps.matrix-id.outputs.id }}.xml 132 | 133 | all-checks: 134 | if: ${{ always() }} 135 | runs-on: ubuntu-latest 136 | name: "All tests" 137 | needs: [changes, style, test] 138 | steps: 139 | - name: Check build matrix status 140 | if: ${{ needs.changes.outputs.changes == 'true' && (needs.style.result != 'success' || needs.test.result != 'success') }} 141 | run: exit 1 142 | 143 | upload-coverage: 144 | runs-on: ubuntu-latest 145 | name: "Upload coverage" 146 | needs: [changes, all-checks] 147 | if: ${{ needs.changes.outputs.changes == 'true' && needs.all-checks.result == 'success' }} 148 | steps: 149 | - uses: actions/checkout@v3 150 | 151 | - name: Set up Python 152 | uses: actions/setup-python@v4 153 | with: 154 | python-version: 3.8 155 | 156 | - name: Install dependencies 157 | run: | 158 | python -m pip install -U coverage>=5.1 coveralls 159 | 160 | - name: Download coverage file 161 | uses: actions/download-artifact@v2 162 | with: 163 | name: coverage 164 | path: coverage 165 | 166 | - name: Upload coverage to Codecov 167 | uses: codecov/codecov-action@v1 168 | with: 169 | directory: ./coverage/ 170 | fail_ci_if_error: true 171 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/vim,emacs,python 2 | # Edit at https://www.gitignore.io/?templates=vim,emacs,python 3 | 4 | ### Emacs ### 5 | # -*- mode: gitignore; -*- 6 | *~ 7 | \#*\# 8 | /.emacs.desktop 9 | /.emacs.desktop.lock 10 | *.elc 11 | auto-save-list 12 | tramp 13 | .\#* 14 | 15 | # Org-mode 16 | .org-id-locations 17 | *_archive 18 | 19 | # flymake-mode 20 | *_flymake.* 21 | 22 | # eshell files 23 | /eshell/history 24 | /eshell/lastdir 25 | 26 | # elpa packages 27 | /elpa/ 28 | 29 | # reftex files 30 | *.rel 31 | 32 | # AUCTeX auto folder 33 | /auto/ 34 | 35 | # cask packages 36 | .cask/ 37 | dist/ 38 | 39 | # Flycheck 40 | flycheck_*.el 41 | 42 | # server auth directory 43 | /server/ 44 | 45 | # projectiles files 46 | .projectile 47 | 48 | # directory configuration 49 | .dir-locals.el 50 | 51 | # network security 52 | /network-security.data 53 | 54 | 55 | ### Python ### 56 | # Byte-compiled / optimized / DLL files 57 | __pycache__/ 58 | *.py[cod] 59 | *$py.class 60 | 61 | # C extensions 62 | *.so 63 | 64 | # Distribution / packaging 65 | .Python 66 | build/ 67 | develop-eggs/ 68 | downloads/ 69 | eggs/ 70 | .eggs/ 71 | lib/ 72 | lib64/ 73 | parts/ 74 | sdist/ 75 | var/ 76 | wheels/ 77 | pip-wheel-metadata/ 78 | share/python-wheels/ 79 | *.egg-info/ 80 | .installed.cfg 81 | *.egg 82 | MANIFEST 83 | 84 | # PyInstaller 85 | # Usually these files are written by a python script from a template 86 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 87 | *.manifest 88 | *.spec 89 | 90 | # Installer logs 91 | pip-log.txt 92 | pip-delete-this-directory.txt 93 | 94 | # Unit test / coverage reports 95 | htmlcov/ 96 | .tox/ 97 | .nox/ 98 | .coverage 99 | .coverage.* 100 | .cache 101 | nosetests.xml 102 | coverage.xml 103 | *.cover 104 | .hypothesis/ 105 | .pytest_cache/ 106 | testing-report.html 107 | 108 | # Translations 109 | *.mo 110 | *.pot 111 | 112 | # Django stuff: 113 | *.log 114 | local_settings.py 115 | db.sqlite3 116 | db.sqlite3-journal 117 | 118 | # Flask stuff: 119 | instance/ 120 | .webassets-cache 121 | 122 | # Scrapy stuff: 123 | .scrapy 124 | 125 | # Sphinx documentation 126 | docs/_build/ 127 | 128 | # PyBuilder 129 | target/ 130 | 131 | # Jupyter Notebook 132 | .ipynb_checkpoints 133 | 134 | # IPython 135 | profile_default/ 136 | ipython_config.py 137 | .ipynb_checkpoints 138 | 139 | # pyenv 140 | .python-version 141 | 142 | # pipenv 143 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 144 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 145 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 146 | # install all needed dependencies. 147 | #Pipfile.lock 148 | 149 | # celery beat schedule file 150 | celerybeat-schedule 151 | 152 | # SageMath parsed files 153 | *.sage.py 154 | 155 | # Environments 156 | .env 157 | .venv 158 | env/ 159 | venv/ 160 | ENV/ 161 | env.bak/ 162 | venv.bak/ 163 | 164 | # Spyder project settings 165 | .spyderproject 166 | .spyproject 167 | 168 | # Rope project settings 169 | .ropeproject 170 | 171 | # mkdocs documentation 172 | /site 173 | 174 | # mypy 175 | .mypy_cache/ 176 | .dmypy.json 177 | dmypy.json 178 | 179 | # Pyre type checker 180 | .pyre/ 181 | 182 | # Pycharm/IntelliJ 183 | .idea 184 | *.iml 185 | 186 | ### Vim ### 187 | # Swap 188 | [._]*.s[a-v][a-z] 189 | [._]*.sw[a-p] 190 | [._]s[a-rt-v][a-z] 191 | [._]ss[a-gi-z] 192 | [._]sw[a-p] 193 | 194 | # Session 195 | Session.vim 196 | Sessionx.vim 197 | 198 | # Temporary 199 | .netrwhist 200 | # Auto-generated tag files 201 | tags 202 | # Persistent undo 203 | [._]*.un~ 204 | 205 | ### OS X ### 206 | .DS_Store 207 | 208 | ### Visual Studio / VSCode ### 209 | .vs/ 210 | .vscode/ 211 | 212 | # End of https://www.gitignore.io/api/vim,emacs,python 213 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: | 2 | (?x)^( 3 | versioneer\.py| 4 | aehmc/_version\.py| 5 | doc/.*| 6 | bin/.* 7 | )$ 8 | repos: 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v4.4.0 11 | hooks: 12 | - id: debug-statements 13 | - id: check-merge-conflict 14 | - repo: https://github.com/psf/black 15 | rev: 23.3.0 16 | hooks: 17 | - id: black 18 | language_version: python3 19 | - repo: https://github.com/pycqa/flake8 20 | rev: 6.0.0 21 | hooks: 22 | - id: flake8 23 | - repo: https://github.com/pycqa/isort 24 | rev: 5.12.0 25 | hooks: 26 | - id: isort 27 | - repo: https://github.com/pre-commit/mirrors-mypy 28 | rev: v1.2.0 29 | hooks: 30 | - id: mypy 31 | args: [--ignore-missing-imports] 32 | - repo: https://github.com/humitos/mirrors-autoflake.git 33 | rev: v1.1 34 | hooks: 35 | - id: autoflake 36 | exclude: | 37 | (?x)^( 38 | .*/?__init__\.py 39 | )$ 40 | args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable'] 41 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | # Use multiple processes to speed up Pylint. 3 | jobs=0 4 | 5 | # Allow loading of arbitrary C extensions. Extensions are imported into the 6 | # active Python interpreter and may run arbitrary code. 7 | unsafe-load-any-extension=no 8 | 9 | # Allow optimization of some AST trees. This will activate a peephole AST 10 | # optimizer, which will apply various small optimizations. For instance, it can 11 | # be used to obtain the result of joining multiple strings with the addition 12 | # operator. Joining a lot of strings can lead to a maximum recursion error in 13 | # Pylint and this flag can prevent that. It has one side effect, the resulting 14 | # AST will be different than the one from reality. 15 | optimize-ast=no 16 | 17 | [MESSAGES CONTROL] 18 | 19 | # Only show warnings with the listed confidence levels. Leave empty to show 20 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 21 | confidence= 22 | 23 | # Disable the message, report, category or checker with the given id(s). You 24 | # can either give multiple identifiers separated by comma (,) or put this 25 | # option multiple times (only on the command line, not in the configuration 26 | # file where it should appear only once).You can also use "--disable=all" to 27 | # disable everything first and then reenable specific checks. For example, if 28 | # you want to run only the similarities checker, you can use "--disable=all 29 | # --enable=similarities". If you want to run only the classes checker, but have 30 | # no Warning level messages displayed, use"--disable=all --enable=classes 31 | # --disable=W" 32 | disable=all 33 | 34 | # Enable the message, report, category or checker with the given id(s). You can 35 | # either give multiple identifier separated by comma (,) or put this option 36 | # multiple time. See also the "--disable" option for examples. 37 | enable=import-error, 38 | import-self, 39 | reimported, 40 | wildcard-import, 41 | misplaced-future, 42 | relative-import, 43 | deprecated-module, 44 | unpacking-non-sequence, 45 | invalid-all-object, 46 | undefined-all-variable, 47 | used-before-assignment, 48 | cell-var-from-loop, 49 | global-variable-undefined, 50 | dangerous-default-value, 51 | # redefined-builtin, 52 | redefine-in-handler, 53 | unused-import, 54 | unused-wildcard-import, 55 | global-variable-not-assigned, 56 | undefined-loop-variable, 57 | global-at-module-level, 58 | bad-open-mode, 59 | redundant-unittest-assert, 60 | boolean-datetime, 61 | # unused-variable 62 | 63 | 64 | [REPORTS] 65 | 66 | # Set the output format. Available formats are text, parseable, colorized, msvs 67 | # (visual studio) and html. You can also give a reporter class, eg 68 | # mypackage.mymodule.MyReporterClass. 69 | output-format=parseable 70 | 71 | # Put messages in a separate file for each module / package specified on the 72 | # command line instead of printing them on stdout. Reports (if any) will be 73 | # written in a file name "pylint_global.[txt|html]". 74 | files-output=no 75 | 76 | # Tells whether to display a full report or only the messages 77 | reports=no 78 | 79 | # Python expression which should return a note less than 10 (10 is the highest 80 | # note). You have access to the variables errors warning, statement which 81 | # respectively contain the number of errors / warnings messages and the total 82 | # number of statements analyzed. This is used by the global evaluation report 83 | # (RP0004). 84 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 85 | 86 | [BASIC] 87 | 88 | # List of builtins function names that should not be used, separated by a comma 89 | bad-functions=map,filter,input 90 | 91 | # Good variable names which should always be accepted, separated by a comma 92 | good-names=i,j,k,ex,Run,_ 93 | 94 | # Bad variable names which should always be refused, separated by a comma 95 | bad-names=foo,bar,baz,toto,tutu,tata 96 | 97 | # Colon-delimited sets of names that determine each other's naming style when 98 | # the name regexes allow several styles. 99 | name-group= 100 | 101 | # Include a hint for the correct naming format with invalid-name 102 | include-naming-hint=yes 103 | 104 | # Regular expression matching correct method names 105 | method-rgx=[a-z_][a-z0-9_]{2,30}$ 106 | 107 | # Naming hint for method names 108 | method-name-hint=[a-z_][a-z0-9_]{2,30}$ 109 | 110 | # Regular expression matching correct function names 111 | function-rgx=[a-z_][a-z0-9_]{2,30}$ 112 | 113 | # Naming hint for function names 114 | function-name-hint=[a-z_][a-z0-9_]{2,30}$ 115 | 116 | # Regular expression matching correct module names 117 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 118 | 119 | # Naming hint for module names 120 | module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 121 | 122 | # Regular expression matching correct attribute names 123 | attr-rgx=[a-z_][a-z0-9_]{2,30}$ 124 | 125 | # Naming hint for attribute names 126 | attr-name-hint=[a-z_][a-z0-9_]{2,30}$ 127 | 128 | # Regular expression matching correct class attribute names 129 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ 130 | 131 | # Naming hint for class attribute names 132 | class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$ 133 | 134 | # Regular expression matching correct constant names 135 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 136 | 137 | # Naming hint for constant names 138 | const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 139 | 140 | # Regular expression matching correct class names 141 | class-rgx=[A-Z_][a-zA-Z0-9]+$ 142 | 143 | # Naming hint for class names 144 | class-name-hint=[A-Z_][a-zA-Z0-9]+$ 145 | 146 | # Regular expression matching correct argument names 147 | argument-rgx=[a-z_][a-z0-9_]{2,30}$ 148 | 149 | # Naming hint for argument names 150 | argument-name-hint=[a-z_][a-z0-9_]{2,30}$ 151 | 152 | # Regular expression matching correct inline iteration names 153 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ 154 | 155 | # Naming hint for inline iteration names 156 | inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$ 157 | 158 | # Regular expression matching correct variable names 159 | variable-rgx=[a-z_][a-z0-9_]{2,30}$ 160 | 161 | # Naming hint for variable names 162 | variable-name-hint=[a-z_][a-z0-9_]{2,30}$ 163 | 164 | # Regular expression which should only match function or class names that do 165 | # not require a docstring. 166 | no-docstring-rgx=^_ 167 | 168 | # Minimum line length for functions/classes that require docstrings, shorter 169 | # ones are exempt. 170 | docstring-min-length=-1 171 | 172 | 173 | [ELIF] 174 | 175 | # Maximum number of nested blocks for function / method body 176 | max-nested-blocks=5 177 | 178 | 179 | [FORMAT] 180 | 181 | # Maximum number of characters on a single line. 182 | max-line-length=100 183 | 184 | # Regexp for a line that is allowed to be longer than the limit. 185 | ignore-long-lines=^\s*(# )??$ 186 | 187 | # Allow the body of an if to be on the same line as the test if there is no 188 | # else. 189 | single-line-if-stmt=no 190 | 191 | # List of optional constructs for which whitespace checking is disabled. `dict- 192 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 193 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 194 | # `empty-line` allows space-only lines. 195 | no-space-check=trailing-comma,dict-separator 196 | 197 | # Maximum number of lines in a module 198 | max-module-lines=1000 199 | 200 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 201 | # tab). 202 | indent-string=' ' 203 | 204 | # Number of spaces of indent required inside a hanging or continued line. 205 | indent-after-paren=4 206 | 207 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 208 | expected-line-ending-format= 209 | 210 | 211 | [LOGGING] 212 | 213 | # Logging modules to check that the string format arguments are in logging 214 | # function parameter format 215 | logging-modules=logging 216 | 217 | 218 | [MISCELLANEOUS] 219 | 220 | # List of note tags to take in consideration, separated by a comma. 221 | notes=FIXME,XXX,TODO 222 | 223 | 224 | [SIMILARITIES] 225 | 226 | # Minimum lines number of a similarity. 227 | min-similarity-lines=4 228 | 229 | # Ignore comments when computing similarities. 230 | ignore-comments=yes 231 | 232 | # Ignore docstrings when computing similarities. 233 | ignore-docstrings=yes 234 | 235 | # Ignore imports when computing similarities. 236 | ignore-imports=no 237 | 238 | 239 | [SPELLING] 240 | 241 | # Spelling dictionary name. Available dictionaries: none. To make it working 242 | # install python-enchant package. 243 | spelling-dict= 244 | 245 | # List of comma separated words that should not be checked. 246 | spelling-ignore-words= 247 | 248 | # A path to a file that contains private dictionary; one word per line. 249 | spelling-private-dict-file= 250 | 251 | # Tells whether to store unknown words to indicated private dictionary in 252 | # --spelling-private-dict-file option instead of raising a message. 253 | spelling-store-unknown-words=no 254 | 255 | 256 | [TYPECHECK] 257 | 258 | # Tells whether missing members accessed in mixin class should be ignored. A 259 | # mixin class is detected if its name ends with "mixin" (case insensitive). 260 | ignore-mixin-members=yes 261 | 262 | # List of module names for which member attributes should not be checked 263 | # (useful for modules/projects where namespaces are manipulated during runtime 264 | # and thus existing member attributes cannot be deduced by static analysis. It 265 | # supports qualified module names, as well as Unix pattern matching. 266 | ignored-modules= 267 | 268 | # List of classes names for which member attributes should not be checked 269 | # (useful for classes with attributes dynamically set). This supports can work 270 | # with qualified names. 271 | ignored-classes= 272 | 273 | # List of members which are set dynamically and missed by pylint inference 274 | # system, and so shouldn't trigger E1101 when accessed. Python regular 275 | # expressions are accepted. 276 | generated-members= 277 | 278 | 279 | [VARIABLES] 280 | 281 | # Tells whether we should check for unused import in __init__ files. 282 | init-import=no 283 | 284 | # A regular expression matching the name of dummy variables (i.e. expectedly 285 | # not used). 286 | dummy-variables-rgx=_$|dummy 287 | 288 | # List of additional names supposed to be defined in builtins. Remember that 289 | # you should avoid to define new builtins when possible. 290 | additional-builtins= 291 | 292 | # List of strings which can identify a callback function by name. A callback 293 | # name must start or end with one of those strings. 294 | callbacks=cb_,_cb 295 | 296 | 297 | [CLASSES] 298 | 299 | # List of method names used to declare (i.e. assign) instance attributes. 300 | defining-attr-methods=__init__,__new__,setUp 301 | 302 | # List of valid names for the first argument in a class method. 303 | valid-classmethod-first-arg=cls 304 | 305 | # List of valid names for the first argument in a metaclass class method. 306 | valid-metaclass-classmethod-first-arg=mcs 307 | 308 | # List of member names, which should be excluded from the protected access 309 | # warning. 310 | exclude-protected=_asdict,_fields,_replace,_source,_make 311 | 312 | 313 | [DESIGN] 314 | 315 | # Maximum number of arguments for function / method 316 | max-args=5 317 | 318 | # Argument names that match this expression will be ignored. Default to name 319 | # with leading underscore 320 | ignored-argument-names=_.* 321 | 322 | # Maximum number of locals for function / method body 323 | max-locals=15 324 | 325 | # Maximum number of return / yield for function / method body 326 | max-returns=6 327 | 328 | # Maximum number of branch for function / method body 329 | max-branches=12 330 | 331 | # Maximum number of statements in function / method body 332 | max-statements=50 333 | 334 | # Maximum number of parents for a class (see R0901). 335 | max-parents=7 336 | 337 | # Maximum number of attributes for a class (see R0902). 338 | max-attributes=7 339 | 340 | # Minimum number of public methods for a class (see R0903). 341 | min-public-methods=2 342 | 343 | # Maximum number of public methods for a class (see R0904). 344 | max-public-methods=20 345 | 346 | # Maximum number of boolean expressions in a if statement 347 | max-bool-expr=5 348 | 349 | 350 | [IMPORTS] 351 | 352 | # Deprecated modules which should not be used, separated by a comma 353 | deprecated-modules=optparse 354 | 355 | # Create a graph of every (i.e. internal and external) dependencies in the 356 | # given file (report RP0402 must not be disabled) 357 | import-graph= 358 | 359 | # Create a graph of external dependencies in the given file (report RP0402 must 360 | # not be disabled) 361 | ext-import-graph= 362 | 363 | # Create a graph of internal dependencies in the given file (report RP0402 must 364 | # not be disabled) 365 | int-import-graph= 366 | 367 | 368 | [EXCEPTIONS] 369 | 370 | # Exceptions that will emit a warning when being caught. Defaults to 371 | # "Exception" 372 | overgeneral-exceptions=Exception 373 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.8" 7 | 8 | python: 9 | install: 10 | - method: pip 11 | path: . 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021-2023 Aesara Developers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune .github 2 | prune examples 3 | exclude *.yaml 4 | exclude *.yml 5 | exclude .git* 6 | exclude *.py 7 | exclude *rc 8 | exclude .flake8 9 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help venv conda docker docstyle format style black test lint check coverage pypi 2 | .DEFAULT_GOAL = help 3 | 4 | PROJECT_NAME = aehmc 5 | PROJECT_DIR = aehmc/ 6 | PYTHON = python 7 | PIP = pip 8 | CONDA = conda 9 | SHELL = bash 10 | 11 | help: 12 | @printf "Usage:\n" 13 | @grep -E '^[a-zA-Z_-]+:.*?# .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?# "}; {printf "\033[1;34mmake %-10s\033[0m%s\n", $$1, $$2}' 14 | 15 | conda: # Set up a conda environment for development. 16 | @printf "Creating conda environment...\n" 17 | ${CONDA} create --yes --name ${PROJECT_NAME}-env python=3.8 18 | ( \ 19 | ${CONDA} activate ${PROJECT_NAME}-env; \ 20 | ${PIP} install -U pip; \ 21 | ${PIP} install -U setuptools wheel; \ 22 | ${PIP} install -r requirements.txt; \ 23 | ${CONDA} deactivate; \ 24 | ) 25 | @printf "\n\nConda environment created! \033[1;34mRun \`conda activate ${PROJECT_NAME}-env\` to activate it.\033[0m\n\n\n" 26 | 27 | venv: # Set up a Python virtual environment for development. 28 | @printf "Creating Python virtual environment...\n" 29 | rm -rf ${PROJECT_NAME}-venv 30 | ${PYTHON} -m venv ${PROJECT_NAME}-venv 31 | ( \ 32 | source ${PROJECT_NAME}-venv/bin/activate; \ 33 | ${PIP} install -U pip; \ 34 | ${PIP} install -U setuptools wheel; \ 35 | ${PIP} install -r requirements.txt; \ 36 | deactivate; \ 37 | ) 38 | @printf "\n\nVirtual environment created! \033[1;34mRun \`source ${PROJECT_NAME}-venv/bin/activate\` to activate it.\033[0m\n\n\n" 39 | 40 | docker: # Set up a Docker image for development. 41 | @printf "Creating Docker image...\n" 42 | ${SHELL} ./scripts/container.sh --build 43 | 44 | docstyle: 45 | @printf "Checking documentation with pydocstyle...\n" 46 | pydocstyle ${PROJECT_DIR} 47 | @printf "\033[1;34mPydocstyle passes!\033[0m\n\n" 48 | 49 | format: 50 | @printf "Checking code style with black...\n" 51 | black --check ${PROJECT_DIR} tests/ 52 | @printf "\033[1;34mBlack passes!\033[0m\n\n" 53 | 54 | style: 55 | @printf "Checking code style with pylint...\n" 56 | pylint ${PROJECT_DIR} tests/ 57 | @printf "\033[1;34mPylint passes!\033[0m\n\n" 58 | 59 | black: # Format code in-place using black. 60 | black ${PROJECT_DIR} tests/ 61 | 62 | test: # Test code using pytest. 63 | pytest -v tests/ ${PROJECT_DIR} --cov=${PROJECT_DIR} --cov-report=xml --html=testing-report.html --self-contained-html 64 | 65 | coverage: test 66 | diff-cover coverage.xml --compare-branch=main --fail-under=100 67 | 68 | pypi: 69 | ${PYTHON} setup.py clean --all; \ 70 | ${PYTHON} setup.py rotate --match=.tar.gz,.whl,.egg,.zip --keep=0; \ 71 | ${PYTHON} setup.py sdist bdist_wheel; \ 72 | twine upload --skip-existing dist/*; 73 | 74 | lint: docstyle format style # Lint code using pydocstyle, black and pylint. 75 | 76 | check: lint test coverage # Both lint and test code. Runs `make lint` followed by `make test`. 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Aehmc 4 | 5 | [![Pypi][pypi-badge]][pypi] 6 | [![Gitter][gitter-badge]][gitter] 7 | [![Discord][discord-badge]][discord] 8 | [![Twitter][twitter-badge]][twitter] 9 | 10 | AeHMC provides implementations for the HMC and NUTS samplers in [Aesara](https://github.com/aesara-devs/aesara). 11 | 12 | [Features](#features) • 13 | [Get Started](#get-started) • 14 | [Install](#install) • 15 | [Get help](#get-help) • 16 | [Contribute](#contribute) 17 | 18 |
19 | 20 | ## Get started 21 | 22 | ``` python 23 | import aesara 24 | from aesara import tensor as at 25 | from aesara.tensor.random.utils import RandomStream 26 | 27 | from aeppl import joint_logprob 28 | 29 | from aehmc import nuts 30 | 31 | # A simple normal distribution 32 | Y_rv = at.random.normal(0, 1) 33 | 34 | 35 | def logprob_fn(y): 36 | return joint_logprob(realized={Y_rv: y})[0] 37 | 38 | 39 | # Build the transition kernel 40 | srng = RandomStream(seed=0) 41 | kernel = nuts.new_kernel(srng, logprob_fn) 42 | 43 | # Compile a function that updates the chain 44 | y_vv = Y_rv.clone() 45 | initial_state = nuts.new_state(y_vv, logprob_fn) 46 | 47 | step_size = at.as_tensor(1e-2) 48 | inverse_mass_matrix=at.as_tensor(1.0) 49 | chain_info, updates = kernel(initial_state, step_size, inverse_mass_matrix) 50 | 51 | next_step_fn = aesara.function([y_vv], chain_info.state.position, updates=updates) 52 | 53 | print(next_step_fn(0)) 54 | # 1.1034719409361107 55 | ``` 56 | 57 | ## Install 58 | 59 | The latest release of AeHMC can be installed from PyPI using ``pip``: 60 | 61 | ``` bash 62 | pip install aehmc 63 | ``` 64 | 65 | Or via conda-forge: 66 | 67 | ``` bash 68 | conda install -c conda-forge aehmc 69 | ``` 70 | 71 | The current development branch of AeHMC can be installed from GitHub using ``pip``: 72 | 73 | ``` bash 74 | pip install git+https://github.com/aesara-devs/aehmc 75 | ``` 76 | 77 | ## Get help 78 | 79 | Report bugs by opening an [issue][issues]. If you have a question regarding the usage of AeHMC, start a [discussion][discussions]. For real-time feedback or more general chat about AeHMC use our [Discord server][discord] or [Gitter room][gitter]. 80 | 81 | ## Contribute 82 | 83 | AeHMC welcomes contributions. A good place to start contributing is by looking at the [issues][issues]. 84 | 85 | If you want to implement a new feature, open a [discussion][discussions] or come chat with us on [Discord][discord] or [Gitter][gitter]. 86 | 87 | [contributors]: https://github.com/aesara-devs/aehmc/graphs/contributors 88 | [contributors-badge]: https://img.shields.io/github/contributors/aesara-devs/aehmc?style=flat-square&logo=github&logoColor=white&color=ECEFF4 89 | [discussions]: https://github.com/aesara-devs/aehmc/discussions 90 | [downloads-badge]: https://img.shields.io/pypi/dm/aehmc?style=flat-square&logo=pypi&logoColor=white&color=8FBCBB 91 | [discord]: https://discord.gg/h3sjmPYuGJ 92 | [discord-badge]: https://img.shields.io/discord/1072170173785723041?color=81A1C1&logo=discord&logoColor=white&style=flat-square 93 | [gitter]: https://gitter.im/aesara-devs/aehmc 94 | [gitter-badge]: https://img.shields.io/gitter/room/aesara-devs/aehmc?color=81A1C1&logo=matrix&logoColor=white&style=flat-square 95 | [issues]: https://github.com/aesara-devs/aehmc/issues 96 | [releases]: https://github.com/aesara-devs/aehmc/releases 97 | [twitter]: https://twitter.com/AesaraDevs 98 | [twitter-badge]: https://img.shields.io/twitter/follow/AesaraDevs?style=social 99 | [pypi]: https://pypi.org/project/aehmc/ 100 | [pypi-badge]: https://img.shields.io/pypi/v/aehmc?color=ECEFF4&logo=python&logoColor=white&style=flat-square 101 | -------------------------------------------------------------------------------- /aehmc/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from ._version import __version__, __version_tuple__ 3 | except ImportError: # pragma: no cover 4 | raise RuntimeError( 5 | "Unable to find the version number that is generated when either building or " 6 | "installing from source. Please make sure that `aehmc` has been properly " 7 | "installed, e.g. with\n\n pip install -e .\n" 8 | ) 9 | -------------------------------------------------------------------------------- /aehmc/algorithms.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, NamedTuple, Tuple 2 | 3 | import aesara.tensor as at 4 | import numpy as np 5 | from aesara import config 6 | from aesara.tensor.var import TensorVariable 7 | 8 | 9 | class DualAveragingState(NamedTuple): 10 | step: TensorVariable 11 | iterates: TensorVariable 12 | iterates_avg: TensorVariable 13 | gradient_avg: TensorVariable 14 | shrinkage_pts: TensorVariable 15 | 16 | 17 | def dual_averaging( 18 | gamma: float = 0.05, t0: int = 10, kappa: float = 0.75 19 | ) -> Tuple[Callable, Callable]: 20 | """Dual averaging algorithm. 21 | 22 | Dual averaging is an algorithm for stochastic optimization that was 23 | originally proposed by Nesterov in [1]_. 24 | 25 | The update scheme we implement here is more elaborate than the one 26 | described in [1]_. We follow [2]_ and add the parameters `t_0` and `kappa` 27 | which respectively improves the stability of computations in early 28 | iterations and set how fast the algorithm should forget past iterates. 29 | 30 | The default values for the parameters are taken from the Stan implementation [3]_. 31 | 32 | Parameters 33 | ---------- 34 | gamma 35 | Controls the amount of shrinkage towards mu. 36 | t0 37 | Improves the stability of computations early on. 38 | kappa 39 | Controls how fast the algorithm should forget past iterates. 40 | 41 | References 42 | ---------- 43 | .. [1]: Nesterov, Y. (2009). Primal-dual subgradient methods for convex 44 | problems. Mathematical programming, 120(1), 221-259. 45 | 46 | .. [2]: Hoffman, M. D., & Gelman, A. (2014). The No-U-Turn sampler: 47 | adaptively setting path lengths in Hamiltonian Monte Carlo. J. Mach. Learn. 48 | Res., 15(1), 1593-1623. 49 | 50 | .. [3]: Carpenter, B., Gelman, A., Hoffman, M. D., Lee, D., Goodrich, B., 51 | Betancourt, M., ... & Riddell, A. (2017). Stan: A probabilistic programming 52 | language. Journal of statistical software, 76(1), 1-32. 53 | 54 | """ 55 | 56 | def init(mu: TensorVariable) -> DualAveragingState: 57 | """ 58 | Initialize dual averaging state using shrinkage points. 59 | 60 | Parameters 61 | ---------- 62 | mu 63 | Chosen points towards which the successive iterates are shrunk. 64 | 65 | """ 66 | step = at.as_tensor(1, "step", dtype=np.int64) 67 | gradient_avg = at.as_tensor(0, "gradient_avg", dtype=config.floatX) 68 | x_t = at.as_tensor(0.0, "x_t", dtype=config.floatX) 69 | x_avg = at.as_tensor(0.0, "x_avg", dtype=config.floatX) 70 | return DualAveragingState( 71 | step=step, 72 | iterates=x_t, 73 | iterates_avg=x_avg, 74 | gradient_avg=gradient_avg, 75 | shrinkage_pts=mu, 76 | ) 77 | 78 | def update( 79 | gradient: TensorVariable, state: DualAveragingState 80 | ) -> DualAveragingState: 81 | """Update the state of the Dual Averaging algorithm. 82 | 83 | Parameters 84 | ---------- 85 | gradient 86 | The current value of the stochastic gradient. Replaced by a 87 | statistic to optimize in the case of MCMC adaptation. 88 | step 89 | The number of the current step in the optimization process. 90 | x 91 | The current value of the iterate. 92 | x_avg 93 | The current value of the averaged iterates. 94 | gradient_avg 95 | The current value of the averaged gradients. 96 | 97 | Returns 98 | ------- 99 | Updated values for the step number, iterate, averaged iterates and 100 | averaged gradients. 101 | 102 | """ 103 | 104 | eta = 1.0 / (state.step + t0) 105 | new_gradient_avg = (1.0 - eta) * state.gradient_avg + eta * gradient 106 | new_x = state.shrinkage_pts - (at.sqrt(state.step) / gamma) * new_gradient_avg 107 | x_eta = state.step ** (-kappa) 108 | new_x_avg = x_eta * state.iterates + (1.0 - x_eta) * state.iterates_avg 109 | 110 | return state._replace( 111 | step=(state.step + 1).astype(np.int64), 112 | iterates=new_x.astype(config.floatX), 113 | iterates_avg=new_x_avg.astype(config.floatX), 114 | gradient_avg=new_gradient_avg.astype(config.floatX), 115 | ) 116 | 117 | return init, update 118 | 119 | 120 | def welford_covariance(compute_covariance: bool) -> Tuple[Callable, Callable, Callable]: 121 | """Welford's online estimator of variance/covariance. 122 | 123 | It is possible to compute the variance of a population of values in an 124 | on-line fashion to avoid storing intermediate results. The naive recurrence 125 | relations between the sample mean and variance at a step and the next are 126 | however not numerically stable. 127 | 128 | Welford's algorithm uses the sum of square of differences 129 | :math:`M_{2,n} = \\sum_{i=1}^n \\left(x_i-\\overline{x_n}\right)^2` 130 | for updating where :math:`x_n` is the current mean and the following 131 | recurrence relationships 132 | 133 | Parameters 134 | ---------- 135 | compute_covariance 136 | When True the algorithm returns a covariance matrix, otherwise returns 137 | a variance vector. 138 | 139 | """ 140 | 141 | def init(n_dims: int) -> Tuple[TensorVariable, TensorVariable, TensorVariable]: 142 | """Initialize the variance estimation. 143 | 144 | Parameters 145 | ---------- 146 | n_dims: int 147 | The number of dimensions of the problem. 148 | 149 | """ 150 | sample_size = at.as_tensor(0, dtype=np.int64) 151 | 152 | if n_dims == 0: 153 | return ( 154 | at.as_tensor(0.0, dtype=config.floatX), 155 | at.as_tensor(0.0, dtype=config.floatX), 156 | sample_size, 157 | ) 158 | 159 | mean = at.zeros((n_dims,), dtype=config.floatX) 160 | if compute_covariance: 161 | m2 = at.zeros((n_dims, n_dims), dtype=config.floatX) 162 | else: 163 | m2 = at.zeros((n_dims,), dtype=config.floatX) 164 | 165 | return mean, m2, sample_size 166 | 167 | def update( 168 | value: TensorVariable, 169 | mean: TensorVariable, 170 | m2: TensorVariable, 171 | sample_size: TensorVariable, 172 | ) -> Tuple[TensorVariable, TensorVariable, TensorVariable]: 173 | """Update the averages and M2 matrix using the new value. 174 | 175 | Parameters 176 | ---------- 177 | value: Array, shape (1,) 178 | The new sample (typically position of the chain) used to update m2 179 | mean 180 | The running average along each dimension 181 | m2 182 | The running value of the unnormalized variance/covariance 183 | sample_size 184 | The number of points that have currently been used to compute `mean` and `m2`. 185 | 186 | """ 187 | sample_size = sample_size + 1 188 | 189 | delta = value - mean 190 | mean = mean + delta / sample_size 191 | updated_delta = value - mean 192 | if compute_covariance and mean.ndim > 0: 193 | m2 = m2 + at.outer(updated_delta, delta) 194 | else: 195 | m2 = m2 + updated_delta * delta 196 | 197 | return mean, m2, sample_size 198 | 199 | def final(m2: TensorVariable, sample_size: TensorVariable) -> TensorVariable: 200 | """Compute the covariance""" 201 | variance_or_covariance = m2 / (sample_size - 1) 202 | return variance_or_covariance 203 | 204 | return init, update, final 205 | -------------------------------------------------------------------------------- /aehmc/hmc.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Tuple 2 | 3 | import aesara 4 | import aesara.tensor as at 5 | import numpy as np 6 | from aesara.ifelse import ifelse 7 | from aesara.tensor.random.utils import RandomStream 8 | from aesara.tensor.var import TensorVariable 9 | 10 | import aehmc.integrators as integrators 11 | import aehmc.metrics as metrics 12 | import aehmc.trajectory as trajectory 13 | from aehmc.integrators import IntegratorState 14 | 15 | 16 | def new_state(q: TensorVariable, logprob_fn: Callable) -> IntegratorState: 17 | """Create a new HMC state from a position. 18 | 19 | Parameters 20 | ---------- 21 | q 22 | The chain's position. 23 | logprob_fn 24 | The function that computes the value of the log-probability density 25 | function at any position. 26 | 27 | Returns 28 | ------- 29 | A new HMC state, i.e. a tuple with the position, current value of the 30 | potential energy and gradient of the potential energy. 31 | 32 | """ 33 | potential_energy = -logprob_fn(q) 34 | potential_energy_grad = aesara.grad(potential_energy, wrt=q) 35 | return IntegratorState( 36 | position=q, 37 | momentum=None, 38 | potential_energy=potential_energy, 39 | potential_energy_grad=potential_energy_grad, 40 | ) 41 | 42 | 43 | def new_kernel( 44 | srng: RandomStream, 45 | logprob_fn: Callable, 46 | divergence_threshold: int = 1000, 47 | ) -> Callable: 48 | """Build a HMC kernel. 49 | 50 | Parameters 51 | ---------- 52 | srng 53 | A RandomStream object that tracks the changes in a shared random state. 54 | logprob_fn 55 | A function that returns the value of the log-probability density 56 | function of a chain at a given position. 57 | divergence_threshold 58 | The difference in energy above which we say the transition is 59 | divergent. 60 | 61 | Returns 62 | ------- 63 | A kernel that takes the current state of the chain and that returns a new 64 | state. 65 | 66 | References 67 | ---------- 68 | .. [0]: Neal, Radford M. "MCMC using Hamiltonian dynamics." Handbook of markov 69 | chain monte carlo 2.11 (2011): 2. 70 | 71 | 72 | """ 73 | 74 | def potential_fn(x): 75 | return -logprob_fn(x) 76 | 77 | def step( 78 | state: IntegratorState, 79 | step_size: TensorVariable, 80 | inverse_mass_matrix: TensorVariable, 81 | num_integration_steps: int, 82 | ) -> Tuple[trajectory.Diagnostics, Dict]: 83 | """Perform a single step of the HMC algorithm. 84 | 85 | Parameters 86 | ---------- 87 | q 88 | The initial position. 89 | potential_energy 90 | The initial value of the potential energy. 91 | potential_energy_grad 92 | The initial value of the gradient of the potential energy wrt the position. 93 | step_size 94 | The step size used in the symplectic integrator 95 | inverse_mass_matrix 96 | One or two-dimensional array used as the inverse mass matrix that 97 | defines the euclidean metric. 98 | num_integration_steps 99 | The number of times we run the integrator at each step. 100 | 101 | Returns 102 | ------- 103 | A tuple that contains on the one hand: the new position, value of the 104 | potential energy, gradient of the potential energy and acceptance 105 | propbability. On the other hand a dictionaruy that contains the update 106 | rules for the shared variables updated in the scan operator. 107 | 108 | """ 109 | 110 | momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_metric( 111 | inverse_mass_matrix 112 | ) 113 | symplectic_integrator = integrators.velocity_verlet( 114 | potential_fn, kinetic_energy_fn 115 | ) 116 | proposal_generator = hmc_proposal( 117 | symplectic_integrator, 118 | kinetic_energy_fn, 119 | num_integration_steps, 120 | divergence_threshold, 121 | ) 122 | updated_state = state._replace(momentum=momentum_generator(srng)) 123 | chain_info, updates = proposal_generator(srng, updated_state, step_size) 124 | return chain_info, updates 125 | 126 | return step 127 | 128 | 129 | def hmc_proposal( 130 | integrator: Callable, 131 | kinetic_energy: Callable[[TensorVariable], TensorVariable], 132 | num_integration_steps: TensorVariable, 133 | divergence_threshold: int, 134 | ) -> Callable: 135 | """Builds a function that returns a HMC proposal. 136 | 137 | Parameters 138 | -------- 139 | integrator 140 | The symplectic integrator used to integrate the hamiltonian dynamics. 141 | kinetic_energy 142 | The function used to compute the kinetic energy. 143 | num_integration_steps 144 | The number of times we need to run the integrator every time the 145 | returned function is called. 146 | divergence_threshold 147 | The difference in energy above which we say the transition is 148 | divergent. 149 | 150 | Returns 151 | ------- 152 | A function that generates a new state for the chain. 153 | 154 | """ 155 | integrate = trajectory.static_integration(integrator, num_integration_steps) 156 | 157 | def propose( 158 | srng: RandomStream, state: IntegratorState, step_size: TensorVariable 159 | ) -> Tuple[trajectory.Diagnostics, Dict]: 160 | """Use the HMC algorithm to propose a new state. 161 | 162 | Parameters 163 | ---------- 164 | srng 165 | A RandomStream object that tracks the changes in a shared random state. 166 | q 167 | The initial position. 168 | potential_energy 169 | The initial value of the potential energy. 170 | potential_energy_grad 171 | The initial value of the gradient of the potential energy wrt the position. 172 | step_size 173 | The step size used in the symplectic integrator 174 | 175 | Returns 176 | ------- 177 | A tuple that contains on the one hand: the new position, value of the 178 | potential energy, gradient of the potential energy and acceptance 179 | probability. On the other hand a dictionary that contains the update 180 | rules for the shared variables updated in the scan operator. 181 | 182 | """ 183 | new_state, updates = integrate(state, step_size) 184 | # flip the momentum to keep detailed balance 185 | new_state = new_state._replace(momentum=-1.0 * new_state.momentum) 186 | # compute transition-related quantities 187 | energy = state.potential_energy + kinetic_energy(state.momentum) 188 | new_energy = new_state.potential_energy + kinetic_energy(new_state.momentum) 189 | delta_energy = energy - new_energy 190 | delta_energy = at.where(at.isnan(delta_energy), -np.inf, delta_energy) 191 | is_transition_divergent = at.abs(delta_energy) > divergence_threshold 192 | 193 | p_accept = at.clip(at.exp(delta_energy), 0, 1.0) 194 | do_accept = srng.bernoulli(p_accept) 195 | final_state = IntegratorState(*ifelse(do_accept, new_state, state)) 196 | chain_info = trajectory.Diagnostics( 197 | state=final_state, 198 | acceptance_probability=p_accept, 199 | is_diverging=is_transition_divergent, 200 | num_doublings=None, 201 | is_turning=None, 202 | ) 203 | 204 | return chain_info, updates 205 | 206 | return propose 207 | -------------------------------------------------------------------------------- /aehmc/integrators.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, NamedTuple 2 | 3 | import aesara 4 | from aesara.tensor.var import TensorVariable 5 | 6 | 7 | class IntegratorState(NamedTuple): 8 | position: TensorVariable 9 | momentum: TensorVariable 10 | potential_energy: TensorVariable 11 | potential_energy_grad: TensorVariable 12 | 13 | 14 | def new_integrator_state( 15 | potential_fn: Callable, position: TensorVariable, momentum: TensorVariable 16 | ) -> IntegratorState: 17 | """Create a new integrator state from the current values of the position and momentum.""" 18 | potential_energy = potential_fn(position) 19 | return IntegratorState( 20 | position=position, 21 | momentum=momentum, 22 | potential_energy=potential_energy, 23 | potential_energy_grad=aesara.grad(potential_energy, position), 24 | ) 25 | 26 | 27 | def velocity_verlet( 28 | potential_fn: Callable[[TensorVariable], TensorVariable], 29 | kinetic_energy_fn: Callable[[TensorVariable], TensorVariable], 30 | ) -> Callable[[IntegratorState, TensorVariable], IntegratorState]: 31 | """The velocity Verlet (or Verlet-Störmer) integrator. 32 | 33 | The velocity Verlet is a two-stage palindromic integrator [1]_ of the form 34 | (a1, b1, a2, b1, a1) with a1 = 0. It is numerically stable for values of 35 | the step size that range between 0 and 2 (when the mass matrix is the 36 | identity). 37 | 38 | Parameters 39 | ---------- 40 | potential_fn 41 | A function that returns the potential energy of a chain at a given 42 | position. 43 | kinetic_energy_fn 44 | A function that returns the kinetic energy of a chain at a given 45 | position and a given momentum. 46 | 47 | References 48 | ---------- 49 | .. [1]: Bou-Rabee, Nawaf, and Jesús Marıa Sanz-Serna. "Geometric 50 | integrators and the Hamiltonian Monte Carlo method." Acta Numerica 27 51 | (2018): 113-206. 52 | 53 | """ 54 | a1 = 0 55 | b1 = 0.5 56 | a2 = 1 - 2 * a1 57 | 58 | def one_step(state: IntegratorState, step_size: TensorVariable) -> IntegratorState: 59 | momentum = state.momentum - b1 * step_size * state.potential_energy_grad 60 | 61 | kinetic_grad = aesara.grad(kinetic_energy_fn(momentum), momentum) 62 | position = state.position + a2 * step_size * kinetic_grad 63 | 64 | potential_energy = potential_fn(position) 65 | potential_energy_grad = aesara.grad(potential_energy, position) 66 | momentum = momentum - b1 * step_size * potential_energy_grad 67 | 68 | return IntegratorState( 69 | position=position, 70 | momentum=momentum, 71 | potential_energy=potential_energy, 72 | potential_energy_grad=potential_energy_grad, 73 | ) 74 | 75 | return one_step 76 | -------------------------------------------------------------------------------- /aehmc/mass_matrix.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | 3 | import aesara.tensor as at 4 | from aesara import config 5 | from aesara.tensor.var import TensorVariable 6 | 7 | from aehmc import algorithms 8 | 9 | WelfordAlgorithmState = Tuple[TensorVariable, TensorVariable, TensorVariable] 10 | 11 | 12 | def covariance_adaptation( 13 | is_mass_matrix_full: bool = False, 14 | ) -> Tuple[Callable, Callable, Callable]: 15 | """Adapts the values in the mass matrix by computing the covariance 16 | between parameters. 17 | 18 | Parameters 19 | ---------- 20 | is_mass_matrix_full 21 | When False the algorithm adapts and returns a diagonal mass matrix 22 | (default), otherwise adapts and returns a dense mass matrix. 23 | 24 | Returns 25 | ------- 26 | init 27 | A function that initializes the step of the mass matrix adaptation. 28 | update 29 | A function that updates the state of the mass matrix. 30 | final 31 | A function that computes the inverse mass matrix based on the current 32 | state. 33 | """ 34 | 35 | wc_init, wc_update, wc_final = algorithms.welford_covariance(is_mass_matrix_full) 36 | 37 | def init( 38 | n_dims: int, 39 | ) -> Tuple[TensorVariable, WelfordAlgorithmState]: 40 | """Initialize the mass matrix adaptation. 41 | 42 | Parameters 43 | ---------- 44 | ndims 45 | The number of dimensions of the mass matrix, which corresponds to 46 | the number of dimensions of the chain position. 47 | 48 | Returns 49 | ------- 50 | The initial value of the mass matrix and the initial state of the 51 | Welford covariance algorithm. 52 | 53 | """ 54 | if n_dims == 0: 55 | inverse_mass_matrix = at.constant(1.0, dtype=config.floatX) 56 | elif is_mass_matrix_full: 57 | inverse_mass_matrix = at.eye(n_dims, dtype=config.floatX) 58 | else: 59 | inverse_mass_matrix = at.ones((n_dims,), dtype=config.floatX) 60 | 61 | wc_state = wc_init(n_dims) 62 | 63 | return inverse_mass_matrix, wc_state 64 | 65 | def update( 66 | position: TensorVariable, wc_state: WelfordAlgorithmState 67 | ) -> WelfordAlgorithmState: 68 | """Update the algorithm's state. 69 | 70 | Parameters 71 | ---------- 72 | position 73 | The current position of the chain. 74 | wc_state 75 | Current state of Welford's algorithm to compute covariance. 76 | 77 | """ 78 | new_wc_state = wc_update(position, *wc_state) 79 | return new_wc_state 80 | 81 | def final(wc_state: WelfordAlgorithmState) -> TensorVariable: 82 | """Final iteration of the mass matrix adaptation. 83 | 84 | In this step we compute the mass matrix from the covariance matrix computed 85 | by the Welford algorithm, applying the shrinkage used in Stan [1]_. 86 | 87 | Parameters 88 | ---------- 89 | wc_state 90 | Current state of Welford's algorithm to compute covariance. 91 | 92 | Returns 93 | ------- 94 | The value of the inverse mass matrix computed from the covariance estimate. 95 | 96 | References 97 | ---------- 98 | .. [1]: Carpenter, B., Gelman, A., Hoffman, M. D., Lee, D., Goodrich, B., 99 | Betancourt, M., ... & Riddell, A. (2017). Stan: A probabilistic programming 100 | language. Journal of statistical software, 76(1), 1-32. 101 | 102 | """ 103 | _, m2, sample_size = wc_state 104 | covariance = wc_final(m2, sample_size) 105 | 106 | scaled_covariance = (sample_size / (sample_size + 5)) * covariance 107 | shrinkage = 1e-3 * (5 / (sample_size + 5)) 108 | if covariance.ndim > 0: 109 | if is_mass_matrix_full: 110 | new_inverse_mass_matrix = ( 111 | scaled_covariance + shrinkage * at.identity_like(covariance) 112 | ) 113 | else: 114 | new_inverse_mass_matrix = scaled_covariance + shrinkage 115 | else: 116 | new_inverse_mass_matrix = scaled_covariance + shrinkage 117 | 118 | return new_inverse_mass_matrix 119 | 120 | return init, update, final 121 | -------------------------------------------------------------------------------- /aehmc/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | 3 | import aesara.tensor as at 4 | from aesara.tensor.random.utils import RandomStream 5 | from aesara.tensor.shape import shape_tuple 6 | from aesara.tensor.slinalg import cholesky, solve_triangular 7 | from aesara.tensor.var import TensorVariable 8 | 9 | 10 | def gaussian_metric( 11 | inverse_mass_matrix: TensorVariable, 12 | ) -> Tuple[Callable, Callable, Callable]: 13 | r"""Hamiltonian dynamic on euclidean manifold with normally-distributed momentum. 14 | 15 | The gaussian euclidean metric is a euclidean metric further characterized 16 | by setting the conditional probability density :math:`\pi(momentum|position)` 17 | to follow a standard gaussian distribution. A Newtonian hamiltonian 18 | dynamics is assumed. 19 | 20 | Arguments 21 | --------- 22 | inverse_mass_matrix 23 | One or two-dimensional array corresponding respectively to a diagonal 24 | or dense mass matrix. 25 | 26 | Returns 27 | ------- 28 | momentum_generator 29 | A function that generates a value for the momentum at random. 30 | kinetic_energy 31 | A function that returns the kinetic energy given the momentum. 32 | is_turning 33 | A function that determines whether a trajectory is turning back on 34 | itself given the values of the momentum along the trajectory. 35 | 36 | References 37 | ---------- 38 | .. [1]: Betancourt, Michael. "A general metric for Riemannian manifold 39 | Hamiltonian Monte Carlo." International Conference on Geometric Science of 40 | Information. Springer, Berlin, Heidelberg, 2013. 41 | 42 | """ 43 | 44 | if inverse_mass_matrix.ndim == 0: 45 | shape: Tuple = () 46 | mass_matrix_sqrt = at.sqrt(at.reciprocal(inverse_mass_matrix)) 47 | dot, matmul = lambda x, y: x * y, lambda x, y: x * y 48 | elif inverse_mass_matrix.ndim == 1: 49 | shape = (shape_tuple(inverse_mass_matrix)[0],) 50 | mass_matrix_sqrt = at.sqrt(at.reciprocal(inverse_mass_matrix)) 51 | dot, matmul = at.dot, lambda x, y: x * y 52 | elif inverse_mass_matrix.ndim == 2: 53 | # inverse mass matrix can be factored into L*L.T. We want the cholesky 54 | # factor (inverse of L.T) of the mass matrix. 55 | shape = (shape_tuple(inverse_mass_matrix)[0],) 56 | L = cholesky(inverse_mass_matrix) 57 | identity = at.eye(*shape) 58 | mass_matrix_sqrt = solve_triangular(L, identity, lower=True, trans=True) 59 | dot, matmul = at.dot, at.dot 60 | else: 61 | raise ValueError( 62 | f"Expected a mass matrix of dimension 1 (diagonal) or 2, got {inverse_mass_matrix.ndim}" 63 | ) 64 | 65 | def momentum_generator(srng: RandomStream) -> TensorVariable: 66 | norm_samples = srng.normal(0, 1, size=shape, name="momentum") 67 | momentum = matmul(mass_matrix_sqrt, norm_samples) 68 | return momentum 69 | 70 | def kinetic_energy(momentum: TensorVariable) -> TensorVariable: 71 | velocity = matmul(inverse_mass_matrix, momentum) 72 | kinetic_energy = 0.5 * dot(velocity, momentum) 73 | return kinetic_energy 74 | 75 | def is_turning( 76 | momentum_left: TensorVariable, 77 | momentum_right: TensorVariable, 78 | momentum_sum: TensorVariable, 79 | ) -> bool: 80 | """Generalized U-turn criterion. 81 | 82 | Parameters 83 | ---------- 84 | momentum_left 85 | Momentum of the leftmost point of the trajectory. 86 | momentum_right 87 | Momentum of the rightmost point of the trajectory. 88 | momentum_sum 89 | Sum of the momenta along the trajectory. 90 | 91 | .. [1]: Betancourt, Michael J. "Generalizing the no-U-turn sampler to Riemannian manifolds." arXiv preprint arXiv:1304.1920 (2013). 92 | .. [2]: "NUTS misses U-turn, runs in cicles until max depth", Stan Discourse Forum 93 | https://discourse.mc-stan.org/t/nuts-misses-u-turns-runs-in-circles-until-max-treedepth/9727/46 94 | """ 95 | velocity_left = matmul(inverse_mass_matrix, momentum_left) 96 | velocity_right = matmul(inverse_mass_matrix, momentum_right) 97 | 98 | rho = momentum_sum - (momentum_right + momentum_left) / 2 99 | turning_at_left = at.dot(velocity_left, rho) <= 0 100 | turning_at_right = at.dot(velocity_right, rho) <= 0 101 | 102 | is_turning = turning_at_left | turning_at_right 103 | 104 | return is_turning 105 | 106 | return momentum_generator, kinetic_energy, is_turning 107 | -------------------------------------------------------------------------------- /aehmc/nuts.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Tuple 2 | 3 | import aesara.tensor as at 4 | import numpy as np 5 | from aesara.tensor.random.utils import RandomStream 6 | from aesara.tensor.var import TensorVariable 7 | 8 | from aehmc import hmc, integrators, metrics 9 | from aehmc.integrators import IntegratorState 10 | from aehmc.proposals import ProposalState 11 | from aehmc.termination import iterative_uturn 12 | from aehmc.trajectory import Diagnostics, dynamic_integration, multiplicative_expansion 13 | 14 | new_state = hmc.new_state 15 | 16 | 17 | def new_kernel( 18 | srng: RandomStream, 19 | logprob_fn: Callable[[TensorVariable], TensorVariable], 20 | max_num_expansions: int = 10, 21 | divergence_threshold: int = 1000, 22 | ) -> Callable: 23 | """Build an iterative NUTS kernel. 24 | 25 | Parameters 26 | ---------- 27 | srng 28 | A RandomStream object that tracks the changes in a shared random state. 29 | logprob_fn 30 | A function that returns the value of the log-probability density 31 | function of a chain at a given position. 32 | max_num_expansions 33 | The maximum number of times we double the length of the trajectory. 34 | Known as the maximum tree depth in most implementations. 35 | divergence_threshold 36 | The difference in energy above which we say the transition is 37 | divergent. 38 | 39 | Returns 40 | ------- 41 | A function which, given a chain state, returns a new chain state. 42 | 43 | References 44 | ---------- 45 | .. [0]: Phan, Du, Neeraj Pradhan, and Martin Jankowiak. "Composable effects 46 | for flexible and accelerated probabilistic programming in NumPyro." arXiv 47 | preprint arXiv:1912.11554 (2019). 48 | .. [1]: Lao, Junpeng, et al. "tfp. mcmc: Modern markov chain monte carlo 49 | tools built for modern hardware." arXiv preprint arXiv:2002.01184 (2020). 50 | 51 | """ 52 | 53 | def potential_fn(x): 54 | return -logprob_fn(x) 55 | 56 | def step( 57 | state: IntegratorState, 58 | step_size: TensorVariable, 59 | inverse_mass_matrix: TensorVariable, 60 | ) -> Tuple[Diagnostics, Dict]: 61 | """Use the NUTS algorithm to propose a new state. 62 | 63 | Parameters 64 | ---------- 65 | q 66 | The initial position. 67 | potential_energy 68 | The initial value of the potential energy. 69 | potential_energy_grad 70 | The initial value of the gradient of the potential energy wrt the position. 71 | step_size 72 | The step size used in the symplectic integrator 73 | inverse_mass_matrix 74 | One or two-dimensional array used as the inverse mass matrix that 75 | defines the euclidean metric. 76 | 77 | Returns 78 | ------- 79 | A tuple that contains on the one hand: the new position, value of the 80 | potential energy, gradient of the potential energy, the acceptance 81 | probability, the number of times the trajectory expanded, whether the 82 | integration diverged, whether the trajectory turned on itself. On the 83 | other hand a dictionary that contains the update rules for the shared 84 | variables updated in the scan operator. 85 | 86 | """ 87 | momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_metric( 88 | inverse_mass_matrix 89 | ) 90 | symplectic_integrator = integrators.velocity_verlet( 91 | potential_fn, kinetic_energy_fn 92 | ) 93 | ( 94 | new_termination_state, 95 | update_termination_state, 96 | is_criterion_met, 97 | ) = iterative_uturn(uturn_check_fn) 98 | trajectory_integrator = dynamic_integration( 99 | srng, 100 | symplectic_integrator, 101 | kinetic_energy_fn, 102 | update_termination_state, 103 | is_criterion_met, 104 | divergence_threshold, 105 | ) 106 | expand = multiplicative_expansion( 107 | srng, 108 | trajectory_integrator, 109 | uturn_check_fn, 110 | max_num_expansions, 111 | ) 112 | 113 | initial_state = state._replace(momentum=momentum_generator(srng)) 114 | initial_termination_state = new_termination_state( 115 | initial_state.position, max_num_expansions 116 | ) 117 | initial_energy = initial_state.potential_energy + kinetic_energy_fn( 118 | initial_state.momentum 119 | ) 120 | initial_proposal = ProposalState( 121 | state=initial_state, 122 | energy=initial_energy, 123 | weight=at.as_tensor(0.0, dtype=np.float64), 124 | sum_log_p_accept=at.as_tensor(-np.inf, dtype=np.float64), 125 | ) 126 | 127 | results, updates = expand( 128 | initial_proposal, 129 | initial_state, 130 | initial_state, 131 | initial_state.momentum, 132 | initial_termination_state, 133 | initial_energy, 134 | step_size, 135 | ) 136 | 137 | # extract the last iteration from multiplicative_expansion chain diagnostics 138 | chain_info = Diagnostics( 139 | state=IntegratorState( 140 | position=results.diagnostics.state.position[-1], 141 | momentum=results.diagnostics.state.momentum[-1], 142 | potential_energy=results.diagnostics.state.potential_energy[-1], 143 | potential_energy_grad=results.diagnostics.state.potential_energy_grad[ 144 | -1 145 | ], 146 | ), 147 | acceptance_probability=results.diagnostics.acceptance_probability[-1], 148 | num_doublings=results.diagnostics.num_doublings[-1], 149 | is_turning=results.diagnostics.is_turning[-1], 150 | is_diverging=results.diagnostics.is_diverging[-1], 151 | ) 152 | 153 | return chain_info, updates 154 | 155 | return step 156 | -------------------------------------------------------------------------------- /aehmc/proposals.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, NamedTuple, Tuple 2 | 3 | import aesara.tensor as at 4 | import numpy as np 5 | from aesara.tensor.random.utils import RandomStream 6 | from aesara.tensor.var import TensorVariable 7 | 8 | from aehmc.integrators import IntegratorState 9 | 10 | 11 | class ProposalState(NamedTuple): 12 | state: IntegratorState 13 | energy: TensorVariable 14 | weight: TensorVariable 15 | sum_log_p_accept: TensorVariable 16 | 17 | 18 | def proposal_generator(kinetic_energy: Callable, divergence_threshold: float): 19 | def update(initial_energy, state: IntegratorState) -> Tuple[ProposalState, bool]: 20 | """Generate a new proposal from a trajectory state. 21 | 22 | The trajectory state records information about the position in the state 23 | space and corresponding potential energy. A proposal also carries a 24 | weight that is equal to the difference between the current energy and 25 | the previous one. It thus carries information about the previous state 26 | as well as the current state. 27 | 28 | Parameters 29 | ---------- 30 | initial_energy: 31 | The initial energy. 32 | state: 33 | The new state. 34 | 35 | Return 36 | ------ 37 | A tuple that contains the new proposal and a boolean that indicates 38 | whether the current transition is divergent. 39 | 40 | """ 41 | new_energy = state.potential_energy + kinetic_energy(state.momentum) 42 | 43 | delta_energy = initial_energy - new_energy 44 | delta_energy = at.where(at.isnan(delta_energy), -np.inf, delta_energy) 45 | is_transition_divergent = at.abs(delta_energy) > divergence_threshold 46 | 47 | weight = delta_energy 48 | log_p_accept = at.where( 49 | at.gt(delta_energy, 0), 50 | at.as_tensor(0, dtype=delta_energy.dtype), 51 | delta_energy, 52 | ) 53 | 54 | return ( 55 | ProposalState( 56 | state=state, 57 | energy=new_energy, 58 | weight=weight, 59 | sum_log_p_accept=log_p_accept, 60 | ), 61 | is_transition_divergent, 62 | ) 63 | 64 | return update 65 | 66 | 67 | # ------------------------------------------------------------------- 68 | # PROGRESSIVE SAMPLING 69 | # ------------------------------------------------------------------- 70 | 71 | 72 | def progressive_uniform_sampling( 73 | srng: RandomStream, proposal: ProposalState, new_proposal: ProposalState 74 | ) -> ProposalState: 75 | """Uniform proposal sampling. 76 | 77 | Choose between the current proposal and the proposal built from the last 78 | trajectory state. 79 | 80 | Parameters 81 | ---------- 82 | srng 83 | RandomStream object 84 | proposal 85 | The current proposal, it does not necessarily correspond to the 86 | previous state on the trajectory 87 | new_proposal 88 | The proposal built from the last trajectory state. 89 | 90 | Return 91 | ------ 92 | Either the current or the new proposal. 93 | 94 | """ 95 | # TODO: Make the `at.isnan` check unnecessary 96 | p_accept = at.expit(new_proposal.weight - proposal.weight) 97 | p_accept = at.where(at.isnan(p_accept), 0, p_accept) 98 | 99 | do_accept = srng.bernoulli(p_accept) 100 | updated_proposal = maybe_update_proposal(do_accept, proposal, new_proposal) 101 | 102 | return updated_proposal 103 | 104 | 105 | def progressive_biased_sampling( 106 | srng: RandomStream, proposal: ProposalState, new_proposal: ProposalState 107 | ) -> ProposalState: 108 | """Baised proposal sampling. 109 | 110 | Choose between the current proposal and the proposal built from the last 111 | trajectory state. Unlike uniform sampling, biased sampling favors new 112 | proposals. It thus biases the transition away from the trajectory's initial 113 | state. 114 | 115 | Parameters 116 | ---------- 117 | srng 118 | RandomStream object 119 | proposal 120 | The current proposal, it does not necessarily correspond to the 121 | previous state on the trajectory 122 | new_proposal 123 | The proposal built from the last trajectory state. 124 | 125 | Return 126 | ------ 127 | Either the current or the new proposal. 128 | 129 | """ 130 | p_accept = at.clip(at.exp(new_proposal.weight - proposal.weight), 0.0, 1.0) 131 | do_accept = srng.bernoulli(p_accept) 132 | updated_proposal = maybe_update_proposal(do_accept, proposal, new_proposal) 133 | 134 | return updated_proposal 135 | 136 | 137 | def maybe_update_proposal( 138 | do_accept: bool, proposal: ProposalState, new_proposal: ProposalState 139 | ) -> ProposalState: 140 | """Return either proposal depending on the boolean `do_accept`""" 141 | updated_weight = at.logaddexp(proposal.weight, new_proposal.weight) 142 | updated_log_sum_p_accept = at.logaddexp( 143 | proposal.sum_log_p_accept, new_proposal.sum_log_p_accept 144 | ) 145 | 146 | updated_q = at.where( 147 | do_accept, new_proposal.state.position, proposal.state.position 148 | ) 149 | updated_p = at.where( 150 | do_accept, new_proposal.state.momentum, proposal.state.momentum 151 | ) 152 | updated_potential_energy = at.where( 153 | do_accept, new_proposal.state.potential_energy, proposal.state.potential_energy 154 | ) 155 | updated_potential_energy_grad = at.where( 156 | do_accept, 157 | new_proposal.state.potential_energy_grad, 158 | proposal.state.potential_energy_grad, 159 | ) 160 | updated_energy = at.where(do_accept, new_proposal.energy, proposal.energy) 161 | 162 | updated_state = IntegratorState( 163 | position=updated_q, 164 | momentum=updated_p, 165 | potential_energy=updated_potential_energy, 166 | potential_energy_grad=updated_potential_energy_grad, 167 | ) 168 | 169 | return ProposalState( 170 | state=updated_state, 171 | energy=updated_energy, 172 | weight=updated_weight, 173 | sum_log_p_accept=updated_log_sum_p_accept, 174 | ) 175 | -------------------------------------------------------------------------------- /aehmc/step_size.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | 3 | import aesara.tensor as at 4 | from aesara.tensor.var import TensorVariable 5 | 6 | from aehmc import algorithms 7 | 8 | 9 | def dual_averaging_adaptation( 10 | target_acceptance_rate: TensorVariable = at.as_tensor(0.8), 11 | gamma: float = 0.05, 12 | t0: int = 10, 13 | kappa: float = 0.75, 14 | ) -> Tuple[Callable, Callable]: 15 | r"""Tune the step size to achieve a desired target acceptance rate. 16 | 17 | Let us note :math:`\epsilon` the current step size, :math:`\alpha_t` the 18 | metropolis acceptance rate at time :math:`t` and :math:`\delta` the desired 19 | aceptance rate. We define: 20 | 21 | .. math: 22 | H_t = \delta - \alpha_t 23 | 24 | the error at time t. We would like to find a procedure that adapts the 25 | value of :math:`\epsilon` such that :math:`h(x) =\mathbb{E}\left[H_t|\epsilon\right] = 0` 26 | Following [1]_, the authors of [2]_ proposed the following update scheme. If 27 | we note :math:``x = \log \epsilon` we follow: 28 | 29 | .. math: 30 | x_{t+1} \LongLeftArrow \mu - \frac{\sqrt{t}}{\gamma} \frac{1}{t+t_0} \sum_{i=1}^t H_i 31 | \overline{x}_{t+1} \LongLeftArrow x_{t+1}\\, t^{-\kappa} + \left(1-t^\kappa\right)\overline{x}_t 32 | 33 | :math:`\overline{x}_{t}` is guaranteed to converge to a value such that 34 | :math:`h(\overline{x}_t)` converges to 0, i.e. the Metropolis acceptance 35 | rate converges to the desired rate. 36 | 37 | See reference [2]_ (section 3.2.1) for a detailed discussion. 38 | 39 | Parameters 40 | ---------- 41 | initial_log_step_size: 42 | Initial value of the logarithm of the step size, used as an iterate in 43 | the dual averaging algorithm. 44 | target_acceptance_rate: 45 | Target acceptance rate. 46 | gamma 47 | Controls the speed of convergence of the scheme. The authors of [2]_ recommend 48 | a value of 0.05. 49 | t0: float >= 0 50 | Free parameter that stabilizes the initial iterations of the algorithm. 51 | Large values may slow down convergence. Introduced in [2]_ with a default 52 | value of 10. 53 | kappa: float in ]0.5, 1] 54 | Controls the weights of past steps in the current update. The scheme will 55 | quickly forget earlier step for a small value of `kappa`. Introduced 56 | in [2]_, with a recommended value of .75 57 | 58 | Returns 59 | ------- 60 | init 61 | A function that initializes the state of the dual averaging scheme. 62 | update 63 | A function that updates the state of the dual averaging scheme. 64 | 65 | References 66 | ---------- 67 | .. [1]: Nesterov, Yurii. "Primal-dual subgradient methods for convex 68 | problems." Mathematical programming 120.1 (2009): 221-259. 69 | .. [2]: Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler: 70 | adaptively setting path lengths in Hamiltonian Monte Carlo." Journal 71 | of Machine Learning Research 15.1 (2014): 1593-1623. 72 | 73 | """ 74 | da_init, da_update = algorithms.dual_averaging(gamma, t0, kappa) 75 | 76 | def update( 77 | acceptance_probability: TensorVariable, state: algorithms.DualAveragingState 78 | ) -> algorithms.DualAveragingState: 79 | """Update the dual averaging adaptation state. 80 | 81 | Parameters 82 | ---------- 83 | acceptance_probability 84 | The acceptance probability returned by the sampling algorithm. 85 | step 86 | The current step number. 87 | log_step_size 88 | The logarithm of the current value of the current step size. 89 | log_step_size_avg 90 | The average of the logarithm of the current step size. 91 | gradient_avg 92 | The current value of the averaged gradients. 93 | mu 94 | The points towards which iterates are shrunk. 95 | 96 | """ 97 | gradient = target_acceptance_rate - acceptance_probability 98 | return da_update(gradient, state) 99 | 100 | return da_init, update 101 | -------------------------------------------------------------------------------- /aehmc/termination.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, NamedTuple, Tuple 2 | 3 | import aesara 4 | import aesara.tensor as at 5 | import numpy as np 6 | from aesara import config as config 7 | from aesara.ifelse import ifelse 8 | from aesara.scan.utils import until 9 | from aesara.tensor.var import TensorVariable 10 | 11 | 12 | class TerminationState(NamedTuple): 13 | momentum_checkpoints: TensorVariable 14 | momentum_sum_checkpoints: TensorVariable 15 | min_index: TensorVariable 16 | max_index: TensorVariable 17 | 18 | 19 | def iterative_uturn(is_turning_fn: Callable) -> Tuple[Callable, Callable, Callable]: 20 | """U-Turn termination criterion to check reversiblity while expanding 21 | the trajectory. 22 | 23 | The code follows the implementation in Numpyro [0]_, which is equivalent to 24 | that in TFP [1]_. 25 | 26 | Parameter 27 | --------- 28 | is_turning_fn: 29 | A function which, given the new momentum and the sum of the momenta 30 | along the trajectory returns a boolean that indicates whether the 31 | trajectory is turning on itself. Depends on the metric. 32 | 33 | References 34 | ---------- 35 | .. [0]: Phan, Du, Neeraj Pradhan, and Martin Jankowiak. "Composable effects 36 | for flexible and accelerated probabilistic programming in NumPyro." arXiv 37 | preprint arXiv:1912.11554 (2019). 38 | .. [1]: Lao, Junpeng, et al. "tfp. mcmc: Modern markov chain monte carlo 39 | tools built for modern hardware." arXiv preprint arXiv:2002.01184 (2020). 40 | 41 | """ 42 | 43 | def new_state( 44 | position: TensorVariable, max_num_doublings: TensorVariable 45 | ) -> TerminationState: 46 | """Initialize the termination state 47 | 48 | Parameters 49 | ---------- 50 | position 51 | Example chain position. Used to infer the shape of the arrays that 52 | store relevant momentam and momentum sums. 53 | max_num_doublings 54 | Maximum number of doublings allowed in the multiplicative 55 | expansion. Determines the maximum number of momenta and momentum 56 | sums to store. 57 | 58 | Returns 59 | ------- 60 | A tuple that represents a new state for the termination criterion. 61 | 62 | """ 63 | if position.ndim == 0: 64 | return TerminationState( 65 | momentum_checkpoints=at.zeros(max_num_doublings, dtype=config.floatX), 66 | momentum_sum_checkpoints=at.zeros( 67 | max_num_doublings, dtype=config.floatX 68 | ), 69 | min_index=at.constant(0, dtype=np.int64), 70 | max_index=at.constant(0, dtype=np.int64), 71 | ) 72 | else: 73 | num_dims = position.shape[0] 74 | return TerminationState( 75 | momentum_checkpoints=at.zeros( 76 | (max_num_doublings, num_dims), dtype=config.floatX 77 | ), 78 | momentum_sum_checkpoints=at.zeros( 79 | (max_num_doublings, num_dims), dtype=config.floatX 80 | ), 81 | min_index=at.constant(0, dtype=np.int64), 82 | max_index=at.constant(0, dtype=np.int64), 83 | ) 84 | 85 | def update( 86 | state: TerminationState, 87 | momentum_sum: TensorVariable, 88 | momentum: TensorVariable, 89 | step: TensorVariable, 90 | ) -> TerminationState: 91 | """Update the termination state. 92 | 93 | Parameters 94 | ---------- 95 | state 96 | The current termination state 97 | momentum_sum 98 | The sum of all momenta along the trajectory 99 | momentum 100 | The current momentum on the trajectory 101 | step 102 | Current step in the trajectory integration (starting at 0) 103 | 104 | Return 105 | ------ 106 | A tuple that represents the updated termination state. 107 | 108 | """ 109 | idx_min, idx_max = ifelse( 110 | at.eq(step, 0), 111 | (state.min_index, state.max_index), 112 | _find_storage_indices(step), 113 | ) 114 | 115 | momentum_ckpt = at.where( 116 | at.eq(step % 2, 0), 117 | at.set_subtensor(state.momentum_checkpoints[idx_max], momentum), 118 | state.momentum_checkpoints, 119 | ) 120 | momentum_sum_ckpt = at.where( 121 | at.eq(step % 2, 0), 122 | at.set_subtensor(state.momentum_sum_checkpoints[idx_max], momentum_sum), 123 | state.momentum_sum_checkpoints, 124 | ) 125 | 126 | return TerminationState( 127 | momentum_checkpoints=momentum_ckpt, 128 | momentum_sum_checkpoints=momentum_sum_ckpt, 129 | min_index=idx_min, 130 | max_index=idx_max, 131 | ) 132 | 133 | def is_iterative_turning( 134 | state: TerminationState, momentum_sum: TensorVariable, momentum: TensorVariable 135 | ) -> bool: 136 | """Check if any sub-trajectory is making a U-turn. 137 | 138 | If we visualize the trajectory as a balanced binary tree, the 139 | subtrajectories for which we need to check the U-turn criterion are the 140 | ones for which the current node is the rightmost node. The 141 | corresponding momenta and sums of momentum corresponding to the nodes 142 | for which we need to check the U-Turn criterion are stored between 143 | `idx_min` and `idx_max` in `momentum_ckpts` and `momentum_sum_ckpts` 144 | respectively. 145 | 146 | Parameters 147 | ---------- 148 | state 149 | The current termination state 150 | momentum_sum 151 | The sum of all momenta along the trajectory 152 | momentum 153 | The current momentum on the trajectory 154 | step 155 | Current step in the trajectory integration (starting at 0) 156 | 157 | 158 | Return 159 | ------ 160 | True if any sub-trajectory makes a U-turn, False otherwise. 161 | 162 | """ 163 | 164 | def body_fn(i): 165 | subtree_momentum_sum = ( 166 | momentum_sum 167 | - state.momentum_sum_checkpoints[i] 168 | + state.momentum_checkpoints[i] 169 | ) 170 | is_turning = is_turning_fn( 171 | state.momentum_checkpoints[i], momentum, subtree_momentum_sum 172 | ) 173 | reached_max_iteration = at.lt(i - 1, state.min_index) 174 | do_stop = at.any(is_turning | reached_max_iteration) 175 | return (i - 1, is_turning), until(do_stop) 176 | 177 | (_, criterion), _ = aesara.scan( 178 | body_fn, outputs_info=(state.max_index, None), n_steps=state.max_index + 2 179 | ) 180 | 181 | is_turning = at.where( 182 | at.lt(state.max_index, state.min_index), 183 | at.as_tensor(0, dtype="bool"), 184 | criterion[-1], 185 | ) 186 | 187 | return is_turning 188 | 189 | return new_state, update, is_iterative_turning 190 | 191 | 192 | def _find_storage_indices(step: TensorVariable) -> Tuple[int, int]: 193 | """Find the indices between which the momenta and sums are stored. 194 | 195 | Parameter 196 | --------- 197 | step 198 | The current step in the trajectory integration. 199 | 200 | Return 201 | ------ 202 | The min and max indices between which the values relevant to check the 203 | U-turn condition for the current step are stored. 204 | 205 | """ 206 | 207 | def count_subtrees(nc0, nc1): 208 | do_stop = at.eq(nc0 & 1, 0) 209 | new_nc0 = nc0 // 2 210 | new_nc1 = nc1 + 1 211 | return (new_nc0, new_nc1), until(do_stop) 212 | 213 | (_, nc1), _ = aesara.scan( 214 | count_subtrees, 215 | outputs_info=(step, -1), 216 | n_steps=step + 1, 217 | ) 218 | num_subtrees = nc1[-1] 219 | 220 | def find_idx_max(nc0, nc1): 221 | do_stop = at.eq(nc0, 0) 222 | new_nc0 = nc0 // 2 223 | new_nc1 = nc1 + (nc0 & 1) 224 | return (new_nc0, new_nc1), until(do_stop) 225 | 226 | init = at.as_tensor(step // 2, dtype=np.int64) 227 | init_nc1 = at.constant(0, dtype=np.int64) 228 | (nc0, nc1), _ = aesara.scan( 229 | find_idx_max, outputs_info=(init, init_nc1), n_steps=step + 1 230 | ) 231 | idx_max = nc1[-1] 232 | 233 | idx_min = idx_max - num_subtrees + 1 234 | 235 | return idx_min, idx_max 236 | -------------------------------------------------------------------------------- /aehmc/trajectory.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, NamedTuple, Tuple 2 | 3 | import aesara 4 | import aesara.tensor as at 5 | import numpy as np 6 | from aesara.ifelse import ifelse 7 | from aesara.scan.utils import until 8 | from aesara.tensor.random.utils import RandomStream 9 | from aesara.tensor.var import TensorVariable 10 | 11 | from aehmc.integrators import IntegratorState 12 | from aehmc.proposals import ( 13 | ProposalState, 14 | progressive_biased_sampling, 15 | progressive_uniform_sampling, 16 | proposal_generator, 17 | ) 18 | from aehmc.termination import TerminationState 19 | 20 | __all__ = ["static_integration", "dynamic_integration", "multiplicative_expansion"] 21 | 22 | 23 | # ------------------------------------------------------------------- 24 | # STATIC INTEGRATION 25 | # 26 | # This section contains algorithms that integrate the trajectory for 27 | # a set number of integrator steps. 28 | # ------------------------------------------------------------------- 29 | 30 | 31 | def static_integration( 32 | integrator: Callable, 33 | num_integration_steps: int, 34 | ) -> Callable: 35 | """Build a function that generates fixed-length trajectories. 36 | 37 | Parameters 38 | ---------- 39 | integrator 40 | Function that performs one integration step. 41 | num_integration_steps 42 | The number of times we need to run the integrator every time the 43 | returned function is called. 44 | 45 | Returns 46 | ------- 47 | A function that integrates the hamiltonian dynamics a 48 | `num_integration_steps` times. 49 | 50 | """ 51 | 52 | def integrate( 53 | init_state: IntegratorState, step_size: TensorVariable 54 | ) -> Tuple[IntegratorState, Dict]: 55 | """Generate a trajectory by integrating several times in one direction. 56 | 57 | Parameters 58 | ---------- 59 | q_init 60 | The initial position. 61 | p_init 62 | The initial value of the momentum. 63 | energy_init 64 | The initial value of the potential energy. 65 | energy_grad_init 66 | The initial value of the gradient of the potential energy wrt the position. 67 | step_size 68 | The size of each step taken by the integrator. 69 | 70 | Returns 71 | ------- 72 | A tuple with the last position, value of the momentum, potential energy, 73 | gradient of the potential energy wrt the position in a tuple as well as 74 | a dictionary that contains the update rules for all the shared variables 75 | updated in `scan`. 76 | 77 | """ 78 | 79 | def one_step(q, p, potential_energy, potential_energy_grad): 80 | new_state = integrator( 81 | IntegratorState(q, p, potential_energy, potential_energy_grad), 82 | step_size, 83 | ) 84 | return new_state 85 | 86 | [q, p, energy, energy_grad], updates = aesara.scan( 87 | fn=one_step, 88 | outputs_info=[ 89 | {"initial": init_state.position}, 90 | {"initial": init_state.momentum}, 91 | {"initial": init_state.potential_energy}, 92 | {"initial": init_state.potential_energy_grad}, 93 | ], 94 | n_steps=num_integration_steps, 95 | ) 96 | 97 | return ( 98 | IntegratorState( 99 | position=q[-1], 100 | momentum=p[-1], 101 | potential_energy=energy[-1], 102 | potential_energy_grad=energy_grad[-1], 103 | ), 104 | updates, 105 | ) 106 | 107 | return integrate 108 | 109 | 110 | # ------------------------------------------------------------------- 111 | # DYNAMIC INTEGRATION 112 | # 113 | # This section contains algorithms that determine the number of 114 | # integrator steps dynamically using a termination criterion that 115 | # is updated at every step. 116 | # ------------------------------------------------------------------- 117 | 118 | 119 | def dynamic_integration( 120 | srng: RandomStream, 121 | integrator: Callable, 122 | kinetic_energy: Callable, 123 | update_termination_state: Callable, 124 | is_criterion_met: Callable, 125 | divergence_threshold: TensorVariable, 126 | ): 127 | """Integrate a trajectory and update the proposal sequentially in one direction 128 | until the termination criterion is met. 129 | 130 | Parameters 131 | ---------- 132 | srng 133 | A RandomStream object that tracks the changes in a shared random state. 134 | integrator 135 | The symplectic integrator used to integrate the hamiltonian dynamics. 136 | kinetic_energy 137 | Function to compute the current value of the kinetic energy. 138 | update_termination_state 139 | Updates the state of the termination mechanism. 140 | is_criterion_met 141 | Determines whether the termination criterion has been met. 142 | divergence_threshold 143 | Value of the difference of energy between two consecutive states above which we say a transition is divergent. 144 | 145 | Returns 146 | ------- 147 | A function that integrates the trajectory in one direction and updates a 148 | proposal until the termination criterion is met. 149 | 150 | """ 151 | generate_proposal = proposal_generator(kinetic_energy, divergence_threshold) 152 | sample_proposal = progressive_uniform_sampling 153 | 154 | def integrate( 155 | previous_last_state: IntegratorState, 156 | direction: TensorVariable, 157 | termination_state: TerminationState, 158 | max_num_steps: TensorVariable, 159 | step_size: TensorVariable, 160 | initial_energy: TensorVariable, 161 | ): 162 | """Integrate the trajectory starting from `initial_state` and update 163 | the proposal sequentially until the termination criterion is met. 164 | 165 | Parameters 166 | ---------- 167 | previous_last_state 168 | The last state of the previously integrated trajectory. 169 | direction int in {-1, 1} 170 | The direction in which to expand the trajectory. 171 | termination_state 172 | The state that keeps track of the information needed for the termination criterion. 173 | max_num_steps 174 | The maximum number of integration steps. The expansion will stop 175 | when this number is reached if the termination criterion has not 176 | been met. 177 | step_size 178 | The step size of the symplectic integrator. 179 | initial_energy 180 | Initial energy H0 of the HMC step (not to confused with the initial energy of the subtree) 181 | 182 | Returns 183 | ------- 184 | A tuple with on the one hand: a new proposal (sampled from the states 185 | traversed while building the trajectory), the last state, the sum of the 186 | momenta values along the trajectory (needed for termination criterion), 187 | the updated termination state, the number of integration steps 188 | performed, a boolean that indicates whether the trajectory has diverged, 189 | a boolean that indicates whether the termination criterion was met. And 190 | on the other hand a dictionary that contains the update rules for the 191 | shared variables updated in the internal `Scan` operator. 192 | 193 | """ 194 | 195 | def add_one_state( 196 | step, 197 | q_proposal, # current proposal 198 | p_proposal, 199 | potential_energy_proposal, 200 | potential_energy_grad_proposal, 201 | energy_proposal, 202 | weight, 203 | sum_p_accept, 204 | q_last, # state 205 | p_last, 206 | potential_energy_last, 207 | potential_energy_grad_last, 208 | momentum_sum: TensorVariable, # sum of momenta 209 | momentum_ckpts, # termination state 210 | momentum_sum_ckpts, 211 | idx_min, 212 | idx_max, 213 | trajectory_length, 214 | ): 215 | termination_state = TerminationState( 216 | momentum_checkpoints=momentum_ckpts, 217 | momentum_sum_checkpoints=momentum_sum_ckpts, 218 | min_index=idx_min, 219 | max_index=idx_max, 220 | ) 221 | proposal = ProposalState( 222 | state=IntegratorState( 223 | position=q_proposal, 224 | momentum=p_proposal, 225 | potential_energy=potential_energy_proposal, 226 | potential_energy_grad=potential_energy_grad_proposal, 227 | ), 228 | energy=energy_proposal, 229 | weight=weight, 230 | sum_log_p_accept=sum_p_accept, 231 | ) 232 | last_state = IntegratorState( 233 | position=q_last, 234 | momentum=p_last, 235 | potential_energy=potential_energy_last, 236 | potential_energy_grad=potential_energy_grad_last, 237 | ) 238 | 239 | new_state = integrator(last_state, direction * step_size) 240 | new_proposal, is_diverging = generate_proposal(initial_energy, new_state) 241 | sampled_proposal = sample_proposal(srng, proposal, new_proposal) 242 | 243 | new_momentum_sum = momentum_sum + new_state.momentum 244 | new_termination_state = update_termination_state( 245 | termination_state, new_momentum_sum, new_state.momentum, step 246 | ) 247 | has_terminated = is_criterion_met( 248 | new_termination_state, new_momentum_sum, new_state.momentum 249 | ) 250 | 251 | do_stop_integrating = is_diverging | has_terminated 252 | 253 | return ( 254 | sampled_proposal.state.position, 255 | sampled_proposal.state.momentum, 256 | sampled_proposal.state.potential_energy, 257 | sampled_proposal.state.potential_energy_grad, 258 | sampled_proposal.energy, 259 | sampled_proposal.weight, 260 | sampled_proposal.sum_log_p_accept, 261 | new_state.position, 262 | new_state.momentum, 263 | new_state.potential_energy, 264 | new_state.potential_energy_grad, 265 | new_momentum_sum, 266 | new_termination_state.momentum_checkpoints, 267 | new_termination_state.momentum_sum_checkpoints, 268 | new_termination_state.min_index, 269 | new_termination_state.max_index, 270 | trajectory_length + 1, 271 | is_diverging, 272 | has_terminated, 273 | ), until(do_stop_integrating) 274 | 275 | # We take one step away to start building the subtrajectory 276 | state = integrator(previous_last_state, direction * step_size) 277 | proposal, is_diverging = generate_proposal(initial_energy, state) 278 | momentum_sum = state.momentum 279 | termination_state = update_termination_state( 280 | termination_state, 281 | momentum_sum, 282 | state.momentum, 283 | 0, 284 | ) 285 | full_initial_state = ( 286 | proposal.state.position, 287 | proposal.state.momentum, 288 | proposal.state.potential_energy, 289 | proposal.state.potential_energy_grad, 290 | proposal.energy, 291 | proposal.weight, 292 | proposal.sum_log_p_accept, 293 | state.position, 294 | state.momentum, 295 | state.potential_energy, 296 | state.potential_energy_grad, 297 | momentum_sum, 298 | termination_state.momentum_checkpoints, 299 | termination_state.momentum_sum_checkpoints, 300 | termination_state.min_index, 301 | termination_state.max_index, 302 | at.as_tensor(1, dtype=np.int64), 303 | is_diverging, 304 | np.array(False), 305 | ) 306 | 307 | steps = at.arange(1, 1 + max_num_steps) 308 | trajectory, updates = aesara.scan( 309 | add_one_state, 310 | outputs_info=( 311 | proposal.state.position, 312 | proposal.state.momentum, 313 | proposal.state.potential_energy, 314 | proposal.state.potential_energy_grad, 315 | proposal.energy, 316 | proposal.weight, 317 | proposal.sum_log_p_accept, 318 | state.position, 319 | state.momentum, 320 | state.potential_energy, 321 | state.potential_energy_grad, 322 | momentum_sum, 323 | termination_state.momentum_checkpoints, 324 | termination_state.momentum_sum_checkpoints, 325 | termination_state.min_index, 326 | termination_state.max_index, 327 | at.as_tensor(1, dtype=np.int64), 328 | None, 329 | None, 330 | ), 331 | sequences=steps, 332 | ) 333 | full_last_state = tuple([_state[-1] for _state in trajectory]) 334 | 335 | # We build the trajectory iff the first step is not diverging 336 | full_state = ifelse(is_diverging, full_initial_state, full_last_state) 337 | 338 | new_proposal = ProposalState( 339 | state=IntegratorState( 340 | position=full_state[0], 341 | momentum=full_state[1], 342 | potential_energy=full_state[2], 343 | potential_energy_grad=full_state[3], 344 | ), 345 | energy=full_state[4], 346 | weight=full_state[5], 347 | sum_log_p_accept=full_state[6], 348 | ) 349 | new_state = IntegratorState( 350 | position=full_state[7], 351 | momentum=full_state[8], 352 | potential_energy=full_state[9], 353 | potential_energy_grad=full_state[10], 354 | ) 355 | subtree_momentum_sum = full_state[11] 356 | new_termination_state = TerminationState( 357 | momentum_checkpoints=full_state[12], 358 | momentum_sum_checkpoints=full_state[13], 359 | min_index=full_state[14], 360 | max_index=full_state[15], 361 | ) 362 | trajectory_length = full_state[-3] 363 | is_diverging = full_state[-2] 364 | has_terminated = full_state[-1] 365 | 366 | return ( 367 | new_proposal, 368 | new_state, 369 | subtree_momentum_sum, 370 | new_termination_state, 371 | trajectory_length, 372 | is_diverging, 373 | has_terminated, 374 | ), updates 375 | 376 | return integrate 377 | 378 | 379 | class Diagnostics(NamedTuple): 380 | state: IntegratorState 381 | acceptance_probability: TensorVariable 382 | num_doublings: TensorVariable 383 | is_turning: TensorVariable 384 | is_diverging: TensorVariable 385 | 386 | 387 | class MultiplicativeExpansionResult(NamedTuple): 388 | proposals: ProposalState 389 | right_states: IntegratorState 390 | left_states: IntegratorState 391 | momentum_sums: TensorVariable 392 | termination_states: TerminationState 393 | diagnostics: Diagnostics 394 | 395 | 396 | def multiplicative_expansion( 397 | srng: RandomStream, 398 | trajectory_integrator: Callable, 399 | uturn_check_fn: Callable, 400 | max_num_expansions: TensorVariable, 401 | ): 402 | """Sample a trajectory and update the proposal sequentially 403 | until the termination criterion is met. 404 | 405 | The trajectory is sampled with the following procedure: 406 | 1. Pick a direction at random; 407 | 2. Integrate `num_step` steps in this direction; 408 | 3. If the integration has stopped prematurely, do not update the proposal; 409 | 4. Else if the trajectory is performing a U-turn, return current proposal; 410 | 5. Else update proposal, `num_steps = num_steps ** rate` and repeat from (1). 411 | 412 | Parameters 413 | ---------- 414 | srng 415 | A RandomStream object that tracks the changes in a shared random state. 416 | trajectory_integrator 417 | A function that runs the symplectic integrators and returns a new proposal 418 | and the integrated trajectory. 419 | uturn_check_fn 420 | Function used to check the U-Turn criterion. 421 | max_num_expansions 422 | The maximum number of trajectory expansions until the proposal is 423 | returned. 424 | 425 | """ 426 | proposal_sampler = progressive_biased_sampling 427 | 428 | def expand( 429 | proposal: ProposalState, 430 | left_state: IntegratorState, 431 | right_state: IntegratorState, 432 | momentum_sum, 433 | termination_state: TerminationState, 434 | initial_energy, 435 | step_size, 436 | ) -> Tuple[MultiplicativeExpansionResult, Dict]: 437 | """Expand the current trajectory multiplicatively. 438 | 439 | At each step we draw a direction at random, build a subtrajectory starting 440 | from the leftmost or rightmost point of the current trajectory that is 441 | twice as long as the current trajectory. 442 | 443 | Once that is done, possibly update the current proposal with that of 444 | the subtrajectory. 445 | 446 | Parameters 447 | ---------- 448 | proposal 449 | Current new state proposal. 450 | left_state 451 | The current leftmost state of the trajectory. 452 | right_state 453 | The current rightmost state of the trajectory. 454 | momentum_sum 455 | The current value of the sum of momenta along the trajectory. 456 | initial_energy 457 | Potential energy before starting to build the trajectory. 458 | step_size 459 | The size of each step taken by the integrator. 460 | 461 | """ 462 | 463 | def expand_once( 464 | step, 465 | q_proposal, # proposal 466 | p_proposal, 467 | potential_energy_proposal, 468 | potential_energy_grad_proposal, 469 | energy_proposal, 470 | weight, 471 | sum_p_accept, 472 | q_left, # trajectory 473 | p_left, 474 | potential_energy_left, 475 | potential_energy_grad_left, 476 | q_right, 477 | p_right, 478 | potential_energy_right, 479 | potential_energy_grad_right, 480 | momentum_sum, # sum of momenta along trajectory 481 | momentum_ckpts, # termination_state 482 | momentum_sum_ckpts, 483 | idx_min, 484 | idx_max, 485 | ) -> Tuple[Tuple[TensorVariable, ...], Dict, until]: 486 | left_state = ( 487 | q_left, 488 | p_left, 489 | potential_energy_left, 490 | potential_energy_grad_left, 491 | ) 492 | right_state = ( 493 | q_right, 494 | p_right, 495 | potential_energy_right, 496 | potential_energy_grad_right, 497 | ) 498 | proposal = ProposalState( 499 | state=IntegratorState( 500 | position=q_proposal, 501 | momentum=p_proposal, 502 | potential_energy=potential_energy_proposal, 503 | potential_energy_grad=potential_energy_grad_proposal, 504 | ), 505 | energy=energy_proposal, 506 | weight=weight, 507 | sum_log_p_accept=sum_p_accept, 508 | ) 509 | termination_state = TerminationState( 510 | momentum_checkpoints=momentum_ckpts, 511 | momentum_sum_checkpoints=momentum_sum_ckpts, 512 | min_index=idx_min, 513 | max_index=idx_max, 514 | ) 515 | 516 | do_go_right = srng.bernoulli(0.5) 517 | direction = at.where(do_go_right, 1.0, -1.0) 518 | start_state = IntegratorState(*ifelse(do_go_right, right_state, left_state)) 519 | 520 | ( 521 | new_proposal, 522 | new_state, 523 | subtree_momentum_sum, 524 | new_termination_state, 525 | subtrajectory_length, 526 | is_diverging, 527 | has_subtree_terminated, 528 | ), inner_updates = trajectory_integrator( 529 | start_state, 530 | direction, 531 | termination_state, 532 | 2**step, 533 | step_size, 534 | initial_energy, 535 | ) 536 | 537 | # Update the trajectory. 538 | # The trajectory integrator always integrates forward in time; we 539 | # thus need to switch the states if the other direction was picked. 540 | new_left_state = IntegratorState( 541 | *ifelse(do_go_right, left_state, new_state) 542 | ) 543 | new_right_state = IntegratorState( 544 | *ifelse(do_go_right, new_state, right_state) 545 | ) 546 | new_momentum_sum = momentum_sum + subtree_momentum_sum 547 | 548 | # Compute the pseudo-acceptance probability for the NUTS algorithm. 549 | # It can be understood as the average acceptance probability MC would give to 550 | # the states explored during the final expansion. 551 | acceptance_probability = ( 552 | at.exp(new_proposal.sum_log_p_accept) / subtrajectory_length 553 | ) 554 | 555 | # Update the proposal 556 | # 557 | # We do not accept proposals that come from diverging or turning subtrajectories. 558 | # However the definition of the acceptance probability is such that the 559 | # acceptance probability needs to be computed across the entire trajectory. 560 | updated_proposal = proposal._replace( 561 | sum_log_p_accept=at.logaddexp( 562 | new_proposal.sum_log_p_accept, proposal.sum_log_p_accept 563 | ) 564 | ) 565 | 566 | sampled_proposal = where_proposal( 567 | is_diverging | has_subtree_terminated, 568 | updated_proposal, 569 | proposal_sampler(srng, proposal, new_proposal), 570 | ) 571 | 572 | # Check if the trajectory is turning and determine whether we need 573 | # to stop expanding the trajectory. 574 | is_turning = uturn_check_fn( 575 | new_left_state.momentum, new_right_state.momentum, new_momentum_sum 576 | ) 577 | do_stop_expanding = is_diverging | is_turning | has_subtree_terminated 578 | 579 | return ( 580 | ( 581 | sampled_proposal.state.position, 582 | sampled_proposal.state.momentum, 583 | sampled_proposal.state.potential_energy, 584 | sampled_proposal.state.potential_energy_grad, 585 | sampled_proposal.energy, 586 | sampled_proposal.weight, 587 | sampled_proposal.sum_log_p_accept, 588 | new_left_state.position, 589 | new_left_state.momentum, 590 | new_left_state.potential_energy, 591 | new_left_state.potential_energy_grad, 592 | new_right_state.position, 593 | new_right_state.momentum, 594 | new_right_state.potential_energy, 595 | new_right_state.potential_energy_grad, 596 | new_momentum_sum, 597 | new_termination_state.momentum_checkpoints, 598 | new_termination_state.momentum_sum_checkpoints, 599 | new_termination_state.min_index, 600 | new_termination_state.max_index, 601 | acceptance_probability, 602 | step + 1, 603 | is_diverging, 604 | is_turning, 605 | ), 606 | inner_updates, 607 | until(do_stop_expanding), 608 | ) 609 | 610 | expansion_steps = at.arange(0, max_num_expansions) 611 | # results, updates = aesara.scan( 612 | ( 613 | proposal_state_position, 614 | proposal_state_momentum, 615 | proposal_state_potential_energy, 616 | proposal_state_potential_energy_grad, 617 | proposal_energy, 618 | proposal_weight, 619 | proposal_sum_log_p_accept, 620 | left_state_position, 621 | left_state_momentum, 622 | left_state_potential_energy, 623 | left_state_potential_energy_grad, 624 | right_state_position, 625 | right_state_momentum, 626 | right_state_potential_energy, 627 | right_state_potential_energy_grad, 628 | momentum_sum, 629 | momentum_checkpoints, 630 | momentum_sum_checkpoints, 631 | min_indices, 632 | max_indices, 633 | acceptance_probability, 634 | num_doublings, 635 | is_diverging, 636 | is_turning, 637 | ), updates = aesara.scan( 638 | expand_once, 639 | outputs_info=( 640 | proposal.state.position, 641 | proposal.state.momentum, 642 | proposal.state.potential_energy, 643 | proposal.state.potential_energy_grad, 644 | proposal.energy, 645 | proposal.weight, 646 | proposal.sum_log_p_accept, 647 | left_state.position, 648 | left_state.momentum, 649 | left_state.potential_energy, 650 | left_state.potential_energy_grad, 651 | right_state.position, 652 | right_state.momentum, 653 | right_state.potential_energy, 654 | right_state.potential_energy_grad, 655 | momentum_sum, 656 | termination_state.momentum_checkpoints, 657 | termination_state.momentum_sum_checkpoints, 658 | termination_state.min_index, 659 | termination_state.max_index, 660 | None, 661 | None, 662 | None, 663 | None, 664 | ), 665 | sequences=expansion_steps, 666 | ) 667 | # Ensure each item of the returned result sequence is packed into the appropriate namedtuples. 668 | typed_result = MultiplicativeExpansionResult( 669 | proposals=ProposalState( 670 | state=IntegratorState( 671 | position=proposal_state_position, 672 | momentum=proposal_state_momentum, 673 | potential_energy=proposal_state_potential_energy, 674 | potential_energy_grad=proposal_state_potential_energy_grad, 675 | ), 676 | energy=proposal_energy, 677 | weight=proposal_weight, 678 | sum_log_p_accept=proposal_sum_log_p_accept, 679 | ), 680 | left_states=IntegratorState( 681 | position=left_state_position, 682 | momentum=left_state_momentum, 683 | potential_energy=left_state_potential_energy, 684 | potential_energy_grad=left_state_potential_energy_grad, 685 | ), 686 | right_states=IntegratorState( 687 | position=right_state_position, 688 | momentum=right_state_momentum, 689 | potential_energy=right_state_potential_energy, 690 | potential_energy_grad=right_state_potential_energy_grad, 691 | ), 692 | momentum_sums=momentum_sum, 693 | termination_states=TerminationState( 694 | momentum_checkpoints=momentum_checkpoints, 695 | momentum_sum_checkpoints=momentum_sum_checkpoints, 696 | min_index=min_indices, 697 | max_index=max_indices, 698 | ), 699 | diagnostics=Diagnostics( 700 | state=IntegratorState( 701 | position=proposal_state_position, 702 | momentum=proposal_state_momentum, 703 | potential_energy=proposal_state_potential_energy, 704 | potential_energy_grad=proposal_state_potential_energy_grad, 705 | ), 706 | acceptance_probability=acceptance_probability, 707 | num_doublings=num_doublings, 708 | is_turning=is_turning, 709 | is_diverging=is_diverging, 710 | ), 711 | ) 712 | return typed_result, updates 713 | 714 | return expand 715 | 716 | 717 | def where_proposal( 718 | do_pick_left: bool, 719 | left_proposal: ProposalState, 720 | right_proposal: ProposalState, 721 | ) -> ProposalState: 722 | """Represents a switch between two proposals depending on a condition.""" 723 | state = ifelse(do_pick_left, left_proposal.state, right_proposal.state) 724 | energy = at.where(do_pick_left, left_proposal.energy, right_proposal.energy) 725 | weight = at.where(do_pick_left, left_proposal.weight, right_proposal.weight) 726 | log_sum_p_accept = ifelse( 727 | do_pick_left, left_proposal.sum_log_p_accept, right_proposal.sum_log_p_accept 728 | ) 729 | 730 | return ProposalState( 731 | state=IntegratorState(*state), 732 | energy=energy, 733 | weight=weight, 734 | sum_log_p_accept=log_sum_p_accept, 735 | ) 736 | -------------------------------------------------------------------------------- /aehmc/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Dict, Iterable, List 3 | 4 | import aesara.tensor as at 5 | from aesara.graph.basic import Variable, ancestors 6 | from aesara.graph.fg import FunctionGraph 7 | from aesara.graph.rewriting.utils import rewrite_graph 8 | from aesara.tensor.rewriting.shape import ShapeFeature 9 | from aesara.tensor.var import TensorVariable 10 | 11 | 12 | def simplify_shapes(graphs: List[Variable]): 13 | """Simply the shape calculations in a list of graphs.""" 14 | shape_fg = FunctionGraph( 15 | outputs=graphs, 16 | features=[ShapeFeature()], 17 | clone=False, 18 | ) 19 | return rewrite_graph(shape_fg).outputs 20 | 21 | 22 | class RaveledParamsMap: 23 | """Maps a set of tensor variables to a vector of their raveled values.""" 24 | 25 | def __init__(self, ref_params: Iterable[TensorVariable]): 26 | self.ref_params = tuple(ref_params) 27 | 28 | self.ref_shapes = [at.shape(p) for p in self.ref_params] 29 | self.ref_shapes = simplify_shapes(self.ref_shapes) 30 | 31 | self.ref_dtypes = [p.dtype for p in self.ref_params] 32 | 33 | ref_shapes_ancestors = set(ancestors(self.ref_shapes)) 34 | uninferred_shape_params = [ 35 | p for p in self.ref_params if (p in ref_shapes_ancestors) 36 | ] 37 | if any(uninferred_shape_params): 38 | # After running the shape optimizations, the graphs in 39 | # `ref_shapes` should not depend on `ref_params` directly. 40 | # If they do, it could imply that we need to sample parts of a 41 | # model in order to get the shapes/sizes of its parameters, and 42 | # that's a worst-case scenario. 43 | warnings.warn( 44 | "The following parameters need to be computed in order to determine " 45 | f"the shapes in this parameter map: {uninferred_shape_params}" 46 | ) 47 | 48 | param_sizes = [at.prod(s) for s in self.ref_shapes] 49 | cumsum_sizes = at.cumsum(param_sizes) 50 | # `at.cumsum` doesn't return a tensor of a fixed/known size 51 | cumsum_sizes = [cumsum_sizes[i] for i in range(len(param_sizes))] 52 | self.slice_indices = list(zip([0] + cumsum_sizes[:-1], cumsum_sizes)) 53 | self.vec_slices = [slice(*idx) for idx in self.slice_indices] 54 | 55 | def ravel_params(self, params: List[TensorVariable]) -> TensorVariable: 56 | """Concatenate the raveled vectors of each parameter.""" 57 | return at.concatenate([at.atleast_1d(p).ravel() for p in params]) 58 | 59 | def unravel_params( 60 | self, raveled_params: TensorVariable 61 | ) -> Dict[TensorVariable, TensorVariable]: 62 | """Unravel a concatenated set of raveled parameters.""" 63 | return { 64 | k: v.reshape(s).astype(t) 65 | for k, v, s, t in zip( 66 | self.ref_params, 67 | [raveled_params[slc] for slc in self.vec_slices], 68 | self.ref_shapes, 69 | self.ref_dtypes, 70 | ) 71 | } 72 | 73 | def __repr__(self): 74 | return f"{type(self).__name__}({self.ref_params})" 75 | -------------------------------------------------------------------------------- /aehmc/window_adaptation.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | import aesara 4 | import aesara.tensor as at 5 | from aesara import config 6 | from aesara.ifelse import ifelse 7 | from aesara.tensor.shape import shape_tuple 8 | from aesara.tensor.var import TensorVariable 9 | 10 | from aehmc.algorithms import DualAveragingState 11 | from aehmc.integrators import IntegratorState 12 | from aehmc.mass_matrix import covariance_adaptation 13 | from aehmc.nuts import Diagnostics 14 | from aehmc.step_size import dual_averaging_adaptation 15 | 16 | 17 | def run( 18 | kernel, 19 | initial_state: IntegratorState, 20 | num_steps=1000, 21 | *, 22 | is_mass_matrix_full=False, 23 | initial_step_size=at.as_tensor(1.0, dtype=config.floatX), 24 | target_acceptance_rate=0.80, 25 | ) -> Tuple[IntegratorState, Tuple[TensorVariable, TensorVariable], Dict]: 26 | init_adapt, update_adapt = window_adaptation( 27 | num_steps, is_mass_matrix_full, initial_step_size, target_acceptance_rate 28 | ) 29 | 30 | def one_step( 31 | warmup_step, 32 | q, # chain state 33 | potential_energy, 34 | potential_energy_grad, 35 | step, # dual averaging adaptation state 36 | log_step_size, 37 | log_step_size_avg, 38 | gradient_avg, 39 | mu, 40 | mean, # mass matrix adaptation state 41 | m2, 42 | sample_size, 43 | step_size, # parameters 44 | inverse_mass_matrix, 45 | ): 46 | chain_state = IntegratorState( 47 | position=q, 48 | momentum=None, 49 | potential_energy=potential_energy, 50 | potential_energy_grad=potential_energy_grad, 51 | ) 52 | 53 | warmup_state = ( 54 | DualAveragingState( 55 | step=step, 56 | iterates=log_step_size, 57 | iterates_avg=log_step_size_avg, 58 | gradient_avg=gradient_avg, 59 | shrinkage_pts=mu, 60 | ), 61 | (mean, m2, sample_size), 62 | ) 63 | parameters = (step_size, inverse_mass_matrix) 64 | 65 | # Advance the chain by one step 66 | chain_info, inner_updates = kernel(chain_state, *parameters) 67 | 68 | # Update the warmup state and parameters 69 | warmup_state, parameters = update_adapt( 70 | warmup_step, warmup_state, parameters, chain_info 71 | ) 72 | da_state = warmup_state[0] 73 | return ( 74 | chain_info.state.position, # q 75 | chain_info.state.potential_energy, # potential_energy 76 | chain_info.state.potential_energy_grad, # potential_energy_grad 77 | da_state.step, 78 | da_state.iterates, # log_step_size 79 | da_state.iterates_avg, # log_step_size_avg 80 | da_state.gradient_avg, 81 | da_state.shrinkage_pts, # mu 82 | *warmup_state[1], 83 | *parameters, 84 | ), inner_updates 85 | 86 | (da_state, mm_state), parameters = init_adapt(initial_state) 87 | 88 | warmup_steps = at.arange(0, num_steps) 89 | state, updates = aesara.scan( 90 | fn=one_step, 91 | outputs_info=( 92 | initial_state.position, 93 | initial_state.potential_energy, 94 | initial_state.potential_energy_grad, 95 | da_state.step, 96 | da_state.iterates, # log_step_size 97 | da_state.iterates_avg, # log_step_size_avg 98 | da_state.gradient_avg, 99 | da_state.shrinkage_pts, # mu 100 | *mm_state, 101 | *parameters, 102 | ), 103 | sequences=(warmup_steps,), 104 | name="window_adaptation", 105 | ) 106 | 107 | last_chain_state = IntegratorState( 108 | position=state[0][-1], 109 | momentum=None, 110 | potential_energy=state[1][-1], 111 | potential_energy_grad=state[2][-1], 112 | ) 113 | step_size = state[-2][-1] 114 | inverse_mass_matrix = state[-1][-1] 115 | 116 | return last_chain_state, (step_size, inverse_mass_matrix), updates 117 | 118 | 119 | def window_adaptation( 120 | num_steps: int, 121 | is_mass_matrix_full: bool = False, 122 | initial_step_size: TensorVariable = at.as_tensor(1.0, dtype=config.floatX), 123 | target_acceptance_rate: TensorVariable = 0.80, 124 | ): 125 | mm_init, mm_update, mm_final = covariance_adaptation(is_mass_matrix_full) 126 | da_init, da_update = dual_averaging_adaptation(target_acceptance_rate) 127 | schedule = build_schedule(num_steps) 128 | 129 | schedule_stage = at.as_tensor([s[0] for s in schedule]) 130 | schedule_middle_window = at.as_tensor([s[1] for s in schedule]) 131 | 132 | def init(initial_chain_state: IntegratorState): 133 | if initial_chain_state.position.ndim == 0: 134 | num_dims = 0 135 | else: 136 | num_dims = shape_tuple(initial_chain_state.position)[0] 137 | inverse_mass_matrix, mm_state = mm_init(num_dims) 138 | 139 | da_state = da_init(initial_step_size) 140 | step_size = at.exp(da_state.iterates) 141 | 142 | warmup_state = (da_state, mm_state) 143 | parameters = (step_size, inverse_mass_matrix) 144 | return warmup_state, parameters 145 | 146 | def fast_update(p_accept, warmup_state, parameters): 147 | da_state, mm_state = warmup_state 148 | _, inverse_mass_matrix = parameters 149 | 150 | new_da_state = da_update(p_accept, da_state) 151 | step_size = at.exp(new_da_state.iterates) 152 | 153 | return (new_da_state, mm_state), (step_size, inverse_mass_matrix) 154 | 155 | def slow_update(position, p_accept, warmup_state, parameters): 156 | da_state, mm_state = warmup_state 157 | _, inverse_mass_matrix = parameters 158 | 159 | new_da_state = da_update(p_accept, da_state) 160 | new_mm_state = mm_update(position, mm_state) 161 | step_size = at.exp(new_da_state.iterates) 162 | 163 | return (new_da_state, new_mm_state), (step_size, inverse_mass_matrix) 164 | 165 | def slow_final(warmup_state): 166 | """We recompute the inverse mass matrix and re-initialize the dual averaging scheme at the end of each 'slow window'.""" 167 | da_state, mm_state = warmup_state 168 | 169 | inverse_mass_matrix = mm_final(mm_state) 170 | 171 | if inverse_mass_matrix.ndim == 0: 172 | num_dims = 0 173 | else: 174 | num_dims = shape_tuple(inverse_mass_matrix)[0] 175 | _, new_mm_state = mm_init(num_dims) 176 | 177 | step_size = at.exp(da_state.iterates) 178 | new_da_state = da_init(step_size) 179 | 180 | warmup_state = (new_da_state, new_mm_state) 181 | parameters = (step_size, inverse_mass_matrix) 182 | return warmup_state, parameters 183 | 184 | def final( 185 | warmup_state: Tuple, parameters: Tuple 186 | ) -> Tuple[TensorVariable, TensorVariable]: 187 | da_state, _ = warmup_state 188 | _, inverse_mass_matrix = parameters 189 | step_size = at.exp(da_state.iterates_avg) # return stepsize_avg at the end 190 | return step_size, inverse_mass_matrix 191 | 192 | def update( 193 | step: int, warmup_state: Tuple, parameters: Tuple, chain_state: Diagnostics 194 | ): 195 | stage = schedule_stage[step] 196 | warmup_state, parameters = where_warmup_state( 197 | at.eq(stage, 0), 198 | fast_update(chain_state.acceptance_probability, warmup_state, parameters), 199 | slow_update( 200 | chain_state.state.position, 201 | chain_state.acceptance_probability, 202 | warmup_state, 203 | parameters, 204 | ), 205 | ) 206 | 207 | is_middle_window_end = schedule_middle_window[step] 208 | warmup_state, parameters = where_warmup_state( 209 | is_middle_window_end, slow_final(warmup_state), (warmup_state, parameters) 210 | ) 211 | 212 | is_last_step = at.eq(step, num_steps - 1) 213 | parameters = ifelse(is_last_step, final(warmup_state, parameters), parameters) 214 | 215 | return warmup_state, parameters 216 | 217 | def where_warmup_state(do_pick_left, left_warmup_state, right_warmup_state): 218 | (left_da_state, left_mm_state), left_params = left_warmup_state 219 | (right_da_state, right_mm_state), right_params = right_warmup_state 220 | 221 | da_state = ifelse(do_pick_left, left_da_state, right_da_state) 222 | mm_state = ifelse(do_pick_left, left_mm_state, right_mm_state) 223 | params = ifelse(do_pick_left, left_params, right_params) 224 | 225 | return (DualAveragingState(*da_state), mm_state), params 226 | 227 | return init, update 228 | 229 | 230 | def build_schedule( 231 | num_steps: int, 232 | initial_buffer_size: int = 75, 233 | final_buffer_size: int = 50, 234 | first_window_size: int = 25, 235 | ) -> List[Tuple[int, bool]]: 236 | """Return the schedule for Stan's warmup. 237 | 238 | The schedule below is intended to be as close as possible to Stan's _[1]. 239 | The warmup period is split into three stages: 240 | 1. An initial fast interval to reach the typical set. Only the step size is 241 | adapted in this window. 242 | 2. "Slow" parameters that require global information (typically covariance) 243 | are estimated in a series of expanding intervals with no memory; the step 244 | size is re-initialized at the end of each window. Each window is twice the 245 | size of the preceding window. 246 | 3. A final fast interval during which the step size is adapted using the 247 | computed mass matrix. 248 | Schematically: 249 | 250 | ``` 251 | +---------+---+------+------------+------------------------+------+ 252 | | fast | s | slow | slow | slow | fast | 253 | +---------+---+------+------------+------------------------+------+ 254 | ``` 255 | 256 | The distinction slow/fast comes from the speed at which the algorithms 257 | converge to a stable value; in the common case, estimation of covariance 258 | requires more steps than dual averaging to give an accurate value. See _[1] 259 | for a more detailed explanation. 260 | 261 | Fast intervals are given the label 0 and slow intervals the label 1. 262 | 263 | Note 264 | ---- 265 | It feels awkward to return a boolean that indicates whether the current 266 | step is the last step of a middle window, but not for other windows. This 267 | should probably be changed to "is_window_end" and we should manage the 268 | distinction upstream. 269 | 270 | Parameters 271 | ---------- 272 | num_steps: int 273 | The number of warmup steps to perform. 274 | initial_buffer: int 275 | The width of the initial fast adaptation interval. 276 | first_window_size: int 277 | The width of the first slow adaptation interval. 278 | final_buffer_size: int 279 | The width of the final fast adaptation interval. 280 | 281 | Returns 282 | ------- 283 | A list of tuples (window_label, is_middle_window_end). 284 | 285 | References 286 | ---------- 287 | .. [1]: Stan Reference Manual v2.22 Section 15.2 "HMC Algorithm" 288 | 289 | """ 290 | schedule = [] 291 | 292 | # Give up on mass matrix adaptation when the number of warmup steps is too small. 293 | if num_steps < 20: 294 | schedule += [(0, False)] * num_steps 295 | else: 296 | # When the number of warmup steps is smaller that the sum of the provided (or default) 297 | # window sizes we need to resize the different windows. 298 | if initial_buffer_size + first_window_size + final_buffer_size > num_steps: 299 | initial_buffer_size = int(0.15 * num_steps) 300 | final_buffer_size = int(0.1 * num_steps) 301 | first_window_size = num_steps - initial_buffer_size - final_buffer_size 302 | 303 | # First stage: adaptation of fast parameters 304 | schedule += [(0, False)] * (initial_buffer_size - 1) 305 | schedule.append((0, False)) 306 | 307 | # Second stage: adaptation of slow parameters in successive windows 308 | # doubling in size. 309 | final_buffer_start = num_steps - final_buffer_size 310 | 311 | next_window_size = first_window_size 312 | next_window_start = initial_buffer_size 313 | while next_window_start < final_buffer_start: 314 | current_start, current_size = next_window_start, next_window_size 315 | if 3 * current_size <= final_buffer_start - current_start: 316 | next_window_size = 2 * current_size 317 | else: 318 | current_size = final_buffer_start - current_start 319 | next_window_start = current_start + current_size 320 | schedule += [(1, False)] * (next_window_start - 1 - current_start) 321 | schedule.append((1, True)) 322 | 323 | # Last stage: adaptation of fast parameters 324 | schedule += [(0, False)] * (num_steps - 1 - final_buffer_start) 325 | schedule.append((0, False)) 326 | 327 | return schedule 328 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def pytest_sessionstart(session): 5 | os.environ["AESARA_FLAGS"] = ",".join( 6 | [ 7 | os.environ.setdefault("AESARA_FLAGS", ""), 8 | "floatX=float64,on_opt_error=raise,on_shape_error=raise,cast_policy=numpy+floatX", 9 | ] 10 | ) 11 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # To use: 2 | # 3 | # $ conda env create -f environment.yml # `mamba` works too for this command 4 | # $ conda activate aehmc-dev 5 | # 6 | name: aehmc-dev 7 | channels: 8 | - conda-forge 9 | dependencies: 10 | - python>=3.8 11 | - compilers 12 | - numpy>=1.18.1 13 | - scipy>=1.4.0 14 | - aesara>=2.8.11 15 | - aeppl>=0.1.4 16 | # Intel BLAS 17 | - mkl 18 | - mkl-service 19 | - libblas=*=*mkl 20 | # For testing 21 | - pytest 22 | - coverage>=5.1 23 | - coveralls 24 | - pytest-cov 25 | - pytest-xdist 26 | # For building docs 27 | - sphinx>=1.3 28 | - sphinx_rtd_theme 29 | - pygments 30 | - pydot 31 | - ipython 32 | # developer tools 33 | - pre-commit 34 | - packaging 35 | - typing_extensions 36 | -------------------------------------------------------------------------------- /examples/requirements.txt: -------------------------------------------------------------------------------- 1 | aeppl 2 | matplotlib 3 | pymc3 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["wheel", "setuptools>=61.2", "setuptools-scm[toml]>=6.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "aehmc" 7 | authors = [ 8 | {name = "Aesara developers", email = "aesara.devs@gmail.com"} 9 | ] 10 | description="HMC samplers in Aesara" 11 | readme = "README.md" 12 | license = {text = "MIT License"} 13 | dynamic = ["version"] 14 | requires-python = ">=3.8" 15 | dependencies = [ 16 | "numpy >= 1.18.1", 17 | "scipy >= 1.4.0", 18 | "aesara >= 2.8.11", 19 | "aeppl >= 0.1.4", 20 | ] 21 | classifiers = [ 22 | "Development Status :: 4 - Beta", 23 | "Intended Audience :: Science/Research", 24 | "Intended Audience :: Developers", 25 | "License :: OSI Approved :: MIT License", 26 | "Operating System :: OS Independent", 27 | "Programming Language :: Python", 28 | "Programming Language :: Python :: 3", 29 | "Programming Language :: Python :: 3.8", 30 | "Programming Language :: Python :: 3.9", 31 | "Programming Language :: Python :: 3.10", 32 | "Programming Language :: Python :: 3.11", 33 | "Programming Language :: Python :: Implementation :: CPython", 34 | "Programming Language :: Python :: Implementation :: PyPy", 35 | ] 36 | keywords = [ 37 | "aesara", 38 | "math", 39 | "symbolic", 40 | "hamiltonian monte carlo", 41 | "nuts sampler", 42 | "No U-turn sampler", 43 | "symplectic integration", 44 | ] 45 | 46 | [project.urls] 47 | source = "http://github.com/aesara-devs/aehmc" 48 | tracker = "http://github.com/aesara-devs/aehmc/issues" 49 | 50 | [tool.setuptools] 51 | packages = ["aehmc"] 52 | include-package-data = false 53 | 54 | [tool.setuptools_scm] 55 | write_to = "aehmc/_version.py" 56 | 57 | [tool.pydocstyle] 58 | # Ignore errors for missing docstrings. 59 | # Ignore D202 (No blank lines allowed after function docstring) 60 | # due to bug in black: https://github.com/ambv/black/issues/355 61 | add-ignore = "D100,D101,D102,D103,D104,D105,D106,D107,D202" 62 | convention = "numpy" 63 | 64 | [tool.pytest.ini_options] 65 | python_files = ["test*.py"] 66 | testpaths = ["tests"] 67 | filterwarnings = [ 68 | "error:::aesara", 69 | "error:::aeppl", 70 | "error:::aemcmc", 71 | "ignore:::xarray", 72 | ] 73 | 74 | [tool.coverage.run] 75 | omit = [ 76 | "aehmc/_version.py", 77 | "tests/*", 78 | ] 79 | branch = true 80 | 81 | [tool.coverage.report] 82 | exclude_lines = [ 83 | "pragma: no cover", 84 | "def __repr__", 85 | "raise AssertionError", 86 | "raise TypeError", 87 | "return NotImplemented", 88 | "raise NotImplementedError", 89 | "if __name__ == .__main__.:", 90 | "assert False", 91 | ] 92 | show_missing = true 93 | 94 | [tool.isort] 95 | profile = "black" 96 | 97 | [tool.pylint] 98 | max-line-length = "88" 99 | 100 | [tool."pylint.messages_control"] 101 | disable = "C0330, C0326" 102 | 103 | [tool.mypy] 104 | ignore_missing_imports = true 105 | no_implicit_optional = true 106 | check_untyped_defs = true 107 | strict_equality = true 108 | warn_redundant_casts = true 109 | warn_unused_ignores = true 110 | show_error_codes = true 111 | 112 | [[tool.mypy.overrides]] 113 | module = ["tests.*"] 114 | ignore_errors = true 115 | check_untyped_defs = false 116 | 117 | [tool.black] 118 | line-length = 88 119 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | arviz 2 | -e ./ 3 | coveralls 4 | pydocstyle>=3.0.0 5 | pytest>=5.0.0 6 | pytest-cov>=2.6.1 7 | pytest-html>=1.20.0 8 | pylint>=2.3.1 9 | black==20.8b1; platform.python_implementation!='PyPy' 10 | diff-cover 11 | autoflake 12 | pre-commit 13 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aehmc/ece0aafa7a62773ff7403827f781fb7453de1cbb/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_adaptation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from aehmc import window_adaptation 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "num_steps, expected_schedule", 8 | [ 9 | (19, [(0, False)] * 19), # no mass matrix adaptation 10 | ( 11 | 100, 12 | [(0, False)] * 15 + [(1, False)] * 74 + [(1, True)] + [(0, False)] * 10, 13 | ), # windows are resized 14 | ( 15 | 200, 16 | [(0, False)] * 75 17 | + [(1, False)] * 24 18 | + [(1, True)] 19 | + [(1, False)] * 49 20 | + [(1, True)] 21 | + [(0, False)] * 50, 22 | ), 23 | ], 24 | ) 25 | def test_adaptation_schedule(num_steps, expected_schedule): 26 | adaptation_schedule = window_adaptation.build_schedule(num_steps) 27 | assert num_steps == len(adaptation_schedule) 28 | assert adaptation_schedule == expected_schedule 29 | -------------------------------------------------------------------------------- /tests/test_algorithms.py: -------------------------------------------------------------------------------- 1 | import aesara 2 | import aesara.tensor as at 3 | import numpy as np 4 | import pytest 5 | from aesara import config 6 | 7 | from aehmc import algorithms 8 | 9 | 10 | def test_dual_averaging(): 11 | """Find the minimum of a simple function using Dual Averaging.""" 12 | 13 | def fn(x): 14 | return (x - 1) ** 2 15 | 16 | init, update = algorithms.dual_averaging(gamma=0.5) 17 | 18 | def one_step(step, x, x_avg, gradient_avg, mu): 19 | value = fn(x) 20 | gradient = aesara.grad(value, x) 21 | current_state = algorithms.DualAveragingState( 22 | step=step, 23 | iterates=x, 24 | iterates_avg=x_avg, 25 | gradient_avg=gradient_avg, 26 | shrinkage_pts=mu, 27 | ) 28 | da_state = update(gradient, current_state) 29 | return ( 30 | da_state.step, 31 | da_state.iterates, 32 | da_state.iterates_avg, 33 | da_state.gradient_avg, 34 | da_state.shrinkage_pts, 35 | ) 36 | 37 | shrinkage_pts = at.as_tensor(at.constant(0.5), dtype=config.floatX) 38 | da_state = init(shrinkage_pts) 39 | 40 | states, updates = aesara.scan( 41 | fn=one_step, 42 | outputs_info=[ 43 | {"initial": da_state.step}, 44 | {"initial": da_state.iterates}, 45 | {"initial": da_state.iterates_avg}, 46 | {"initial": da_state.gradient_avg}, 47 | {"initial": da_state.shrinkage_pts}, 48 | ], 49 | n_steps=100, 50 | ) 51 | 52 | last_x = states[1].eval()[-1] 53 | last_x_avg = states[2].eval()[-1] 54 | assert last_x_avg == pytest.approx(1.0, 1e-2) 55 | assert last_x == pytest.approx(1.0, 1e-2) 56 | 57 | 58 | @pytest.mark.parametrize("num_dims", [0, 1, 3]) 59 | @pytest.mark.parametrize("do_compute_covariance", [True, False]) 60 | def test_welford_constant(num_dims, do_compute_covariance): 61 | num_samples = 10 62 | 63 | if num_dims > 0: 64 | sample = at.ones((num_dims,)) # constant samples 65 | else: 66 | sample = at.constant(1.0) 67 | 68 | init, update, final = algorithms.welford_covariance(do_compute_covariance) 69 | state = init(num_dims) 70 | for _ in range(num_samples): 71 | state = update(sample, *state) 72 | 73 | mean = state[0].eval() 74 | if num_dims > 0: 75 | expected = np.ones(num_dims) 76 | assert np.shape(mean) == np.shape(expected) 77 | np.testing.assert_allclose(mean, expected, rtol=1e-1) 78 | else: 79 | assert np.ndim(mean) == 0 80 | assert mean == 1.0 81 | 82 | cov = final(state[1], state[2]).eval() 83 | if num_dims > 0: 84 | if do_compute_covariance: 85 | expected = np.zeros((num_dims, num_dims)) 86 | else: 87 | expected = np.zeros(num_dims) 88 | 89 | assert np.shape(cov) == np.shape(expected) 90 | np.testing.assert_allclose(cov, expected) 91 | else: 92 | assert np.ndim(cov) == 0 93 | assert cov == 0 94 | 95 | 96 | @pytest.mark.parametrize("do_compute_covariance", [True, False]) 97 | @pytest.mark.parametrize("n_dim", [1, 3]) 98 | def test_welford(n_dim, do_compute_covariance): 99 | num_samples = 10 100 | 101 | init, update, final = algorithms.welford_covariance(do_compute_covariance) 102 | state = init(n_dim) 103 | for i in range(num_samples): 104 | sample = i * at.ones(n_dim) 105 | state = update(sample, *state) 106 | 107 | mean = state[0].eval() 108 | expected = (9.0 / 2) * np.ones(n_dim) 109 | np.testing.assert_allclose(mean, expected) 110 | 111 | cov = final(state[1], state[2]).eval() 112 | if do_compute_covariance: 113 | expected = 55.0 / 6.0 * np.ones((n_dim, n_dim)) 114 | else: 115 | expected = 55.0 / 6.0 * np.ones(n_dim) 116 | assert np.shape(cov) == np.shape(expected) 117 | np.testing.assert_allclose(cov, expected) 118 | 119 | 120 | @pytest.mark.parametrize("do_compute_covariance", [True, False]) 121 | def test_welford_scalar(do_compute_covariance): 122 | """ "Test the Welford algorithm when the state is a scalar.""" 123 | num_samples = 10 124 | 125 | init, update, final = algorithms.welford_covariance(do_compute_covariance) 126 | state = init(0) 127 | for i in range(num_samples): 128 | sample = at.as_tensor(i) 129 | state = update(sample, *state) 130 | 131 | cov = final(state[1], state[2]).eval() 132 | assert np.ndim(cov) == 0 133 | assert pytest.approx(cov.squeeze()) == 55.0 / 6.0 134 | -------------------------------------------------------------------------------- /tests/test_hmc.py: -------------------------------------------------------------------------------- 1 | import aesara 2 | import aesara.tensor as at 3 | import arviz 4 | import numpy as np 5 | import pytest 6 | import scipy.stats as stats 7 | from aeppl import joint_logprob 8 | from aesara.tensor.var import TensorVariable 9 | 10 | from aehmc import hmc, nuts, window_adaptation 11 | 12 | 13 | def test_warmup_scalar(): 14 | """Test the warmup on a univariate normal distribution.""" 15 | 16 | srng = at.random.RandomStream(seed=0) 17 | Y_rv = srng.normal(1, 2) 18 | 19 | def logprob_fn(y: TensorVariable): 20 | logprob, _ = joint_logprob(realized={Y_rv: y}) 21 | return logprob 22 | 23 | y_vv = Y_rv.clone() 24 | kernel = nuts.new_kernel(srng, logprob_fn) 25 | initial_state = nuts.new_state(y_vv, logprob_fn) 26 | 27 | state, (step_size, inverse_mass_matrix), updates = window_adaptation.run( 28 | kernel, initial_state, num_steps=1000 29 | ) 30 | 31 | # Compile the warmup and execute to get a value for the step size and the 32 | # mass matrix. 33 | warmup_fn = aesara.function( 34 | (y_vv,), 35 | ( 36 | state.position, 37 | state.potential_energy, 38 | state.potential_energy_grad, 39 | step_size, 40 | inverse_mass_matrix, 41 | ), 42 | updates=updates, 43 | ) 44 | 45 | final_position, *_, step_size, inverse_mass_matrix = warmup_fn(3.0) 46 | 47 | assert final_position != 3.0 # the chain has moved 48 | assert np.ndim(step_size) == 0 # scalar step size 49 | assert step_size != 1.0 # step size changed 50 | assert step_size > 0.1 and step_size < 2 # stable range for the step size 51 | assert np.ndim(inverse_mass_matrix) == 0 # scalar mass matrix 52 | assert inverse_mass_matrix == pytest.approx(4, rel=1.0) 53 | 54 | 55 | def test_warmup_vector(): 56 | """Test the warmup on a multivariate normal distribution.""" 57 | 58 | loc = np.array([0.0, 3.0]) 59 | scale = np.array([1.0, 2.0]) 60 | cov = np.diag(scale**2) 61 | 62 | srng = at.random.RandomStream(seed=0) 63 | Y_rv = srng.multivariate_normal(loc, cov) 64 | 65 | def logprob_fn(y: TensorVariable): 66 | logprob, _ = joint_logprob(realized={Y_rv: y}) 67 | return logprob 68 | 69 | y_vv = Y_rv.clone() 70 | kernel = nuts.new_kernel(srng, logprob_fn) 71 | initial_state = nuts.new_state(y_vv, logprob_fn) 72 | 73 | state, (step_size, inverse_mass_matrix), updates = window_adaptation.run( 74 | kernel, initial_state, num_steps=1000 75 | ) 76 | 77 | # Compile the warmup and execute to get a value for the step size and the 78 | # mass matrix. 79 | warmup_fn = aesara.function( 80 | (y_vv,), 81 | ( 82 | state.position, 83 | state.potential_energy, 84 | state.potential_energy_grad, 85 | step_size, 86 | inverse_mass_matrix, 87 | ), 88 | updates=updates, 89 | ) 90 | 91 | final_position, *_, step_size, inverse_mass_matrix = warmup_fn([1.0, 1.0]) 92 | 93 | assert np.all(final_position != np.array([1.0, 1.0])) # the chain has moved 94 | assert np.ndim(step_size) == 0 # scalar step size 95 | assert step_size > 0.1 and step_size < 2 # stable range for the step size 96 | assert np.ndim(inverse_mass_matrix) == 1 # scalar mass matrix 97 | np.testing.assert_allclose(inverse_mass_matrix, scale**2, rtol=1.0) 98 | 99 | 100 | @pytest.mark.parametrize("step_size, diverges", [(3.9, False), (4.1, True)]) 101 | def test_univariate_hmc(step_size, diverges): 102 | """Test the NUTS kernel on a univariate gaussian target. 103 | 104 | Theory [1]_ says that the integration of the trajectory should be stable as 105 | long as the step size is smaller than twice the standard deviation. 106 | 107 | References 108 | ---------- 109 | .. [1]: Neal, R. M. (2011). MCMC using Hamiltonian dynamics. Handbook of markov chain monte carlo, 2(11), 2. 110 | 111 | """ 112 | inverse_mass_matrix = at.as_tensor(1.0) 113 | num_integration_steps = 30 114 | 115 | srng = at.random.RandomStream(seed=0) 116 | Y_rv = srng.normal(1, 2) 117 | 118 | def logprob_fn(y): 119 | logprob, _ = joint_logprob(realized={Y_rv: y}) 120 | return logprob 121 | 122 | kernel = hmc.new_kernel(srng, logprob_fn) 123 | 124 | y_vv = Y_rv.clone() 125 | initial_state = hmc.new_state(y_vv, logprob_fn) 126 | 127 | def update_hmc_state(pos, energy, energy_grad): 128 | current_state = hmc.IntegratorState(pos, None, energy, energy_grad) 129 | chain_info, _ = kernel( 130 | current_state, step_size, inverse_mass_matrix, num_integration_steps 131 | ) 132 | return ( 133 | chain_info.state.position, 134 | chain_info.state.potential_energy, 135 | chain_info.state.potential_energy_grad, 136 | ) 137 | 138 | trajectory, updates = aesara.scan( 139 | update_hmc_state, 140 | outputs_info=[ 141 | {"initial": initial_state.position}, 142 | {"initial": initial_state.potential_energy}, 143 | {"initial": initial_state.potential_energy_grad}, 144 | ], 145 | n_steps=2_000, 146 | ) 147 | 148 | trajectory_generator = aesara.function((y_vv,), trajectory[0], updates=updates) 149 | 150 | samples = trajectory_generator(3.0) 151 | if diverges: 152 | assert np.all(samples == 3.0) 153 | else: 154 | assert np.mean(samples[1000:]) == pytest.approx(1.0, rel=1e-1) 155 | assert np.var(samples[1000:]) == pytest.approx(4.0, rel=1e-1) 156 | 157 | 158 | def compute_ess(samples): 159 | d = arviz.convert_to_dataset(np.expand_dims(samples, axis=0)) 160 | ess = arviz.ess(d).to_array().to_numpy().squeeze() 161 | return ess 162 | 163 | 164 | def compute_mcse(x): 165 | ess = compute_ess(x) 166 | std_x = np.std(x, axis=0, ddof=1) 167 | return np.mean(x, axis=0), std_x / np.sqrt(ess) 168 | 169 | 170 | def multivariate_normal_model(srng): 171 | loc = np.array([0.0, 3.0]) 172 | scale = np.array([1.0, 2.0]) 173 | rho = np.array(0.5) 174 | 175 | cov = np.diag(scale**2) 176 | cov[0, 1] = rho * scale[0] * scale[1] 177 | cov[1, 0] = rho * scale[0] * scale[1] 178 | 179 | loc_tt = at.as_tensor(loc) 180 | cov_tt = at.as_tensor(cov) 181 | 182 | Y_rv = srng.multivariate_normal(loc_tt, cov_tt) 183 | 184 | def logprob_fn(y): 185 | return joint_logprob(realized={Y_rv: y})[0] 186 | 187 | return (loc, scale, rho), Y_rv, logprob_fn 188 | 189 | 190 | def test_hmc_mcse(): 191 | """This examples is recommanded in the Stan documentation [0]_ to find bugs 192 | that introduce bias in the average as well as the variance. 193 | 194 | The example is simple enough to be analytically tractable, but complex enough 195 | to find subtle bugs. It uses the MCMC CLT [1]_ to check that the estimates 196 | of different quantities are within the expected range. 197 | 198 | We set the inverse mass matrix to be the diagonal of the covariance matrix. 199 | 200 | We can also show that trajectory integration will not diverge as long as we 201 | choose any step size that is smaller than 2 [2]_ (section 4.2); We adjusted 202 | the number of integration steps manually in order to get a reasonable number 203 | of effective samples. 204 | 205 | References 206 | ---------- 207 | .. [0]: https://github.com/stan-dev/stan/wiki/Testing:-Samplers 208 | .. [1]: Geyer, C. J. (2011). Introduction to markov chain monte carlo. Handbook of markov chain monte carlo, 20116022, 45. 209 | .. [2]: Neal, R. M. (2011). MCMC using Hamiltonian dynamics. Handbook of markov chain monte carlo, 2(11), 2. 210 | 211 | """ 212 | srng = at.random.RandomStream(seed=1) 213 | (loc, scale, rho), Y_rv, logprob_fn = multivariate_normal_model(srng) 214 | 215 | step_size = 1.0 216 | L = 30 217 | inverse_mass_matrix = at.as_tensor(scale) 218 | kernel = hmc.new_kernel(srng, logprob_fn) 219 | 220 | y_vv = Y_rv.clone() 221 | initial_state = hmc.new_state(y_vv, logprob_fn) 222 | 223 | def update_hmc_state(pos, energy, energy_grad): 224 | current_state = hmc.IntegratorState(pos, None, energy, energy_grad) 225 | chain_info, _ = kernel(current_state, step_size, inverse_mass_matrix, L) 226 | return ( 227 | chain_info.state.position, 228 | chain_info.state.potential_energy, 229 | chain_info.state.potential_energy_grad, 230 | ) 231 | 232 | trajectory, updates = aesara.scan( 233 | update_hmc_state, 234 | outputs_info=[ 235 | {"initial": initial_state.position}, 236 | {"initial": initial_state.potential_energy}, 237 | {"initial": initial_state.potential_energy_grad}, 238 | ], 239 | n_steps=3000, 240 | ) 241 | 242 | trajectory_generator = aesara.function((y_vv,), trajectory, updates=updates) 243 | 244 | rng = np.random.default_rng(seed=0) 245 | trace = trajectory_generator(rng.standard_normal(2)) 246 | samples = trace[0][1000:] 247 | 248 | # MCSE on the location 249 | delta_loc = samples - loc 250 | mean, mcse = compute_mcse(delta_loc) 251 | p_greater_error = stats.norm.sf(np.abs(mean) / mcse) 252 | np.testing.assert_array_less(0.01, p_greater_error) 253 | 254 | # MCSE on the variance 255 | delta_var = np.square(samples - loc) - scale**2 256 | mean, mcse = compute_mcse(delta_var) 257 | p_greater_error = stats.norm.sf(np.abs(mean) / mcse) 258 | np.testing.assert_array_less(0.01, p_greater_error) 259 | 260 | # MCSE on the correlation 261 | delta_cor = np.prod(samples - loc, axis=1) / np.prod(scale) - rho 262 | mean, mcse = compute_mcse(delta_cor) 263 | p_greater_error = stats.norm.sf(np.abs(mean) / mcse) 264 | np.testing.assert_array_less(0.01, p_greater_error) 265 | 266 | 267 | def test_nuts_mcse(): 268 | """This examples is recommanded in the Stan documentation [0]_ to find bugs 269 | that introduce bias in the average as well as the variance. 270 | 271 | The example is simple enough to be analytically tractable, but complex enough 272 | to find subtle bugs. It uses the MCMC CLT [1]_ to check that the estimates 273 | of different quantities are within the expected range. 274 | 275 | We set the inverse mass matrix to be the diagonal of the covariance matrix. 276 | 277 | We can also show that trajectory integration will not diverge as long as we 278 | choose any step size that is smaller than 2 [2]_ (section 4.2); We adjusted 279 | the number of integration steps manually in order to get a reasonable number 280 | of effective samples. 281 | 282 | References 283 | ---------- 284 | .. [0]: https://github.com/stan-dev/stan/wiki/Testing:-Samplers 285 | .. [1]: Geyer, C. J. (2011). Introduction to markov chain monte carlo. Handbook of markov chain monte carlo, 20116022, 45. 286 | .. [2]: Neal, R. M. (2011). MCMC using Hamiltonian dynamics. Handbook of markov chain monte carlo, 2(11), 2. 287 | 288 | """ 289 | srng = at.random.RandomStream(seed=1) 290 | (loc, scale, rho), Y_rv, logprob_fn = multivariate_normal_model(srng) 291 | 292 | step_size = at.as_tensor(1.0) 293 | inverse_mass_matrix = at.as_tensor(scale) 294 | kernel = nuts.new_kernel(srng, logprob_fn) 295 | 296 | def wrapped_kernel(pos, energy, energy_grad): 297 | state = nuts.IntegratorState( 298 | position=pos, 299 | momentum=None, 300 | potential_energy=energy, 301 | potential_energy_grad=energy_grad, 302 | ) 303 | chain_info, updates = kernel(state, step_size, inverse_mass_matrix) 304 | 305 | return ( 306 | chain_info.state.position, 307 | chain_info.state.potential_energy, 308 | chain_info.state.potential_energy_grad, 309 | ), updates 310 | 311 | y_vv = Y_rv.clone() 312 | initial_state = nuts.new_state(y_vv, logprob_fn) 313 | 314 | trajectory, updates = aesara.scan( 315 | wrapped_kernel, 316 | outputs_info=[ 317 | {"initial": initial_state.position}, 318 | {"initial": initial_state.potential_energy}, 319 | {"initial": initial_state.potential_energy_grad}, 320 | ], 321 | n_steps=3000, 322 | ) 323 | 324 | trajectory_generator = aesara.function((y_vv,), trajectory, updates=updates) 325 | 326 | rng = np.random.default_rng(seed=0) 327 | trace = trajectory_generator(rng.standard_normal(2)) 328 | samples = trace[0][-1000:] 329 | 330 | # MCSE on the location 331 | delta_loc = samples - loc 332 | mean, mcse = compute_mcse(delta_loc) 333 | p_greater_error = stats.norm.sf(np.abs(mean) / mcse) 334 | np.testing.assert_array_less(0.01, p_greater_error) 335 | 336 | # MCSE on the variance 337 | delta_var = (samples - loc) ** 2 - scale**2 338 | mean, mcse = compute_mcse(delta_var) 339 | p_greater_error = stats.norm.sf(np.abs(mean) / mcse) 340 | np.testing.assert_array_less(0.01, p_greater_error) 341 | 342 | # MCSE on the correlation 343 | delta_cor = np.prod(samples - loc, axis=1) / np.prod(scale) - rho 344 | mean, mcse = compute_mcse(delta_cor) 345 | p_greater_error = stats.norm.sf(np.abs(mean) / mcse) 346 | np.testing.assert_array_less(0.01, p_greater_error) 347 | -------------------------------------------------------------------------------- /tests/test_integrators.py: -------------------------------------------------------------------------------- 1 | import aesara 2 | import aesara.tensor as at 3 | import numpy as np 4 | import pytest 5 | from aesara.tensor.var import TensorVariable 6 | 7 | from aehmc.integrators import IntegratorState, velocity_verlet 8 | 9 | 10 | def HarmonicOscillator(inverse_mass_matrix, k=1.0, m=1.0): 11 | """Potential and Kinetic energy of an harmonic oscillator.""" 12 | 13 | def potential_energy(x: TensorVariable) -> TensorVariable: 14 | return at.sum(0.5 * k * at.square(x)) 15 | 16 | def kinetic_energy(p: TensorVariable) -> TensorVariable: 17 | v = inverse_mass_matrix * p 18 | return at.sum(0.5 * at.dot(v, p)) 19 | 20 | return potential_energy, kinetic_energy 21 | 22 | 23 | def FreeFall(inverse_mass_matrix, g=1.0): 24 | """Potential and kinetic energy of a free-falling object.""" 25 | 26 | def potential_energy(h: TensorVariable) -> TensorVariable: 27 | return at.sum(g * h) 28 | 29 | def kinetic_energy(p: TensorVariable) -> TensorVariable: 30 | v = inverse_mass_matrix * p 31 | return at.sum(0.5 * at.dot(v, p)) 32 | 33 | return potential_energy, kinetic_energy 34 | 35 | 36 | def CircularMotion(inverse_mass_matrix): 37 | def potential_energy(q: TensorVariable) -> TensorVariable: 38 | return -1.0 / at.power(at.square(q[0]) + at.square(q[1]), 0.5) 39 | 40 | def kinetic_energy(p: TensorVariable) -> TensorVariable: 41 | return 0.5 * at.dot(inverse_mass_matrix, at.square(p)) 42 | 43 | return potential_energy, kinetic_energy 44 | 45 | 46 | integration_examples = [ 47 | { 48 | "model": FreeFall, 49 | "n_steps": 100, 50 | "step_size": 0.01, 51 | "q_init": np.array([0.0]), 52 | "p_init": np.array([1.0]), 53 | "q_final": np.array([0.5]), 54 | "p_final": np.array([0.0]), 55 | "inverse_mass_matrix": np.array([1.0]), 56 | }, 57 | { 58 | "model": HarmonicOscillator, 59 | "n_steps": 100, 60 | "step_size": 0.01, 61 | "q_init": np.array([0.0]), 62 | "p_init": np.array([1.0]), 63 | "q_final": np.array([np.sin(1.0)]), 64 | "p_final": np.array([np.cos(1.0)]), 65 | "inverse_mass_matrix": np.array([1.0]), 66 | }, 67 | { 68 | "model": CircularMotion, 69 | "n_steps": 628, 70 | "step_size": 0.01, 71 | "q_init": np.array([1.0, 0.0]), 72 | "p_init": np.array([0.0, 1.0]), 73 | "q_final": np.array([1.0, 0.0]), 74 | "p_final": np.array([0.0, 1.0]), 75 | "inverse_mass_matrix": np.array([1.0, 1.0]), 76 | }, 77 | ] 78 | 79 | 80 | def create_integrate_fn(potential, step_fn, n_steps): 81 | q = at.vector("q") 82 | p = at.vector("p") 83 | step_size = at.scalar("step_size") 84 | energy = potential(q) 85 | energy_grad = aesara.grad(energy, q) 86 | trajectory, _ = aesara.scan( 87 | fn=step_fn, 88 | outputs_info=[ 89 | {"initial": q}, 90 | {"initial": p}, 91 | {"initial": energy}, 92 | {"initial": energy_grad}, 93 | ], 94 | non_sequences=[step_size], 95 | n_steps=n_steps, 96 | ) 97 | integrate_fn = aesara.function((q, p, step_size), trajectory) 98 | return integrate_fn 99 | 100 | 101 | @pytest.mark.parametrize("example", integration_examples) 102 | def test_velocity_verlet(example): 103 | model = example["model"] 104 | inverse_mass_matrix = example["inverse_mass_matrix"] 105 | step_size = example["step_size"] 106 | q_init = example["q_init"] 107 | p_init = example["p_init"] 108 | 109 | potential, kinetic_energy = model(inverse_mass_matrix) 110 | step = velocity_verlet(potential, kinetic_energy) 111 | 112 | q = at.vector("q") 113 | p = at.vector("p") 114 | p_final = at.vector("p_final") 115 | energy_at = potential(q) + kinetic_energy(p) 116 | energy_fn = aesara.function((q, p), energy_at) 117 | 118 | def wrapped_step(pos, mom, energy, energy_grad, step_size): 119 | return step(IntegratorState(pos, mom, energy, energy_grad), step_size) 120 | 121 | integrate_fn = create_integrate_fn(potential, wrapped_step, example["n_steps"]) 122 | q_final, p_final, energy_final, _ = integrate_fn(q_init, p_init, step_size) 123 | 124 | # Check that the trajectory was correctly integrated 125 | np.testing.assert_allclose(example["q_final"], q_final[-1], atol=1e-2) 126 | np.testing.assert_allclose(example["p_final"], p_final[-1], atol=1e-2) 127 | 128 | # Symplectic integrators conserve energy 129 | energy = energy_fn(q_init, p_init) 130 | new_energy = energy_fn(q_final[-1], p_final[-1]) 131 | assert energy == pytest.approx(new_energy, 1e-4) 132 | -------------------------------------------------------------------------------- /tests/test_mass_matrix.py: -------------------------------------------------------------------------------- 1 | import aesara 2 | import aesara.tensor as at 3 | import numpy as np 4 | import pytest 5 | from aesara.tensor.random.utils import RandomStream 6 | from numpy.testing import assert_allclose 7 | 8 | from aehmc import mass_matrix 9 | 10 | 11 | @pytest.mark.parametrize("is_full_matrix", [True, False]) 12 | @pytest.mark.parametrize("n_dims", [0, 1, 3]) 13 | def test_mass_matrix_adaptation(is_full_matrix, n_dims): 14 | srng = RandomStream(seed=0) 15 | 16 | if n_dims > 0: 17 | mu = 0.5 * at.ones((n_dims,)) 18 | cov = 0.33 * at.ones((n_dims, n_dims)) 19 | else: 20 | mu = at.constant(0.5) 21 | cov = at.constant(0.33) 22 | 23 | init, update, final = mass_matrix.covariance_adaptation(is_full_matrix) 24 | _, wc_state = init(n_dims) 25 | if n_dims > 0: 26 | dist = srng.multivariate_normal 27 | else: 28 | dist = srng.normal 29 | 30 | def one_step(*wc_state): 31 | sample = dist(mu, cov) 32 | wc_state = update(sample, wc_state) 33 | return wc_state 34 | 35 | results, updates = aesara.scan( 36 | fn=one_step, 37 | outputs_info=[ 38 | {"initial": wc_state[0]}, 39 | {"initial": wc_state[1]}, 40 | {"initial": wc_state[2]}, 41 | ], 42 | n_steps=2_000, 43 | ) 44 | 45 | inverse_mass_matrix = final((results[0][-1], results[1][-1], results[2][-1])) 46 | 47 | if n_dims > 0: 48 | if is_full_matrix: 49 | expected = cov.eval() 50 | inverse_mass_matrix = inverse_mass_matrix.eval() 51 | else: 52 | expected = np.diagonal(cov.eval()) 53 | inverse_mass_matrix = inverse_mass_matrix.eval() 54 | assert np.shape(inverse_mass_matrix) == np.shape(expected) 55 | assert_allclose(inverse_mass_matrix, expected, rtol=0.1) 56 | else: 57 | sigma = at.sqrt(inverse_mass_matrix).eval() 58 | expected_sigma = cov.eval() 59 | assert np.ndim(expected_sigma) == 0 60 | assert sigma == pytest.approx(expected_sigma, rel=0.1) 61 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import aesara 2 | import aesara.tensor as at 3 | import numpy as np 4 | import pytest 5 | from aesara.tensor.random.utils import RandomStream 6 | 7 | from aehmc.metrics import gaussian_metric 8 | 9 | momentum_test_cases = [ 10 | (1.0, 0.144), 11 | (np.array([1.0]), np.array([0.144])), 12 | (np.array([1.0, 1.0]), np.array([0.144, 1.27])), 13 | (np.array([[1.0, 0], [0, 1.0]]), np.array([0.144, 1.27])), 14 | ] 15 | 16 | 17 | @pytest.mark.skip(reason="This test relies on a specific rng implementation and seed.") 18 | @pytest.mark.parametrize("case", momentum_test_cases) 19 | def test_gaussian_metric_momentum(case): 20 | inverse_mass_matrix_val, expected_momentum = case 21 | 22 | # Momentum 23 | if np.ndim(inverse_mass_matrix_val) == 0: 24 | inverse_mass_matrix = at.scalar("inverse_mass_matrix") 25 | elif np.ndim(inverse_mass_matrix_val) == 1: 26 | inverse_mass_matrix = at.vector("inverse_mass_matrix") 27 | else: 28 | inverse_mass_matrix = at.matrix("inverse_mass_matrix") 29 | 30 | momentum_fn, _, _ = gaussian_metric(inverse_mass_matrix) 31 | srng = RandomStream(seed=59) 32 | momentum_generator = aesara.function([inverse_mass_matrix], momentum_fn(srng)) 33 | 34 | momentum = momentum_generator(inverse_mass_matrix_val) 35 | assert np.shape(momentum) == np.shape(expected_momentum) 36 | assert momentum == pytest.approx(expected_momentum, 1e-2) 37 | 38 | 39 | kinetic_energy_test_cases = [ 40 | (1.0, 1.0, 0.5), 41 | (np.array([1.0]), np.array([1.0]), 0.5), 42 | (np.array([1.0, 1.0]), np.array([1.0, 1.0]), 1.0), 43 | (np.array([[1.0, 0], [0, 1.0]]), np.array([1.0, 1.0]), 1.0), 44 | ] 45 | 46 | 47 | @pytest.mark.parametrize("case", kinetic_energy_test_cases) 48 | def test_gaussian_metric_kinetic_energy(case): 49 | inverse_mass_matrix_val, momentum_val, expected_energy = case 50 | 51 | if np.ndim(inverse_mass_matrix_val) == 0: 52 | inverse_mass_matrix = at.scalar("inverse_mass_matrix") 53 | momentum = at.scalar("momentum") 54 | elif np.ndim(inverse_mass_matrix_val) == 1: 55 | inverse_mass_matrix = at.vector("inverse_mass_matrix") 56 | momentum = at.vector("momentum") 57 | else: 58 | inverse_mass_matrix = at.matrix("inverse_mass_matrix") 59 | momentum = at.vector("momentum") 60 | 61 | _, kinetic_energy_fn, _ = gaussian_metric(inverse_mass_matrix) 62 | kinetic_energy = aesara.function( 63 | (inverse_mass_matrix, momentum), kinetic_energy_fn(momentum) 64 | ) 65 | 66 | kinetic = kinetic_energy(inverse_mass_matrix_val, momentum_val) 67 | assert np.ndim(kinetic) == 0 68 | assert kinetic == expected_energy 69 | 70 | 71 | turning_test_cases = [ 72 | (1.0, 1.0, 1.0, 1.0), 73 | ( 74 | np.array([1.0, 1.0]), # inverse mass matrix 75 | np.array([1.0, 1.0]), # p_left 76 | np.array([1.0, 1.0]), # p_right 77 | np.array([1.0, 1.0]), # p_sum 78 | ), 79 | ( 80 | np.array([[1.0, 0.0], [0.0, 1.0]]), 81 | np.array([1.0, 1.0]), 82 | np.array([1.0, 1.0]), 83 | np.array([1.0, 1.0]), 84 | ), 85 | ] 86 | 87 | 88 | @pytest.mark.parametrize("case", turning_test_cases) 89 | def test_turning(case): 90 | inverse_mass_matrix_val, p_left_val, p_right_val, p_sum_val = case 91 | 92 | if np.ndim(inverse_mass_matrix_val) == 0: 93 | inverse_mass_matrix = at.scalar("inverse_mass_matrix") 94 | p_left = at.scalar("p_left") 95 | p_right = at.scalar("p_right") 96 | p_sum = at.scalar("p_sum") 97 | elif np.ndim(inverse_mass_matrix_val) == 1: 98 | inverse_mass_matrix = at.vector("inverse_mass_matrix") 99 | p_left = at.vector("p_left") 100 | p_right = at.vector("p_right") 101 | p_sum = at.vector("p_sum") 102 | else: 103 | inverse_mass_matrix = at.matrix("inverse_mass_matrix") 104 | p_left = at.vector("p_left") 105 | p_right = at.vector("p_right") 106 | p_sum = at.vector("p_sum") 107 | 108 | _, _, turning_fn = gaussian_metric(inverse_mass_matrix) 109 | 110 | is_turning_fn = aesara.function( 111 | (inverse_mass_matrix, p_left, p_right, p_sum), 112 | turning_fn(p_left, p_right, p_sum), 113 | ) 114 | 115 | is_turning = is_turning_fn( 116 | inverse_mass_matrix_val, p_left_val, p_right_val, p_sum_val 117 | ) 118 | 119 | assert is_turning.ndim == 0 120 | assert is_turning.item() is True 121 | 122 | 123 | def test_fail_wrong_mass_matrix_dimension(): 124 | """`gaussian_metric` should fail when the dimension of the mass matrix is greater than 2.""" 125 | inverse_mass_matrix = np.array([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]) 126 | with pytest.raises(ValueError): 127 | _ = gaussian_metric(inverse_mass_matrix) 128 | -------------------------------------------------------------------------------- /tests/test_step_size.py: -------------------------------------------------------------------------------- 1 | import aesara 2 | import aesara.tensor as at 3 | import numpy as np 4 | import pytest 5 | from aesara import config 6 | from aesara.tensor.random.utils import RandomStream 7 | 8 | from aehmc import hmc 9 | from aehmc.algorithms import DualAveragingState 10 | from aehmc.step_size import dual_averaging_adaptation 11 | 12 | 13 | @pytest.fixture() 14 | def init(): 15 | def logprob_fn(x): 16 | return -2 * (x - 1.0) ** 2 17 | 18 | srng = RandomStream(seed=0) 19 | kernel = hmc.new_kernel(srng, logprob_fn) 20 | 21 | initial_position = at.as_tensor(1.0, dtype=config.floatX) 22 | initial_state = hmc.new_state(initial_position, logprob_fn) 23 | 24 | return initial_state, kernel 25 | 26 | 27 | def test_dual_averaging_adaptation(init): 28 | initial_state, kernel = init 29 | 30 | init_stepsize = at.as_tensor(1.0, dtype=config.floatX) 31 | inverse_mass_matrix = at.as_tensor(1.0) 32 | num_integration_steps = 10 33 | 34 | init_fn, update_fn = dual_averaging_adaptation() 35 | da_state = init_fn(init_stepsize) 36 | 37 | def one_step(q, logprob, logprob_grad, step, x_t, x_avg, gradient_avg, mu): 38 | state = hmc.IntegratorState( 39 | position=q, 40 | momentum=None, 41 | potential_energy=logprob, 42 | potential_energy_grad=logprob_grad, 43 | ) 44 | chain_info, inner_updates = kernel( 45 | state, at.exp(x_t), inverse_mass_matrix, num_integration_steps 46 | ) 47 | current_da_state = DualAveragingState( 48 | step=step, 49 | iterates=x_t, 50 | iterates_avg=x_avg, 51 | gradient_avg=gradient_avg, 52 | shrinkage_pts=mu, 53 | ) 54 | da_state = update_fn(chain_info.acceptance_probability, current_da_state) 55 | 56 | return ( 57 | chain_info.state.position, 58 | chain_info.state.potential_energy, 59 | chain_info.state.potential_energy_grad, 60 | da_state.step, 61 | da_state.iterates, 62 | da_state.iterates_avg, 63 | da_state.gradient_avg, 64 | da_state.shrinkage_pts, 65 | chain_info.acceptance_probability, 66 | ), inner_updates 67 | 68 | states, updates = aesara.scan( 69 | fn=one_step, 70 | outputs_info=[ 71 | {"initial": initial_state.position}, 72 | {"initial": initial_state.potential_energy}, 73 | {"initial": initial_state.potential_energy_grad}, 74 | {"initial": da_state.step}, 75 | {"initial": da_state.iterates}, # logstepsize 76 | {"initial": da_state.iterates_avg}, # logstepsize_avg 77 | {"initial": da_state.gradient_avg}, 78 | {"initial": da_state.shrinkage_pts}, # mu 79 | None, 80 | ], 81 | n_steps=10_000, 82 | ) 83 | 84 | p_accept = aesara.function((), states[-1], updates=updates) 85 | step_size = aesara.function((), at.exp(states[-4][-1]), updates=updates) 86 | assert np.mean(p_accept()) == pytest.approx(0.8, rel=10e-3) 87 | assert step_size() < 10 88 | assert step_size() > 1e-1 89 | -------------------------------------------------------------------------------- /tests/test_termination.py: -------------------------------------------------------------------------------- 1 | """Test dynamic termination criteria.""" 2 | import aesara 3 | import aesara.tensor as at 4 | import numpy as np 5 | import pytest 6 | from numpy.testing import assert_array_equal 7 | 8 | from aehmc.metrics import gaussian_metric 9 | from aehmc.termination import TerminationState, _find_storage_indices, iterative_uturn 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "checkpoint_idxs, momentum, momentum_sum, inverse_mass_matrix, expected_turning", 14 | [ 15 | ((3, 3), at.as_tensor(1.0), at.as_tensor(3.0), at.as_tensor(1.0), True), 16 | ((3, 2), at.as_tensor(1.0), at.as_tensor(3.0), at.as_tensor(1.0), False), 17 | ((0, 0), at.as_tensor(1.0), at.as_tensor(3.0), at.as_tensor(1.0), False), 18 | ((0, 1), at.as_tensor(1.0), at.as_tensor(3.0), at.as_tensor(1.0), True), 19 | ((1, 3), at.as_tensor(1.0), at.as_tensor(3.0), at.as_tensor(1.0), True), 20 | ((1, 3), at.as_tensor([1.0]), at.as_tensor([3.0]), at.ones(1), True), 21 | ], 22 | ) 23 | def test_iterative_turning_termination( 24 | checkpoint_idxs, momentum, momentum_sum, inverse_mass_matrix, expected_turning 25 | ): 26 | _, _, is_turning = gaussian_metric(inverse_mass_matrix) 27 | _, _, is_iterative_turning = iterative_uturn(is_turning) 28 | 29 | idx_min, idx_max = checkpoint_idxs 30 | idx_min = at.as_tensor(idx_min) 31 | idx_max = at.as_tensor(idx_max) 32 | momentum_ckpts = at.as_tensor(np.array([1.0, 2.0, 3.0, -2.0])) 33 | momentum_sum_ckpts = at.as_tensor(np.array([2.0, 4.0, 4.0, -1.0])) 34 | ckpt_state = TerminationState( 35 | momentum_checkpoints=momentum_ckpts, 36 | momentum_sum_checkpoints=momentum_sum_ckpts, 37 | min_index=idx_min, 38 | max_index=idx_max, 39 | ) 40 | 41 | _, _, is_iterative_turning_fn = iterative_uturn(is_turning) 42 | is_iterative_turning = is_iterative_turning_fn(ckpt_state, momentum_sum, momentum) 43 | fn = aesara.function((), is_iterative_turning, on_unused_input="ignore") 44 | 45 | actual_turning = fn() 46 | 47 | assert actual_turning.ndim == 0 48 | assert expected_turning == actual_turning 49 | 50 | 51 | @pytest.mark.parametrize( 52 | "step, expected_idx", 53 | [(0, (1, 0)), (6, (3, 2)), (7, (0, 2)), (13, (2, 2)), (15, (0, 3))], 54 | ) 55 | def test_leaf_idx_to_ckpt_idx(step, expected_idx): 56 | step_tt = at.scalar("step", dtype=np.int64) 57 | idx_tt = _find_storage_indices(step_tt) 58 | fn = aesara.function((step_tt,), (*idx_tt,)) 59 | 60 | idx_vv = fn(step) 61 | assert idx_vv[0].item() == expected_idx[0] 62 | assert idx_vv[1].item() == expected_idx[1] 63 | 64 | 65 | @pytest.mark.parametrize( 66 | "num_dims", 67 | [1, 3], 68 | ) 69 | def test_termination_update(num_dims): 70 | inverse_mass_matrix = at.as_tensor(np.ones(1)) 71 | _, _, is_turning = gaussian_metric(inverse_mass_matrix) 72 | new_state, update, _ = iterative_uturn(is_turning) 73 | 74 | position = at.as_tensor(np.ones(num_dims)) 75 | momentum = at.as_tensor(np.ones(num_dims)) 76 | momentum_sum = at.as_tensor(np.ones(num_dims)) 77 | 78 | num_doublings = at.as_tensor(4) 79 | termination_state = new_state(position, num_doublings) 80 | 81 | step = at.scalar("step", dtype=np.int64) 82 | updated = update(termination_state, momentum_sum, momentum, step) 83 | update_fn = aesara.function((step,), updated, on_unused_input="ignore") 84 | 85 | # Make sure this works for a single step 86 | result_odd = update_fn(1) 87 | 88 | # When the number of steps is odd there should be no update 89 | result_odd = update_fn(5) 90 | assert_array_equal(result_odd[0], np.zeros((4, num_dims))) 91 | assert_array_equal(result_odd[1], np.zeros((4, num_dims))) 92 | -------------------------------------------------------------------------------- /tests/test_trajectory.py: -------------------------------------------------------------------------------- 1 | import aesara 2 | import aesara.tensor as at 3 | import numpy as np 4 | import pytest 5 | from aeppl.logprob import logprob 6 | from aesara.tensor.random.utils import RandomStream 7 | from aesara.tensor.var import TensorVariable 8 | 9 | from aehmc.integrators import IntegratorState, new_integrator_state, velocity_verlet 10 | from aehmc.metrics import gaussian_metric 11 | from aehmc.termination import iterative_uturn 12 | from aehmc.trajectory import ( 13 | ProposalState, 14 | dynamic_integration, 15 | multiplicative_expansion, 16 | static_integration, 17 | ) 18 | 19 | aesara.config.optimizer = "fast_compile" 20 | aesara.config.exception_verbosity = "high" 21 | 22 | 23 | def CircularMotion(inverse_mass_matrix): 24 | def potential_energy(q: TensorVariable) -> TensorVariable: 25 | return -1.0 / at.power(at.square(q[0]) + at.square(q[1]), 0.5) 26 | 27 | def kinetic_energy(p: TensorVariable) -> TensorVariable: 28 | return 0.5 * at.dot(inverse_mass_matrix, at.square(p)) 29 | 30 | return potential_energy, kinetic_energy 31 | 32 | 33 | examples = [ 34 | { 35 | "n_steps": 628, 36 | "step_size": 0.01, 37 | "q_init": np.array([1.0, 0.0]), 38 | "p_init": np.array([0.0, 1.0]), 39 | "q_final": np.array([1.0, 0.0]), 40 | "p_final": np.array([0.0, 1.0]), 41 | "inverse_mass_matrix": np.array([1.0, 1.0]), 42 | }, 43 | ] 44 | 45 | 46 | @pytest.mark.parametrize("example", examples) 47 | def test_static_integration(example): 48 | inverse_mass_matrix = example["inverse_mass_matrix"] 49 | step_size = example["step_size"] 50 | num_steps = example["n_steps"] 51 | q_init = example["q_init"] 52 | p_init = example["p_init"] 53 | 54 | potential, kinetic_energy = CircularMotion(inverse_mass_matrix) 55 | step = velocity_verlet(potential, kinetic_energy) 56 | integrator = static_integration(step, num_steps) 57 | 58 | q = at.vector("q") 59 | p = at.vector("p") 60 | energy = potential(q) 61 | energy_grad = aesara.grad(energy, q) 62 | init_state = IntegratorState( 63 | position=q, 64 | momentum=p, 65 | potential_energy=energy, 66 | potential_energy_grad=energy_grad, 67 | ) 68 | final_state, updates = integrator(init_state, step_size) 69 | integrate_fn = aesara.function((q, p), final_state, updates=updates) 70 | 71 | q_final, p_final, *_ = integrate_fn(q_init, p_init) 72 | 73 | np.testing.assert_allclose(q_final, example["q_final"], atol=1e-1) 74 | np.testing.assert_allclose(p_final, example["p_final"], atol=1e-1) 75 | 76 | 77 | @pytest.mark.parametrize( 78 | "case", 79 | [ 80 | (0.0000001, False, False), 81 | (1000, True, False), 82 | (1e100, True, False), 83 | ], 84 | ) 85 | def test_dynamic_integration(case): 86 | srng = RandomStream(seed=59) 87 | 88 | def potential_fn(x): 89 | return -at.sum(logprob(srng.normal(0.0, 1.0), x)) 90 | 91 | step_size, should_diverge, should_turn = case 92 | 93 | # Set up the trajectory integrator 94 | inverse_mass_matrix = at.ones(1) 95 | 96 | momentum_generator, kinetic_energy_fn, uturn_check_fn = gaussian_metric( 97 | inverse_mass_matrix 98 | ) 99 | integrator = velocity_verlet(potential_fn, kinetic_energy_fn) 100 | ( 101 | new_criterion_state, 102 | update_criterion_state, 103 | is_criterion_met, 104 | ) = iterative_uturn(uturn_check_fn) 105 | 106 | trajectory_integrator = dynamic_integration( 107 | srng, 108 | integrator, 109 | kinetic_energy_fn, 110 | update_criterion_state, 111 | is_criterion_met, 112 | divergence_threshold=at.as_tensor(1000), 113 | ) 114 | 115 | # Initialize the state 116 | direction = at.as_tensor(1) 117 | step_size = at.as_tensor(step_size) 118 | max_num_steps = at.as_tensor(10) 119 | num_doublings = at.as_tensor(10) 120 | position = at.as_tensor(np.ones(1)) 121 | 122 | initial_state = new_integrator_state( 123 | potential_fn, position, momentum_generator(srng) 124 | ) 125 | initial_energy = initial_state[2] + kinetic_energy_fn(initial_state[1]) 126 | termination_state = new_criterion_state(initial_state[0], num_doublings) 127 | 128 | state, updates = trajectory_integrator( 129 | initial_state, 130 | direction, 131 | termination_state, 132 | max_num_steps, 133 | step_size, 134 | initial_energy, 135 | ) 136 | 137 | is_turning = aesara.function((), state[-1], updates=updates)() 138 | is_diverging = aesara.function((), state[-2], updates=updates)() 139 | 140 | assert is_diverging.item() is should_diverge 141 | assert is_turning.item() is should_turn 142 | 143 | 144 | @pytest.mark.parametrize( 145 | "step_size, should_diverge, should_turn, expected_doublings", 146 | [ 147 | (100000.0, True, False, 1), 148 | (0.0000001, False, False, 10), 149 | (1.0, False, True, 1), 150 | ], 151 | ) 152 | def test_multiplicative_expansion( 153 | step_size, should_diverge, should_turn, expected_doublings 154 | ): 155 | srng = RandomStream(seed=59) 156 | 157 | def potential_fn(x): 158 | return 0.5 * at.sum(at.square(x)) 159 | 160 | step_size = at.as_tensor(step_size) 161 | inverse_mass_matrix = at.as_tensor(1.0, dtype=np.float64) 162 | position = at.as_tensor(1.0, dtype=np.float64) 163 | 164 | momentum_generator, kinetic_energy_fn, uturn_check_fn = gaussian_metric( 165 | inverse_mass_matrix 166 | ) 167 | integrator = velocity_verlet(potential_fn, kinetic_energy_fn) 168 | ( 169 | new_criterion_state, 170 | update_criterion_state, 171 | is_criterion_met, 172 | ) = iterative_uturn(uturn_check_fn) 173 | 174 | trajectory_integrator = dynamic_integration( 175 | srng, 176 | integrator, 177 | kinetic_energy_fn, 178 | update_criterion_state, 179 | is_criterion_met, 180 | divergence_threshold=at.as_tensor(1000), 181 | ) 182 | 183 | expand = multiplicative_expansion(srng, trajectory_integrator, uturn_check_fn, 10) 184 | 185 | # Create the initial state 186 | state = new_integrator_state(potential_fn, position, momentum_generator(srng)) 187 | energy = state.potential_energy + kinetic_energy_fn(state.momentum) 188 | proposal = ProposalState( 189 | state=state, 190 | energy=energy, 191 | weight=at.as_tensor(0.0, dtype=np.float64), 192 | sum_log_p_accept=at.as_tensor(-np.inf, dtype=np.float64), 193 | ) 194 | termination_state = new_criterion_state(state.position, 10) 195 | result, updates = expand( 196 | proposal, state, state, state.momentum, termination_state, energy, step_size 197 | ) 198 | outputs = ( 199 | result.diagnostics.num_doublings[-1], 200 | result.diagnostics.is_diverging[-1], 201 | result.diagnostics.is_turning[-1], 202 | ) 203 | fn = aesara.function((), outputs, updates=updates) 204 | num_doublings, does_diverge, does_turn = fn() 205 | 206 | assert does_diverge == should_diverge 207 | assert does_turn == should_turn 208 | assert expected_doublings == num_doublings 209 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import aesara.tensor as at 2 | import numpy as np 3 | import pytest 4 | from aesara.graph.basic import Apply 5 | from aesara.tensor.exceptions import ShapeError 6 | from aesara.tensor.random.basic import NormalRV 7 | 8 | from aehmc.utils import RaveledParamsMap 9 | 10 | 11 | def test_RaveledParamsMap(): 12 | tau_rv = at.random.invgamma(0.5, 0.5, name="tau") 13 | tau_vv = tau_rv.clone() 14 | tau_vv.name = "t" 15 | 16 | beta_size = (3, 2) 17 | beta_rv = at.random.normal(0, at.sqrt(tau_rv), size=beta_size, name="beta") 18 | beta_vv = beta_rv.clone() 19 | beta_vv.name = "b" 20 | 21 | kappa_size = (20,) 22 | kappa_rv = at.random.normal(0, 1, size=kappa_size, name="kappa") 23 | kappa_vv = kappa_rv.clone() 24 | kappa_vv.name = "k" 25 | 26 | params_map = {beta_rv: beta_vv, tau_rv: tau_vv, kappa_rv: kappa_vv} 27 | 28 | rp_map = RaveledParamsMap(params_map.keys()) 29 | 30 | assert repr(rp_map) == "RaveledParamsMap((beta, tau, kappa))" 31 | 32 | q = at.vector("q") 33 | 34 | exp_beta_part = np.exp(np.arange(np.prod(beta_size)).reshape(beta_size)) 35 | exp_tau_part = 1.0 36 | exp_kappa_part = np.exp(np.arange(np.prod(kappa_size)).reshape(kappa_size)) 37 | exp_raveled_params = np.concatenate( 38 | [exp_beta_part.ravel(), np.atleast_1d(exp_tau_part), exp_kappa_part.ravel()] 39 | ) 40 | 41 | raveled_params_at = rp_map.ravel_params([beta_vv, tau_vv, kappa_vv]) 42 | raveled_params_val = raveled_params_at.eval( 43 | {beta_vv: exp_beta_part, tau_vv: exp_tau_part, kappa_vv: exp_kappa_part} 44 | ) 45 | 46 | assert np.array_equal(raveled_params_val, exp_raveled_params) 47 | 48 | unraveled_params_at = rp_map.unravel_params(q) 49 | 50 | beta_part = unraveled_params_at[beta_rv] 51 | tau_part = unraveled_params_at[tau_rv] 52 | kappa_part = unraveled_params_at[kappa_rv] 53 | 54 | new_test_point = {q: exp_raveled_params} 55 | assert np.array_equal(beta_part.eval(new_test_point), exp_beta_part) 56 | assert np.array_equal(tau_part.eval(new_test_point), exp_tau_part) 57 | assert np.array_equal(kappa_part.eval(new_test_point), exp_kappa_part) 58 | 59 | 60 | def test_RaveledParamsMap_dtype(): 61 | tau_rv = at.random.normal(0, 1, name="tau") 62 | tau_vv = tau_rv.clone() 63 | tau_vv.name = "t" 64 | 65 | lambda_rv = at.random.binomial(10, 0.5, name="lmbda") 66 | lambda_vv = lambda_rv.clone() 67 | lambda_vv.name = "l" 68 | 69 | params_map = {tau_rv: tau_vv, lambda_rv: lambda_vv} 70 | rp_map = RaveledParamsMap(params_map.keys()) 71 | 72 | q = rp_map.ravel_params((tau_vv, lambda_vv)) 73 | unraveled_params = rp_map.unravel_params(q) 74 | 75 | tau_part = unraveled_params[tau_rv] 76 | lambda_part = unraveled_params[lambda_rv] 77 | 78 | assert tau_part.dtype == tau_rv.dtype 79 | assert lambda_part.dtype == lambda_rv.dtype 80 | 81 | 82 | def test_RaveledParamsMap_bad_infer_shape(): 83 | class BadNormalRV(NormalRV): 84 | def make_node(self, *args, **kwargs): 85 | res = super().make_node(*args, **kwargs) 86 | # Drop static `Type`-level shape information 87 | rv_out = res.outputs[1] 88 | outputs = [ 89 | res.outputs[0].clone(), 90 | at.tensor(dtype=rv_out.type.dtype, shape=(None,) * rv_out.type.ndim), 91 | ] 92 | return Apply( 93 | self, 94 | res.inputs, 95 | outputs, 96 | ) 97 | 98 | def infer_shape(self, *args, **kwargs): 99 | raise ShapeError() 100 | 101 | bad_normal_op = BadNormalRV() 102 | 103 | size = (3, 2) 104 | beta_rv = bad_normal_op(0, 1, size=size, name="beta") 105 | beta_vv = beta_rv.clone() 106 | beta_vv.name = "b" 107 | 108 | params_map = {beta_rv: beta_vv} 109 | 110 | with pytest.warns(Warning): 111 | RaveledParamsMap(params_map.keys()) 112 | --------------------------------------------------------------------------------