├── .flake8
├── .gitattributes
├── .github
├── FUNDING.yml
├── ISSUE_TEMPLATE.md
├── PULL_REQUEST_TEMPLATE.md
├── SECURITY.md
└── workflows
│ ├── pypi.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .pylintrc
├── .readthedocs.yaml
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── aehmc
├── __init__.py
├── algorithms.py
├── hmc.py
├── integrators.py
├── mass_matrix.py
├── metrics.py
├── nuts.py
├── proposals.py
├── step_size.py
├── termination.py
├── trajectory.py
├── utils.py
└── window_adaptation.py
├── conftest.py
├── environment.yml
├── examples
├── LinearRegression.ipynb
└── requirements.txt
├── pyproject.toml
├── requirements.txt
└── tests
├── __init__.py
├── test_adaptation.py
├── test_algorithms.py
├── test_hmc.py
├── test_integrators.py
├── test_mass_matrix.py
├── test_metrics.py
├── test_step_size.py
├── test_termination.py
├── test_trajectory.py
└── test_utils.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 88
3 | select = C,E,F,W
4 | ignore = E203,E231,E501,E741,W503,W504,C901
5 | per-file-ignores =
6 | **/__init__.py:F401,E402,F403
7 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | aehmc/_version.py export-subst
2 |
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github:
2 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ## Description of your problem or feature request
2 |
3 | **Please provide a minimal, self-contained, and reproducible example.**
4 | ```python
5 | [Your code here]
6 | ```
7 |
8 | **Please provide the full traceback of any errors.**
9 | ```python
10 | [The error output here]
11 | ```
12 |
13 | **Please provide any additional information below.**
14 |
15 |
16 | ## Versions and main components
17 |
18 | * Aesara version:
19 | * Aesara config (`python -c "import aesara; print(aesara.config)"`)
20 | * Python version:
21 | * Operating system:
22 | * How did you install Aesara: (conda/pip)
23 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
2 | **Thank you for opening a PR!**
3 |
4 | Here are a few important guidelines and requirements to check before your PR can be merged:
5 | + [ ] There is an informative high-level description of the changes.
6 | + [ ] The description and/or commit message(s) references the relevant GitHub issue(s).
7 | + [ ] [`pre-commit`](https://pre-commit.com/#installation) is installed and [set up](https://pre-commit.com/#3-install-the-git-hook-scripts).
8 | + [ ] The commit messages follow [these guidelines](https://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html).
9 | + [ ] The commits correspond to [_relevant logical changes_](https://wiki.openstack.org/wiki/GitCommitMessages#Structural_split_of_changes), and there are **no commits that fix changes introduced by other commits in the same branch/BR**. If your commit description starts with "Fix...", then you're probably making this mistake.
10 | + [ ] There are tests covering the changes introduced in the PR.
11 |
12 | Don't worry, your PR doesn't need to be in perfect order to submit it. As development progresses and/or reviewers request changes, you can always [rewrite the history](https://git-scm.com/book/en/v2/Git-Tools-Rewriting-History#_rewriting_history) of your feature/PR branches.
13 |
14 | If your PR is an ongoing effort and you would like to involve us in the process, simply make it a [draft PR](https://docs.github.com/en/free-pro-team@latest/github/collaborating-with-issues-and-pull-requests/about-pull-requests#draft-pull-requests).
15 |
--------------------------------------------------------------------------------
/.github/SECURITY.md:
--------------------------------------------------------------------------------
1 | To report a security vulnerability to Aesara, please go to
2 | https://tidelift.com/security and see the instructions there.
3 |
4 |
--------------------------------------------------------------------------------
/.github/workflows/pypi.yml:
--------------------------------------------------------------------------------
1 | name: PyPI
2 | on:
3 | push:
4 | branches:
5 | - main
6 | - auto-release
7 | pull_request:
8 | branches: [main]
9 | release:
10 | types: [published]
11 |
12 | # Cancels all previous workflow runs for pull requests that have not completed.
13 | concurrency:
14 | # The concurrency group contains the workflow name and the branch name for pull requests
15 | # or the commit hash for any other events.
16 | group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.head_ref || github.sha }}
17 | cancel-in-progress: true
18 |
19 | jobs:
20 | build:
21 | name: Build source distribution
22 | runs-on: ubuntu-latest
23 | steps:
24 | - uses: actions/checkout@v3
25 | with:
26 | fetch-depth: 0
27 | - uses: actions/setup-python@v4
28 | with:
29 | python-version: "3.8"
30 | - name: Build the sdist and wheel
31 | run: |
32 | python -m pip install --upgrade pip build
33 | python -m build
34 | - name: Check the sdist installs and imports
35 | run: |
36 | python -m venv venv-sdist
37 | # Since the whl distribution is build using sdist, it suffices
38 | # to only test the wheel installation to ensure both function as expected.
39 | venv-sdist/bin/python -m pip install dist/aehmc-*.whl
40 | venv-sdist/bin/python -c "import aehmc;print(aehmc.__version__)"
41 | - uses: actions/upload-artifact@v3
42 | with:
43 | name: artifact
44 | path: dist
45 | if-no-files-found: error
46 |
47 | upload_pypi:
48 | name: Upload to PyPI on release
49 | needs: [build]
50 | runs-on: ubuntu-latest
51 | if: github.event_name == 'release' && github.event.action == 'published'
52 | steps:
53 | - uses: actions/download-artifact@v2
54 | with:
55 | name: artifact
56 | path: dist
57 | - uses: pypa/gh-action-pypi-publish@release/v1
58 | with:
59 | user: __token__
60 | password: ${{ secrets.pypi_secret }}
61 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - checks
8 | pull_request:
9 | branches:
10 | - main
11 |
12 | # Cancels all previous workflow runs for pull requests that have not completed.
13 | concurrency:
14 | # The concurrency group contains the workflow name and the branch name for pull requests
15 | # or the commit hash for any other events.
16 | group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.head_ref || github.sha }}
17 | cancel-in-progress: true
18 |
19 | jobs:
20 | changes:
21 | name: "Check for changes"
22 | runs-on: ubuntu-latest
23 | outputs:
24 | changes: ${{ steps.changes.outputs.src }}
25 | steps:
26 | - uses: actions/checkout@v3
27 | with:
28 | fetch-depth: 0
29 | - uses: dorny/paths-filter@v2
30 | id: changes
31 | with:
32 | filters: |
33 | python: &python
34 | - 'aehmc/**/*.py'
35 | - 'tests/**/*.py'
36 | - 'aehmc/**/*.pyx'
37 | - 'tests/**/*.pyx'
38 | - '*.py'
39 | src:
40 | - *python
41 | - 'aehmc/**/*.c'
42 | - 'tests/**/*.c'
43 | - 'aehmc/**/*.h'
44 | - 'tests/**/*.h'
45 | - '.github/workflows/*.yml'
46 | - 'setup.cfg'
47 | - 'requirements.txt'
48 | - '.coveragerc'
49 |
50 |
51 | style:
52 | name: Check code style
53 | needs: changes
54 | runs-on: ubuntu-latest
55 | if: ${{ needs.changes.outputs.changes == 'true' }}
56 | steps:
57 | - uses: actions/checkout@v3
58 | - uses: actions/setup-python@v4
59 | - uses: pre-commit/action@v2.0.0
60 |
61 | test:
62 | name: "Test py${{ matrix.python-version }}: ${{ matrix.part }}"
63 | needs:
64 | - changes
65 | - style
66 | runs-on: ubuntu-latest
67 | if: ${{ needs.changes.outputs.changes == 'true' && needs.style.result == 'success' }}
68 | strategy:
69 | fail-fast: true
70 | matrix:
71 | python-version: ["3.8", "3.10"]
72 | fast-compile: [0]
73 | float32: [0]
74 | part:
75 | - "tests"
76 |
77 | steps:
78 | - uses: actions/checkout@v3
79 | with:
80 | fetch-depth: 0
81 | - name: Set up Python ${{ matrix.python-version }}
82 | uses: conda-incubator/setup-miniconda@v2
83 | with:
84 | mamba-version: "*"
85 | channels: conda-forge,defaults
86 | channel-priority: true
87 | python-version: ${{ matrix.python-version }}
88 | auto-update-conda: true
89 |
90 | - name: Create matrix id
91 | id: matrix-id
92 | env:
93 | MATRIX_CONTEXT: ${{ toJson(matrix) }}
94 | run: |
95 | echo $MATRIX_CONTEXT
96 | export MATRIX_ID=`echo $MATRIX_CONTEXT | md5sum | cut -c 1-32`
97 | echo $MATRIX_ID
98 | echo "::set-output name=id::$MATRIX_ID"
99 |
100 | - name: Install dependencies
101 | shell: bash -l {0}
102 | run: |
103 | mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service "aesara>=2.8.11"
104 | pip install -q -r requirements.txt
105 | mamba list && pip freeze
106 | python -c 'import aesara; print(aesara.config.__str__(print_doc=False))'
107 | python -c 'import aesara; assert(aesara.config.blas__ldflags != "")'
108 | env:
109 | PYTHON_VERSION: ${{ matrix.python-version }}
110 |
111 | - name: Run tests
112 | shell: bash -l {0}
113 | run: |
114 | if [[ $FAST_COMPILE == "1" ]]; then export AESARA_FLAGS=$AESARA_FLAGS,mode=FAST_COMPILE; fi
115 | if [[ $FLOAT32 == "1" ]]; then export AESARA_FLAGS=$AESARA_FLAGS,floatX=float32; fi
116 | export AESARA_FLAGS=$AESARA_FLAGS,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
117 | python -m pytest -x -r A --verbose --cov=aehmc --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART
118 | env:
119 | MATRIX_ID: ${{ steps.matrix-id.outputs.id }}
120 | MKL_THREADING_LAYER: GNU
121 | MKL_NUM_THREADS: 1
122 | OMP_NUM_THREADS: 1
123 | PART: ${{ matrix.part }}
124 | FAST_COMPILE: ${{ matrix.fast-compile }}
125 | FLOAT32: ${{ matrix.float32 }}
126 |
127 | - name: Upload coverage file
128 | uses: actions/upload-artifact@v2
129 | with:
130 | name: coverage
131 | path: coverage/coverage-${{ steps.matrix-id.outputs.id }}.xml
132 |
133 | all-checks:
134 | if: ${{ always() }}
135 | runs-on: ubuntu-latest
136 | name: "All tests"
137 | needs: [changes, style, test]
138 | steps:
139 | - name: Check build matrix status
140 | if: ${{ needs.changes.outputs.changes == 'true' && (needs.style.result != 'success' || needs.test.result != 'success') }}
141 | run: exit 1
142 |
143 | upload-coverage:
144 | runs-on: ubuntu-latest
145 | name: "Upload coverage"
146 | needs: [changes, all-checks]
147 | if: ${{ needs.changes.outputs.changes == 'true' && needs.all-checks.result == 'success' }}
148 | steps:
149 | - uses: actions/checkout@v3
150 |
151 | - name: Set up Python
152 | uses: actions/setup-python@v4
153 | with:
154 | python-version: 3.8
155 |
156 | - name: Install dependencies
157 | run: |
158 | python -m pip install -U coverage>=5.1 coveralls
159 |
160 | - name: Download coverage file
161 | uses: actions/download-artifact@v2
162 | with:
163 | name: coverage
164 | path: coverage
165 |
166 | - name: Upload coverage to Codecov
167 | uses: codecov/codecov-action@v1
168 | with:
169 | directory: ./coverage/
170 | fail_ci_if_error: true
171 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.gitignore.io/api/vim,emacs,python
2 | # Edit at https://www.gitignore.io/?templates=vim,emacs,python
3 |
4 | ### Emacs ###
5 | # -*- mode: gitignore; -*-
6 | *~
7 | \#*\#
8 | /.emacs.desktop
9 | /.emacs.desktop.lock
10 | *.elc
11 | auto-save-list
12 | tramp
13 | .\#*
14 |
15 | # Org-mode
16 | .org-id-locations
17 | *_archive
18 |
19 | # flymake-mode
20 | *_flymake.*
21 |
22 | # eshell files
23 | /eshell/history
24 | /eshell/lastdir
25 |
26 | # elpa packages
27 | /elpa/
28 |
29 | # reftex files
30 | *.rel
31 |
32 | # AUCTeX auto folder
33 | /auto/
34 |
35 | # cask packages
36 | .cask/
37 | dist/
38 |
39 | # Flycheck
40 | flycheck_*.el
41 |
42 | # server auth directory
43 | /server/
44 |
45 | # projectiles files
46 | .projectile
47 |
48 | # directory configuration
49 | .dir-locals.el
50 |
51 | # network security
52 | /network-security.data
53 |
54 |
55 | ### Python ###
56 | # Byte-compiled / optimized / DLL files
57 | __pycache__/
58 | *.py[cod]
59 | *$py.class
60 |
61 | # C extensions
62 | *.so
63 |
64 | # Distribution / packaging
65 | .Python
66 | build/
67 | develop-eggs/
68 | downloads/
69 | eggs/
70 | .eggs/
71 | lib/
72 | lib64/
73 | parts/
74 | sdist/
75 | var/
76 | wheels/
77 | pip-wheel-metadata/
78 | share/python-wheels/
79 | *.egg-info/
80 | .installed.cfg
81 | *.egg
82 | MANIFEST
83 |
84 | # PyInstaller
85 | # Usually these files are written by a python script from a template
86 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
87 | *.manifest
88 | *.spec
89 |
90 | # Installer logs
91 | pip-log.txt
92 | pip-delete-this-directory.txt
93 |
94 | # Unit test / coverage reports
95 | htmlcov/
96 | .tox/
97 | .nox/
98 | .coverage
99 | .coverage.*
100 | .cache
101 | nosetests.xml
102 | coverage.xml
103 | *.cover
104 | .hypothesis/
105 | .pytest_cache/
106 | testing-report.html
107 |
108 | # Translations
109 | *.mo
110 | *.pot
111 |
112 | # Django stuff:
113 | *.log
114 | local_settings.py
115 | db.sqlite3
116 | db.sqlite3-journal
117 |
118 | # Flask stuff:
119 | instance/
120 | .webassets-cache
121 |
122 | # Scrapy stuff:
123 | .scrapy
124 |
125 | # Sphinx documentation
126 | docs/_build/
127 |
128 | # PyBuilder
129 | target/
130 |
131 | # Jupyter Notebook
132 | .ipynb_checkpoints
133 |
134 | # IPython
135 | profile_default/
136 | ipython_config.py
137 | .ipynb_checkpoints
138 |
139 | # pyenv
140 | .python-version
141 |
142 | # pipenv
143 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
144 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
145 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
146 | # install all needed dependencies.
147 | #Pipfile.lock
148 |
149 | # celery beat schedule file
150 | celerybeat-schedule
151 |
152 | # SageMath parsed files
153 | *.sage.py
154 |
155 | # Environments
156 | .env
157 | .venv
158 | env/
159 | venv/
160 | ENV/
161 | env.bak/
162 | venv.bak/
163 |
164 | # Spyder project settings
165 | .spyderproject
166 | .spyproject
167 |
168 | # Rope project settings
169 | .ropeproject
170 |
171 | # mkdocs documentation
172 | /site
173 |
174 | # mypy
175 | .mypy_cache/
176 | .dmypy.json
177 | dmypy.json
178 |
179 | # Pyre type checker
180 | .pyre/
181 |
182 | # Pycharm/IntelliJ
183 | .idea
184 | *.iml
185 |
186 | ### Vim ###
187 | # Swap
188 | [._]*.s[a-v][a-z]
189 | [._]*.sw[a-p]
190 | [._]s[a-rt-v][a-z]
191 | [._]ss[a-gi-z]
192 | [._]sw[a-p]
193 |
194 | # Session
195 | Session.vim
196 | Sessionx.vim
197 |
198 | # Temporary
199 | .netrwhist
200 | # Auto-generated tag files
201 | tags
202 | # Persistent undo
203 | [._]*.un~
204 |
205 | ### OS X ###
206 | .DS_Store
207 |
208 | ### Visual Studio / VSCode ###
209 | .vs/
210 | .vscode/
211 |
212 | # End of https://www.gitignore.io/api/vim,emacs,python
213 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | exclude: |
2 | (?x)^(
3 | versioneer\.py|
4 | aehmc/_version\.py|
5 | doc/.*|
6 | bin/.*
7 | )$
8 | repos:
9 | - repo: https://github.com/pre-commit/pre-commit-hooks
10 | rev: v4.4.0
11 | hooks:
12 | - id: debug-statements
13 | - id: check-merge-conflict
14 | - repo: https://github.com/psf/black
15 | rev: 23.3.0
16 | hooks:
17 | - id: black
18 | language_version: python3
19 | - repo: https://github.com/pycqa/flake8
20 | rev: 6.0.0
21 | hooks:
22 | - id: flake8
23 | - repo: https://github.com/pycqa/isort
24 | rev: 5.12.0
25 | hooks:
26 | - id: isort
27 | - repo: https://github.com/pre-commit/mirrors-mypy
28 | rev: v1.2.0
29 | hooks:
30 | - id: mypy
31 | args: [--ignore-missing-imports]
32 | - repo: https://github.com/humitos/mirrors-autoflake.git
33 | rev: v1.1
34 | hooks:
35 | - id: autoflake
36 | exclude: |
37 | (?x)^(
38 | .*/?__init__\.py
39 | )$
40 | args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable']
41 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 | # Use multiple processes to speed up Pylint.
3 | jobs=0
4 |
5 | # Allow loading of arbitrary C extensions. Extensions are imported into the
6 | # active Python interpreter and may run arbitrary code.
7 | unsafe-load-any-extension=no
8 |
9 | # Allow optimization of some AST trees. This will activate a peephole AST
10 | # optimizer, which will apply various small optimizations. For instance, it can
11 | # be used to obtain the result of joining multiple strings with the addition
12 | # operator. Joining a lot of strings can lead to a maximum recursion error in
13 | # Pylint and this flag can prevent that. It has one side effect, the resulting
14 | # AST will be different than the one from reality.
15 | optimize-ast=no
16 |
17 | [MESSAGES CONTROL]
18 |
19 | # Only show warnings with the listed confidence levels. Leave empty to show
20 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
21 | confidence=
22 |
23 | # Disable the message, report, category or checker with the given id(s). You
24 | # can either give multiple identifiers separated by comma (,) or put this
25 | # option multiple times (only on the command line, not in the configuration
26 | # file where it should appear only once).You can also use "--disable=all" to
27 | # disable everything first and then reenable specific checks. For example, if
28 | # you want to run only the similarities checker, you can use "--disable=all
29 | # --enable=similarities". If you want to run only the classes checker, but have
30 | # no Warning level messages displayed, use"--disable=all --enable=classes
31 | # --disable=W"
32 | disable=all
33 |
34 | # Enable the message, report, category or checker with the given id(s). You can
35 | # either give multiple identifier separated by comma (,) or put this option
36 | # multiple time. See also the "--disable" option for examples.
37 | enable=import-error,
38 | import-self,
39 | reimported,
40 | wildcard-import,
41 | misplaced-future,
42 | relative-import,
43 | deprecated-module,
44 | unpacking-non-sequence,
45 | invalid-all-object,
46 | undefined-all-variable,
47 | used-before-assignment,
48 | cell-var-from-loop,
49 | global-variable-undefined,
50 | dangerous-default-value,
51 | # redefined-builtin,
52 | redefine-in-handler,
53 | unused-import,
54 | unused-wildcard-import,
55 | global-variable-not-assigned,
56 | undefined-loop-variable,
57 | global-at-module-level,
58 | bad-open-mode,
59 | redundant-unittest-assert,
60 | boolean-datetime,
61 | # unused-variable
62 |
63 |
64 | [REPORTS]
65 |
66 | # Set the output format. Available formats are text, parseable, colorized, msvs
67 | # (visual studio) and html. You can also give a reporter class, eg
68 | # mypackage.mymodule.MyReporterClass.
69 | output-format=parseable
70 |
71 | # Put messages in a separate file for each module / package specified on the
72 | # command line instead of printing them on stdout. Reports (if any) will be
73 | # written in a file name "pylint_global.[txt|html]".
74 | files-output=no
75 |
76 | # Tells whether to display a full report or only the messages
77 | reports=no
78 |
79 | # Python expression which should return a note less than 10 (10 is the highest
80 | # note). You have access to the variables errors warning, statement which
81 | # respectively contain the number of errors / warnings messages and the total
82 | # number of statements analyzed. This is used by the global evaluation report
83 | # (RP0004).
84 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
85 |
86 | [BASIC]
87 |
88 | # List of builtins function names that should not be used, separated by a comma
89 | bad-functions=map,filter,input
90 |
91 | # Good variable names which should always be accepted, separated by a comma
92 | good-names=i,j,k,ex,Run,_
93 |
94 | # Bad variable names which should always be refused, separated by a comma
95 | bad-names=foo,bar,baz,toto,tutu,tata
96 |
97 | # Colon-delimited sets of names that determine each other's naming style when
98 | # the name regexes allow several styles.
99 | name-group=
100 |
101 | # Include a hint for the correct naming format with invalid-name
102 | include-naming-hint=yes
103 |
104 | # Regular expression matching correct method names
105 | method-rgx=[a-z_][a-z0-9_]{2,30}$
106 |
107 | # Naming hint for method names
108 | method-name-hint=[a-z_][a-z0-9_]{2,30}$
109 |
110 | # Regular expression matching correct function names
111 | function-rgx=[a-z_][a-z0-9_]{2,30}$
112 |
113 | # Naming hint for function names
114 | function-name-hint=[a-z_][a-z0-9_]{2,30}$
115 |
116 | # Regular expression matching correct module names
117 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
118 |
119 | # Naming hint for module names
120 | module-name-hint=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
121 |
122 | # Regular expression matching correct attribute names
123 | attr-rgx=[a-z_][a-z0-9_]{2,30}$
124 |
125 | # Naming hint for attribute names
126 | attr-name-hint=[a-z_][a-z0-9_]{2,30}$
127 |
128 | # Regular expression matching correct class attribute names
129 | class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
130 |
131 | # Naming hint for class attribute names
132 | class-attribute-name-hint=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
133 |
134 | # Regular expression matching correct constant names
135 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
136 |
137 | # Naming hint for constant names
138 | const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$
139 |
140 | # Regular expression matching correct class names
141 | class-rgx=[A-Z_][a-zA-Z0-9]+$
142 |
143 | # Naming hint for class names
144 | class-name-hint=[A-Z_][a-zA-Z0-9]+$
145 |
146 | # Regular expression matching correct argument names
147 | argument-rgx=[a-z_][a-z0-9_]{2,30}$
148 |
149 | # Naming hint for argument names
150 | argument-name-hint=[a-z_][a-z0-9_]{2,30}$
151 |
152 | # Regular expression matching correct inline iteration names
153 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
154 |
155 | # Naming hint for inline iteration names
156 | inlinevar-name-hint=[A-Za-z_][A-Za-z0-9_]*$
157 |
158 | # Regular expression matching correct variable names
159 | variable-rgx=[a-z_][a-z0-9_]{2,30}$
160 |
161 | # Naming hint for variable names
162 | variable-name-hint=[a-z_][a-z0-9_]{2,30}$
163 |
164 | # Regular expression which should only match function or class names that do
165 | # not require a docstring.
166 | no-docstring-rgx=^_
167 |
168 | # Minimum line length for functions/classes that require docstrings, shorter
169 | # ones are exempt.
170 | docstring-min-length=-1
171 |
172 |
173 | [ELIF]
174 |
175 | # Maximum number of nested blocks for function / method body
176 | max-nested-blocks=5
177 |
178 |
179 | [FORMAT]
180 |
181 | # Maximum number of characters on a single line.
182 | max-line-length=100
183 |
184 | # Regexp for a line that is allowed to be longer than the limit.
185 | ignore-long-lines=^\s*(# )??$
186 |
187 | # Allow the body of an if to be on the same line as the test if there is no
188 | # else.
189 | single-line-if-stmt=no
190 |
191 | # List of optional constructs for which whitespace checking is disabled. `dict-
192 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
193 | # `trailing-comma` allows a space between comma and closing bracket: (a, ).
194 | # `empty-line` allows space-only lines.
195 | no-space-check=trailing-comma,dict-separator
196 |
197 | # Maximum number of lines in a module
198 | max-module-lines=1000
199 |
200 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
201 | # tab).
202 | indent-string=' '
203 |
204 | # Number of spaces of indent required inside a hanging or continued line.
205 | indent-after-paren=4
206 |
207 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
208 | expected-line-ending-format=
209 |
210 |
211 | [LOGGING]
212 |
213 | # Logging modules to check that the string format arguments are in logging
214 | # function parameter format
215 | logging-modules=logging
216 |
217 |
218 | [MISCELLANEOUS]
219 |
220 | # List of note tags to take in consideration, separated by a comma.
221 | notes=FIXME,XXX,TODO
222 |
223 |
224 | [SIMILARITIES]
225 |
226 | # Minimum lines number of a similarity.
227 | min-similarity-lines=4
228 |
229 | # Ignore comments when computing similarities.
230 | ignore-comments=yes
231 |
232 | # Ignore docstrings when computing similarities.
233 | ignore-docstrings=yes
234 |
235 | # Ignore imports when computing similarities.
236 | ignore-imports=no
237 |
238 |
239 | [SPELLING]
240 |
241 | # Spelling dictionary name. Available dictionaries: none. To make it working
242 | # install python-enchant package.
243 | spelling-dict=
244 |
245 | # List of comma separated words that should not be checked.
246 | spelling-ignore-words=
247 |
248 | # A path to a file that contains private dictionary; one word per line.
249 | spelling-private-dict-file=
250 |
251 | # Tells whether to store unknown words to indicated private dictionary in
252 | # --spelling-private-dict-file option instead of raising a message.
253 | spelling-store-unknown-words=no
254 |
255 |
256 | [TYPECHECK]
257 |
258 | # Tells whether missing members accessed in mixin class should be ignored. A
259 | # mixin class is detected if its name ends with "mixin" (case insensitive).
260 | ignore-mixin-members=yes
261 |
262 | # List of module names for which member attributes should not be checked
263 | # (useful for modules/projects where namespaces are manipulated during runtime
264 | # and thus existing member attributes cannot be deduced by static analysis. It
265 | # supports qualified module names, as well as Unix pattern matching.
266 | ignored-modules=
267 |
268 | # List of classes names for which member attributes should not be checked
269 | # (useful for classes with attributes dynamically set). This supports can work
270 | # with qualified names.
271 | ignored-classes=
272 |
273 | # List of members which are set dynamically and missed by pylint inference
274 | # system, and so shouldn't trigger E1101 when accessed. Python regular
275 | # expressions are accepted.
276 | generated-members=
277 |
278 |
279 | [VARIABLES]
280 |
281 | # Tells whether we should check for unused import in __init__ files.
282 | init-import=no
283 |
284 | # A regular expression matching the name of dummy variables (i.e. expectedly
285 | # not used).
286 | dummy-variables-rgx=_$|dummy
287 |
288 | # List of additional names supposed to be defined in builtins. Remember that
289 | # you should avoid to define new builtins when possible.
290 | additional-builtins=
291 |
292 | # List of strings which can identify a callback function by name. A callback
293 | # name must start or end with one of those strings.
294 | callbacks=cb_,_cb
295 |
296 |
297 | [CLASSES]
298 |
299 | # List of method names used to declare (i.e. assign) instance attributes.
300 | defining-attr-methods=__init__,__new__,setUp
301 |
302 | # List of valid names for the first argument in a class method.
303 | valid-classmethod-first-arg=cls
304 |
305 | # List of valid names for the first argument in a metaclass class method.
306 | valid-metaclass-classmethod-first-arg=mcs
307 |
308 | # List of member names, which should be excluded from the protected access
309 | # warning.
310 | exclude-protected=_asdict,_fields,_replace,_source,_make
311 |
312 |
313 | [DESIGN]
314 |
315 | # Maximum number of arguments for function / method
316 | max-args=5
317 |
318 | # Argument names that match this expression will be ignored. Default to name
319 | # with leading underscore
320 | ignored-argument-names=_.*
321 |
322 | # Maximum number of locals for function / method body
323 | max-locals=15
324 |
325 | # Maximum number of return / yield for function / method body
326 | max-returns=6
327 |
328 | # Maximum number of branch for function / method body
329 | max-branches=12
330 |
331 | # Maximum number of statements in function / method body
332 | max-statements=50
333 |
334 | # Maximum number of parents for a class (see R0901).
335 | max-parents=7
336 |
337 | # Maximum number of attributes for a class (see R0902).
338 | max-attributes=7
339 |
340 | # Minimum number of public methods for a class (see R0903).
341 | min-public-methods=2
342 |
343 | # Maximum number of public methods for a class (see R0904).
344 | max-public-methods=20
345 |
346 | # Maximum number of boolean expressions in a if statement
347 | max-bool-expr=5
348 |
349 |
350 | [IMPORTS]
351 |
352 | # Deprecated modules which should not be used, separated by a comma
353 | deprecated-modules=optparse
354 |
355 | # Create a graph of every (i.e. internal and external) dependencies in the
356 | # given file (report RP0402 must not be disabled)
357 | import-graph=
358 |
359 | # Create a graph of external dependencies in the given file (report RP0402 must
360 | # not be disabled)
361 | ext-import-graph=
362 |
363 | # Create a graph of internal dependencies in the given file (report RP0402 must
364 | # not be disabled)
365 | int-import-graph=
366 |
367 |
368 | [EXCEPTIONS]
369 |
370 | # Exceptions that will emit a warning when being caught. Defaults to
371 | # "Exception"
372 | overgeneral-exceptions=Exception
373 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: ubuntu-22.04
5 | tools:
6 | python: "3.8"
7 |
8 | python:
9 | install:
10 | - method: pip
11 | path: .
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021-2023 Aesara Developers
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | prune .github
2 | prune examples
3 | exclude *.yaml
4 | exclude *.yml
5 | exclude .git*
6 | exclude *.py
7 | exclude *rc
8 | exclude .flake8
9 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: help venv conda docker docstyle format style black test lint check coverage pypi
2 | .DEFAULT_GOAL = help
3 |
4 | PROJECT_NAME = aehmc
5 | PROJECT_DIR = aehmc/
6 | PYTHON = python
7 | PIP = pip
8 | CONDA = conda
9 | SHELL = bash
10 |
11 | help:
12 | @printf "Usage:\n"
13 | @grep -E '^[a-zA-Z_-]+:.*?# .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?# "}; {printf "\033[1;34mmake %-10s\033[0m%s\n", $$1, $$2}'
14 |
15 | conda: # Set up a conda environment for development.
16 | @printf "Creating conda environment...\n"
17 | ${CONDA} create --yes --name ${PROJECT_NAME}-env python=3.8
18 | ( \
19 | ${CONDA} activate ${PROJECT_NAME}-env; \
20 | ${PIP} install -U pip; \
21 | ${PIP} install -U setuptools wheel; \
22 | ${PIP} install -r requirements.txt; \
23 | ${CONDA} deactivate; \
24 | )
25 | @printf "\n\nConda environment created! \033[1;34mRun \`conda activate ${PROJECT_NAME}-env\` to activate it.\033[0m\n\n\n"
26 |
27 | venv: # Set up a Python virtual environment for development.
28 | @printf "Creating Python virtual environment...\n"
29 | rm -rf ${PROJECT_NAME}-venv
30 | ${PYTHON} -m venv ${PROJECT_NAME}-venv
31 | ( \
32 | source ${PROJECT_NAME}-venv/bin/activate; \
33 | ${PIP} install -U pip; \
34 | ${PIP} install -U setuptools wheel; \
35 | ${PIP} install -r requirements.txt; \
36 | deactivate; \
37 | )
38 | @printf "\n\nVirtual environment created! \033[1;34mRun \`source ${PROJECT_NAME}-venv/bin/activate\` to activate it.\033[0m\n\n\n"
39 |
40 | docker: # Set up a Docker image for development.
41 | @printf "Creating Docker image...\n"
42 | ${SHELL} ./scripts/container.sh --build
43 |
44 | docstyle:
45 | @printf "Checking documentation with pydocstyle...\n"
46 | pydocstyle ${PROJECT_DIR}
47 | @printf "\033[1;34mPydocstyle passes!\033[0m\n\n"
48 |
49 | format:
50 | @printf "Checking code style with black...\n"
51 | black --check ${PROJECT_DIR} tests/
52 | @printf "\033[1;34mBlack passes!\033[0m\n\n"
53 |
54 | style:
55 | @printf "Checking code style with pylint...\n"
56 | pylint ${PROJECT_DIR} tests/
57 | @printf "\033[1;34mPylint passes!\033[0m\n\n"
58 |
59 | black: # Format code in-place using black.
60 | black ${PROJECT_DIR} tests/
61 |
62 | test: # Test code using pytest.
63 | pytest -v tests/ ${PROJECT_DIR} --cov=${PROJECT_DIR} --cov-report=xml --html=testing-report.html --self-contained-html
64 |
65 | coverage: test
66 | diff-cover coverage.xml --compare-branch=main --fail-under=100
67 |
68 | pypi:
69 | ${PYTHON} setup.py clean --all; \
70 | ${PYTHON} setup.py rotate --match=.tar.gz,.whl,.egg,.zip --keep=0; \
71 | ${PYTHON} setup.py sdist bdist_wheel; \
72 | twine upload --skip-existing dist/*;
73 |
74 | lint: docstyle format style # Lint code using pydocstyle, black and pylint.
75 |
76 | check: lint test coverage # Both lint and test code. Runs `make lint` followed by `make test`.
77 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Aehmc
4 |
5 | [![Pypi][pypi-badge]][pypi]
6 | [![Gitter][gitter-badge]][gitter]
7 | [![Discord][discord-badge]][discord]
8 | [![Twitter][twitter-badge]][twitter]
9 |
10 | AeHMC provides implementations for the HMC and NUTS samplers in [Aesara](https://github.com/aesara-devs/aesara).
11 |
12 | [Features](#features) •
13 | [Get Started](#get-started) •
14 | [Install](#install) •
15 | [Get help](#get-help) •
16 | [Contribute](#contribute)
17 |
18 |
19 |
20 | ## Get started
21 |
22 | ``` python
23 | import aesara
24 | from aesara import tensor as at
25 | from aesara.tensor.random.utils import RandomStream
26 |
27 | from aeppl import joint_logprob
28 |
29 | from aehmc import nuts
30 |
31 | # A simple normal distribution
32 | Y_rv = at.random.normal(0, 1)
33 |
34 |
35 | def logprob_fn(y):
36 | return joint_logprob(realized={Y_rv: y})[0]
37 |
38 |
39 | # Build the transition kernel
40 | srng = RandomStream(seed=0)
41 | kernel = nuts.new_kernel(srng, logprob_fn)
42 |
43 | # Compile a function that updates the chain
44 | y_vv = Y_rv.clone()
45 | initial_state = nuts.new_state(y_vv, logprob_fn)
46 |
47 | step_size = at.as_tensor(1e-2)
48 | inverse_mass_matrix=at.as_tensor(1.0)
49 | chain_info, updates = kernel(initial_state, step_size, inverse_mass_matrix)
50 |
51 | next_step_fn = aesara.function([y_vv], chain_info.state.position, updates=updates)
52 |
53 | print(next_step_fn(0))
54 | # 1.1034719409361107
55 | ```
56 |
57 | ## Install
58 |
59 | The latest release of AeHMC can be installed from PyPI using ``pip``:
60 |
61 | ``` bash
62 | pip install aehmc
63 | ```
64 |
65 | Or via conda-forge:
66 |
67 | ``` bash
68 | conda install -c conda-forge aehmc
69 | ```
70 |
71 | The current development branch of AeHMC can be installed from GitHub using ``pip``:
72 |
73 | ``` bash
74 | pip install git+https://github.com/aesara-devs/aehmc
75 | ```
76 |
77 | ## Get help
78 |
79 | Report bugs by opening an [issue][issues]. If you have a question regarding the usage of AeHMC, start a [discussion][discussions]. For real-time feedback or more general chat about AeHMC use our [Discord server][discord] or [Gitter room][gitter].
80 |
81 | ## Contribute
82 |
83 | AeHMC welcomes contributions. A good place to start contributing is by looking at the [issues][issues].
84 |
85 | If you want to implement a new feature, open a [discussion][discussions] or come chat with us on [Discord][discord] or [Gitter][gitter].
86 |
87 | [contributors]: https://github.com/aesara-devs/aehmc/graphs/contributors
88 | [contributors-badge]: https://img.shields.io/github/contributors/aesara-devs/aehmc?style=flat-square&logo=github&logoColor=white&color=ECEFF4
89 | [discussions]: https://github.com/aesara-devs/aehmc/discussions
90 | [downloads-badge]: https://img.shields.io/pypi/dm/aehmc?style=flat-square&logo=pypi&logoColor=white&color=8FBCBB
91 | [discord]: https://discord.gg/h3sjmPYuGJ
92 | [discord-badge]: https://img.shields.io/discord/1072170173785723041?color=81A1C1&logo=discord&logoColor=white&style=flat-square
93 | [gitter]: https://gitter.im/aesara-devs/aehmc
94 | [gitter-badge]: https://img.shields.io/gitter/room/aesara-devs/aehmc?color=81A1C1&logo=matrix&logoColor=white&style=flat-square
95 | [issues]: https://github.com/aesara-devs/aehmc/issues
96 | [releases]: https://github.com/aesara-devs/aehmc/releases
97 | [twitter]: https://twitter.com/AesaraDevs
98 | [twitter-badge]: https://img.shields.io/twitter/follow/AesaraDevs?style=social
99 | [pypi]: https://pypi.org/project/aehmc/
100 | [pypi-badge]: https://img.shields.io/pypi/v/aehmc?color=ECEFF4&logo=python&logoColor=white&style=flat-square
101 |
--------------------------------------------------------------------------------
/aehmc/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from ._version import __version__, __version_tuple__
3 | except ImportError: # pragma: no cover
4 | raise RuntimeError(
5 | "Unable to find the version number that is generated when either building or "
6 | "installing from source. Please make sure that `aehmc` has been properly "
7 | "installed, e.g. with\n\n pip install -e .\n"
8 | )
9 |
--------------------------------------------------------------------------------
/aehmc/algorithms.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, NamedTuple, Tuple
2 |
3 | import aesara.tensor as at
4 | import numpy as np
5 | from aesara import config
6 | from aesara.tensor.var import TensorVariable
7 |
8 |
9 | class DualAveragingState(NamedTuple):
10 | step: TensorVariable
11 | iterates: TensorVariable
12 | iterates_avg: TensorVariable
13 | gradient_avg: TensorVariable
14 | shrinkage_pts: TensorVariable
15 |
16 |
17 | def dual_averaging(
18 | gamma: float = 0.05, t0: int = 10, kappa: float = 0.75
19 | ) -> Tuple[Callable, Callable]:
20 | """Dual averaging algorithm.
21 |
22 | Dual averaging is an algorithm for stochastic optimization that was
23 | originally proposed by Nesterov in [1]_.
24 |
25 | The update scheme we implement here is more elaborate than the one
26 | described in [1]_. We follow [2]_ and add the parameters `t_0` and `kappa`
27 | which respectively improves the stability of computations in early
28 | iterations and set how fast the algorithm should forget past iterates.
29 |
30 | The default values for the parameters are taken from the Stan implementation [3]_.
31 |
32 | Parameters
33 | ----------
34 | gamma
35 | Controls the amount of shrinkage towards mu.
36 | t0
37 | Improves the stability of computations early on.
38 | kappa
39 | Controls how fast the algorithm should forget past iterates.
40 |
41 | References
42 | ----------
43 | .. [1]: Nesterov, Y. (2009). Primal-dual subgradient methods for convex
44 | problems. Mathematical programming, 120(1), 221-259.
45 |
46 | .. [2]: Hoffman, M. D., & Gelman, A. (2014). The No-U-Turn sampler:
47 | adaptively setting path lengths in Hamiltonian Monte Carlo. J. Mach. Learn.
48 | Res., 15(1), 1593-1623.
49 |
50 | .. [3]: Carpenter, B., Gelman, A., Hoffman, M. D., Lee, D., Goodrich, B.,
51 | Betancourt, M., ... & Riddell, A. (2017). Stan: A probabilistic programming
52 | language. Journal of statistical software, 76(1), 1-32.
53 |
54 | """
55 |
56 | def init(mu: TensorVariable) -> DualAveragingState:
57 | """
58 | Initialize dual averaging state using shrinkage points.
59 |
60 | Parameters
61 | ----------
62 | mu
63 | Chosen points towards which the successive iterates are shrunk.
64 |
65 | """
66 | step = at.as_tensor(1, "step", dtype=np.int64)
67 | gradient_avg = at.as_tensor(0, "gradient_avg", dtype=config.floatX)
68 | x_t = at.as_tensor(0.0, "x_t", dtype=config.floatX)
69 | x_avg = at.as_tensor(0.0, "x_avg", dtype=config.floatX)
70 | return DualAveragingState(
71 | step=step,
72 | iterates=x_t,
73 | iterates_avg=x_avg,
74 | gradient_avg=gradient_avg,
75 | shrinkage_pts=mu,
76 | )
77 |
78 | def update(
79 | gradient: TensorVariable, state: DualAveragingState
80 | ) -> DualAveragingState:
81 | """Update the state of the Dual Averaging algorithm.
82 |
83 | Parameters
84 | ----------
85 | gradient
86 | The current value of the stochastic gradient. Replaced by a
87 | statistic to optimize in the case of MCMC adaptation.
88 | step
89 | The number of the current step in the optimization process.
90 | x
91 | The current value of the iterate.
92 | x_avg
93 | The current value of the averaged iterates.
94 | gradient_avg
95 | The current value of the averaged gradients.
96 |
97 | Returns
98 | -------
99 | Updated values for the step number, iterate, averaged iterates and
100 | averaged gradients.
101 |
102 | """
103 |
104 | eta = 1.0 / (state.step + t0)
105 | new_gradient_avg = (1.0 - eta) * state.gradient_avg + eta * gradient
106 | new_x = state.shrinkage_pts - (at.sqrt(state.step) / gamma) * new_gradient_avg
107 | x_eta = state.step ** (-kappa)
108 | new_x_avg = x_eta * state.iterates + (1.0 - x_eta) * state.iterates_avg
109 |
110 | return state._replace(
111 | step=(state.step + 1).astype(np.int64),
112 | iterates=new_x.astype(config.floatX),
113 | iterates_avg=new_x_avg.astype(config.floatX),
114 | gradient_avg=new_gradient_avg.astype(config.floatX),
115 | )
116 |
117 | return init, update
118 |
119 |
120 | def welford_covariance(compute_covariance: bool) -> Tuple[Callable, Callable, Callable]:
121 | """Welford's online estimator of variance/covariance.
122 |
123 | It is possible to compute the variance of a population of values in an
124 | on-line fashion to avoid storing intermediate results. The naive recurrence
125 | relations between the sample mean and variance at a step and the next are
126 | however not numerically stable.
127 |
128 | Welford's algorithm uses the sum of square of differences
129 | :math:`M_{2,n} = \\sum_{i=1}^n \\left(x_i-\\overline{x_n}\right)^2`
130 | for updating where :math:`x_n` is the current mean and the following
131 | recurrence relationships
132 |
133 | Parameters
134 | ----------
135 | compute_covariance
136 | When True the algorithm returns a covariance matrix, otherwise returns
137 | a variance vector.
138 |
139 | """
140 |
141 | def init(n_dims: int) -> Tuple[TensorVariable, TensorVariable, TensorVariable]:
142 | """Initialize the variance estimation.
143 |
144 | Parameters
145 | ----------
146 | n_dims: int
147 | The number of dimensions of the problem.
148 |
149 | """
150 | sample_size = at.as_tensor(0, dtype=np.int64)
151 |
152 | if n_dims == 0:
153 | return (
154 | at.as_tensor(0.0, dtype=config.floatX),
155 | at.as_tensor(0.0, dtype=config.floatX),
156 | sample_size,
157 | )
158 |
159 | mean = at.zeros((n_dims,), dtype=config.floatX)
160 | if compute_covariance:
161 | m2 = at.zeros((n_dims, n_dims), dtype=config.floatX)
162 | else:
163 | m2 = at.zeros((n_dims,), dtype=config.floatX)
164 |
165 | return mean, m2, sample_size
166 |
167 | def update(
168 | value: TensorVariable,
169 | mean: TensorVariable,
170 | m2: TensorVariable,
171 | sample_size: TensorVariable,
172 | ) -> Tuple[TensorVariable, TensorVariable, TensorVariable]:
173 | """Update the averages and M2 matrix using the new value.
174 |
175 | Parameters
176 | ----------
177 | value: Array, shape (1,)
178 | The new sample (typically position of the chain) used to update m2
179 | mean
180 | The running average along each dimension
181 | m2
182 | The running value of the unnormalized variance/covariance
183 | sample_size
184 | The number of points that have currently been used to compute `mean` and `m2`.
185 |
186 | """
187 | sample_size = sample_size + 1
188 |
189 | delta = value - mean
190 | mean = mean + delta / sample_size
191 | updated_delta = value - mean
192 | if compute_covariance and mean.ndim > 0:
193 | m2 = m2 + at.outer(updated_delta, delta)
194 | else:
195 | m2 = m2 + updated_delta * delta
196 |
197 | return mean, m2, sample_size
198 |
199 | def final(m2: TensorVariable, sample_size: TensorVariable) -> TensorVariable:
200 | """Compute the covariance"""
201 | variance_or_covariance = m2 / (sample_size - 1)
202 | return variance_or_covariance
203 |
204 | return init, update, final
205 |
--------------------------------------------------------------------------------
/aehmc/hmc.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Dict, Tuple
2 |
3 | import aesara
4 | import aesara.tensor as at
5 | import numpy as np
6 | from aesara.ifelse import ifelse
7 | from aesara.tensor.random.utils import RandomStream
8 | from aesara.tensor.var import TensorVariable
9 |
10 | import aehmc.integrators as integrators
11 | import aehmc.metrics as metrics
12 | import aehmc.trajectory as trajectory
13 | from aehmc.integrators import IntegratorState
14 |
15 |
16 | def new_state(q: TensorVariable, logprob_fn: Callable) -> IntegratorState:
17 | """Create a new HMC state from a position.
18 |
19 | Parameters
20 | ----------
21 | q
22 | The chain's position.
23 | logprob_fn
24 | The function that computes the value of the log-probability density
25 | function at any position.
26 |
27 | Returns
28 | -------
29 | A new HMC state, i.e. a tuple with the position, current value of the
30 | potential energy and gradient of the potential energy.
31 |
32 | """
33 | potential_energy = -logprob_fn(q)
34 | potential_energy_grad = aesara.grad(potential_energy, wrt=q)
35 | return IntegratorState(
36 | position=q,
37 | momentum=None,
38 | potential_energy=potential_energy,
39 | potential_energy_grad=potential_energy_grad,
40 | )
41 |
42 |
43 | def new_kernel(
44 | srng: RandomStream,
45 | logprob_fn: Callable,
46 | divergence_threshold: int = 1000,
47 | ) -> Callable:
48 | """Build a HMC kernel.
49 |
50 | Parameters
51 | ----------
52 | srng
53 | A RandomStream object that tracks the changes in a shared random state.
54 | logprob_fn
55 | A function that returns the value of the log-probability density
56 | function of a chain at a given position.
57 | divergence_threshold
58 | The difference in energy above which we say the transition is
59 | divergent.
60 |
61 | Returns
62 | -------
63 | A kernel that takes the current state of the chain and that returns a new
64 | state.
65 |
66 | References
67 | ----------
68 | .. [0]: Neal, Radford M. "MCMC using Hamiltonian dynamics." Handbook of markov
69 | chain monte carlo 2.11 (2011): 2.
70 |
71 |
72 | """
73 |
74 | def potential_fn(x):
75 | return -logprob_fn(x)
76 |
77 | def step(
78 | state: IntegratorState,
79 | step_size: TensorVariable,
80 | inverse_mass_matrix: TensorVariable,
81 | num_integration_steps: int,
82 | ) -> Tuple[trajectory.Diagnostics, Dict]:
83 | """Perform a single step of the HMC algorithm.
84 |
85 | Parameters
86 | ----------
87 | q
88 | The initial position.
89 | potential_energy
90 | The initial value of the potential energy.
91 | potential_energy_grad
92 | The initial value of the gradient of the potential energy wrt the position.
93 | step_size
94 | The step size used in the symplectic integrator
95 | inverse_mass_matrix
96 | One or two-dimensional array used as the inverse mass matrix that
97 | defines the euclidean metric.
98 | num_integration_steps
99 | The number of times we run the integrator at each step.
100 |
101 | Returns
102 | -------
103 | A tuple that contains on the one hand: the new position, value of the
104 | potential energy, gradient of the potential energy and acceptance
105 | propbability. On the other hand a dictionaruy that contains the update
106 | rules for the shared variables updated in the scan operator.
107 |
108 | """
109 |
110 | momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_metric(
111 | inverse_mass_matrix
112 | )
113 | symplectic_integrator = integrators.velocity_verlet(
114 | potential_fn, kinetic_energy_fn
115 | )
116 | proposal_generator = hmc_proposal(
117 | symplectic_integrator,
118 | kinetic_energy_fn,
119 | num_integration_steps,
120 | divergence_threshold,
121 | )
122 | updated_state = state._replace(momentum=momentum_generator(srng))
123 | chain_info, updates = proposal_generator(srng, updated_state, step_size)
124 | return chain_info, updates
125 |
126 | return step
127 |
128 |
129 | def hmc_proposal(
130 | integrator: Callable,
131 | kinetic_energy: Callable[[TensorVariable], TensorVariable],
132 | num_integration_steps: TensorVariable,
133 | divergence_threshold: int,
134 | ) -> Callable:
135 | """Builds a function that returns a HMC proposal.
136 |
137 | Parameters
138 | --------
139 | integrator
140 | The symplectic integrator used to integrate the hamiltonian dynamics.
141 | kinetic_energy
142 | The function used to compute the kinetic energy.
143 | num_integration_steps
144 | The number of times we need to run the integrator every time the
145 | returned function is called.
146 | divergence_threshold
147 | The difference in energy above which we say the transition is
148 | divergent.
149 |
150 | Returns
151 | -------
152 | A function that generates a new state for the chain.
153 |
154 | """
155 | integrate = trajectory.static_integration(integrator, num_integration_steps)
156 |
157 | def propose(
158 | srng: RandomStream, state: IntegratorState, step_size: TensorVariable
159 | ) -> Tuple[trajectory.Diagnostics, Dict]:
160 | """Use the HMC algorithm to propose a new state.
161 |
162 | Parameters
163 | ----------
164 | srng
165 | A RandomStream object that tracks the changes in a shared random state.
166 | q
167 | The initial position.
168 | potential_energy
169 | The initial value of the potential energy.
170 | potential_energy_grad
171 | The initial value of the gradient of the potential energy wrt the position.
172 | step_size
173 | The step size used in the symplectic integrator
174 |
175 | Returns
176 | -------
177 | A tuple that contains on the one hand: the new position, value of the
178 | potential energy, gradient of the potential energy and acceptance
179 | probability. On the other hand a dictionary that contains the update
180 | rules for the shared variables updated in the scan operator.
181 |
182 | """
183 | new_state, updates = integrate(state, step_size)
184 | # flip the momentum to keep detailed balance
185 | new_state = new_state._replace(momentum=-1.0 * new_state.momentum)
186 | # compute transition-related quantities
187 | energy = state.potential_energy + kinetic_energy(state.momentum)
188 | new_energy = new_state.potential_energy + kinetic_energy(new_state.momentum)
189 | delta_energy = energy - new_energy
190 | delta_energy = at.where(at.isnan(delta_energy), -np.inf, delta_energy)
191 | is_transition_divergent = at.abs(delta_energy) > divergence_threshold
192 |
193 | p_accept = at.clip(at.exp(delta_energy), 0, 1.0)
194 | do_accept = srng.bernoulli(p_accept)
195 | final_state = IntegratorState(*ifelse(do_accept, new_state, state))
196 | chain_info = trajectory.Diagnostics(
197 | state=final_state,
198 | acceptance_probability=p_accept,
199 | is_diverging=is_transition_divergent,
200 | num_doublings=None,
201 | is_turning=None,
202 | )
203 |
204 | return chain_info, updates
205 |
206 | return propose
207 |
--------------------------------------------------------------------------------
/aehmc/integrators.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, NamedTuple
2 |
3 | import aesara
4 | from aesara.tensor.var import TensorVariable
5 |
6 |
7 | class IntegratorState(NamedTuple):
8 | position: TensorVariable
9 | momentum: TensorVariable
10 | potential_energy: TensorVariable
11 | potential_energy_grad: TensorVariable
12 |
13 |
14 | def new_integrator_state(
15 | potential_fn: Callable, position: TensorVariable, momentum: TensorVariable
16 | ) -> IntegratorState:
17 | """Create a new integrator state from the current values of the position and momentum."""
18 | potential_energy = potential_fn(position)
19 | return IntegratorState(
20 | position=position,
21 | momentum=momentum,
22 | potential_energy=potential_energy,
23 | potential_energy_grad=aesara.grad(potential_energy, position),
24 | )
25 |
26 |
27 | def velocity_verlet(
28 | potential_fn: Callable[[TensorVariable], TensorVariable],
29 | kinetic_energy_fn: Callable[[TensorVariable], TensorVariable],
30 | ) -> Callable[[IntegratorState, TensorVariable], IntegratorState]:
31 | """The velocity Verlet (or Verlet-Störmer) integrator.
32 |
33 | The velocity Verlet is a two-stage palindromic integrator [1]_ of the form
34 | (a1, b1, a2, b1, a1) with a1 = 0. It is numerically stable for values of
35 | the step size that range between 0 and 2 (when the mass matrix is the
36 | identity).
37 |
38 | Parameters
39 | ----------
40 | potential_fn
41 | A function that returns the potential energy of a chain at a given
42 | position.
43 | kinetic_energy_fn
44 | A function that returns the kinetic energy of a chain at a given
45 | position and a given momentum.
46 |
47 | References
48 | ----------
49 | .. [1]: Bou-Rabee, Nawaf, and Jesús Marıa Sanz-Serna. "Geometric
50 | integrators and the Hamiltonian Monte Carlo method." Acta Numerica 27
51 | (2018): 113-206.
52 |
53 | """
54 | a1 = 0
55 | b1 = 0.5
56 | a2 = 1 - 2 * a1
57 |
58 | def one_step(state: IntegratorState, step_size: TensorVariable) -> IntegratorState:
59 | momentum = state.momentum - b1 * step_size * state.potential_energy_grad
60 |
61 | kinetic_grad = aesara.grad(kinetic_energy_fn(momentum), momentum)
62 | position = state.position + a2 * step_size * kinetic_grad
63 |
64 | potential_energy = potential_fn(position)
65 | potential_energy_grad = aesara.grad(potential_energy, position)
66 | momentum = momentum - b1 * step_size * potential_energy_grad
67 |
68 | return IntegratorState(
69 | position=position,
70 | momentum=momentum,
71 | potential_energy=potential_energy,
72 | potential_energy_grad=potential_energy_grad,
73 | )
74 |
75 | return one_step
76 |
--------------------------------------------------------------------------------
/aehmc/mass_matrix.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Tuple
2 |
3 | import aesara.tensor as at
4 | from aesara import config
5 | from aesara.tensor.var import TensorVariable
6 |
7 | from aehmc import algorithms
8 |
9 | WelfordAlgorithmState = Tuple[TensorVariable, TensorVariable, TensorVariable]
10 |
11 |
12 | def covariance_adaptation(
13 | is_mass_matrix_full: bool = False,
14 | ) -> Tuple[Callable, Callable, Callable]:
15 | """Adapts the values in the mass matrix by computing the covariance
16 | between parameters.
17 |
18 | Parameters
19 | ----------
20 | is_mass_matrix_full
21 | When False the algorithm adapts and returns a diagonal mass matrix
22 | (default), otherwise adapts and returns a dense mass matrix.
23 |
24 | Returns
25 | -------
26 | init
27 | A function that initializes the step of the mass matrix adaptation.
28 | update
29 | A function that updates the state of the mass matrix.
30 | final
31 | A function that computes the inverse mass matrix based on the current
32 | state.
33 | """
34 |
35 | wc_init, wc_update, wc_final = algorithms.welford_covariance(is_mass_matrix_full)
36 |
37 | def init(
38 | n_dims: int,
39 | ) -> Tuple[TensorVariable, WelfordAlgorithmState]:
40 | """Initialize the mass matrix adaptation.
41 |
42 | Parameters
43 | ----------
44 | ndims
45 | The number of dimensions of the mass matrix, which corresponds to
46 | the number of dimensions of the chain position.
47 |
48 | Returns
49 | -------
50 | The initial value of the mass matrix and the initial state of the
51 | Welford covariance algorithm.
52 |
53 | """
54 | if n_dims == 0:
55 | inverse_mass_matrix = at.constant(1.0, dtype=config.floatX)
56 | elif is_mass_matrix_full:
57 | inverse_mass_matrix = at.eye(n_dims, dtype=config.floatX)
58 | else:
59 | inverse_mass_matrix = at.ones((n_dims,), dtype=config.floatX)
60 |
61 | wc_state = wc_init(n_dims)
62 |
63 | return inverse_mass_matrix, wc_state
64 |
65 | def update(
66 | position: TensorVariable, wc_state: WelfordAlgorithmState
67 | ) -> WelfordAlgorithmState:
68 | """Update the algorithm's state.
69 |
70 | Parameters
71 | ----------
72 | position
73 | The current position of the chain.
74 | wc_state
75 | Current state of Welford's algorithm to compute covariance.
76 |
77 | """
78 | new_wc_state = wc_update(position, *wc_state)
79 | return new_wc_state
80 |
81 | def final(wc_state: WelfordAlgorithmState) -> TensorVariable:
82 | """Final iteration of the mass matrix adaptation.
83 |
84 | In this step we compute the mass matrix from the covariance matrix computed
85 | by the Welford algorithm, applying the shrinkage used in Stan [1]_.
86 |
87 | Parameters
88 | ----------
89 | wc_state
90 | Current state of Welford's algorithm to compute covariance.
91 |
92 | Returns
93 | -------
94 | The value of the inverse mass matrix computed from the covariance estimate.
95 |
96 | References
97 | ----------
98 | .. [1]: Carpenter, B., Gelman, A., Hoffman, M. D., Lee, D., Goodrich, B.,
99 | Betancourt, M., ... & Riddell, A. (2017). Stan: A probabilistic programming
100 | language. Journal of statistical software, 76(1), 1-32.
101 |
102 | """
103 | _, m2, sample_size = wc_state
104 | covariance = wc_final(m2, sample_size)
105 |
106 | scaled_covariance = (sample_size / (sample_size + 5)) * covariance
107 | shrinkage = 1e-3 * (5 / (sample_size + 5))
108 | if covariance.ndim > 0:
109 | if is_mass_matrix_full:
110 | new_inverse_mass_matrix = (
111 | scaled_covariance + shrinkage * at.identity_like(covariance)
112 | )
113 | else:
114 | new_inverse_mass_matrix = scaled_covariance + shrinkage
115 | else:
116 | new_inverse_mass_matrix = scaled_covariance + shrinkage
117 |
118 | return new_inverse_mass_matrix
119 |
120 | return init, update, final
121 |
--------------------------------------------------------------------------------
/aehmc/metrics.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Tuple
2 |
3 | import aesara.tensor as at
4 | from aesara.tensor.random.utils import RandomStream
5 | from aesara.tensor.shape import shape_tuple
6 | from aesara.tensor.slinalg import cholesky, solve_triangular
7 | from aesara.tensor.var import TensorVariable
8 |
9 |
10 | def gaussian_metric(
11 | inverse_mass_matrix: TensorVariable,
12 | ) -> Tuple[Callable, Callable, Callable]:
13 | r"""Hamiltonian dynamic on euclidean manifold with normally-distributed momentum.
14 |
15 | The gaussian euclidean metric is a euclidean metric further characterized
16 | by setting the conditional probability density :math:`\pi(momentum|position)`
17 | to follow a standard gaussian distribution. A Newtonian hamiltonian
18 | dynamics is assumed.
19 |
20 | Arguments
21 | ---------
22 | inverse_mass_matrix
23 | One or two-dimensional array corresponding respectively to a diagonal
24 | or dense mass matrix.
25 |
26 | Returns
27 | -------
28 | momentum_generator
29 | A function that generates a value for the momentum at random.
30 | kinetic_energy
31 | A function that returns the kinetic energy given the momentum.
32 | is_turning
33 | A function that determines whether a trajectory is turning back on
34 | itself given the values of the momentum along the trajectory.
35 |
36 | References
37 | ----------
38 | .. [1]: Betancourt, Michael. "A general metric for Riemannian manifold
39 | Hamiltonian Monte Carlo." International Conference on Geometric Science of
40 | Information. Springer, Berlin, Heidelberg, 2013.
41 |
42 | """
43 |
44 | if inverse_mass_matrix.ndim == 0:
45 | shape: Tuple = ()
46 | mass_matrix_sqrt = at.sqrt(at.reciprocal(inverse_mass_matrix))
47 | dot, matmul = lambda x, y: x * y, lambda x, y: x * y
48 | elif inverse_mass_matrix.ndim == 1:
49 | shape = (shape_tuple(inverse_mass_matrix)[0],)
50 | mass_matrix_sqrt = at.sqrt(at.reciprocal(inverse_mass_matrix))
51 | dot, matmul = at.dot, lambda x, y: x * y
52 | elif inverse_mass_matrix.ndim == 2:
53 | # inverse mass matrix can be factored into L*L.T. We want the cholesky
54 | # factor (inverse of L.T) of the mass matrix.
55 | shape = (shape_tuple(inverse_mass_matrix)[0],)
56 | L = cholesky(inverse_mass_matrix)
57 | identity = at.eye(*shape)
58 | mass_matrix_sqrt = solve_triangular(L, identity, lower=True, trans=True)
59 | dot, matmul = at.dot, at.dot
60 | else:
61 | raise ValueError(
62 | f"Expected a mass matrix of dimension 1 (diagonal) or 2, got {inverse_mass_matrix.ndim}"
63 | )
64 |
65 | def momentum_generator(srng: RandomStream) -> TensorVariable:
66 | norm_samples = srng.normal(0, 1, size=shape, name="momentum")
67 | momentum = matmul(mass_matrix_sqrt, norm_samples)
68 | return momentum
69 |
70 | def kinetic_energy(momentum: TensorVariable) -> TensorVariable:
71 | velocity = matmul(inverse_mass_matrix, momentum)
72 | kinetic_energy = 0.5 * dot(velocity, momentum)
73 | return kinetic_energy
74 |
75 | def is_turning(
76 | momentum_left: TensorVariable,
77 | momentum_right: TensorVariable,
78 | momentum_sum: TensorVariable,
79 | ) -> bool:
80 | """Generalized U-turn criterion.
81 |
82 | Parameters
83 | ----------
84 | momentum_left
85 | Momentum of the leftmost point of the trajectory.
86 | momentum_right
87 | Momentum of the rightmost point of the trajectory.
88 | momentum_sum
89 | Sum of the momenta along the trajectory.
90 |
91 | .. [1]: Betancourt, Michael J. "Generalizing the no-U-turn sampler to Riemannian manifolds." arXiv preprint arXiv:1304.1920 (2013).
92 | .. [2]: "NUTS misses U-turn, runs in cicles until max depth", Stan Discourse Forum
93 | https://discourse.mc-stan.org/t/nuts-misses-u-turns-runs-in-circles-until-max-treedepth/9727/46
94 | """
95 | velocity_left = matmul(inverse_mass_matrix, momentum_left)
96 | velocity_right = matmul(inverse_mass_matrix, momentum_right)
97 |
98 | rho = momentum_sum - (momentum_right + momentum_left) / 2
99 | turning_at_left = at.dot(velocity_left, rho) <= 0
100 | turning_at_right = at.dot(velocity_right, rho) <= 0
101 |
102 | is_turning = turning_at_left | turning_at_right
103 |
104 | return is_turning
105 |
106 | return momentum_generator, kinetic_energy, is_turning
107 |
--------------------------------------------------------------------------------
/aehmc/nuts.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Dict, Tuple
2 |
3 | import aesara.tensor as at
4 | import numpy as np
5 | from aesara.tensor.random.utils import RandomStream
6 | from aesara.tensor.var import TensorVariable
7 |
8 | from aehmc import hmc, integrators, metrics
9 | from aehmc.integrators import IntegratorState
10 | from aehmc.proposals import ProposalState
11 | from aehmc.termination import iterative_uturn
12 | from aehmc.trajectory import Diagnostics, dynamic_integration, multiplicative_expansion
13 |
14 | new_state = hmc.new_state
15 |
16 |
17 | def new_kernel(
18 | srng: RandomStream,
19 | logprob_fn: Callable[[TensorVariable], TensorVariable],
20 | max_num_expansions: int = 10,
21 | divergence_threshold: int = 1000,
22 | ) -> Callable:
23 | """Build an iterative NUTS kernel.
24 |
25 | Parameters
26 | ----------
27 | srng
28 | A RandomStream object that tracks the changes in a shared random state.
29 | logprob_fn
30 | A function that returns the value of the log-probability density
31 | function of a chain at a given position.
32 | max_num_expansions
33 | The maximum number of times we double the length of the trajectory.
34 | Known as the maximum tree depth in most implementations.
35 | divergence_threshold
36 | The difference in energy above which we say the transition is
37 | divergent.
38 |
39 | Returns
40 | -------
41 | A function which, given a chain state, returns a new chain state.
42 |
43 | References
44 | ----------
45 | .. [0]: Phan, Du, Neeraj Pradhan, and Martin Jankowiak. "Composable effects
46 | for flexible and accelerated probabilistic programming in NumPyro." arXiv
47 | preprint arXiv:1912.11554 (2019).
48 | .. [1]: Lao, Junpeng, et al. "tfp. mcmc: Modern markov chain monte carlo
49 | tools built for modern hardware." arXiv preprint arXiv:2002.01184 (2020).
50 |
51 | """
52 |
53 | def potential_fn(x):
54 | return -logprob_fn(x)
55 |
56 | def step(
57 | state: IntegratorState,
58 | step_size: TensorVariable,
59 | inverse_mass_matrix: TensorVariable,
60 | ) -> Tuple[Diagnostics, Dict]:
61 | """Use the NUTS algorithm to propose a new state.
62 |
63 | Parameters
64 | ----------
65 | q
66 | The initial position.
67 | potential_energy
68 | The initial value of the potential energy.
69 | potential_energy_grad
70 | The initial value of the gradient of the potential energy wrt the position.
71 | step_size
72 | The step size used in the symplectic integrator
73 | inverse_mass_matrix
74 | One or two-dimensional array used as the inverse mass matrix that
75 | defines the euclidean metric.
76 |
77 | Returns
78 | -------
79 | A tuple that contains on the one hand: the new position, value of the
80 | potential energy, gradient of the potential energy, the acceptance
81 | probability, the number of times the trajectory expanded, whether the
82 | integration diverged, whether the trajectory turned on itself. On the
83 | other hand a dictionary that contains the update rules for the shared
84 | variables updated in the scan operator.
85 |
86 | """
87 | momentum_generator, kinetic_energy_fn, uturn_check_fn = metrics.gaussian_metric(
88 | inverse_mass_matrix
89 | )
90 | symplectic_integrator = integrators.velocity_verlet(
91 | potential_fn, kinetic_energy_fn
92 | )
93 | (
94 | new_termination_state,
95 | update_termination_state,
96 | is_criterion_met,
97 | ) = iterative_uturn(uturn_check_fn)
98 | trajectory_integrator = dynamic_integration(
99 | srng,
100 | symplectic_integrator,
101 | kinetic_energy_fn,
102 | update_termination_state,
103 | is_criterion_met,
104 | divergence_threshold,
105 | )
106 | expand = multiplicative_expansion(
107 | srng,
108 | trajectory_integrator,
109 | uturn_check_fn,
110 | max_num_expansions,
111 | )
112 |
113 | initial_state = state._replace(momentum=momentum_generator(srng))
114 | initial_termination_state = new_termination_state(
115 | initial_state.position, max_num_expansions
116 | )
117 | initial_energy = initial_state.potential_energy + kinetic_energy_fn(
118 | initial_state.momentum
119 | )
120 | initial_proposal = ProposalState(
121 | state=initial_state,
122 | energy=initial_energy,
123 | weight=at.as_tensor(0.0, dtype=np.float64),
124 | sum_log_p_accept=at.as_tensor(-np.inf, dtype=np.float64),
125 | )
126 |
127 | results, updates = expand(
128 | initial_proposal,
129 | initial_state,
130 | initial_state,
131 | initial_state.momentum,
132 | initial_termination_state,
133 | initial_energy,
134 | step_size,
135 | )
136 |
137 | # extract the last iteration from multiplicative_expansion chain diagnostics
138 | chain_info = Diagnostics(
139 | state=IntegratorState(
140 | position=results.diagnostics.state.position[-1],
141 | momentum=results.diagnostics.state.momentum[-1],
142 | potential_energy=results.diagnostics.state.potential_energy[-1],
143 | potential_energy_grad=results.diagnostics.state.potential_energy_grad[
144 | -1
145 | ],
146 | ),
147 | acceptance_probability=results.diagnostics.acceptance_probability[-1],
148 | num_doublings=results.diagnostics.num_doublings[-1],
149 | is_turning=results.diagnostics.is_turning[-1],
150 | is_diverging=results.diagnostics.is_diverging[-1],
151 | )
152 |
153 | return chain_info, updates
154 |
155 | return step
156 |
--------------------------------------------------------------------------------
/aehmc/proposals.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, NamedTuple, Tuple
2 |
3 | import aesara.tensor as at
4 | import numpy as np
5 | from aesara.tensor.random.utils import RandomStream
6 | from aesara.tensor.var import TensorVariable
7 |
8 | from aehmc.integrators import IntegratorState
9 |
10 |
11 | class ProposalState(NamedTuple):
12 | state: IntegratorState
13 | energy: TensorVariable
14 | weight: TensorVariable
15 | sum_log_p_accept: TensorVariable
16 |
17 |
18 | def proposal_generator(kinetic_energy: Callable, divergence_threshold: float):
19 | def update(initial_energy, state: IntegratorState) -> Tuple[ProposalState, bool]:
20 | """Generate a new proposal from a trajectory state.
21 |
22 | The trajectory state records information about the position in the state
23 | space and corresponding potential energy. A proposal also carries a
24 | weight that is equal to the difference between the current energy and
25 | the previous one. It thus carries information about the previous state
26 | as well as the current state.
27 |
28 | Parameters
29 | ----------
30 | initial_energy:
31 | The initial energy.
32 | state:
33 | The new state.
34 |
35 | Return
36 | ------
37 | A tuple that contains the new proposal and a boolean that indicates
38 | whether the current transition is divergent.
39 |
40 | """
41 | new_energy = state.potential_energy + kinetic_energy(state.momentum)
42 |
43 | delta_energy = initial_energy - new_energy
44 | delta_energy = at.where(at.isnan(delta_energy), -np.inf, delta_energy)
45 | is_transition_divergent = at.abs(delta_energy) > divergence_threshold
46 |
47 | weight = delta_energy
48 | log_p_accept = at.where(
49 | at.gt(delta_energy, 0),
50 | at.as_tensor(0, dtype=delta_energy.dtype),
51 | delta_energy,
52 | )
53 |
54 | return (
55 | ProposalState(
56 | state=state,
57 | energy=new_energy,
58 | weight=weight,
59 | sum_log_p_accept=log_p_accept,
60 | ),
61 | is_transition_divergent,
62 | )
63 |
64 | return update
65 |
66 |
67 | # -------------------------------------------------------------------
68 | # PROGRESSIVE SAMPLING
69 | # -------------------------------------------------------------------
70 |
71 |
72 | def progressive_uniform_sampling(
73 | srng: RandomStream, proposal: ProposalState, new_proposal: ProposalState
74 | ) -> ProposalState:
75 | """Uniform proposal sampling.
76 |
77 | Choose between the current proposal and the proposal built from the last
78 | trajectory state.
79 |
80 | Parameters
81 | ----------
82 | srng
83 | RandomStream object
84 | proposal
85 | The current proposal, it does not necessarily correspond to the
86 | previous state on the trajectory
87 | new_proposal
88 | The proposal built from the last trajectory state.
89 |
90 | Return
91 | ------
92 | Either the current or the new proposal.
93 |
94 | """
95 | # TODO: Make the `at.isnan` check unnecessary
96 | p_accept = at.expit(new_proposal.weight - proposal.weight)
97 | p_accept = at.where(at.isnan(p_accept), 0, p_accept)
98 |
99 | do_accept = srng.bernoulli(p_accept)
100 | updated_proposal = maybe_update_proposal(do_accept, proposal, new_proposal)
101 |
102 | return updated_proposal
103 |
104 |
105 | def progressive_biased_sampling(
106 | srng: RandomStream, proposal: ProposalState, new_proposal: ProposalState
107 | ) -> ProposalState:
108 | """Baised proposal sampling.
109 |
110 | Choose between the current proposal and the proposal built from the last
111 | trajectory state. Unlike uniform sampling, biased sampling favors new
112 | proposals. It thus biases the transition away from the trajectory's initial
113 | state.
114 |
115 | Parameters
116 | ----------
117 | srng
118 | RandomStream object
119 | proposal
120 | The current proposal, it does not necessarily correspond to the
121 | previous state on the trajectory
122 | new_proposal
123 | The proposal built from the last trajectory state.
124 |
125 | Return
126 | ------
127 | Either the current or the new proposal.
128 |
129 | """
130 | p_accept = at.clip(at.exp(new_proposal.weight - proposal.weight), 0.0, 1.0)
131 | do_accept = srng.bernoulli(p_accept)
132 | updated_proposal = maybe_update_proposal(do_accept, proposal, new_proposal)
133 |
134 | return updated_proposal
135 |
136 |
137 | def maybe_update_proposal(
138 | do_accept: bool, proposal: ProposalState, new_proposal: ProposalState
139 | ) -> ProposalState:
140 | """Return either proposal depending on the boolean `do_accept`"""
141 | updated_weight = at.logaddexp(proposal.weight, new_proposal.weight)
142 | updated_log_sum_p_accept = at.logaddexp(
143 | proposal.sum_log_p_accept, new_proposal.sum_log_p_accept
144 | )
145 |
146 | updated_q = at.where(
147 | do_accept, new_proposal.state.position, proposal.state.position
148 | )
149 | updated_p = at.where(
150 | do_accept, new_proposal.state.momentum, proposal.state.momentum
151 | )
152 | updated_potential_energy = at.where(
153 | do_accept, new_proposal.state.potential_energy, proposal.state.potential_energy
154 | )
155 | updated_potential_energy_grad = at.where(
156 | do_accept,
157 | new_proposal.state.potential_energy_grad,
158 | proposal.state.potential_energy_grad,
159 | )
160 | updated_energy = at.where(do_accept, new_proposal.energy, proposal.energy)
161 |
162 | updated_state = IntegratorState(
163 | position=updated_q,
164 | momentum=updated_p,
165 | potential_energy=updated_potential_energy,
166 | potential_energy_grad=updated_potential_energy_grad,
167 | )
168 |
169 | return ProposalState(
170 | state=updated_state,
171 | energy=updated_energy,
172 | weight=updated_weight,
173 | sum_log_p_accept=updated_log_sum_p_accept,
174 | )
175 |
--------------------------------------------------------------------------------
/aehmc/step_size.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Tuple
2 |
3 | import aesara.tensor as at
4 | from aesara.tensor.var import TensorVariable
5 |
6 | from aehmc import algorithms
7 |
8 |
9 | def dual_averaging_adaptation(
10 | target_acceptance_rate: TensorVariable = at.as_tensor(0.8),
11 | gamma: float = 0.05,
12 | t0: int = 10,
13 | kappa: float = 0.75,
14 | ) -> Tuple[Callable, Callable]:
15 | r"""Tune the step size to achieve a desired target acceptance rate.
16 |
17 | Let us note :math:`\epsilon` the current step size, :math:`\alpha_t` the
18 | metropolis acceptance rate at time :math:`t` and :math:`\delta` the desired
19 | aceptance rate. We define:
20 |
21 | .. math:
22 | H_t = \delta - \alpha_t
23 |
24 | the error at time t. We would like to find a procedure that adapts the
25 | value of :math:`\epsilon` such that :math:`h(x) =\mathbb{E}\left[H_t|\epsilon\right] = 0`
26 | Following [1]_, the authors of [2]_ proposed the following update scheme. If
27 | we note :math:``x = \log \epsilon` we follow:
28 |
29 | .. math:
30 | x_{t+1} \LongLeftArrow \mu - \frac{\sqrt{t}}{\gamma} \frac{1}{t+t_0} \sum_{i=1}^t H_i
31 | \overline{x}_{t+1} \LongLeftArrow x_{t+1}\\, t^{-\kappa} + \left(1-t^\kappa\right)\overline{x}_t
32 |
33 | :math:`\overline{x}_{t}` is guaranteed to converge to a value such that
34 | :math:`h(\overline{x}_t)` converges to 0, i.e. the Metropolis acceptance
35 | rate converges to the desired rate.
36 |
37 | See reference [2]_ (section 3.2.1) for a detailed discussion.
38 |
39 | Parameters
40 | ----------
41 | initial_log_step_size:
42 | Initial value of the logarithm of the step size, used as an iterate in
43 | the dual averaging algorithm.
44 | target_acceptance_rate:
45 | Target acceptance rate.
46 | gamma
47 | Controls the speed of convergence of the scheme. The authors of [2]_ recommend
48 | a value of 0.05.
49 | t0: float >= 0
50 | Free parameter that stabilizes the initial iterations of the algorithm.
51 | Large values may slow down convergence. Introduced in [2]_ with a default
52 | value of 10.
53 | kappa: float in ]0.5, 1]
54 | Controls the weights of past steps in the current update. The scheme will
55 | quickly forget earlier step for a small value of `kappa`. Introduced
56 | in [2]_, with a recommended value of .75
57 |
58 | Returns
59 | -------
60 | init
61 | A function that initializes the state of the dual averaging scheme.
62 | update
63 | A function that updates the state of the dual averaging scheme.
64 |
65 | References
66 | ----------
67 | .. [1]: Nesterov, Yurii. "Primal-dual subgradient methods for convex
68 | problems." Mathematical programming 120.1 (2009): 221-259.
69 | .. [2]: Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler:
70 | adaptively setting path lengths in Hamiltonian Monte Carlo." Journal
71 | of Machine Learning Research 15.1 (2014): 1593-1623.
72 |
73 | """
74 | da_init, da_update = algorithms.dual_averaging(gamma, t0, kappa)
75 |
76 | def update(
77 | acceptance_probability: TensorVariable, state: algorithms.DualAveragingState
78 | ) -> algorithms.DualAveragingState:
79 | """Update the dual averaging adaptation state.
80 |
81 | Parameters
82 | ----------
83 | acceptance_probability
84 | The acceptance probability returned by the sampling algorithm.
85 | step
86 | The current step number.
87 | log_step_size
88 | The logarithm of the current value of the current step size.
89 | log_step_size_avg
90 | The average of the logarithm of the current step size.
91 | gradient_avg
92 | The current value of the averaged gradients.
93 | mu
94 | The points towards which iterates are shrunk.
95 |
96 | """
97 | gradient = target_acceptance_rate - acceptance_probability
98 | return da_update(gradient, state)
99 |
100 | return da_init, update
101 |
--------------------------------------------------------------------------------
/aehmc/termination.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, NamedTuple, Tuple
2 |
3 | import aesara
4 | import aesara.tensor as at
5 | import numpy as np
6 | from aesara import config as config
7 | from aesara.ifelse import ifelse
8 | from aesara.scan.utils import until
9 | from aesara.tensor.var import TensorVariable
10 |
11 |
12 | class TerminationState(NamedTuple):
13 | momentum_checkpoints: TensorVariable
14 | momentum_sum_checkpoints: TensorVariable
15 | min_index: TensorVariable
16 | max_index: TensorVariable
17 |
18 |
19 | def iterative_uturn(is_turning_fn: Callable) -> Tuple[Callable, Callable, Callable]:
20 | """U-Turn termination criterion to check reversiblity while expanding
21 | the trajectory.
22 |
23 | The code follows the implementation in Numpyro [0]_, which is equivalent to
24 | that in TFP [1]_.
25 |
26 | Parameter
27 | ---------
28 | is_turning_fn:
29 | A function which, given the new momentum and the sum of the momenta
30 | along the trajectory returns a boolean that indicates whether the
31 | trajectory is turning on itself. Depends on the metric.
32 |
33 | References
34 | ----------
35 | .. [0]: Phan, Du, Neeraj Pradhan, and Martin Jankowiak. "Composable effects
36 | for flexible and accelerated probabilistic programming in NumPyro." arXiv
37 | preprint arXiv:1912.11554 (2019).
38 | .. [1]: Lao, Junpeng, et al. "tfp. mcmc: Modern markov chain monte carlo
39 | tools built for modern hardware." arXiv preprint arXiv:2002.01184 (2020).
40 |
41 | """
42 |
43 | def new_state(
44 | position: TensorVariable, max_num_doublings: TensorVariable
45 | ) -> TerminationState:
46 | """Initialize the termination state
47 |
48 | Parameters
49 | ----------
50 | position
51 | Example chain position. Used to infer the shape of the arrays that
52 | store relevant momentam and momentum sums.
53 | max_num_doublings
54 | Maximum number of doublings allowed in the multiplicative
55 | expansion. Determines the maximum number of momenta and momentum
56 | sums to store.
57 |
58 | Returns
59 | -------
60 | A tuple that represents a new state for the termination criterion.
61 |
62 | """
63 | if position.ndim == 0:
64 | return TerminationState(
65 | momentum_checkpoints=at.zeros(max_num_doublings, dtype=config.floatX),
66 | momentum_sum_checkpoints=at.zeros(
67 | max_num_doublings, dtype=config.floatX
68 | ),
69 | min_index=at.constant(0, dtype=np.int64),
70 | max_index=at.constant(0, dtype=np.int64),
71 | )
72 | else:
73 | num_dims = position.shape[0]
74 | return TerminationState(
75 | momentum_checkpoints=at.zeros(
76 | (max_num_doublings, num_dims), dtype=config.floatX
77 | ),
78 | momentum_sum_checkpoints=at.zeros(
79 | (max_num_doublings, num_dims), dtype=config.floatX
80 | ),
81 | min_index=at.constant(0, dtype=np.int64),
82 | max_index=at.constant(0, dtype=np.int64),
83 | )
84 |
85 | def update(
86 | state: TerminationState,
87 | momentum_sum: TensorVariable,
88 | momentum: TensorVariable,
89 | step: TensorVariable,
90 | ) -> TerminationState:
91 | """Update the termination state.
92 |
93 | Parameters
94 | ----------
95 | state
96 | The current termination state
97 | momentum_sum
98 | The sum of all momenta along the trajectory
99 | momentum
100 | The current momentum on the trajectory
101 | step
102 | Current step in the trajectory integration (starting at 0)
103 |
104 | Return
105 | ------
106 | A tuple that represents the updated termination state.
107 |
108 | """
109 | idx_min, idx_max = ifelse(
110 | at.eq(step, 0),
111 | (state.min_index, state.max_index),
112 | _find_storage_indices(step),
113 | )
114 |
115 | momentum_ckpt = at.where(
116 | at.eq(step % 2, 0),
117 | at.set_subtensor(state.momentum_checkpoints[idx_max], momentum),
118 | state.momentum_checkpoints,
119 | )
120 | momentum_sum_ckpt = at.where(
121 | at.eq(step % 2, 0),
122 | at.set_subtensor(state.momentum_sum_checkpoints[idx_max], momentum_sum),
123 | state.momentum_sum_checkpoints,
124 | )
125 |
126 | return TerminationState(
127 | momentum_checkpoints=momentum_ckpt,
128 | momentum_sum_checkpoints=momentum_sum_ckpt,
129 | min_index=idx_min,
130 | max_index=idx_max,
131 | )
132 |
133 | def is_iterative_turning(
134 | state: TerminationState, momentum_sum: TensorVariable, momentum: TensorVariable
135 | ) -> bool:
136 | """Check if any sub-trajectory is making a U-turn.
137 |
138 | If we visualize the trajectory as a balanced binary tree, the
139 | subtrajectories for which we need to check the U-turn criterion are the
140 | ones for which the current node is the rightmost node. The
141 | corresponding momenta and sums of momentum corresponding to the nodes
142 | for which we need to check the U-Turn criterion are stored between
143 | `idx_min` and `idx_max` in `momentum_ckpts` and `momentum_sum_ckpts`
144 | respectively.
145 |
146 | Parameters
147 | ----------
148 | state
149 | The current termination state
150 | momentum_sum
151 | The sum of all momenta along the trajectory
152 | momentum
153 | The current momentum on the trajectory
154 | step
155 | Current step in the trajectory integration (starting at 0)
156 |
157 |
158 | Return
159 | ------
160 | True if any sub-trajectory makes a U-turn, False otherwise.
161 |
162 | """
163 |
164 | def body_fn(i):
165 | subtree_momentum_sum = (
166 | momentum_sum
167 | - state.momentum_sum_checkpoints[i]
168 | + state.momentum_checkpoints[i]
169 | )
170 | is_turning = is_turning_fn(
171 | state.momentum_checkpoints[i], momentum, subtree_momentum_sum
172 | )
173 | reached_max_iteration = at.lt(i - 1, state.min_index)
174 | do_stop = at.any(is_turning | reached_max_iteration)
175 | return (i - 1, is_turning), until(do_stop)
176 |
177 | (_, criterion), _ = aesara.scan(
178 | body_fn, outputs_info=(state.max_index, None), n_steps=state.max_index + 2
179 | )
180 |
181 | is_turning = at.where(
182 | at.lt(state.max_index, state.min_index),
183 | at.as_tensor(0, dtype="bool"),
184 | criterion[-1],
185 | )
186 |
187 | return is_turning
188 |
189 | return new_state, update, is_iterative_turning
190 |
191 |
192 | def _find_storage_indices(step: TensorVariable) -> Tuple[int, int]:
193 | """Find the indices between which the momenta and sums are stored.
194 |
195 | Parameter
196 | ---------
197 | step
198 | The current step in the trajectory integration.
199 |
200 | Return
201 | ------
202 | The min and max indices between which the values relevant to check the
203 | U-turn condition for the current step are stored.
204 |
205 | """
206 |
207 | def count_subtrees(nc0, nc1):
208 | do_stop = at.eq(nc0 & 1, 0)
209 | new_nc0 = nc0 // 2
210 | new_nc1 = nc1 + 1
211 | return (new_nc0, new_nc1), until(do_stop)
212 |
213 | (_, nc1), _ = aesara.scan(
214 | count_subtrees,
215 | outputs_info=(step, -1),
216 | n_steps=step + 1,
217 | )
218 | num_subtrees = nc1[-1]
219 |
220 | def find_idx_max(nc0, nc1):
221 | do_stop = at.eq(nc0, 0)
222 | new_nc0 = nc0 // 2
223 | new_nc1 = nc1 + (nc0 & 1)
224 | return (new_nc0, new_nc1), until(do_stop)
225 |
226 | init = at.as_tensor(step // 2, dtype=np.int64)
227 | init_nc1 = at.constant(0, dtype=np.int64)
228 | (nc0, nc1), _ = aesara.scan(
229 | find_idx_max, outputs_info=(init, init_nc1), n_steps=step + 1
230 | )
231 | idx_max = nc1[-1]
232 |
233 | idx_min = idx_max - num_subtrees + 1
234 |
235 | return idx_min, idx_max
236 |
--------------------------------------------------------------------------------
/aehmc/trajectory.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Dict, NamedTuple, Tuple
2 |
3 | import aesara
4 | import aesara.tensor as at
5 | import numpy as np
6 | from aesara.ifelse import ifelse
7 | from aesara.scan.utils import until
8 | from aesara.tensor.random.utils import RandomStream
9 | from aesara.tensor.var import TensorVariable
10 |
11 | from aehmc.integrators import IntegratorState
12 | from aehmc.proposals import (
13 | ProposalState,
14 | progressive_biased_sampling,
15 | progressive_uniform_sampling,
16 | proposal_generator,
17 | )
18 | from aehmc.termination import TerminationState
19 |
20 | __all__ = ["static_integration", "dynamic_integration", "multiplicative_expansion"]
21 |
22 |
23 | # -------------------------------------------------------------------
24 | # STATIC INTEGRATION
25 | #
26 | # This section contains algorithms that integrate the trajectory for
27 | # a set number of integrator steps.
28 | # -------------------------------------------------------------------
29 |
30 |
31 | def static_integration(
32 | integrator: Callable,
33 | num_integration_steps: int,
34 | ) -> Callable:
35 | """Build a function that generates fixed-length trajectories.
36 |
37 | Parameters
38 | ----------
39 | integrator
40 | Function that performs one integration step.
41 | num_integration_steps
42 | The number of times we need to run the integrator every time the
43 | returned function is called.
44 |
45 | Returns
46 | -------
47 | A function that integrates the hamiltonian dynamics a
48 | `num_integration_steps` times.
49 |
50 | """
51 |
52 | def integrate(
53 | init_state: IntegratorState, step_size: TensorVariable
54 | ) -> Tuple[IntegratorState, Dict]:
55 | """Generate a trajectory by integrating several times in one direction.
56 |
57 | Parameters
58 | ----------
59 | q_init
60 | The initial position.
61 | p_init
62 | The initial value of the momentum.
63 | energy_init
64 | The initial value of the potential energy.
65 | energy_grad_init
66 | The initial value of the gradient of the potential energy wrt the position.
67 | step_size
68 | The size of each step taken by the integrator.
69 |
70 | Returns
71 | -------
72 | A tuple with the last position, value of the momentum, potential energy,
73 | gradient of the potential energy wrt the position in a tuple as well as
74 | a dictionary that contains the update rules for all the shared variables
75 | updated in `scan`.
76 |
77 | """
78 |
79 | def one_step(q, p, potential_energy, potential_energy_grad):
80 | new_state = integrator(
81 | IntegratorState(q, p, potential_energy, potential_energy_grad),
82 | step_size,
83 | )
84 | return new_state
85 |
86 | [q, p, energy, energy_grad], updates = aesara.scan(
87 | fn=one_step,
88 | outputs_info=[
89 | {"initial": init_state.position},
90 | {"initial": init_state.momentum},
91 | {"initial": init_state.potential_energy},
92 | {"initial": init_state.potential_energy_grad},
93 | ],
94 | n_steps=num_integration_steps,
95 | )
96 |
97 | return (
98 | IntegratorState(
99 | position=q[-1],
100 | momentum=p[-1],
101 | potential_energy=energy[-1],
102 | potential_energy_grad=energy_grad[-1],
103 | ),
104 | updates,
105 | )
106 |
107 | return integrate
108 |
109 |
110 | # -------------------------------------------------------------------
111 | # DYNAMIC INTEGRATION
112 | #
113 | # This section contains algorithms that determine the number of
114 | # integrator steps dynamically using a termination criterion that
115 | # is updated at every step.
116 | # -------------------------------------------------------------------
117 |
118 |
119 | def dynamic_integration(
120 | srng: RandomStream,
121 | integrator: Callable,
122 | kinetic_energy: Callable,
123 | update_termination_state: Callable,
124 | is_criterion_met: Callable,
125 | divergence_threshold: TensorVariable,
126 | ):
127 | """Integrate a trajectory and update the proposal sequentially in one direction
128 | until the termination criterion is met.
129 |
130 | Parameters
131 | ----------
132 | srng
133 | A RandomStream object that tracks the changes in a shared random state.
134 | integrator
135 | The symplectic integrator used to integrate the hamiltonian dynamics.
136 | kinetic_energy
137 | Function to compute the current value of the kinetic energy.
138 | update_termination_state
139 | Updates the state of the termination mechanism.
140 | is_criterion_met
141 | Determines whether the termination criterion has been met.
142 | divergence_threshold
143 | Value of the difference of energy between two consecutive states above which we say a transition is divergent.
144 |
145 | Returns
146 | -------
147 | A function that integrates the trajectory in one direction and updates a
148 | proposal until the termination criterion is met.
149 |
150 | """
151 | generate_proposal = proposal_generator(kinetic_energy, divergence_threshold)
152 | sample_proposal = progressive_uniform_sampling
153 |
154 | def integrate(
155 | previous_last_state: IntegratorState,
156 | direction: TensorVariable,
157 | termination_state: TerminationState,
158 | max_num_steps: TensorVariable,
159 | step_size: TensorVariable,
160 | initial_energy: TensorVariable,
161 | ):
162 | """Integrate the trajectory starting from `initial_state` and update
163 | the proposal sequentially until the termination criterion is met.
164 |
165 | Parameters
166 | ----------
167 | previous_last_state
168 | The last state of the previously integrated trajectory.
169 | direction int in {-1, 1}
170 | The direction in which to expand the trajectory.
171 | termination_state
172 | The state that keeps track of the information needed for the termination criterion.
173 | max_num_steps
174 | The maximum number of integration steps. The expansion will stop
175 | when this number is reached if the termination criterion has not
176 | been met.
177 | step_size
178 | The step size of the symplectic integrator.
179 | initial_energy
180 | Initial energy H0 of the HMC step (not to confused with the initial energy of the subtree)
181 |
182 | Returns
183 | -------
184 | A tuple with on the one hand: a new proposal (sampled from the states
185 | traversed while building the trajectory), the last state, the sum of the
186 | momenta values along the trajectory (needed for termination criterion),
187 | the updated termination state, the number of integration steps
188 | performed, a boolean that indicates whether the trajectory has diverged,
189 | a boolean that indicates whether the termination criterion was met. And
190 | on the other hand a dictionary that contains the update rules for the
191 | shared variables updated in the internal `Scan` operator.
192 |
193 | """
194 |
195 | def add_one_state(
196 | step,
197 | q_proposal, # current proposal
198 | p_proposal,
199 | potential_energy_proposal,
200 | potential_energy_grad_proposal,
201 | energy_proposal,
202 | weight,
203 | sum_p_accept,
204 | q_last, # state
205 | p_last,
206 | potential_energy_last,
207 | potential_energy_grad_last,
208 | momentum_sum: TensorVariable, # sum of momenta
209 | momentum_ckpts, # termination state
210 | momentum_sum_ckpts,
211 | idx_min,
212 | idx_max,
213 | trajectory_length,
214 | ):
215 | termination_state = TerminationState(
216 | momentum_checkpoints=momentum_ckpts,
217 | momentum_sum_checkpoints=momentum_sum_ckpts,
218 | min_index=idx_min,
219 | max_index=idx_max,
220 | )
221 | proposal = ProposalState(
222 | state=IntegratorState(
223 | position=q_proposal,
224 | momentum=p_proposal,
225 | potential_energy=potential_energy_proposal,
226 | potential_energy_grad=potential_energy_grad_proposal,
227 | ),
228 | energy=energy_proposal,
229 | weight=weight,
230 | sum_log_p_accept=sum_p_accept,
231 | )
232 | last_state = IntegratorState(
233 | position=q_last,
234 | momentum=p_last,
235 | potential_energy=potential_energy_last,
236 | potential_energy_grad=potential_energy_grad_last,
237 | )
238 |
239 | new_state = integrator(last_state, direction * step_size)
240 | new_proposal, is_diverging = generate_proposal(initial_energy, new_state)
241 | sampled_proposal = sample_proposal(srng, proposal, new_proposal)
242 |
243 | new_momentum_sum = momentum_sum + new_state.momentum
244 | new_termination_state = update_termination_state(
245 | termination_state, new_momentum_sum, new_state.momentum, step
246 | )
247 | has_terminated = is_criterion_met(
248 | new_termination_state, new_momentum_sum, new_state.momentum
249 | )
250 |
251 | do_stop_integrating = is_diverging | has_terminated
252 |
253 | return (
254 | sampled_proposal.state.position,
255 | sampled_proposal.state.momentum,
256 | sampled_proposal.state.potential_energy,
257 | sampled_proposal.state.potential_energy_grad,
258 | sampled_proposal.energy,
259 | sampled_proposal.weight,
260 | sampled_proposal.sum_log_p_accept,
261 | new_state.position,
262 | new_state.momentum,
263 | new_state.potential_energy,
264 | new_state.potential_energy_grad,
265 | new_momentum_sum,
266 | new_termination_state.momentum_checkpoints,
267 | new_termination_state.momentum_sum_checkpoints,
268 | new_termination_state.min_index,
269 | new_termination_state.max_index,
270 | trajectory_length + 1,
271 | is_diverging,
272 | has_terminated,
273 | ), until(do_stop_integrating)
274 |
275 | # We take one step away to start building the subtrajectory
276 | state = integrator(previous_last_state, direction * step_size)
277 | proposal, is_diverging = generate_proposal(initial_energy, state)
278 | momentum_sum = state.momentum
279 | termination_state = update_termination_state(
280 | termination_state,
281 | momentum_sum,
282 | state.momentum,
283 | 0,
284 | )
285 | full_initial_state = (
286 | proposal.state.position,
287 | proposal.state.momentum,
288 | proposal.state.potential_energy,
289 | proposal.state.potential_energy_grad,
290 | proposal.energy,
291 | proposal.weight,
292 | proposal.sum_log_p_accept,
293 | state.position,
294 | state.momentum,
295 | state.potential_energy,
296 | state.potential_energy_grad,
297 | momentum_sum,
298 | termination_state.momentum_checkpoints,
299 | termination_state.momentum_sum_checkpoints,
300 | termination_state.min_index,
301 | termination_state.max_index,
302 | at.as_tensor(1, dtype=np.int64),
303 | is_diverging,
304 | np.array(False),
305 | )
306 |
307 | steps = at.arange(1, 1 + max_num_steps)
308 | trajectory, updates = aesara.scan(
309 | add_one_state,
310 | outputs_info=(
311 | proposal.state.position,
312 | proposal.state.momentum,
313 | proposal.state.potential_energy,
314 | proposal.state.potential_energy_grad,
315 | proposal.energy,
316 | proposal.weight,
317 | proposal.sum_log_p_accept,
318 | state.position,
319 | state.momentum,
320 | state.potential_energy,
321 | state.potential_energy_grad,
322 | momentum_sum,
323 | termination_state.momentum_checkpoints,
324 | termination_state.momentum_sum_checkpoints,
325 | termination_state.min_index,
326 | termination_state.max_index,
327 | at.as_tensor(1, dtype=np.int64),
328 | None,
329 | None,
330 | ),
331 | sequences=steps,
332 | )
333 | full_last_state = tuple([_state[-1] for _state in trajectory])
334 |
335 | # We build the trajectory iff the first step is not diverging
336 | full_state = ifelse(is_diverging, full_initial_state, full_last_state)
337 |
338 | new_proposal = ProposalState(
339 | state=IntegratorState(
340 | position=full_state[0],
341 | momentum=full_state[1],
342 | potential_energy=full_state[2],
343 | potential_energy_grad=full_state[3],
344 | ),
345 | energy=full_state[4],
346 | weight=full_state[5],
347 | sum_log_p_accept=full_state[6],
348 | )
349 | new_state = IntegratorState(
350 | position=full_state[7],
351 | momentum=full_state[8],
352 | potential_energy=full_state[9],
353 | potential_energy_grad=full_state[10],
354 | )
355 | subtree_momentum_sum = full_state[11]
356 | new_termination_state = TerminationState(
357 | momentum_checkpoints=full_state[12],
358 | momentum_sum_checkpoints=full_state[13],
359 | min_index=full_state[14],
360 | max_index=full_state[15],
361 | )
362 | trajectory_length = full_state[-3]
363 | is_diverging = full_state[-2]
364 | has_terminated = full_state[-1]
365 |
366 | return (
367 | new_proposal,
368 | new_state,
369 | subtree_momentum_sum,
370 | new_termination_state,
371 | trajectory_length,
372 | is_diverging,
373 | has_terminated,
374 | ), updates
375 |
376 | return integrate
377 |
378 |
379 | class Diagnostics(NamedTuple):
380 | state: IntegratorState
381 | acceptance_probability: TensorVariable
382 | num_doublings: TensorVariable
383 | is_turning: TensorVariable
384 | is_diverging: TensorVariable
385 |
386 |
387 | class MultiplicativeExpansionResult(NamedTuple):
388 | proposals: ProposalState
389 | right_states: IntegratorState
390 | left_states: IntegratorState
391 | momentum_sums: TensorVariable
392 | termination_states: TerminationState
393 | diagnostics: Diagnostics
394 |
395 |
396 | def multiplicative_expansion(
397 | srng: RandomStream,
398 | trajectory_integrator: Callable,
399 | uturn_check_fn: Callable,
400 | max_num_expansions: TensorVariable,
401 | ):
402 | """Sample a trajectory and update the proposal sequentially
403 | until the termination criterion is met.
404 |
405 | The trajectory is sampled with the following procedure:
406 | 1. Pick a direction at random;
407 | 2. Integrate `num_step` steps in this direction;
408 | 3. If the integration has stopped prematurely, do not update the proposal;
409 | 4. Else if the trajectory is performing a U-turn, return current proposal;
410 | 5. Else update proposal, `num_steps = num_steps ** rate` and repeat from (1).
411 |
412 | Parameters
413 | ----------
414 | srng
415 | A RandomStream object that tracks the changes in a shared random state.
416 | trajectory_integrator
417 | A function that runs the symplectic integrators and returns a new proposal
418 | and the integrated trajectory.
419 | uturn_check_fn
420 | Function used to check the U-Turn criterion.
421 | max_num_expansions
422 | The maximum number of trajectory expansions until the proposal is
423 | returned.
424 |
425 | """
426 | proposal_sampler = progressive_biased_sampling
427 |
428 | def expand(
429 | proposal: ProposalState,
430 | left_state: IntegratorState,
431 | right_state: IntegratorState,
432 | momentum_sum,
433 | termination_state: TerminationState,
434 | initial_energy,
435 | step_size,
436 | ) -> Tuple[MultiplicativeExpansionResult, Dict]:
437 | """Expand the current trajectory multiplicatively.
438 |
439 | At each step we draw a direction at random, build a subtrajectory starting
440 | from the leftmost or rightmost point of the current trajectory that is
441 | twice as long as the current trajectory.
442 |
443 | Once that is done, possibly update the current proposal with that of
444 | the subtrajectory.
445 |
446 | Parameters
447 | ----------
448 | proposal
449 | Current new state proposal.
450 | left_state
451 | The current leftmost state of the trajectory.
452 | right_state
453 | The current rightmost state of the trajectory.
454 | momentum_sum
455 | The current value of the sum of momenta along the trajectory.
456 | initial_energy
457 | Potential energy before starting to build the trajectory.
458 | step_size
459 | The size of each step taken by the integrator.
460 |
461 | """
462 |
463 | def expand_once(
464 | step,
465 | q_proposal, # proposal
466 | p_proposal,
467 | potential_energy_proposal,
468 | potential_energy_grad_proposal,
469 | energy_proposal,
470 | weight,
471 | sum_p_accept,
472 | q_left, # trajectory
473 | p_left,
474 | potential_energy_left,
475 | potential_energy_grad_left,
476 | q_right,
477 | p_right,
478 | potential_energy_right,
479 | potential_energy_grad_right,
480 | momentum_sum, # sum of momenta along trajectory
481 | momentum_ckpts, # termination_state
482 | momentum_sum_ckpts,
483 | idx_min,
484 | idx_max,
485 | ) -> Tuple[Tuple[TensorVariable, ...], Dict, until]:
486 | left_state = (
487 | q_left,
488 | p_left,
489 | potential_energy_left,
490 | potential_energy_grad_left,
491 | )
492 | right_state = (
493 | q_right,
494 | p_right,
495 | potential_energy_right,
496 | potential_energy_grad_right,
497 | )
498 | proposal = ProposalState(
499 | state=IntegratorState(
500 | position=q_proposal,
501 | momentum=p_proposal,
502 | potential_energy=potential_energy_proposal,
503 | potential_energy_grad=potential_energy_grad_proposal,
504 | ),
505 | energy=energy_proposal,
506 | weight=weight,
507 | sum_log_p_accept=sum_p_accept,
508 | )
509 | termination_state = TerminationState(
510 | momentum_checkpoints=momentum_ckpts,
511 | momentum_sum_checkpoints=momentum_sum_ckpts,
512 | min_index=idx_min,
513 | max_index=idx_max,
514 | )
515 |
516 | do_go_right = srng.bernoulli(0.5)
517 | direction = at.where(do_go_right, 1.0, -1.0)
518 | start_state = IntegratorState(*ifelse(do_go_right, right_state, left_state))
519 |
520 | (
521 | new_proposal,
522 | new_state,
523 | subtree_momentum_sum,
524 | new_termination_state,
525 | subtrajectory_length,
526 | is_diverging,
527 | has_subtree_terminated,
528 | ), inner_updates = trajectory_integrator(
529 | start_state,
530 | direction,
531 | termination_state,
532 | 2**step,
533 | step_size,
534 | initial_energy,
535 | )
536 |
537 | # Update the trajectory.
538 | # The trajectory integrator always integrates forward in time; we
539 | # thus need to switch the states if the other direction was picked.
540 | new_left_state = IntegratorState(
541 | *ifelse(do_go_right, left_state, new_state)
542 | )
543 | new_right_state = IntegratorState(
544 | *ifelse(do_go_right, new_state, right_state)
545 | )
546 | new_momentum_sum = momentum_sum + subtree_momentum_sum
547 |
548 | # Compute the pseudo-acceptance probability for the NUTS algorithm.
549 | # It can be understood as the average acceptance probability MC would give to
550 | # the states explored during the final expansion.
551 | acceptance_probability = (
552 | at.exp(new_proposal.sum_log_p_accept) / subtrajectory_length
553 | )
554 |
555 | # Update the proposal
556 | #
557 | # We do not accept proposals that come from diverging or turning subtrajectories.
558 | # However the definition of the acceptance probability is such that the
559 | # acceptance probability needs to be computed across the entire trajectory.
560 | updated_proposal = proposal._replace(
561 | sum_log_p_accept=at.logaddexp(
562 | new_proposal.sum_log_p_accept, proposal.sum_log_p_accept
563 | )
564 | )
565 |
566 | sampled_proposal = where_proposal(
567 | is_diverging | has_subtree_terminated,
568 | updated_proposal,
569 | proposal_sampler(srng, proposal, new_proposal),
570 | )
571 |
572 | # Check if the trajectory is turning and determine whether we need
573 | # to stop expanding the trajectory.
574 | is_turning = uturn_check_fn(
575 | new_left_state.momentum, new_right_state.momentum, new_momentum_sum
576 | )
577 | do_stop_expanding = is_diverging | is_turning | has_subtree_terminated
578 |
579 | return (
580 | (
581 | sampled_proposal.state.position,
582 | sampled_proposal.state.momentum,
583 | sampled_proposal.state.potential_energy,
584 | sampled_proposal.state.potential_energy_grad,
585 | sampled_proposal.energy,
586 | sampled_proposal.weight,
587 | sampled_proposal.sum_log_p_accept,
588 | new_left_state.position,
589 | new_left_state.momentum,
590 | new_left_state.potential_energy,
591 | new_left_state.potential_energy_grad,
592 | new_right_state.position,
593 | new_right_state.momentum,
594 | new_right_state.potential_energy,
595 | new_right_state.potential_energy_grad,
596 | new_momentum_sum,
597 | new_termination_state.momentum_checkpoints,
598 | new_termination_state.momentum_sum_checkpoints,
599 | new_termination_state.min_index,
600 | new_termination_state.max_index,
601 | acceptance_probability,
602 | step + 1,
603 | is_diverging,
604 | is_turning,
605 | ),
606 | inner_updates,
607 | until(do_stop_expanding),
608 | )
609 |
610 | expansion_steps = at.arange(0, max_num_expansions)
611 | # results, updates = aesara.scan(
612 | (
613 | proposal_state_position,
614 | proposal_state_momentum,
615 | proposal_state_potential_energy,
616 | proposal_state_potential_energy_grad,
617 | proposal_energy,
618 | proposal_weight,
619 | proposal_sum_log_p_accept,
620 | left_state_position,
621 | left_state_momentum,
622 | left_state_potential_energy,
623 | left_state_potential_energy_grad,
624 | right_state_position,
625 | right_state_momentum,
626 | right_state_potential_energy,
627 | right_state_potential_energy_grad,
628 | momentum_sum,
629 | momentum_checkpoints,
630 | momentum_sum_checkpoints,
631 | min_indices,
632 | max_indices,
633 | acceptance_probability,
634 | num_doublings,
635 | is_diverging,
636 | is_turning,
637 | ), updates = aesara.scan(
638 | expand_once,
639 | outputs_info=(
640 | proposal.state.position,
641 | proposal.state.momentum,
642 | proposal.state.potential_energy,
643 | proposal.state.potential_energy_grad,
644 | proposal.energy,
645 | proposal.weight,
646 | proposal.sum_log_p_accept,
647 | left_state.position,
648 | left_state.momentum,
649 | left_state.potential_energy,
650 | left_state.potential_energy_grad,
651 | right_state.position,
652 | right_state.momentum,
653 | right_state.potential_energy,
654 | right_state.potential_energy_grad,
655 | momentum_sum,
656 | termination_state.momentum_checkpoints,
657 | termination_state.momentum_sum_checkpoints,
658 | termination_state.min_index,
659 | termination_state.max_index,
660 | None,
661 | None,
662 | None,
663 | None,
664 | ),
665 | sequences=expansion_steps,
666 | )
667 | # Ensure each item of the returned result sequence is packed into the appropriate namedtuples.
668 | typed_result = MultiplicativeExpansionResult(
669 | proposals=ProposalState(
670 | state=IntegratorState(
671 | position=proposal_state_position,
672 | momentum=proposal_state_momentum,
673 | potential_energy=proposal_state_potential_energy,
674 | potential_energy_grad=proposal_state_potential_energy_grad,
675 | ),
676 | energy=proposal_energy,
677 | weight=proposal_weight,
678 | sum_log_p_accept=proposal_sum_log_p_accept,
679 | ),
680 | left_states=IntegratorState(
681 | position=left_state_position,
682 | momentum=left_state_momentum,
683 | potential_energy=left_state_potential_energy,
684 | potential_energy_grad=left_state_potential_energy_grad,
685 | ),
686 | right_states=IntegratorState(
687 | position=right_state_position,
688 | momentum=right_state_momentum,
689 | potential_energy=right_state_potential_energy,
690 | potential_energy_grad=right_state_potential_energy_grad,
691 | ),
692 | momentum_sums=momentum_sum,
693 | termination_states=TerminationState(
694 | momentum_checkpoints=momentum_checkpoints,
695 | momentum_sum_checkpoints=momentum_sum_checkpoints,
696 | min_index=min_indices,
697 | max_index=max_indices,
698 | ),
699 | diagnostics=Diagnostics(
700 | state=IntegratorState(
701 | position=proposal_state_position,
702 | momentum=proposal_state_momentum,
703 | potential_energy=proposal_state_potential_energy,
704 | potential_energy_grad=proposal_state_potential_energy_grad,
705 | ),
706 | acceptance_probability=acceptance_probability,
707 | num_doublings=num_doublings,
708 | is_turning=is_turning,
709 | is_diverging=is_diverging,
710 | ),
711 | )
712 | return typed_result, updates
713 |
714 | return expand
715 |
716 |
717 | def where_proposal(
718 | do_pick_left: bool,
719 | left_proposal: ProposalState,
720 | right_proposal: ProposalState,
721 | ) -> ProposalState:
722 | """Represents a switch between two proposals depending on a condition."""
723 | state = ifelse(do_pick_left, left_proposal.state, right_proposal.state)
724 | energy = at.where(do_pick_left, left_proposal.energy, right_proposal.energy)
725 | weight = at.where(do_pick_left, left_proposal.weight, right_proposal.weight)
726 | log_sum_p_accept = ifelse(
727 | do_pick_left, left_proposal.sum_log_p_accept, right_proposal.sum_log_p_accept
728 | )
729 |
730 | return ProposalState(
731 | state=IntegratorState(*state),
732 | energy=energy,
733 | weight=weight,
734 | sum_log_p_accept=log_sum_p_accept,
735 | )
736 |
--------------------------------------------------------------------------------
/aehmc/utils.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from typing import Dict, Iterable, List
3 |
4 | import aesara.tensor as at
5 | from aesara.graph.basic import Variable, ancestors
6 | from aesara.graph.fg import FunctionGraph
7 | from aesara.graph.rewriting.utils import rewrite_graph
8 | from aesara.tensor.rewriting.shape import ShapeFeature
9 | from aesara.tensor.var import TensorVariable
10 |
11 |
12 | def simplify_shapes(graphs: List[Variable]):
13 | """Simply the shape calculations in a list of graphs."""
14 | shape_fg = FunctionGraph(
15 | outputs=graphs,
16 | features=[ShapeFeature()],
17 | clone=False,
18 | )
19 | return rewrite_graph(shape_fg).outputs
20 |
21 |
22 | class RaveledParamsMap:
23 | """Maps a set of tensor variables to a vector of their raveled values."""
24 |
25 | def __init__(self, ref_params: Iterable[TensorVariable]):
26 | self.ref_params = tuple(ref_params)
27 |
28 | self.ref_shapes = [at.shape(p) for p in self.ref_params]
29 | self.ref_shapes = simplify_shapes(self.ref_shapes)
30 |
31 | self.ref_dtypes = [p.dtype for p in self.ref_params]
32 |
33 | ref_shapes_ancestors = set(ancestors(self.ref_shapes))
34 | uninferred_shape_params = [
35 | p for p in self.ref_params if (p in ref_shapes_ancestors)
36 | ]
37 | if any(uninferred_shape_params):
38 | # After running the shape optimizations, the graphs in
39 | # `ref_shapes` should not depend on `ref_params` directly.
40 | # If they do, it could imply that we need to sample parts of a
41 | # model in order to get the shapes/sizes of its parameters, and
42 | # that's a worst-case scenario.
43 | warnings.warn(
44 | "The following parameters need to be computed in order to determine "
45 | f"the shapes in this parameter map: {uninferred_shape_params}"
46 | )
47 |
48 | param_sizes = [at.prod(s) for s in self.ref_shapes]
49 | cumsum_sizes = at.cumsum(param_sizes)
50 | # `at.cumsum` doesn't return a tensor of a fixed/known size
51 | cumsum_sizes = [cumsum_sizes[i] for i in range(len(param_sizes))]
52 | self.slice_indices = list(zip([0] + cumsum_sizes[:-1], cumsum_sizes))
53 | self.vec_slices = [slice(*idx) for idx in self.slice_indices]
54 |
55 | def ravel_params(self, params: List[TensorVariable]) -> TensorVariable:
56 | """Concatenate the raveled vectors of each parameter."""
57 | return at.concatenate([at.atleast_1d(p).ravel() for p in params])
58 |
59 | def unravel_params(
60 | self, raveled_params: TensorVariable
61 | ) -> Dict[TensorVariable, TensorVariable]:
62 | """Unravel a concatenated set of raveled parameters."""
63 | return {
64 | k: v.reshape(s).astype(t)
65 | for k, v, s, t in zip(
66 | self.ref_params,
67 | [raveled_params[slc] for slc in self.vec_slices],
68 | self.ref_shapes,
69 | self.ref_dtypes,
70 | )
71 | }
72 |
73 | def __repr__(self):
74 | return f"{type(self).__name__}({self.ref_params})"
75 |
--------------------------------------------------------------------------------
/aehmc/window_adaptation.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple
2 |
3 | import aesara
4 | import aesara.tensor as at
5 | from aesara import config
6 | from aesara.ifelse import ifelse
7 | from aesara.tensor.shape import shape_tuple
8 | from aesara.tensor.var import TensorVariable
9 |
10 | from aehmc.algorithms import DualAveragingState
11 | from aehmc.integrators import IntegratorState
12 | from aehmc.mass_matrix import covariance_adaptation
13 | from aehmc.nuts import Diagnostics
14 | from aehmc.step_size import dual_averaging_adaptation
15 |
16 |
17 | def run(
18 | kernel,
19 | initial_state: IntegratorState,
20 | num_steps=1000,
21 | *,
22 | is_mass_matrix_full=False,
23 | initial_step_size=at.as_tensor(1.0, dtype=config.floatX),
24 | target_acceptance_rate=0.80,
25 | ) -> Tuple[IntegratorState, Tuple[TensorVariable, TensorVariable], Dict]:
26 | init_adapt, update_adapt = window_adaptation(
27 | num_steps, is_mass_matrix_full, initial_step_size, target_acceptance_rate
28 | )
29 |
30 | def one_step(
31 | warmup_step,
32 | q, # chain state
33 | potential_energy,
34 | potential_energy_grad,
35 | step, # dual averaging adaptation state
36 | log_step_size,
37 | log_step_size_avg,
38 | gradient_avg,
39 | mu,
40 | mean, # mass matrix adaptation state
41 | m2,
42 | sample_size,
43 | step_size, # parameters
44 | inverse_mass_matrix,
45 | ):
46 | chain_state = IntegratorState(
47 | position=q,
48 | momentum=None,
49 | potential_energy=potential_energy,
50 | potential_energy_grad=potential_energy_grad,
51 | )
52 |
53 | warmup_state = (
54 | DualAveragingState(
55 | step=step,
56 | iterates=log_step_size,
57 | iterates_avg=log_step_size_avg,
58 | gradient_avg=gradient_avg,
59 | shrinkage_pts=mu,
60 | ),
61 | (mean, m2, sample_size),
62 | )
63 | parameters = (step_size, inverse_mass_matrix)
64 |
65 | # Advance the chain by one step
66 | chain_info, inner_updates = kernel(chain_state, *parameters)
67 |
68 | # Update the warmup state and parameters
69 | warmup_state, parameters = update_adapt(
70 | warmup_step, warmup_state, parameters, chain_info
71 | )
72 | da_state = warmup_state[0]
73 | return (
74 | chain_info.state.position, # q
75 | chain_info.state.potential_energy, # potential_energy
76 | chain_info.state.potential_energy_grad, # potential_energy_grad
77 | da_state.step,
78 | da_state.iterates, # log_step_size
79 | da_state.iterates_avg, # log_step_size_avg
80 | da_state.gradient_avg,
81 | da_state.shrinkage_pts, # mu
82 | *warmup_state[1],
83 | *parameters,
84 | ), inner_updates
85 |
86 | (da_state, mm_state), parameters = init_adapt(initial_state)
87 |
88 | warmup_steps = at.arange(0, num_steps)
89 | state, updates = aesara.scan(
90 | fn=one_step,
91 | outputs_info=(
92 | initial_state.position,
93 | initial_state.potential_energy,
94 | initial_state.potential_energy_grad,
95 | da_state.step,
96 | da_state.iterates, # log_step_size
97 | da_state.iterates_avg, # log_step_size_avg
98 | da_state.gradient_avg,
99 | da_state.shrinkage_pts, # mu
100 | *mm_state,
101 | *parameters,
102 | ),
103 | sequences=(warmup_steps,),
104 | name="window_adaptation",
105 | )
106 |
107 | last_chain_state = IntegratorState(
108 | position=state[0][-1],
109 | momentum=None,
110 | potential_energy=state[1][-1],
111 | potential_energy_grad=state[2][-1],
112 | )
113 | step_size = state[-2][-1]
114 | inverse_mass_matrix = state[-1][-1]
115 |
116 | return last_chain_state, (step_size, inverse_mass_matrix), updates
117 |
118 |
119 | def window_adaptation(
120 | num_steps: int,
121 | is_mass_matrix_full: bool = False,
122 | initial_step_size: TensorVariable = at.as_tensor(1.0, dtype=config.floatX),
123 | target_acceptance_rate: TensorVariable = 0.80,
124 | ):
125 | mm_init, mm_update, mm_final = covariance_adaptation(is_mass_matrix_full)
126 | da_init, da_update = dual_averaging_adaptation(target_acceptance_rate)
127 | schedule = build_schedule(num_steps)
128 |
129 | schedule_stage = at.as_tensor([s[0] for s in schedule])
130 | schedule_middle_window = at.as_tensor([s[1] for s in schedule])
131 |
132 | def init(initial_chain_state: IntegratorState):
133 | if initial_chain_state.position.ndim == 0:
134 | num_dims = 0
135 | else:
136 | num_dims = shape_tuple(initial_chain_state.position)[0]
137 | inverse_mass_matrix, mm_state = mm_init(num_dims)
138 |
139 | da_state = da_init(initial_step_size)
140 | step_size = at.exp(da_state.iterates)
141 |
142 | warmup_state = (da_state, mm_state)
143 | parameters = (step_size, inverse_mass_matrix)
144 | return warmup_state, parameters
145 |
146 | def fast_update(p_accept, warmup_state, parameters):
147 | da_state, mm_state = warmup_state
148 | _, inverse_mass_matrix = parameters
149 |
150 | new_da_state = da_update(p_accept, da_state)
151 | step_size = at.exp(new_da_state.iterates)
152 |
153 | return (new_da_state, mm_state), (step_size, inverse_mass_matrix)
154 |
155 | def slow_update(position, p_accept, warmup_state, parameters):
156 | da_state, mm_state = warmup_state
157 | _, inverse_mass_matrix = parameters
158 |
159 | new_da_state = da_update(p_accept, da_state)
160 | new_mm_state = mm_update(position, mm_state)
161 | step_size = at.exp(new_da_state.iterates)
162 |
163 | return (new_da_state, new_mm_state), (step_size, inverse_mass_matrix)
164 |
165 | def slow_final(warmup_state):
166 | """We recompute the inverse mass matrix and re-initialize the dual averaging scheme at the end of each 'slow window'."""
167 | da_state, mm_state = warmup_state
168 |
169 | inverse_mass_matrix = mm_final(mm_state)
170 |
171 | if inverse_mass_matrix.ndim == 0:
172 | num_dims = 0
173 | else:
174 | num_dims = shape_tuple(inverse_mass_matrix)[0]
175 | _, new_mm_state = mm_init(num_dims)
176 |
177 | step_size = at.exp(da_state.iterates)
178 | new_da_state = da_init(step_size)
179 |
180 | warmup_state = (new_da_state, new_mm_state)
181 | parameters = (step_size, inverse_mass_matrix)
182 | return warmup_state, parameters
183 |
184 | def final(
185 | warmup_state: Tuple, parameters: Tuple
186 | ) -> Tuple[TensorVariable, TensorVariable]:
187 | da_state, _ = warmup_state
188 | _, inverse_mass_matrix = parameters
189 | step_size = at.exp(da_state.iterates_avg) # return stepsize_avg at the end
190 | return step_size, inverse_mass_matrix
191 |
192 | def update(
193 | step: int, warmup_state: Tuple, parameters: Tuple, chain_state: Diagnostics
194 | ):
195 | stage = schedule_stage[step]
196 | warmup_state, parameters = where_warmup_state(
197 | at.eq(stage, 0),
198 | fast_update(chain_state.acceptance_probability, warmup_state, parameters),
199 | slow_update(
200 | chain_state.state.position,
201 | chain_state.acceptance_probability,
202 | warmup_state,
203 | parameters,
204 | ),
205 | )
206 |
207 | is_middle_window_end = schedule_middle_window[step]
208 | warmup_state, parameters = where_warmup_state(
209 | is_middle_window_end, slow_final(warmup_state), (warmup_state, parameters)
210 | )
211 |
212 | is_last_step = at.eq(step, num_steps - 1)
213 | parameters = ifelse(is_last_step, final(warmup_state, parameters), parameters)
214 |
215 | return warmup_state, parameters
216 |
217 | def where_warmup_state(do_pick_left, left_warmup_state, right_warmup_state):
218 | (left_da_state, left_mm_state), left_params = left_warmup_state
219 | (right_da_state, right_mm_state), right_params = right_warmup_state
220 |
221 | da_state = ifelse(do_pick_left, left_da_state, right_da_state)
222 | mm_state = ifelse(do_pick_left, left_mm_state, right_mm_state)
223 | params = ifelse(do_pick_left, left_params, right_params)
224 |
225 | return (DualAveragingState(*da_state), mm_state), params
226 |
227 | return init, update
228 |
229 |
230 | def build_schedule(
231 | num_steps: int,
232 | initial_buffer_size: int = 75,
233 | final_buffer_size: int = 50,
234 | first_window_size: int = 25,
235 | ) -> List[Tuple[int, bool]]:
236 | """Return the schedule for Stan's warmup.
237 |
238 | The schedule below is intended to be as close as possible to Stan's _[1].
239 | The warmup period is split into three stages:
240 | 1. An initial fast interval to reach the typical set. Only the step size is
241 | adapted in this window.
242 | 2. "Slow" parameters that require global information (typically covariance)
243 | are estimated in a series of expanding intervals with no memory; the step
244 | size is re-initialized at the end of each window. Each window is twice the
245 | size of the preceding window.
246 | 3. A final fast interval during which the step size is adapted using the
247 | computed mass matrix.
248 | Schematically:
249 |
250 | ```
251 | +---------+---+------+------------+------------------------+------+
252 | | fast | s | slow | slow | slow | fast |
253 | +---------+---+------+------------+------------------------+------+
254 | ```
255 |
256 | The distinction slow/fast comes from the speed at which the algorithms
257 | converge to a stable value; in the common case, estimation of covariance
258 | requires more steps than dual averaging to give an accurate value. See _[1]
259 | for a more detailed explanation.
260 |
261 | Fast intervals are given the label 0 and slow intervals the label 1.
262 |
263 | Note
264 | ----
265 | It feels awkward to return a boolean that indicates whether the current
266 | step is the last step of a middle window, but not for other windows. This
267 | should probably be changed to "is_window_end" and we should manage the
268 | distinction upstream.
269 |
270 | Parameters
271 | ----------
272 | num_steps: int
273 | The number of warmup steps to perform.
274 | initial_buffer: int
275 | The width of the initial fast adaptation interval.
276 | first_window_size: int
277 | The width of the first slow adaptation interval.
278 | final_buffer_size: int
279 | The width of the final fast adaptation interval.
280 |
281 | Returns
282 | -------
283 | A list of tuples (window_label, is_middle_window_end).
284 |
285 | References
286 | ----------
287 | .. [1]: Stan Reference Manual v2.22 Section 15.2 "HMC Algorithm"
288 |
289 | """
290 | schedule = []
291 |
292 | # Give up on mass matrix adaptation when the number of warmup steps is too small.
293 | if num_steps < 20:
294 | schedule += [(0, False)] * num_steps
295 | else:
296 | # When the number of warmup steps is smaller that the sum of the provided (or default)
297 | # window sizes we need to resize the different windows.
298 | if initial_buffer_size + first_window_size + final_buffer_size > num_steps:
299 | initial_buffer_size = int(0.15 * num_steps)
300 | final_buffer_size = int(0.1 * num_steps)
301 | first_window_size = num_steps - initial_buffer_size - final_buffer_size
302 |
303 | # First stage: adaptation of fast parameters
304 | schedule += [(0, False)] * (initial_buffer_size - 1)
305 | schedule.append((0, False))
306 |
307 | # Second stage: adaptation of slow parameters in successive windows
308 | # doubling in size.
309 | final_buffer_start = num_steps - final_buffer_size
310 |
311 | next_window_size = first_window_size
312 | next_window_start = initial_buffer_size
313 | while next_window_start < final_buffer_start:
314 | current_start, current_size = next_window_start, next_window_size
315 | if 3 * current_size <= final_buffer_start - current_start:
316 | next_window_size = 2 * current_size
317 | else:
318 | current_size = final_buffer_start - current_start
319 | next_window_start = current_start + current_size
320 | schedule += [(1, False)] * (next_window_start - 1 - current_start)
321 | schedule.append((1, True))
322 |
323 | # Last stage: adaptation of fast parameters
324 | schedule += [(0, False)] * (num_steps - 1 - final_buffer_start)
325 | schedule.append((0, False))
326 |
327 | return schedule
328 |
--------------------------------------------------------------------------------
/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def pytest_sessionstart(session):
5 | os.environ["AESARA_FLAGS"] = ",".join(
6 | [
7 | os.environ.setdefault("AESARA_FLAGS", ""),
8 | "floatX=float64,on_opt_error=raise,on_shape_error=raise,cast_policy=numpy+floatX",
9 | ]
10 | )
11 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | # To use:
2 | #
3 | # $ conda env create -f environment.yml # `mamba` works too for this command
4 | # $ conda activate aehmc-dev
5 | #
6 | name: aehmc-dev
7 | channels:
8 | - conda-forge
9 | dependencies:
10 | - python>=3.8
11 | - compilers
12 | - numpy>=1.18.1
13 | - scipy>=1.4.0
14 | - aesara>=2.8.11
15 | - aeppl>=0.1.4
16 | # Intel BLAS
17 | - mkl
18 | - mkl-service
19 | - libblas=*=*mkl
20 | # For testing
21 | - pytest
22 | - coverage>=5.1
23 | - coveralls
24 | - pytest-cov
25 | - pytest-xdist
26 | # For building docs
27 | - sphinx>=1.3
28 | - sphinx_rtd_theme
29 | - pygments
30 | - pydot
31 | - ipython
32 | # developer tools
33 | - pre-commit
34 | - packaging
35 | - typing_extensions
36 |
--------------------------------------------------------------------------------
/examples/requirements.txt:
--------------------------------------------------------------------------------
1 | aeppl
2 | matplotlib
3 | pymc3
4 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["wheel", "setuptools>=61.2", "setuptools-scm[toml]>=6.2"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "aehmc"
7 | authors = [
8 | {name = "Aesara developers", email = "aesara.devs@gmail.com"}
9 | ]
10 | description="HMC samplers in Aesara"
11 | readme = "README.md"
12 | license = {text = "MIT License"}
13 | dynamic = ["version"]
14 | requires-python = ">=3.8"
15 | dependencies = [
16 | "numpy >= 1.18.1",
17 | "scipy >= 1.4.0",
18 | "aesara >= 2.8.11",
19 | "aeppl >= 0.1.4",
20 | ]
21 | classifiers = [
22 | "Development Status :: 4 - Beta",
23 | "Intended Audience :: Science/Research",
24 | "Intended Audience :: Developers",
25 | "License :: OSI Approved :: MIT License",
26 | "Operating System :: OS Independent",
27 | "Programming Language :: Python",
28 | "Programming Language :: Python :: 3",
29 | "Programming Language :: Python :: 3.8",
30 | "Programming Language :: Python :: 3.9",
31 | "Programming Language :: Python :: 3.10",
32 | "Programming Language :: Python :: 3.11",
33 | "Programming Language :: Python :: Implementation :: CPython",
34 | "Programming Language :: Python :: Implementation :: PyPy",
35 | ]
36 | keywords = [
37 | "aesara",
38 | "math",
39 | "symbolic",
40 | "hamiltonian monte carlo",
41 | "nuts sampler",
42 | "No U-turn sampler",
43 | "symplectic integration",
44 | ]
45 |
46 | [project.urls]
47 | source = "http://github.com/aesara-devs/aehmc"
48 | tracker = "http://github.com/aesara-devs/aehmc/issues"
49 |
50 | [tool.setuptools]
51 | packages = ["aehmc"]
52 | include-package-data = false
53 |
54 | [tool.setuptools_scm]
55 | write_to = "aehmc/_version.py"
56 |
57 | [tool.pydocstyle]
58 | # Ignore errors for missing docstrings.
59 | # Ignore D202 (No blank lines allowed after function docstring)
60 | # due to bug in black: https://github.com/ambv/black/issues/355
61 | add-ignore = "D100,D101,D102,D103,D104,D105,D106,D107,D202"
62 | convention = "numpy"
63 |
64 | [tool.pytest.ini_options]
65 | python_files = ["test*.py"]
66 | testpaths = ["tests"]
67 | filterwarnings = [
68 | "error:::aesara",
69 | "error:::aeppl",
70 | "error:::aemcmc",
71 | "ignore:::xarray",
72 | ]
73 |
74 | [tool.coverage.run]
75 | omit = [
76 | "aehmc/_version.py",
77 | "tests/*",
78 | ]
79 | branch = true
80 |
81 | [tool.coverage.report]
82 | exclude_lines = [
83 | "pragma: no cover",
84 | "def __repr__",
85 | "raise AssertionError",
86 | "raise TypeError",
87 | "return NotImplemented",
88 | "raise NotImplementedError",
89 | "if __name__ == .__main__.:",
90 | "assert False",
91 | ]
92 | show_missing = true
93 |
94 | [tool.isort]
95 | profile = "black"
96 |
97 | [tool.pylint]
98 | max-line-length = "88"
99 |
100 | [tool."pylint.messages_control"]
101 | disable = "C0330, C0326"
102 |
103 | [tool.mypy]
104 | ignore_missing_imports = true
105 | no_implicit_optional = true
106 | check_untyped_defs = true
107 | strict_equality = true
108 | warn_redundant_casts = true
109 | warn_unused_ignores = true
110 | show_error_codes = true
111 |
112 | [[tool.mypy.overrides]]
113 | module = ["tests.*"]
114 | ignore_errors = true
115 | check_untyped_defs = false
116 |
117 | [tool.black]
118 | line-length = 88
119 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | arviz
2 | -e ./
3 | coveralls
4 | pydocstyle>=3.0.0
5 | pytest>=5.0.0
6 | pytest-cov>=2.6.1
7 | pytest-html>=1.20.0
8 | pylint>=2.3.1
9 | black==20.8b1; platform.python_implementation!='PyPy'
10 | diff-cover
11 | autoflake
12 | pre-commit
13 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aesara-devs/aehmc/ece0aafa7a62773ff7403827f781fb7453de1cbb/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_adaptation.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from aehmc import window_adaptation
4 |
5 |
6 | @pytest.mark.parametrize(
7 | "num_steps, expected_schedule",
8 | [
9 | (19, [(0, False)] * 19), # no mass matrix adaptation
10 | (
11 | 100,
12 | [(0, False)] * 15 + [(1, False)] * 74 + [(1, True)] + [(0, False)] * 10,
13 | ), # windows are resized
14 | (
15 | 200,
16 | [(0, False)] * 75
17 | + [(1, False)] * 24
18 | + [(1, True)]
19 | + [(1, False)] * 49
20 | + [(1, True)]
21 | + [(0, False)] * 50,
22 | ),
23 | ],
24 | )
25 | def test_adaptation_schedule(num_steps, expected_schedule):
26 | adaptation_schedule = window_adaptation.build_schedule(num_steps)
27 | assert num_steps == len(adaptation_schedule)
28 | assert adaptation_schedule == expected_schedule
29 |
--------------------------------------------------------------------------------
/tests/test_algorithms.py:
--------------------------------------------------------------------------------
1 | import aesara
2 | import aesara.tensor as at
3 | import numpy as np
4 | import pytest
5 | from aesara import config
6 |
7 | from aehmc import algorithms
8 |
9 |
10 | def test_dual_averaging():
11 | """Find the minimum of a simple function using Dual Averaging."""
12 |
13 | def fn(x):
14 | return (x - 1) ** 2
15 |
16 | init, update = algorithms.dual_averaging(gamma=0.5)
17 |
18 | def one_step(step, x, x_avg, gradient_avg, mu):
19 | value = fn(x)
20 | gradient = aesara.grad(value, x)
21 | current_state = algorithms.DualAveragingState(
22 | step=step,
23 | iterates=x,
24 | iterates_avg=x_avg,
25 | gradient_avg=gradient_avg,
26 | shrinkage_pts=mu,
27 | )
28 | da_state = update(gradient, current_state)
29 | return (
30 | da_state.step,
31 | da_state.iterates,
32 | da_state.iterates_avg,
33 | da_state.gradient_avg,
34 | da_state.shrinkage_pts,
35 | )
36 |
37 | shrinkage_pts = at.as_tensor(at.constant(0.5), dtype=config.floatX)
38 | da_state = init(shrinkage_pts)
39 |
40 | states, updates = aesara.scan(
41 | fn=one_step,
42 | outputs_info=[
43 | {"initial": da_state.step},
44 | {"initial": da_state.iterates},
45 | {"initial": da_state.iterates_avg},
46 | {"initial": da_state.gradient_avg},
47 | {"initial": da_state.shrinkage_pts},
48 | ],
49 | n_steps=100,
50 | )
51 |
52 | last_x = states[1].eval()[-1]
53 | last_x_avg = states[2].eval()[-1]
54 | assert last_x_avg == pytest.approx(1.0, 1e-2)
55 | assert last_x == pytest.approx(1.0, 1e-2)
56 |
57 |
58 | @pytest.mark.parametrize("num_dims", [0, 1, 3])
59 | @pytest.mark.parametrize("do_compute_covariance", [True, False])
60 | def test_welford_constant(num_dims, do_compute_covariance):
61 | num_samples = 10
62 |
63 | if num_dims > 0:
64 | sample = at.ones((num_dims,)) # constant samples
65 | else:
66 | sample = at.constant(1.0)
67 |
68 | init, update, final = algorithms.welford_covariance(do_compute_covariance)
69 | state = init(num_dims)
70 | for _ in range(num_samples):
71 | state = update(sample, *state)
72 |
73 | mean = state[0].eval()
74 | if num_dims > 0:
75 | expected = np.ones(num_dims)
76 | assert np.shape(mean) == np.shape(expected)
77 | np.testing.assert_allclose(mean, expected, rtol=1e-1)
78 | else:
79 | assert np.ndim(mean) == 0
80 | assert mean == 1.0
81 |
82 | cov = final(state[1], state[2]).eval()
83 | if num_dims > 0:
84 | if do_compute_covariance:
85 | expected = np.zeros((num_dims, num_dims))
86 | else:
87 | expected = np.zeros(num_dims)
88 |
89 | assert np.shape(cov) == np.shape(expected)
90 | np.testing.assert_allclose(cov, expected)
91 | else:
92 | assert np.ndim(cov) == 0
93 | assert cov == 0
94 |
95 |
96 | @pytest.mark.parametrize("do_compute_covariance", [True, False])
97 | @pytest.mark.parametrize("n_dim", [1, 3])
98 | def test_welford(n_dim, do_compute_covariance):
99 | num_samples = 10
100 |
101 | init, update, final = algorithms.welford_covariance(do_compute_covariance)
102 | state = init(n_dim)
103 | for i in range(num_samples):
104 | sample = i * at.ones(n_dim)
105 | state = update(sample, *state)
106 |
107 | mean = state[0].eval()
108 | expected = (9.0 / 2) * np.ones(n_dim)
109 | np.testing.assert_allclose(mean, expected)
110 |
111 | cov = final(state[1], state[2]).eval()
112 | if do_compute_covariance:
113 | expected = 55.0 / 6.0 * np.ones((n_dim, n_dim))
114 | else:
115 | expected = 55.0 / 6.0 * np.ones(n_dim)
116 | assert np.shape(cov) == np.shape(expected)
117 | np.testing.assert_allclose(cov, expected)
118 |
119 |
120 | @pytest.mark.parametrize("do_compute_covariance", [True, False])
121 | def test_welford_scalar(do_compute_covariance):
122 | """ "Test the Welford algorithm when the state is a scalar."""
123 | num_samples = 10
124 |
125 | init, update, final = algorithms.welford_covariance(do_compute_covariance)
126 | state = init(0)
127 | for i in range(num_samples):
128 | sample = at.as_tensor(i)
129 | state = update(sample, *state)
130 |
131 | cov = final(state[1], state[2]).eval()
132 | assert np.ndim(cov) == 0
133 | assert pytest.approx(cov.squeeze()) == 55.0 / 6.0
134 |
--------------------------------------------------------------------------------
/tests/test_hmc.py:
--------------------------------------------------------------------------------
1 | import aesara
2 | import aesara.tensor as at
3 | import arviz
4 | import numpy as np
5 | import pytest
6 | import scipy.stats as stats
7 | from aeppl import joint_logprob
8 | from aesara.tensor.var import TensorVariable
9 |
10 | from aehmc import hmc, nuts, window_adaptation
11 |
12 |
13 | def test_warmup_scalar():
14 | """Test the warmup on a univariate normal distribution."""
15 |
16 | srng = at.random.RandomStream(seed=0)
17 | Y_rv = srng.normal(1, 2)
18 |
19 | def logprob_fn(y: TensorVariable):
20 | logprob, _ = joint_logprob(realized={Y_rv: y})
21 | return logprob
22 |
23 | y_vv = Y_rv.clone()
24 | kernel = nuts.new_kernel(srng, logprob_fn)
25 | initial_state = nuts.new_state(y_vv, logprob_fn)
26 |
27 | state, (step_size, inverse_mass_matrix), updates = window_adaptation.run(
28 | kernel, initial_state, num_steps=1000
29 | )
30 |
31 | # Compile the warmup and execute to get a value for the step size and the
32 | # mass matrix.
33 | warmup_fn = aesara.function(
34 | (y_vv,),
35 | (
36 | state.position,
37 | state.potential_energy,
38 | state.potential_energy_grad,
39 | step_size,
40 | inverse_mass_matrix,
41 | ),
42 | updates=updates,
43 | )
44 |
45 | final_position, *_, step_size, inverse_mass_matrix = warmup_fn(3.0)
46 |
47 | assert final_position != 3.0 # the chain has moved
48 | assert np.ndim(step_size) == 0 # scalar step size
49 | assert step_size != 1.0 # step size changed
50 | assert step_size > 0.1 and step_size < 2 # stable range for the step size
51 | assert np.ndim(inverse_mass_matrix) == 0 # scalar mass matrix
52 | assert inverse_mass_matrix == pytest.approx(4, rel=1.0)
53 |
54 |
55 | def test_warmup_vector():
56 | """Test the warmup on a multivariate normal distribution."""
57 |
58 | loc = np.array([0.0, 3.0])
59 | scale = np.array([1.0, 2.0])
60 | cov = np.diag(scale**2)
61 |
62 | srng = at.random.RandomStream(seed=0)
63 | Y_rv = srng.multivariate_normal(loc, cov)
64 |
65 | def logprob_fn(y: TensorVariable):
66 | logprob, _ = joint_logprob(realized={Y_rv: y})
67 | return logprob
68 |
69 | y_vv = Y_rv.clone()
70 | kernel = nuts.new_kernel(srng, logprob_fn)
71 | initial_state = nuts.new_state(y_vv, logprob_fn)
72 |
73 | state, (step_size, inverse_mass_matrix), updates = window_adaptation.run(
74 | kernel, initial_state, num_steps=1000
75 | )
76 |
77 | # Compile the warmup and execute to get a value for the step size and the
78 | # mass matrix.
79 | warmup_fn = aesara.function(
80 | (y_vv,),
81 | (
82 | state.position,
83 | state.potential_energy,
84 | state.potential_energy_grad,
85 | step_size,
86 | inverse_mass_matrix,
87 | ),
88 | updates=updates,
89 | )
90 |
91 | final_position, *_, step_size, inverse_mass_matrix = warmup_fn([1.0, 1.0])
92 |
93 | assert np.all(final_position != np.array([1.0, 1.0])) # the chain has moved
94 | assert np.ndim(step_size) == 0 # scalar step size
95 | assert step_size > 0.1 and step_size < 2 # stable range for the step size
96 | assert np.ndim(inverse_mass_matrix) == 1 # scalar mass matrix
97 | np.testing.assert_allclose(inverse_mass_matrix, scale**2, rtol=1.0)
98 |
99 |
100 | @pytest.mark.parametrize("step_size, diverges", [(3.9, False), (4.1, True)])
101 | def test_univariate_hmc(step_size, diverges):
102 | """Test the NUTS kernel on a univariate gaussian target.
103 |
104 | Theory [1]_ says that the integration of the trajectory should be stable as
105 | long as the step size is smaller than twice the standard deviation.
106 |
107 | References
108 | ----------
109 | .. [1]: Neal, R. M. (2011). MCMC using Hamiltonian dynamics. Handbook of markov chain monte carlo, 2(11), 2.
110 |
111 | """
112 | inverse_mass_matrix = at.as_tensor(1.0)
113 | num_integration_steps = 30
114 |
115 | srng = at.random.RandomStream(seed=0)
116 | Y_rv = srng.normal(1, 2)
117 |
118 | def logprob_fn(y):
119 | logprob, _ = joint_logprob(realized={Y_rv: y})
120 | return logprob
121 |
122 | kernel = hmc.new_kernel(srng, logprob_fn)
123 |
124 | y_vv = Y_rv.clone()
125 | initial_state = hmc.new_state(y_vv, logprob_fn)
126 |
127 | def update_hmc_state(pos, energy, energy_grad):
128 | current_state = hmc.IntegratorState(pos, None, energy, energy_grad)
129 | chain_info, _ = kernel(
130 | current_state, step_size, inverse_mass_matrix, num_integration_steps
131 | )
132 | return (
133 | chain_info.state.position,
134 | chain_info.state.potential_energy,
135 | chain_info.state.potential_energy_grad,
136 | )
137 |
138 | trajectory, updates = aesara.scan(
139 | update_hmc_state,
140 | outputs_info=[
141 | {"initial": initial_state.position},
142 | {"initial": initial_state.potential_energy},
143 | {"initial": initial_state.potential_energy_grad},
144 | ],
145 | n_steps=2_000,
146 | )
147 |
148 | trajectory_generator = aesara.function((y_vv,), trajectory[0], updates=updates)
149 |
150 | samples = trajectory_generator(3.0)
151 | if diverges:
152 | assert np.all(samples == 3.0)
153 | else:
154 | assert np.mean(samples[1000:]) == pytest.approx(1.0, rel=1e-1)
155 | assert np.var(samples[1000:]) == pytest.approx(4.0, rel=1e-1)
156 |
157 |
158 | def compute_ess(samples):
159 | d = arviz.convert_to_dataset(np.expand_dims(samples, axis=0))
160 | ess = arviz.ess(d).to_array().to_numpy().squeeze()
161 | return ess
162 |
163 |
164 | def compute_mcse(x):
165 | ess = compute_ess(x)
166 | std_x = np.std(x, axis=0, ddof=1)
167 | return np.mean(x, axis=0), std_x / np.sqrt(ess)
168 |
169 |
170 | def multivariate_normal_model(srng):
171 | loc = np.array([0.0, 3.0])
172 | scale = np.array([1.0, 2.0])
173 | rho = np.array(0.5)
174 |
175 | cov = np.diag(scale**2)
176 | cov[0, 1] = rho * scale[0] * scale[1]
177 | cov[1, 0] = rho * scale[0] * scale[1]
178 |
179 | loc_tt = at.as_tensor(loc)
180 | cov_tt = at.as_tensor(cov)
181 |
182 | Y_rv = srng.multivariate_normal(loc_tt, cov_tt)
183 |
184 | def logprob_fn(y):
185 | return joint_logprob(realized={Y_rv: y})[0]
186 |
187 | return (loc, scale, rho), Y_rv, logprob_fn
188 |
189 |
190 | def test_hmc_mcse():
191 | """This examples is recommanded in the Stan documentation [0]_ to find bugs
192 | that introduce bias in the average as well as the variance.
193 |
194 | The example is simple enough to be analytically tractable, but complex enough
195 | to find subtle bugs. It uses the MCMC CLT [1]_ to check that the estimates
196 | of different quantities are within the expected range.
197 |
198 | We set the inverse mass matrix to be the diagonal of the covariance matrix.
199 |
200 | We can also show that trajectory integration will not diverge as long as we
201 | choose any step size that is smaller than 2 [2]_ (section 4.2); We adjusted
202 | the number of integration steps manually in order to get a reasonable number
203 | of effective samples.
204 |
205 | References
206 | ----------
207 | .. [0]: https://github.com/stan-dev/stan/wiki/Testing:-Samplers
208 | .. [1]: Geyer, C. J. (2011). Introduction to markov chain monte carlo. Handbook of markov chain monte carlo, 20116022, 45.
209 | .. [2]: Neal, R. M. (2011). MCMC using Hamiltonian dynamics. Handbook of markov chain monte carlo, 2(11), 2.
210 |
211 | """
212 | srng = at.random.RandomStream(seed=1)
213 | (loc, scale, rho), Y_rv, logprob_fn = multivariate_normal_model(srng)
214 |
215 | step_size = 1.0
216 | L = 30
217 | inverse_mass_matrix = at.as_tensor(scale)
218 | kernel = hmc.new_kernel(srng, logprob_fn)
219 |
220 | y_vv = Y_rv.clone()
221 | initial_state = hmc.new_state(y_vv, logprob_fn)
222 |
223 | def update_hmc_state(pos, energy, energy_grad):
224 | current_state = hmc.IntegratorState(pos, None, energy, energy_grad)
225 | chain_info, _ = kernel(current_state, step_size, inverse_mass_matrix, L)
226 | return (
227 | chain_info.state.position,
228 | chain_info.state.potential_energy,
229 | chain_info.state.potential_energy_grad,
230 | )
231 |
232 | trajectory, updates = aesara.scan(
233 | update_hmc_state,
234 | outputs_info=[
235 | {"initial": initial_state.position},
236 | {"initial": initial_state.potential_energy},
237 | {"initial": initial_state.potential_energy_grad},
238 | ],
239 | n_steps=3000,
240 | )
241 |
242 | trajectory_generator = aesara.function((y_vv,), trajectory, updates=updates)
243 |
244 | rng = np.random.default_rng(seed=0)
245 | trace = trajectory_generator(rng.standard_normal(2))
246 | samples = trace[0][1000:]
247 |
248 | # MCSE on the location
249 | delta_loc = samples - loc
250 | mean, mcse = compute_mcse(delta_loc)
251 | p_greater_error = stats.norm.sf(np.abs(mean) / mcse)
252 | np.testing.assert_array_less(0.01, p_greater_error)
253 |
254 | # MCSE on the variance
255 | delta_var = np.square(samples - loc) - scale**2
256 | mean, mcse = compute_mcse(delta_var)
257 | p_greater_error = stats.norm.sf(np.abs(mean) / mcse)
258 | np.testing.assert_array_less(0.01, p_greater_error)
259 |
260 | # MCSE on the correlation
261 | delta_cor = np.prod(samples - loc, axis=1) / np.prod(scale) - rho
262 | mean, mcse = compute_mcse(delta_cor)
263 | p_greater_error = stats.norm.sf(np.abs(mean) / mcse)
264 | np.testing.assert_array_less(0.01, p_greater_error)
265 |
266 |
267 | def test_nuts_mcse():
268 | """This examples is recommanded in the Stan documentation [0]_ to find bugs
269 | that introduce bias in the average as well as the variance.
270 |
271 | The example is simple enough to be analytically tractable, but complex enough
272 | to find subtle bugs. It uses the MCMC CLT [1]_ to check that the estimates
273 | of different quantities are within the expected range.
274 |
275 | We set the inverse mass matrix to be the diagonal of the covariance matrix.
276 |
277 | We can also show that trajectory integration will not diverge as long as we
278 | choose any step size that is smaller than 2 [2]_ (section 4.2); We adjusted
279 | the number of integration steps manually in order to get a reasonable number
280 | of effective samples.
281 |
282 | References
283 | ----------
284 | .. [0]: https://github.com/stan-dev/stan/wiki/Testing:-Samplers
285 | .. [1]: Geyer, C. J. (2011). Introduction to markov chain monte carlo. Handbook of markov chain monte carlo, 20116022, 45.
286 | .. [2]: Neal, R. M. (2011). MCMC using Hamiltonian dynamics. Handbook of markov chain monte carlo, 2(11), 2.
287 |
288 | """
289 | srng = at.random.RandomStream(seed=1)
290 | (loc, scale, rho), Y_rv, logprob_fn = multivariate_normal_model(srng)
291 |
292 | step_size = at.as_tensor(1.0)
293 | inverse_mass_matrix = at.as_tensor(scale)
294 | kernel = nuts.new_kernel(srng, logprob_fn)
295 |
296 | def wrapped_kernel(pos, energy, energy_grad):
297 | state = nuts.IntegratorState(
298 | position=pos,
299 | momentum=None,
300 | potential_energy=energy,
301 | potential_energy_grad=energy_grad,
302 | )
303 | chain_info, updates = kernel(state, step_size, inverse_mass_matrix)
304 |
305 | return (
306 | chain_info.state.position,
307 | chain_info.state.potential_energy,
308 | chain_info.state.potential_energy_grad,
309 | ), updates
310 |
311 | y_vv = Y_rv.clone()
312 | initial_state = nuts.new_state(y_vv, logprob_fn)
313 |
314 | trajectory, updates = aesara.scan(
315 | wrapped_kernel,
316 | outputs_info=[
317 | {"initial": initial_state.position},
318 | {"initial": initial_state.potential_energy},
319 | {"initial": initial_state.potential_energy_grad},
320 | ],
321 | n_steps=3000,
322 | )
323 |
324 | trajectory_generator = aesara.function((y_vv,), trajectory, updates=updates)
325 |
326 | rng = np.random.default_rng(seed=0)
327 | trace = trajectory_generator(rng.standard_normal(2))
328 | samples = trace[0][-1000:]
329 |
330 | # MCSE on the location
331 | delta_loc = samples - loc
332 | mean, mcse = compute_mcse(delta_loc)
333 | p_greater_error = stats.norm.sf(np.abs(mean) / mcse)
334 | np.testing.assert_array_less(0.01, p_greater_error)
335 |
336 | # MCSE on the variance
337 | delta_var = (samples - loc) ** 2 - scale**2
338 | mean, mcse = compute_mcse(delta_var)
339 | p_greater_error = stats.norm.sf(np.abs(mean) / mcse)
340 | np.testing.assert_array_less(0.01, p_greater_error)
341 |
342 | # MCSE on the correlation
343 | delta_cor = np.prod(samples - loc, axis=1) / np.prod(scale) - rho
344 | mean, mcse = compute_mcse(delta_cor)
345 | p_greater_error = stats.norm.sf(np.abs(mean) / mcse)
346 | np.testing.assert_array_less(0.01, p_greater_error)
347 |
--------------------------------------------------------------------------------
/tests/test_integrators.py:
--------------------------------------------------------------------------------
1 | import aesara
2 | import aesara.tensor as at
3 | import numpy as np
4 | import pytest
5 | from aesara.tensor.var import TensorVariable
6 |
7 | from aehmc.integrators import IntegratorState, velocity_verlet
8 |
9 |
10 | def HarmonicOscillator(inverse_mass_matrix, k=1.0, m=1.0):
11 | """Potential and Kinetic energy of an harmonic oscillator."""
12 |
13 | def potential_energy(x: TensorVariable) -> TensorVariable:
14 | return at.sum(0.5 * k * at.square(x))
15 |
16 | def kinetic_energy(p: TensorVariable) -> TensorVariable:
17 | v = inverse_mass_matrix * p
18 | return at.sum(0.5 * at.dot(v, p))
19 |
20 | return potential_energy, kinetic_energy
21 |
22 |
23 | def FreeFall(inverse_mass_matrix, g=1.0):
24 | """Potential and kinetic energy of a free-falling object."""
25 |
26 | def potential_energy(h: TensorVariable) -> TensorVariable:
27 | return at.sum(g * h)
28 |
29 | def kinetic_energy(p: TensorVariable) -> TensorVariable:
30 | v = inverse_mass_matrix * p
31 | return at.sum(0.5 * at.dot(v, p))
32 |
33 | return potential_energy, kinetic_energy
34 |
35 |
36 | def CircularMotion(inverse_mass_matrix):
37 | def potential_energy(q: TensorVariable) -> TensorVariable:
38 | return -1.0 / at.power(at.square(q[0]) + at.square(q[1]), 0.5)
39 |
40 | def kinetic_energy(p: TensorVariable) -> TensorVariable:
41 | return 0.5 * at.dot(inverse_mass_matrix, at.square(p))
42 |
43 | return potential_energy, kinetic_energy
44 |
45 |
46 | integration_examples = [
47 | {
48 | "model": FreeFall,
49 | "n_steps": 100,
50 | "step_size": 0.01,
51 | "q_init": np.array([0.0]),
52 | "p_init": np.array([1.0]),
53 | "q_final": np.array([0.5]),
54 | "p_final": np.array([0.0]),
55 | "inverse_mass_matrix": np.array([1.0]),
56 | },
57 | {
58 | "model": HarmonicOscillator,
59 | "n_steps": 100,
60 | "step_size": 0.01,
61 | "q_init": np.array([0.0]),
62 | "p_init": np.array([1.0]),
63 | "q_final": np.array([np.sin(1.0)]),
64 | "p_final": np.array([np.cos(1.0)]),
65 | "inverse_mass_matrix": np.array([1.0]),
66 | },
67 | {
68 | "model": CircularMotion,
69 | "n_steps": 628,
70 | "step_size": 0.01,
71 | "q_init": np.array([1.0, 0.0]),
72 | "p_init": np.array([0.0, 1.0]),
73 | "q_final": np.array([1.0, 0.0]),
74 | "p_final": np.array([0.0, 1.0]),
75 | "inverse_mass_matrix": np.array([1.0, 1.0]),
76 | },
77 | ]
78 |
79 |
80 | def create_integrate_fn(potential, step_fn, n_steps):
81 | q = at.vector("q")
82 | p = at.vector("p")
83 | step_size = at.scalar("step_size")
84 | energy = potential(q)
85 | energy_grad = aesara.grad(energy, q)
86 | trajectory, _ = aesara.scan(
87 | fn=step_fn,
88 | outputs_info=[
89 | {"initial": q},
90 | {"initial": p},
91 | {"initial": energy},
92 | {"initial": energy_grad},
93 | ],
94 | non_sequences=[step_size],
95 | n_steps=n_steps,
96 | )
97 | integrate_fn = aesara.function((q, p, step_size), trajectory)
98 | return integrate_fn
99 |
100 |
101 | @pytest.mark.parametrize("example", integration_examples)
102 | def test_velocity_verlet(example):
103 | model = example["model"]
104 | inverse_mass_matrix = example["inverse_mass_matrix"]
105 | step_size = example["step_size"]
106 | q_init = example["q_init"]
107 | p_init = example["p_init"]
108 |
109 | potential, kinetic_energy = model(inverse_mass_matrix)
110 | step = velocity_verlet(potential, kinetic_energy)
111 |
112 | q = at.vector("q")
113 | p = at.vector("p")
114 | p_final = at.vector("p_final")
115 | energy_at = potential(q) + kinetic_energy(p)
116 | energy_fn = aesara.function((q, p), energy_at)
117 |
118 | def wrapped_step(pos, mom, energy, energy_grad, step_size):
119 | return step(IntegratorState(pos, mom, energy, energy_grad), step_size)
120 |
121 | integrate_fn = create_integrate_fn(potential, wrapped_step, example["n_steps"])
122 | q_final, p_final, energy_final, _ = integrate_fn(q_init, p_init, step_size)
123 |
124 | # Check that the trajectory was correctly integrated
125 | np.testing.assert_allclose(example["q_final"], q_final[-1], atol=1e-2)
126 | np.testing.assert_allclose(example["p_final"], p_final[-1], atol=1e-2)
127 |
128 | # Symplectic integrators conserve energy
129 | energy = energy_fn(q_init, p_init)
130 | new_energy = energy_fn(q_final[-1], p_final[-1])
131 | assert energy == pytest.approx(new_energy, 1e-4)
132 |
--------------------------------------------------------------------------------
/tests/test_mass_matrix.py:
--------------------------------------------------------------------------------
1 | import aesara
2 | import aesara.tensor as at
3 | import numpy as np
4 | import pytest
5 | from aesara.tensor.random.utils import RandomStream
6 | from numpy.testing import assert_allclose
7 |
8 | from aehmc import mass_matrix
9 |
10 |
11 | @pytest.mark.parametrize("is_full_matrix", [True, False])
12 | @pytest.mark.parametrize("n_dims", [0, 1, 3])
13 | def test_mass_matrix_adaptation(is_full_matrix, n_dims):
14 | srng = RandomStream(seed=0)
15 |
16 | if n_dims > 0:
17 | mu = 0.5 * at.ones((n_dims,))
18 | cov = 0.33 * at.ones((n_dims, n_dims))
19 | else:
20 | mu = at.constant(0.5)
21 | cov = at.constant(0.33)
22 |
23 | init, update, final = mass_matrix.covariance_adaptation(is_full_matrix)
24 | _, wc_state = init(n_dims)
25 | if n_dims > 0:
26 | dist = srng.multivariate_normal
27 | else:
28 | dist = srng.normal
29 |
30 | def one_step(*wc_state):
31 | sample = dist(mu, cov)
32 | wc_state = update(sample, wc_state)
33 | return wc_state
34 |
35 | results, updates = aesara.scan(
36 | fn=one_step,
37 | outputs_info=[
38 | {"initial": wc_state[0]},
39 | {"initial": wc_state[1]},
40 | {"initial": wc_state[2]},
41 | ],
42 | n_steps=2_000,
43 | )
44 |
45 | inverse_mass_matrix = final((results[0][-1], results[1][-1], results[2][-1]))
46 |
47 | if n_dims > 0:
48 | if is_full_matrix:
49 | expected = cov.eval()
50 | inverse_mass_matrix = inverse_mass_matrix.eval()
51 | else:
52 | expected = np.diagonal(cov.eval())
53 | inverse_mass_matrix = inverse_mass_matrix.eval()
54 | assert np.shape(inverse_mass_matrix) == np.shape(expected)
55 | assert_allclose(inverse_mass_matrix, expected, rtol=0.1)
56 | else:
57 | sigma = at.sqrt(inverse_mass_matrix).eval()
58 | expected_sigma = cov.eval()
59 | assert np.ndim(expected_sigma) == 0
60 | assert sigma == pytest.approx(expected_sigma, rel=0.1)
61 |
--------------------------------------------------------------------------------
/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | import aesara
2 | import aesara.tensor as at
3 | import numpy as np
4 | import pytest
5 | from aesara.tensor.random.utils import RandomStream
6 |
7 | from aehmc.metrics import gaussian_metric
8 |
9 | momentum_test_cases = [
10 | (1.0, 0.144),
11 | (np.array([1.0]), np.array([0.144])),
12 | (np.array([1.0, 1.0]), np.array([0.144, 1.27])),
13 | (np.array([[1.0, 0], [0, 1.0]]), np.array([0.144, 1.27])),
14 | ]
15 |
16 |
17 | @pytest.mark.skip(reason="This test relies on a specific rng implementation and seed.")
18 | @pytest.mark.parametrize("case", momentum_test_cases)
19 | def test_gaussian_metric_momentum(case):
20 | inverse_mass_matrix_val, expected_momentum = case
21 |
22 | # Momentum
23 | if np.ndim(inverse_mass_matrix_val) == 0:
24 | inverse_mass_matrix = at.scalar("inverse_mass_matrix")
25 | elif np.ndim(inverse_mass_matrix_val) == 1:
26 | inverse_mass_matrix = at.vector("inverse_mass_matrix")
27 | else:
28 | inverse_mass_matrix = at.matrix("inverse_mass_matrix")
29 |
30 | momentum_fn, _, _ = gaussian_metric(inverse_mass_matrix)
31 | srng = RandomStream(seed=59)
32 | momentum_generator = aesara.function([inverse_mass_matrix], momentum_fn(srng))
33 |
34 | momentum = momentum_generator(inverse_mass_matrix_val)
35 | assert np.shape(momentum) == np.shape(expected_momentum)
36 | assert momentum == pytest.approx(expected_momentum, 1e-2)
37 |
38 |
39 | kinetic_energy_test_cases = [
40 | (1.0, 1.0, 0.5),
41 | (np.array([1.0]), np.array([1.0]), 0.5),
42 | (np.array([1.0, 1.0]), np.array([1.0, 1.0]), 1.0),
43 | (np.array([[1.0, 0], [0, 1.0]]), np.array([1.0, 1.0]), 1.0),
44 | ]
45 |
46 |
47 | @pytest.mark.parametrize("case", kinetic_energy_test_cases)
48 | def test_gaussian_metric_kinetic_energy(case):
49 | inverse_mass_matrix_val, momentum_val, expected_energy = case
50 |
51 | if np.ndim(inverse_mass_matrix_val) == 0:
52 | inverse_mass_matrix = at.scalar("inverse_mass_matrix")
53 | momentum = at.scalar("momentum")
54 | elif np.ndim(inverse_mass_matrix_val) == 1:
55 | inverse_mass_matrix = at.vector("inverse_mass_matrix")
56 | momentum = at.vector("momentum")
57 | else:
58 | inverse_mass_matrix = at.matrix("inverse_mass_matrix")
59 | momentum = at.vector("momentum")
60 |
61 | _, kinetic_energy_fn, _ = gaussian_metric(inverse_mass_matrix)
62 | kinetic_energy = aesara.function(
63 | (inverse_mass_matrix, momentum), kinetic_energy_fn(momentum)
64 | )
65 |
66 | kinetic = kinetic_energy(inverse_mass_matrix_val, momentum_val)
67 | assert np.ndim(kinetic) == 0
68 | assert kinetic == expected_energy
69 |
70 |
71 | turning_test_cases = [
72 | (1.0, 1.0, 1.0, 1.0),
73 | (
74 | np.array([1.0, 1.0]), # inverse mass matrix
75 | np.array([1.0, 1.0]), # p_left
76 | np.array([1.0, 1.0]), # p_right
77 | np.array([1.0, 1.0]), # p_sum
78 | ),
79 | (
80 | np.array([[1.0, 0.0], [0.0, 1.0]]),
81 | np.array([1.0, 1.0]),
82 | np.array([1.0, 1.0]),
83 | np.array([1.0, 1.0]),
84 | ),
85 | ]
86 |
87 |
88 | @pytest.mark.parametrize("case", turning_test_cases)
89 | def test_turning(case):
90 | inverse_mass_matrix_val, p_left_val, p_right_val, p_sum_val = case
91 |
92 | if np.ndim(inverse_mass_matrix_val) == 0:
93 | inverse_mass_matrix = at.scalar("inverse_mass_matrix")
94 | p_left = at.scalar("p_left")
95 | p_right = at.scalar("p_right")
96 | p_sum = at.scalar("p_sum")
97 | elif np.ndim(inverse_mass_matrix_val) == 1:
98 | inverse_mass_matrix = at.vector("inverse_mass_matrix")
99 | p_left = at.vector("p_left")
100 | p_right = at.vector("p_right")
101 | p_sum = at.vector("p_sum")
102 | else:
103 | inverse_mass_matrix = at.matrix("inverse_mass_matrix")
104 | p_left = at.vector("p_left")
105 | p_right = at.vector("p_right")
106 | p_sum = at.vector("p_sum")
107 |
108 | _, _, turning_fn = gaussian_metric(inverse_mass_matrix)
109 |
110 | is_turning_fn = aesara.function(
111 | (inverse_mass_matrix, p_left, p_right, p_sum),
112 | turning_fn(p_left, p_right, p_sum),
113 | )
114 |
115 | is_turning = is_turning_fn(
116 | inverse_mass_matrix_val, p_left_val, p_right_val, p_sum_val
117 | )
118 |
119 | assert is_turning.ndim == 0
120 | assert is_turning.item() is True
121 |
122 |
123 | def test_fail_wrong_mass_matrix_dimension():
124 | """`gaussian_metric` should fail when the dimension of the mass matrix is greater than 2."""
125 | inverse_mass_matrix = np.array([[[1, 1], [1, 1]], [[1, 1], [1, 1]]])
126 | with pytest.raises(ValueError):
127 | _ = gaussian_metric(inverse_mass_matrix)
128 |
--------------------------------------------------------------------------------
/tests/test_step_size.py:
--------------------------------------------------------------------------------
1 | import aesara
2 | import aesara.tensor as at
3 | import numpy as np
4 | import pytest
5 | from aesara import config
6 | from aesara.tensor.random.utils import RandomStream
7 |
8 | from aehmc import hmc
9 | from aehmc.algorithms import DualAveragingState
10 | from aehmc.step_size import dual_averaging_adaptation
11 |
12 |
13 | @pytest.fixture()
14 | def init():
15 | def logprob_fn(x):
16 | return -2 * (x - 1.0) ** 2
17 |
18 | srng = RandomStream(seed=0)
19 | kernel = hmc.new_kernel(srng, logprob_fn)
20 |
21 | initial_position = at.as_tensor(1.0, dtype=config.floatX)
22 | initial_state = hmc.new_state(initial_position, logprob_fn)
23 |
24 | return initial_state, kernel
25 |
26 |
27 | def test_dual_averaging_adaptation(init):
28 | initial_state, kernel = init
29 |
30 | init_stepsize = at.as_tensor(1.0, dtype=config.floatX)
31 | inverse_mass_matrix = at.as_tensor(1.0)
32 | num_integration_steps = 10
33 |
34 | init_fn, update_fn = dual_averaging_adaptation()
35 | da_state = init_fn(init_stepsize)
36 |
37 | def one_step(q, logprob, logprob_grad, step, x_t, x_avg, gradient_avg, mu):
38 | state = hmc.IntegratorState(
39 | position=q,
40 | momentum=None,
41 | potential_energy=logprob,
42 | potential_energy_grad=logprob_grad,
43 | )
44 | chain_info, inner_updates = kernel(
45 | state, at.exp(x_t), inverse_mass_matrix, num_integration_steps
46 | )
47 | current_da_state = DualAveragingState(
48 | step=step,
49 | iterates=x_t,
50 | iterates_avg=x_avg,
51 | gradient_avg=gradient_avg,
52 | shrinkage_pts=mu,
53 | )
54 | da_state = update_fn(chain_info.acceptance_probability, current_da_state)
55 |
56 | return (
57 | chain_info.state.position,
58 | chain_info.state.potential_energy,
59 | chain_info.state.potential_energy_grad,
60 | da_state.step,
61 | da_state.iterates,
62 | da_state.iterates_avg,
63 | da_state.gradient_avg,
64 | da_state.shrinkage_pts,
65 | chain_info.acceptance_probability,
66 | ), inner_updates
67 |
68 | states, updates = aesara.scan(
69 | fn=one_step,
70 | outputs_info=[
71 | {"initial": initial_state.position},
72 | {"initial": initial_state.potential_energy},
73 | {"initial": initial_state.potential_energy_grad},
74 | {"initial": da_state.step},
75 | {"initial": da_state.iterates}, # logstepsize
76 | {"initial": da_state.iterates_avg}, # logstepsize_avg
77 | {"initial": da_state.gradient_avg},
78 | {"initial": da_state.shrinkage_pts}, # mu
79 | None,
80 | ],
81 | n_steps=10_000,
82 | )
83 |
84 | p_accept = aesara.function((), states[-1], updates=updates)
85 | step_size = aesara.function((), at.exp(states[-4][-1]), updates=updates)
86 | assert np.mean(p_accept()) == pytest.approx(0.8, rel=10e-3)
87 | assert step_size() < 10
88 | assert step_size() > 1e-1
89 |
--------------------------------------------------------------------------------
/tests/test_termination.py:
--------------------------------------------------------------------------------
1 | """Test dynamic termination criteria."""
2 | import aesara
3 | import aesara.tensor as at
4 | import numpy as np
5 | import pytest
6 | from numpy.testing import assert_array_equal
7 |
8 | from aehmc.metrics import gaussian_metric
9 | from aehmc.termination import TerminationState, _find_storage_indices, iterative_uturn
10 |
11 |
12 | @pytest.mark.parametrize(
13 | "checkpoint_idxs, momentum, momentum_sum, inverse_mass_matrix, expected_turning",
14 | [
15 | ((3, 3), at.as_tensor(1.0), at.as_tensor(3.0), at.as_tensor(1.0), True),
16 | ((3, 2), at.as_tensor(1.0), at.as_tensor(3.0), at.as_tensor(1.0), False),
17 | ((0, 0), at.as_tensor(1.0), at.as_tensor(3.0), at.as_tensor(1.0), False),
18 | ((0, 1), at.as_tensor(1.0), at.as_tensor(3.0), at.as_tensor(1.0), True),
19 | ((1, 3), at.as_tensor(1.0), at.as_tensor(3.0), at.as_tensor(1.0), True),
20 | ((1, 3), at.as_tensor([1.0]), at.as_tensor([3.0]), at.ones(1), True),
21 | ],
22 | )
23 | def test_iterative_turning_termination(
24 | checkpoint_idxs, momentum, momentum_sum, inverse_mass_matrix, expected_turning
25 | ):
26 | _, _, is_turning = gaussian_metric(inverse_mass_matrix)
27 | _, _, is_iterative_turning = iterative_uturn(is_turning)
28 |
29 | idx_min, idx_max = checkpoint_idxs
30 | idx_min = at.as_tensor(idx_min)
31 | idx_max = at.as_tensor(idx_max)
32 | momentum_ckpts = at.as_tensor(np.array([1.0, 2.0, 3.0, -2.0]))
33 | momentum_sum_ckpts = at.as_tensor(np.array([2.0, 4.0, 4.0, -1.0]))
34 | ckpt_state = TerminationState(
35 | momentum_checkpoints=momentum_ckpts,
36 | momentum_sum_checkpoints=momentum_sum_ckpts,
37 | min_index=idx_min,
38 | max_index=idx_max,
39 | )
40 |
41 | _, _, is_iterative_turning_fn = iterative_uturn(is_turning)
42 | is_iterative_turning = is_iterative_turning_fn(ckpt_state, momentum_sum, momentum)
43 | fn = aesara.function((), is_iterative_turning, on_unused_input="ignore")
44 |
45 | actual_turning = fn()
46 |
47 | assert actual_turning.ndim == 0
48 | assert expected_turning == actual_turning
49 |
50 |
51 | @pytest.mark.parametrize(
52 | "step, expected_idx",
53 | [(0, (1, 0)), (6, (3, 2)), (7, (0, 2)), (13, (2, 2)), (15, (0, 3))],
54 | )
55 | def test_leaf_idx_to_ckpt_idx(step, expected_idx):
56 | step_tt = at.scalar("step", dtype=np.int64)
57 | idx_tt = _find_storage_indices(step_tt)
58 | fn = aesara.function((step_tt,), (*idx_tt,))
59 |
60 | idx_vv = fn(step)
61 | assert idx_vv[0].item() == expected_idx[0]
62 | assert idx_vv[1].item() == expected_idx[1]
63 |
64 |
65 | @pytest.mark.parametrize(
66 | "num_dims",
67 | [1, 3],
68 | )
69 | def test_termination_update(num_dims):
70 | inverse_mass_matrix = at.as_tensor(np.ones(1))
71 | _, _, is_turning = gaussian_metric(inverse_mass_matrix)
72 | new_state, update, _ = iterative_uturn(is_turning)
73 |
74 | position = at.as_tensor(np.ones(num_dims))
75 | momentum = at.as_tensor(np.ones(num_dims))
76 | momentum_sum = at.as_tensor(np.ones(num_dims))
77 |
78 | num_doublings = at.as_tensor(4)
79 | termination_state = new_state(position, num_doublings)
80 |
81 | step = at.scalar("step", dtype=np.int64)
82 | updated = update(termination_state, momentum_sum, momentum, step)
83 | update_fn = aesara.function((step,), updated, on_unused_input="ignore")
84 |
85 | # Make sure this works for a single step
86 | result_odd = update_fn(1)
87 |
88 | # When the number of steps is odd there should be no update
89 | result_odd = update_fn(5)
90 | assert_array_equal(result_odd[0], np.zeros((4, num_dims)))
91 | assert_array_equal(result_odd[1], np.zeros((4, num_dims)))
92 |
--------------------------------------------------------------------------------
/tests/test_trajectory.py:
--------------------------------------------------------------------------------
1 | import aesara
2 | import aesara.tensor as at
3 | import numpy as np
4 | import pytest
5 | from aeppl.logprob import logprob
6 | from aesara.tensor.random.utils import RandomStream
7 | from aesara.tensor.var import TensorVariable
8 |
9 | from aehmc.integrators import IntegratorState, new_integrator_state, velocity_verlet
10 | from aehmc.metrics import gaussian_metric
11 | from aehmc.termination import iterative_uturn
12 | from aehmc.trajectory import (
13 | ProposalState,
14 | dynamic_integration,
15 | multiplicative_expansion,
16 | static_integration,
17 | )
18 |
19 | aesara.config.optimizer = "fast_compile"
20 | aesara.config.exception_verbosity = "high"
21 |
22 |
23 | def CircularMotion(inverse_mass_matrix):
24 | def potential_energy(q: TensorVariable) -> TensorVariable:
25 | return -1.0 / at.power(at.square(q[0]) + at.square(q[1]), 0.5)
26 |
27 | def kinetic_energy(p: TensorVariable) -> TensorVariable:
28 | return 0.5 * at.dot(inverse_mass_matrix, at.square(p))
29 |
30 | return potential_energy, kinetic_energy
31 |
32 |
33 | examples = [
34 | {
35 | "n_steps": 628,
36 | "step_size": 0.01,
37 | "q_init": np.array([1.0, 0.0]),
38 | "p_init": np.array([0.0, 1.0]),
39 | "q_final": np.array([1.0, 0.0]),
40 | "p_final": np.array([0.0, 1.0]),
41 | "inverse_mass_matrix": np.array([1.0, 1.0]),
42 | },
43 | ]
44 |
45 |
46 | @pytest.mark.parametrize("example", examples)
47 | def test_static_integration(example):
48 | inverse_mass_matrix = example["inverse_mass_matrix"]
49 | step_size = example["step_size"]
50 | num_steps = example["n_steps"]
51 | q_init = example["q_init"]
52 | p_init = example["p_init"]
53 |
54 | potential, kinetic_energy = CircularMotion(inverse_mass_matrix)
55 | step = velocity_verlet(potential, kinetic_energy)
56 | integrator = static_integration(step, num_steps)
57 |
58 | q = at.vector("q")
59 | p = at.vector("p")
60 | energy = potential(q)
61 | energy_grad = aesara.grad(energy, q)
62 | init_state = IntegratorState(
63 | position=q,
64 | momentum=p,
65 | potential_energy=energy,
66 | potential_energy_grad=energy_grad,
67 | )
68 | final_state, updates = integrator(init_state, step_size)
69 | integrate_fn = aesara.function((q, p), final_state, updates=updates)
70 |
71 | q_final, p_final, *_ = integrate_fn(q_init, p_init)
72 |
73 | np.testing.assert_allclose(q_final, example["q_final"], atol=1e-1)
74 | np.testing.assert_allclose(p_final, example["p_final"], atol=1e-1)
75 |
76 |
77 | @pytest.mark.parametrize(
78 | "case",
79 | [
80 | (0.0000001, False, False),
81 | (1000, True, False),
82 | (1e100, True, False),
83 | ],
84 | )
85 | def test_dynamic_integration(case):
86 | srng = RandomStream(seed=59)
87 |
88 | def potential_fn(x):
89 | return -at.sum(logprob(srng.normal(0.0, 1.0), x))
90 |
91 | step_size, should_diverge, should_turn = case
92 |
93 | # Set up the trajectory integrator
94 | inverse_mass_matrix = at.ones(1)
95 |
96 | momentum_generator, kinetic_energy_fn, uturn_check_fn = gaussian_metric(
97 | inverse_mass_matrix
98 | )
99 | integrator = velocity_verlet(potential_fn, kinetic_energy_fn)
100 | (
101 | new_criterion_state,
102 | update_criterion_state,
103 | is_criterion_met,
104 | ) = iterative_uturn(uturn_check_fn)
105 |
106 | trajectory_integrator = dynamic_integration(
107 | srng,
108 | integrator,
109 | kinetic_energy_fn,
110 | update_criterion_state,
111 | is_criterion_met,
112 | divergence_threshold=at.as_tensor(1000),
113 | )
114 |
115 | # Initialize the state
116 | direction = at.as_tensor(1)
117 | step_size = at.as_tensor(step_size)
118 | max_num_steps = at.as_tensor(10)
119 | num_doublings = at.as_tensor(10)
120 | position = at.as_tensor(np.ones(1))
121 |
122 | initial_state = new_integrator_state(
123 | potential_fn, position, momentum_generator(srng)
124 | )
125 | initial_energy = initial_state[2] + kinetic_energy_fn(initial_state[1])
126 | termination_state = new_criterion_state(initial_state[0], num_doublings)
127 |
128 | state, updates = trajectory_integrator(
129 | initial_state,
130 | direction,
131 | termination_state,
132 | max_num_steps,
133 | step_size,
134 | initial_energy,
135 | )
136 |
137 | is_turning = aesara.function((), state[-1], updates=updates)()
138 | is_diverging = aesara.function((), state[-2], updates=updates)()
139 |
140 | assert is_diverging.item() is should_diverge
141 | assert is_turning.item() is should_turn
142 |
143 |
144 | @pytest.mark.parametrize(
145 | "step_size, should_diverge, should_turn, expected_doublings",
146 | [
147 | (100000.0, True, False, 1),
148 | (0.0000001, False, False, 10),
149 | (1.0, False, True, 1),
150 | ],
151 | )
152 | def test_multiplicative_expansion(
153 | step_size, should_diverge, should_turn, expected_doublings
154 | ):
155 | srng = RandomStream(seed=59)
156 |
157 | def potential_fn(x):
158 | return 0.5 * at.sum(at.square(x))
159 |
160 | step_size = at.as_tensor(step_size)
161 | inverse_mass_matrix = at.as_tensor(1.0, dtype=np.float64)
162 | position = at.as_tensor(1.0, dtype=np.float64)
163 |
164 | momentum_generator, kinetic_energy_fn, uturn_check_fn = gaussian_metric(
165 | inverse_mass_matrix
166 | )
167 | integrator = velocity_verlet(potential_fn, kinetic_energy_fn)
168 | (
169 | new_criterion_state,
170 | update_criterion_state,
171 | is_criterion_met,
172 | ) = iterative_uturn(uturn_check_fn)
173 |
174 | trajectory_integrator = dynamic_integration(
175 | srng,
176 | integrator,
177 | kinetic_energy_fn,
178 | update_criterion_state,
179 | is_criterion_met,
180 | divergence_threshold=at.as_tensor(1000),
181 | )
182 |
183 | expand = multiplicative_expansion(srng, trajectory_integrator, uturn_check_fn, 10)
184 |
185 | # Create the initial state
186 | state = new_integrator_state(potential_fn, position, momentum_generator(srng))
187 | energy = state.potential_energy + kinetic_energy_fn(state.momentum)
188 | proposal = ProposalState(
189 | state=state,
190 | energy=energy,
191 | weight=at.as_tensor(0.0, dtype=np.float64),
192 | sum_log_p_accept=at.as_tensor(-np.inf, dtype=np.float64),
193 | )
194 | termination_state = new_criterion_state(state.position, 10)
195 | result, updates = expand(
196 | proposal, state, state, state.momentum, termination_state, energy, step_size
197 | )
198 | outputs = (
199 | result.diagnostics.num_doublings[-1],
200 | result.diagnostics.is_diverging[-1],
201 | result.diagnostics.is_turning[-1],
202 | )
203 | fn = aesara.function((), outputs, updates=updates)
204 | num_doublings, does_diverge, does_turn = fn()
205 |
206 | assert does_diverge == should_diverge
207 | assert does_turn == should_turn
208 | assert expected_doublings == num_doublings
209 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import aesara.tensor as at
2 | import numpy as np
3 | import pytest
4 | from aesara.graph.basic import Apply
5 | from aesara.tensor.exceptions import ShapeError
6 | from aesara.tensor.random.basic import NormalRV
7 |
8 | from aehmc.utils import RaveledParamsMap
9 |
10 |
11 | def test_RaveledParamsMap():
12 | tau_rv = at.random.invgamma(0.5, 0.5, name="tau")
13 | tau_vv = tau_rv.clone()
14 | tau_vv.name = "t"
15 |
16 | beta_size = (3, 2)
17 | beta_rv = at.random.normal(0, at.sqrt(tau_rv), size=beta_size, name="beta")
18 | beta_vv = beta_rv.clone()
19 | beta_vv.name = "b"
20 |
21 | kappa_size = (20,)
22 | kappa_rv = at.random.normal(0, 1, size=kappa_size, name="kappa")
23 | kappa_vv = kappa_rv.clone()
24 | kappa_vv.name = "k"
25 |
26 | params_map = {beta_rv: beta_vv, tau_rv: tau_vv, kappa_rv: kappa_vv}
27 |
28 | rp_map = RaveledParamsMap(params_map.keys())
29 |
30 | assert repr(rp_map) == "RaveledParamsMap((beta, tau, kappa))"
31 |
32 | q = at.vector("q")
33 |
34 | exp_beta_part = np.exp(np.arange(np.prod(beta_size)).reshape(beta_size))
35 | exp_tau_part = 1.0
36 | exp_kappa_part = np.exp(np.arange(np.prod(kappa_size)).reshape(kappa_size))
37 | exp_raveled_params = np.concatenate(
38 | [exp_beta_part.ravel(), np.atleast_1d(exp_tau_part), exp_kappa_part.ravel()]
39 | )
40 |
41 | raveled_params_at = rp_map.ravel_params([beta_vv, tau_vv, kappa_vv])
42 | raveled_params_val = raveled_params_at.eval(
43 | {beta_vv: exp_beta_part, tau_vv: exp_tau_part, kappa_vv: exp_kappa_part}
44 | )
45 |
46 | assert np.array_equal(raveled_params_val, exp_raveled_params)
47 |
48 | unraveled_params_at = rp_map.unravel_params(q)
49 |
50 | beta_part = unraveled_params_at[beta_rv]
51 | tau_part = unraveled_params_at[tau_rv]
52 | kappa_part = unraveled_params_at[kappa_rv]
53 |
54 | new_test_point = {q: exp_raveled_params}
55 | assert np.array_equal(beta_part.eval(new_test_point), exp_beta_part)
56 | assert np.array_equal(tau_part.eval(new_test_point), exp_tau_part)
57 | assert np.array_equal(kappa_part.eval(new_test_point), exp_kappa_part)
58 |
59 |
60 | def test_RaveledParamsMap_dtype():
61 | tau_rv = at.random.normal(0, 1, name="tau")
62 | tau_vv = tau_rv.clone()
63 | tau_vv.name = "t"
64 |
65 | lambda_rv = at.random.binomial(10, 0.5, name="lmbda")
66 | lambda_vv = lambda_rv.clone()
67 | lambda_vv.name = "l"
68 |
69 | params_map = {tau_rv: tau_vv, lambda_rv: lambda_vv}
70 | rp_map = RaveledParamsMap(params_map.keys())
71 |
72 | q = rp_map.ravel_params((tau_vv, lambda_vv))
73 | unraveled_params = rp_map.unravel_params(q)
74 |
75 | tau_part = unraveled_params[tau_rv]
76 | lambda_part = unraveled_params[lambda_rv]
77 |
78 | assert tau_part.dtype == tau_rv.dtype
79 | assert lambda_part.dtype == lambda_rv.dtype
80 |
81 |
82 | def test_RaveledParamsMap_bad_infer_shape():
83 | class BadNormalRV(NormalRV):
84 | def make_node(self, *args, **kwargs):
85 | res = super().make_node(*args, **kwargs)
86 | # Drop static `Type`-level shape information
87 | rv_out = res.outputs[1]
88 | outputs = [
89 | res.outputs[0].clone(),
90 | at.tensor(dtype=rv_out.type.dtype, shape=(None,) * rv_out.type.ndim),
91 | ]
92 | return Apply(
93 | self,
94 | res.inputs,
95 | outputs,
96 | )
97 |
98 | def infer_shape(self, *args, **kwargs):
99 | raise ShapeError()
100 |
101 | bad_normal_op = BadNormalRV()
102 |
103 | size = (3, 2)
104 | beta_rv = bad_normal_op(0, 1, size=size, name="beta")
105 | beta_vv = beta_rv.clone()
106 | beta_vv.name = "b"
107 |
108 | params_map = {beta_rv: beta_vv}
109 |
110 | with pytest.warns(Warning):
111 | RaveledParamsMap(params_map.keys())
112 |
--------------------------------------------------------------------------------