├── .copier-answers.yml ├── .github ├── dependabot.yml └── workflows │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── noxfile.py ├── pyproject.toml ├── readthedocs.yaml ├── setup.py ├── src └── numpyro_ext │ ├── __init__.py │ ├── distributions.py │ ├── infer.py │ ├── info.py │ ├── linear_op.py │ └── optim.py └── tests ├── test_distributions.py ├── test_infer.py ├── test_info.py ├── test_linear_op.py └── test_optim.py /.copier-answers.yml: -------------------------------------------------------------------------------- 1 | _commit: v0.0.14 2 | _src_path: gh:dfm/copier-python 3 | author_email: foreman.mackey@gmail.com 4 | author_fullname: Dan Foreman-Mackey 5 | author_username: dfm 6 | code_of_conduct_email: foreman.mackey@gmail.com 7 | copyright_holder: Simons Foundation, Inc. 8 | copyright_license: Apache 9 | copyright_year: '2022' 10 | documentation_url: https://github.com/dfm/numpyro-ext 11 | enable_mypy: false 12 | enable_pybind11: false 13 | enable_windows: false 14 | general_nox_sessions: '"lint"' 15 | project_description: Some extensions to numpyro that I find useful 16 | project_development_status: beta 17 | project_line_length: 79 18 | project_name: numpyro-ext 19 | python_formal_min_version: 7 20 | python_max_version: 10 21 | python_min_version: 7 22 | python_nox_sessions: '"tests"' 23 | python_package_distribution_name: numpyro-ext 24 | python_package_import_name: numpyro_ext 25 | readthedocs_project_name: numpyro-ext 26 | repository_name: numpyro-ext 27 | repository_namespace: dfm 28 | 29 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "monthly" 7 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - "*" 9 | pull_request: 10 | workflow_dispatch: 11 | inputs: 12 | prerelease: 13 | description: "Run a pre-release, testing the build" 14 | required: false 15 | type: boolean 16 | default: false 17 | 18 | jobs: 19 | tests: 20 | runs-on: ${{ matrix.os }} 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | os: ["ubuntu-latest"] 25 | python-version: ["3.9", "3.10", "3.11", "3.12"] 26 | nox-session: ["tests"] 27 | include: 28 | - os: macos-latest 29 | python-version: "3.10" 30 | nox-session: "tests" 31 | - os: ubuntu-latest 32 | python-version: "3.10" 33 | nox-session: "doctest" 34 | - os: ubuntu-latest 35 | python-version: "3.10" 36 | nox-session: "lint" 37 | 38 | steps: 39 | - name: Checkout 40 | uses: actions/checkout@v4 41 | with: 42 | submodules: true 43 | fetch-depth: 0 44 | - name: Configure Python 45 | uses: actions/setup-python@v5 46 | with: 47 | python-version: ${{ matrix.python-version }} 48 | - name: Cache pip packages 49 | uses: actions/cache@v4 50 | with: 51 | path: ~/.cache/pip 52 | key: ${{ runner.os }}-pip-${{ matrix.nox-session }}-${{ hashFiles('**/noxfile.py') }} 53 | restore-keys: | 54 | ${{ runner.os }}-pip-${{ matrix.nox-session }}- 55 | - name: Cache nox session files 56 | uses: actions/cache@v4 57 | with: 58 | path: .nox 59 | key: ${{ runner.os }}-nox-${{ matrix.nox-session }}-${{ hashFiles('**/noxfile.py') }} 60 | restore-keys: | 61 | ${{ runner.os }}-nox-${{ matrix.nox-session }}- 62 | - name: Cache pre-commit environments 63 | if: ${{ matrix.nox-session == 'lint' }} 64 | uses: actions/cache@v4 65 | with: 66 | path: ~/.cache/pre-commit 67 | key: ${{ runner.os }}-pre-commit-${{ hashFiles('.pre-commit-config.yaml') }} 68 | restore-keys: | 69 | ${{ runner.os }}-pre-commit- 70 | - name: Install nox 71 | run: | 72 | python -m pip install -U pip 73 | python -m pip install -U nox 74 | - name: Run tests 75 | run: python -m nox --non-interactive -s ${{ matrix.nox-session }} 76 | 77 | build: 78 | runs-on: ubuntu-latest 79 | steps: 80 | - uses: actions/checkout@v4 81 | with: 82 | fetch-depth: 0 83 | - uses: actions/setup-python@v5 84 | name: Install Python 85 | with: 86 | python-version: "3.9" 87 | - name: Build sdist and wheel 88 | run: | 89 | python -m pip install -U pip 90 | python -m pip install -U build 91 | python -m build . 92 | - uses: actions/upload-artifact@v4 93 | with: 94 | path: dist/* 95 | 96 | upload_pypi: 97 | environment: 98 | name: pypi 99 | url: https://pypi.org/p/numpyro-ext 100 | permissions: 101 | id-token: write 102 | needs: [tests, build] 103 | runs-on: ubuntu-latest 104 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') 105 | steps: 106 | - uses: actions/download-artifact@v4 107 | with: 108 | name: artifact 109 | path: dist 110 | - uses: pypa/gh-action-pypi-publish@v1.12.3 111 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /src/*/version.py 2 | *.pyc 3 | __pycache__ 4 | *.egg-info 5 | _skbuild 6 | *.so 7 | .mypy_cache 8 | .tox 9 | .nox 10 | .coverage 11 | .coverage* 12 | dist 13 | build 14 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: forbidden-files 5 | name: forbidden files 6 | entry: found copier update rejection files; review them and remove them 7 | language: fail 8 | files: "\\.rej$" 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v4.4.0 11 | hooks: 12 | - id: trailing-whitespace 13 | - id: end-of-file-fixer 14 | exclude_types: [json, binary] 15 | exclude: ".copier-answers.yml" 16 | - id: check-yaml 17 | - repo: https://github.com/charliermarsh/ruff-pre-commit 18 | rev: "v0.0.265" 19 | hooks: 20 | - id: ruff 21 | exclude: "^docs/" 22 | args: [--fix, --exit-non-zero-on-fix] 23 | - repo: https://github.com/psf/black 24 | rev: "23.3.0" 25 | hooks: 26 | - id: black-jupyter 27 | - repo: https://github.com/kynan/nbstripout 28 | rev: "0.6.1" 29 | hooks: 30 | - id: nbstripout 31 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | - Demonstrating empathy and kindness toward other people 21 | - Being respectful of differing opinions, viewpoints, and experiences 22 | - Giving and gracefully accepting constructive feedback 23 | - Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | - Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | - The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | - Trolling, insulting or derogatory comments, and personal or political attacks 33 | - Public or private harassment 34 | - Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | - Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | foreman.mackey@gmail.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. 120 | 121 | Community Impact Guidelines were inspired by 122 | [Mozilla's code of conduct enforcement ladder][mozilla coc]. 123 | 124 | For answers to common questions about this code of conduct, see the FAQ at 125 | [https://www.contributor-covenant.org/faq][faq]. Translations are available 126 | at [https://www.contributor-covenant.org/translations][translations]. 127 | 128 | [homepage]: https://www.contributor-covenant.org 129 | [v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html 130 | [mozilla coc]: https://github.com/mozilla/diversity 131 | [faq]: https://www.contributor-covenant.org/faq 132 | [translations]: https://www.contributor-covenant.org/translations 133 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributor Guide 2 | 3 | Thank you for your interest in improving this project. This project is 4 | open-source under the MIT License and welcomes contributions in the form of bug 5 | reports, feature requests, and pull requests. 6 | 7 | Here is a list of important resources for contributors: 8 | 9 | - [Source Code](https://github.com/dfm/numpyro-ext) 10 | - [Documentation](https://github.com/dfm/numpyro-ext) 11 | - [Issue Tracker](https://github.com/dfm/numpyro-ext/issues) 12 | 13 | ## How to report a bug 14 | 15 | Report bugs on the [Issue Tracker](https://github.com/dfm/numpyro-ext/issues). 16 | 17 | When filing an issue, make sure to answer these questions: 18 | 19 | - Which operating system and Python version are you using? 20 | - Which version of this project are you using? 21 | - What did you do? 22 | - What did you expect to see? 23 | - What did you see instead? 24 | 25 | The best way to get your bug fixed is to provide a test case, and/or steps to 26 | reproduce the issue. In particular, please include a [Minimal, Reproducible 27 | Example](https://stackoverflow.com/help/minimal-reproducible-example). 28 | 29 | ## How to request a feature 30 | 31 | Feel free to request features on the [Issue 32 | Tracker](https://github.com/dfm/numpyro-ext/issues). 33 | 34 | ## How to set up your development environment 35 | 36 | TODO 37 | 38 | ## How to test the project 39 | 40 | ```bash 41 | python -m pip install nox 42 | python -m nox 43 | ``` 44 | 45 | ## How to submit changes 46 | 47 | Open a [Pull Request](https://github.com/dfm/numpyro-ext/pulls). 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Extensions for NumPyro 2 | 3 | This library includes a miscellaneous set of helper functions, custom 4 | distributions, and other utilities that I find useful when using 5 | [NumPyro](https://num.pyro.ai) in my work. 6 | 7 | ## Installation 8 | 9 | Since NumPyro, and hence this library, are built on top of JAX, it's typically 10 | good practice to start by installing JAX following [the installation 11 | instructions](https://jax.readthedocs.io/en/latest/#installation). Then, you can 12 | install this library using pip: 13 | 14 | ```bash 15 | python -m pip install numpyro-ext 16 | ``` 17 | 18 | ## Usage 19 | 20 | Since this README is checked using `doctest`, let's start by importing some 21 | common modules that we'll need in all our examples: 22 | 23 | ```python 24 | >>> import jax 25 | >>> import jax.numpy as jnp 26 | >>> import numpyro 27 | >>> import numpyro_ext 28 | 29 | ``` 30 | 31 | ### Distributions 32 | 33 | The tradition is to import `numpyro_ext.distributions` as `distx` to 34 | differentiate from `numpyro.distributions`, which is imported as `dist`: 35 | 36 | ```python 37 | >>> from numpyro import distributions as dist 38 | >>> from numpyro_ext import distributions as distx 39 | >>> key = jax.random.PRNGKey(0) 40 | 41 | ``` 42 | 43 | #### Angle 44 | 45 | A uniform distribution over angles in radians. The actual sampling is performed 46 | in the two-dimensional vector space proportional to `(sin(theta), cos(theta))` 47 | so that the sampler doesn't see a discontinuity at pi. 48 | 49 | ```python 50 | >>> angle = distx.Angle() 51 | >>> print(angle.sample(key, (2, 3))) 52 | [[ 0.4...] 53 | [ 2.4...]] 54 | 55 | ``` 56 | 57 | #### UnitDisk 58 | 59 | A uniform distribution over two-dimensional points within the disk of radius 1. 60 | This means that the sum over squares of the last dimension of a random variable 61 | generated from this distribution will always be less than 1. 62 | 63 | ```python 64 | >>> unit_disk = distx.UnitDisk() 65 | >>> u = unit_disk.sample(key, (5,)) 66 | >>> print(jnp.sum(u**2, axis=-1)) 67 | [0.07...] 68 | 69 | ``` 70 | 71 | #### NoncentralChi2 72 | 73 | A [non-central chi-squared 74 | distribution](https://en.wikipedia.org/wiki/Noncentral_chi-squared_distribution). 75 | To use this distribution, you'll need to install the optional 76 | `tensorflow-probability` dependency. 77 | 78 | ```python 79 | >>> ncx2 = distx.NoncentralChi2(df=3, nc=2.) 80 | >>> print(ncx2.sample(key, (5,))) 81 | [2.19...] 82 | 83 | ``` 84 | 85 | #### MarginalizedLinear 86 | 87 | The marginalized product of two (possibly multivariate) normal distributions 88 | with a linear relationship between them. The mathematical details of these 89 | models are discussed in detail in [this note](https://arxiv.org/abs/2005.14199), 90 | and this distribution implements the math presented there, in a computationally 91 | efficient way, assuming that the number of marginalized parameters is small 92 | compared to the size of the dataset. 93 | 94 | The following example shows a particularly simple example of a 95 | fully-marginalized model for fitting a line to data: 96 | 97 | ```python 98 | >>> def model(x, y=None): 99 | ... design_matrix = jnp.vander(x, 2) 100 | ... prior = dist.Normal(0.0, 1.0) 101 | ... data = dist.Normal(0.0, 2.0) 102 | ... numpyro.sample( 103 | ... "y", 104 | ... distx.MarginalizedLinear(design_matrix, prior, data), 105 | ... obs=y 106 | ... ) 107 | ... 108 | 109 | ``` 110 | 111 | Things get a little more interesting when the design matrix and/or the 112 | distributions are functions of non-linear parameters. For example, if we want to 113 | find the period of a sinusoidal signal, also fitting for some unknown excess 114 | measurement uncertainty (often called "jitter") we can use the following model: 115 | 116 | ```python 117 | >>> def model(x, y_err, y=None): 118 | ... period = numpyro.sample("period", dist.Uniform(1.0, 250.0)) 119 | ... ln_jitter = numpyro.sample("ln_jitter", dist.Normal(0.0, 2.0)) 120 | ... design_matrix = jnp.stack( 121 | ... [ 122 | ... jnp.sin(2 * jnp.pi * x / period), 123 | ... jnp.cos(2 * jnp.pi * x / period), 124 | ... jnp.ones_like(x), 125 | ... ], 126 | ... axis=-1, 127 | ... ) 128 | ... prior = dist.Normal(0.0, 10.0).expand([3]) 129 | ... data = dist.Normal(0.0, jnp.sqrt(y_err**2 + jnp.exp(2*ln_jitter))) 130 | ... numpyro.sample( 131 | ... "y", 132 | ... distx.MarginalizedLinear(design_matrix, prior, data), 133 | ... obs=y 134 | ... ) 135 | ... 136 | >>> x = jnp.linspace(-1.0, 1.0, 5) 137 | >>> samples = numpyro.infer.Predictive(model, num_samples=2)(key, x, 0.1) 138 | >>> print(samples["period"]) 139 | [... ...] 140 | >>> print(samples["y"]) 141 | [[... ... ...] 142 | [... ... ...]] 143 | 144 | ``` 145 | 146 | It's often useful to also track conditional samples of the marginalized 147 | parameters during inference. The conditional distribution can be accessed using 148 | the `conditional` method on `MarginalizedLinear`: 149 | 150 | ```python 151 | >>> x = jnp.linspace(-1.0, 1.0, 5) 152 | >>> y = jnp.sin(x) # just some fake data 153 | >>> design_matrix = jnp.vander(x, 2) 154 | >>> prior = dist.Normal(0.0, 1.0) 155 | >>> data = dist.Normal(0.0, 2.0) 156 | >>> marg = distx.MarginalizedLinear(design_matrix, prior, data) 157 | >>> cond = marg.conditional(y) 158 | >>> print(type(cond).__name__) 159 | MultivariateNormal 160 | >>> print(cond.sample(key, (3,))) 161 | [[...] 162 | [...] 163 | [...]] 164 | 165 | ``` 166 | 167 | ### Optimization 168 | 169 | The inference lore is a little mixed on the benefits of optimization as an 170 | initialization tool for MCMC, but I find that at least in a lot of astronomy 171 | applications, an initial optimization can make a huge difference in performance. 172 | Even if you don't want to use the optimization results as an initialization, it 173 | can still sometimes be useful to numerically search for the maximum _a 174 | posteriori_ parameters for your model. However, the NumPyro interface for these 175 | types of optimization isn't terribly user-friendly, so this library provides 176 | some helpers to make it a little more straightforward. 177 | 178 | By default, this optimization uses the wrappers of scipy's optimization routines 179 | provided by the [JAXopt](https://github.com/google/jaxopt) library, so you'll 180 | need to install JAXopt: 181 | 182 | ```bash 183 | python -m pip install jaxopt 184 | ``` 185 | 186 | before running these examples. 187 | 188 | The following example shows a simple optimization of a model with a single 189 | parameter: 190 | 191 | ```python 192 | >>> from numpyro_ext import optim as optimx 193 | >>> 194 | >>> def model(y=None): 195 | ... x = numpyro.sample("x", dist.Normal(0.0, 1.0)) 196 | ... numpyro.sample("y", dist.Normal(x, 2.0), obs=y) 197 | ... 198 | >>> soln = optimx.optimize(model)(key, y=0.5) 199 | 200 | ``` 201 | 202 | By default, the optimization starts from a prior sample, but you can provide 203 | custom initial coordinates as follows: 204 | 205 | ```python 206 | >>> soln = optimx.optimize(model, start={"x": 12.3})(key, y=0.5) 207 | 208 | ``` 209 | 210 | Similarly, if you only want to optimize a subset of the parameters, you can 211 | provide a list of parameters to target: 212 | 213 | ```python 214 | >>> soln = optimx.optimize(model, sites=["x"])(key, y=0.5) 215 | 216 | ``` 217 | 218 | ### Information matrix computation 219 | 220 | The Fisher information matrix for models with Gaussian likelihoods is 221 | [straightforward to 222 | compute](https://en.wikipedia.org/wiki/Fisher_information#Multivariate_normal_distribution), 223 | and this library provides a helper function for automating this computation: 224 | 225 | ```python 226 | >>> from numpyro_ext import information 227 | >>> 228 | >>> def model(x, y=None): 229 | ... a = numpyro.sample("a", dist.Normal(0.0, 1.0)) 230 | ... b = numpyro.sample("b", dist.Normal(0.0, 1.0)) 231 | ... log_alpha = numpyro.sample("log_alpha", dist.Normal(0.0, 1.0)) 232 | ... cov = jnp.exp(log_alpha - 0.5 * (x[:, None] - x[None, :])**2) 233 | ... cov += 0.1 * jnp.eye(len(x)) 234 | ... numpyro.sample( 235 | ... "y", 236 | ... dist.MultivariateNormal(loc=a * x + b, covariance_matrix=cov), 237 | ... obs=y, 238 | ... ) 239 | ... 240 | >>> x = jnp.linspace(-1.0, 1.0, 5) 241 | >>> y = jnp.sin(x) # the input data just needs to have the right shape 242 | >>> params = {"a": 0.5, "b": -0.2, "log_alpha": -0.5} 243 | >>> info = information(model)(params, x, y=y) 244 | >>> print(info) 245 | {'a': {'a': ..., 'b': ... 'log_alpha': ...}, 'b': ...} 246 | 247 | ``` 248 | 249 | The returned information matrix is a nested dictionary of dictionaries, indexed 250 | by pairs of parameter names, where the values are the corresponding blocks of 251 | the information matrix. 252 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | import nox 2 | 3 | 4 | @nox.session 5 | def tests(session): 6 | session.install("-e", ".[test]") 7 | session.run("pytest", "-v", "tests") 8 | 9 | 10 | @nox.session 11 | def doctest(session): 12 | session.install("-e", ".[test,ncx2]") 13 | session.run( 14 | "python", 15 | "-m", 16 | "doctest", 17 | "-o", 18 | "ELLIPSIS", 19 | "-o", 20 | "NORMALIZE_WHITESPACE", 21 | "-v", 22 | "README.md", 23 | ) 24 | 25 | 26 | @nox.session 27 | def lint(session): 28 | session.install("pre-commit") 29 | session.run("pre-commit", "run", "--all-files") 30 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=62.0", "setuptools_scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "numpyro-ext" 7 | description = "A miscellaneous set of helper functions, custom distributions, and other utilities that I find useful when using NumPyro in my work" 8 | authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }] 9 | readme = "README.md" 10 | requires-python = ">=3.9" 11 | license = { text = "Apache License" } 12 | classifiers = [ 13 | "Operating System :: OS Independent", 14 | "Programming Language :: Python :: 3", 15 | "Development Status :: 4 - Beta", 16 | "License :: OSI Approved :: Apache Software License", 17 | ] 18 | dynamic = ["version"] 19 | dependencies = ["numpyro>=0.13.1"] 20 | 21 | [project.urls] 22 | "Homepage" = "https://github.com/dfm/numpyro-ext" 23 | "Source" = "https://github.com/dfm/numpyro-ext" 24 | "Bug Tracker" = "https://github.com/dfm/numpyro-ext/issues" 25 | 26 | [project.optional-dependencies] 27 | test = ["pytest", "jaxopt", "typing_extensions"] 28 | docs = [] 29 | ncx2 = ["tensorflow-probability"] 30 | 31 | [tool.setuptools_scm] 32 | write_to = "src/numpyro_ext/version.py" 33 | 34 | [tool.black] 35 | target-version = ["py39"] 36 | line-length = 88 37 | 38 | [tool.ruff] 39 | src = ["src"] 40 | line-length = 89 41 | target-version = "py38" 42 | select = ["F", "I", "E", "W", "YTT", "B", "Q", "PLE", "PLW"] 43 | ignore = [ 44 | "E741", # Allow ambiguous variable names (e.g. "l" in starry) 45 | "B023", # Allow using global variables in lambdas 46 | ] 47 | exclude = [] 48 | -------------------------------------------------------------------------------- /readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | submodules: 4 | include: all 5 | 6 | build: 7 | os: ubuntu-20.04 8 | tools: 9 | python: "3.10" 10 | 11 | python: 12 | install: 13 | - method: pip 14 | path: . 15 | extra_requirements: 16 | - docs 17 | 18 | sphinx: 19 | builder: dirhtml 20 | configuration: docs/conf.py 21 | fail_on_warning: true 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="numpyro-ext", 5 | packages=find_packages(where="src"), 6 | package_dir={"": "src"}, 7 | ) 8 | -------------------------------------------------------------------------------- /src/numpyro_ext/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Simons Foundation, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from numpyro_ext import distributions as distributions 16 | from numpyro_ext import infer as infer 17 | from numpyro_ext import optim as optim 18 | from numpyro_ext.info import information as information 19 | from numpyro_ext.version import __version__ as __version__ 20 | -------------------------------------------------------------------------------- /src/numpyro_ext/distributions.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "QuadLDParams", 3 | "UnitDisk", 4 | "Angle", 5 | "MixtureGeneral", 6 | "NoncentralChi2", 7 | "MarginalizedLinear", 8 | ] 9 | 10 | import jax 11 | import jax.numpy as jnp 12 | import jax.scipy as jsp 13 | import numpyro.distributions as dist 14 | from jax import lax 15 | from jax.scipy.linalg import cho_factor, cho_solve 16 | from numpyro.distributions import MixtureGeneral as MixtureGeneral 17 | from numpyro.distributions import constraints, transforms 18 | from numpyro.distributions.util import promote_shapes, validate_sample 19 | from numpyro.util import is_prng_key 20 | 21 | from numpyro_ext.linear_op import to_linear_op 22 | 23 | # --------------- 24 | # - Constraints - 25 | # --------------- 26 | 27 | 28 | class _QuadLDConstraint(constraints.Constraint): 29 | event_dim = 1 30 | 31 | def tree_flatten(self): 32 | return (), ((), dict()) 33 | 34 | def __call__(self, u): 35 | assert jnp.shape(u) == (2,) 36 | a = u[0] + u[1] < 1.0 37 | b = u[0] > 0.0 38 | c = u[0] + 2 * u[1] > 0.0 39 | return a & b & c 40 | 41 | def feasible_like(self, prototype): 42 | assert jnp.shape(prototype)[-1] == 2 43 | return QuadLDParams.q2u(jnp.full_like(prototype, 0.5)) 44 | 45 | 46 | class _UnitDiskConstraint(constraints.Constraint): 47 | event_dim = 1 48 | 49 | def tree_flatten(self): 50 | return (), ((), dict()) 51 | 52 | def __call__(self, x): 53 | assert jnp.shape(x) == (2,) 54 | return x[0] ** 2 + x[1] ** 2 <= 1.0 55 | 56 | def feasible_like(self, prototype): 57 | assert jnp.shape(prototype)[-1] == 2 58 | return jnp.zeros_like(prototype) 59 | 60 | 61 | class _AngleConstraint(constraints._Interval): 62 | def __init__(self, regularized=10.0): 63 | self.regularized = regularized 64 | super().__init__(-jnp.pi, jnp.pi) 65 | 66 | def tree_flatten(self): 67 | return (self.regularized,), (("regularized",), dict()) 68 | 69 | 70 | quad_ld = _QuadLDConstraint() 71 | unit_disk = _UnitDiskConstraint() 72 | angle = _AngleConstraint() 73 | 74 | # -------------- 75 | # - Transforms - 76 | # -------------- 77 | 78 | 79 | class QuadLDTransform(transforms.Transform): 80 | domain = constraints.independent(constraints.unit_interval, 1) 81 | codomain = quad_ld 82 | 83 | def tree_flatten(self): 84 | return (), ((), dict()) 85 | 86 | def __eq__(self, other): 87 | return isinstance(other, QuadLDTransform) 88 | 89 | def __call__(self, q): 90 | return QuadLDParams.q2u(q) 91 | 92 | def _inverse(self, u): 93 | return QuadLDParams.u2q(u) 94 | 95 | def log_abs_det_jacobian(self, x, y, intermediates=None): 96 | del y, intermediates 97 | return jnp.zeros_like(x[..., 0]) 98 | 99 | 100 | class UnitDiskTransform(transforms.Transform): 101 | domain = constraints.independent(constraints.interval(-1.0, 1.0), 1) 102 | codomain = unit_disk 103 | 104 | def tree_flatten(self): 105 | return (), ((), dict()) 106 | 107 | def __eq__(self, other): 108 | return isinstance(other, UnitDiskTransform) 109 | 110 | def __call__(self, x): 111 | assert jnp.ndim(x) >= 1 and jnp.shape(x)[-1] == 2 112 | return jnp.stack( 113 | ( 114 | x[..., 0], 115 | x[..., 1] * jnp.sqrt(1 - jnp.clip(x[..., 0], -1, 1) ** 2), 116 | ), 117 | axis=-1, 118 | ) 119 | 120 | def _inverse(self, y): 121 | assert jnp.ndim(y) >= 1 and jnp.shape(y)[-1] == 2 122 | return jnp.stack( 123 | ( 124 | y[..., 0], 125 | y[..., 1] / jnp.sqrt(1 - jnp.clip(y[..., 0], -1.0, 1.0) ** 2), 126 | ), 127 | axis=-1, 128 | ) 129 | 130 | def log_abs_det_jacobian(self, x, y, intermediates=None): 131 | del y, intermediates 132 | return 0.5 * jnp.log(1 - jnp.clip(x[..., 0], -1.0, 1.0) ** 2) 133 | 134 | 135 | class Arctan2Transform(transforms.Transform): 136 | domain = constraints.real_vector 137 | codomain = angle 138 | 139 | def __init__(self, regularized=None): 140 | self.regularized = regularized 141 | 142 | def tree_flatten(self): 143 | return (self.regularized,), (("regularized",), dict()) 144 | 145 | def __eq__(self, other): 146 | return isinstance(other, Arctan2Transform) 147 | 148 | def __call__(self, x): 149 | assert jnp.ndim(x) >= 1 and jnp.shape(x)[-1] == 2 150 | return jnp.arctan2(x[..., 0], x[..., 1]) 151 | 152 | def _inverse(self, y): 153 | return jnp.stack((jnp.sin(y), jnp.cos(y)), axis=-1) 154 | 155 | def log_abs_det_jacobian(self, x, y, intermediates=None): 156 | del y, intermediates 157 | sm = jnp.sum(jnp.square(x), axis=-1) 158 | if self.regularized is None: 159 | return -0.5 * sm 160 | return self.regularized * jnp.log(sm) - 0.5 * sm 161 | 162 | def forward_shape(self, shape): 163 | return shape[:-1] 164 | 165 | def inverse_shape(self, shape): 166 | return shape + (2,) 167 | 168 | 169 | @transforms.biject_to.register(quad_ld) 170 | def _(constraint): 171 | del constraint 172 | return transforms.ComposeTransform( 173 | [ 174 | transforms.SigmoidTransform(), 175 | QuadLDTransform(), 176 | ] 177 | ) 178 | 179 | 180 | @transforms.biject_to.register(unit_disk) 181 | def _(constraint): 182 | del constraint 183 | return transforms.ComposeTransform( 184 | [ 185 | transforms.SigmoidTransform(), 186 | transforms.AffineTransform(-1.0, 2.0, domain=constraints.unit_interval), 187 | UnitDiskTransform(), 188 | ] 189 | ) 190 | 191 | 192 | @transforms.biject_to.register(angle) 193 | def _(constraint): 194 | return Arctan2Transform(regularized=constraint.regularized) 195 | 196 | 197 | # ----------------- 198 | # - Distributions - 199 | # ----------------- 200 | 201 | 202 | class QuadLDParams(dist.Distribution): 203 | """An uninformative prior for quadratic limb darkening parameters 204 | 205 | This is an implementation of the `Kipping (2013) 206 | `_ re-parameterization of the two-parameter 207 | limb darkening model to allow for efficient and uninformative sampling. 208 | """ 209 | 210 | support = quad_ld 211 | 212 | def __init__(self, *, validate_args=None): 213 | super().__init__(batch_shape=(), event_shape=(2,), validate_args=validate_args) 214 | 215 | def sample(self, key, sample_shape=()): 216 | assert is_prng_key(key) 217 | q = jax.random.uniform(key, shape=sample_shape + (2,), minval=0, maxval=1) 218 | return QuadLDParams.q2u(q) 219 | 220 | @validate_sample 221 | def log_prob(self, value): 222 | return jnp.zeros_like(value[..., 0]) 223 | 224 | @property 225 | def mean(self): 226 | return jnp.array([2.0 / 3.0, 0.0]) 227 | 228 | @property 229 | def variance(self): 230 | return jnp.array([2.0 / 9.0, 1.0 / 6.0]) 231 | 232 | @staticmethod 233 | def q2u(q): 234 | assert jnp.ndim(q) >= 1 and jnp.shape(q)[-1] == 2 235 | q1 = jnp.sqrt(q[..., 0]) 236 | q2 = 2 * q[..., 1] 237 | u1 = q1 * q2 238 | u2 = q1 * (1 - q2) 239 | return jnp.stack((u1, u2), axis=-1) 240 | 241 | @staticmethod 242 | def u2q(u): 243 | assert jnp.ndim(u) >= 1 and jnp.shape(u)[-1] == 2 244 | u1 = u[..., 0] 245 | u2 = u1 + u[..., 1] 246 | q1 = jnp.square(u2) 247 | q2 = 0.5 * u1 / u2 248 | return jnp.stack((q1, q2), axis=-1) 249 | 250 | 251 | class UnitDisk(dist.Distribution): 252 | """Two dimensional parameters constrained to live within the unit disk""" 253 | 254 | support = unit_disk 255 | 256 | def __init__(self, *, validate_args=None): 257 | super().__init__(batch_shape=(), event_shape=(2,), validate_args=validate_args) 258 | 259 | def sample(self, key, sample_shape=()): 260 | assert is_prng_key(key) 261 | key1, key2 = jax.random.split(key) 262 | theta = jax.random.uniform( 263 | key1, shape=sample_shape, minval=-jnp.pi, maxval=jnp.pi 264 | ) 265 | r = jnp.sqrt( 266 | jax.random.uniform(key2, shape=sample_shape, minval=0.0, maxval=1.0) 267 | ) 268 | return jnp.stack((r * jnp.cos(theta), r * jnp.sin(theta)), axis=-1) 269 | 270 | @validate_sample 271 | def log_prob(self, value): 272 | del value 273 | return -jnp.log(jnp.pi) 274 | 275 | @property 276 | def mean(self): 277 | return jnp.array([0.0, 0.0]) 278 | 279 | @property 280 | def variance(self): 281 | return jnp.array([0.25, 0.25]) 282 | 283 | 284 | class Angle(dist.Distribution): 285 | """An angle constrained to be in the range -pi to pi 286 | 287 | The actual sampling is performed in the two dimensional vector space 288 | proportional to ``(sin(theta), cos(theta))`` so that the sampler doesn't see 289 | a discontinuity at pi. 290 | 291 | The ``regularized`` parameter can be used to improve sampling performance 292 | when the value of the angle is well constrained. It removes prior mass near 293 | the origin in the sampling space, which can lead to bad geometry when the 294 | angle is poorly constrained, but better performance when it is. The default 295 | value of ``10.0`` is a good starting point. 296 | """ 297 | 298 | def __init__(self, *, regularized=10.0, validate_args=None): 299 | self.regularized = regularized 300 | super().__init__(batch_shape=(), event_shape=(), validate_args=validate_args) 301 | 302 | @constraints.dependent_property 303 | def support(self): 304 | return _AngleConstraint(self.regularized) 305 | 306 | def sample(self, key, sample_shape=()): 307 | assert is_prng_key(key) 308 | return jax.random.uniform( 309 | key, shape=sample_shape, minval=-jnp.pi, maxval=jnp.pi 310 | ) 311 | 312 | @validate_sample 313 | def log_prob(self, value): 314 | del value 315 | return -jnp.log(jnp.pi) 316 | 317 | @property 318 | def mean(self): 319 | return 0.0 320 | 321 | @property 322 | def variance(self): 323 | return jnp.pi**2 / 12.0 324 | 325 | def cdf(self, value): 326 | cdf = (value + 0.5 * jnp.pi) / jnp.pi 327 | return jnp.clip(cdf, a_min=0.0, a_max=1.0) 328 | 329 | def icdf(self, value): 330 | return (value - 0.5) * jnp.pi 331 | 332 | 333 | class NoncentralChi2(dist.Distribution): 334 | arg_constraints = { 335 | "df": constraints.positive, 336 | "nc": constraints.positive, 337 | } 338 | support = constraints.positive 339 | reparametrized_params = ["df", "nc"] 340 | 341 | def __init__(self, df, nc, validate_args=None): 342 | self.df, self.nc = promote_shapes(df, nc) 343 | batch_shape = lax.broadcast_shapes(jnp.shape(df), jnp.shape(nc)) 344 | super(NoncentralChi2, self).__init__( 345 | batch_shape=batch_shape, validate_args=validate_args 346 | ) 347 | 348 | def sample(self, key, sample_shape=()): 349 | # Ref: https://github.com/numpy/numpy/blob/ 350 | # 89c80ba606f4346f8df2a31cfcc0e967045a68ed/numpy/ 351 | # random/src/distributions/distributions.c#L797-L813 352 | 353 | def _random_chi2(key, df, shape=(), dtype=jnp.float_): 354 | return 2.0 * jax.random.gamma(key, 0.5 * df, shape=shape, dtype=dtype) 355 | 356 | assert is_prng_key(key) 357 | shape = sample_shape + self.batch_shape + self.event_shape 358 | 359 | key1, key2, key3 = jax.random.split(key, 3) 360 | i = jax.random.poisson(key1, 0.5 * self.nc, shape=shape) 361 | n = jax.random.normal(key2, shape=shape) + jnp.sqrt(self.nc) 362 | cond = jnp.greater(self.df, 1.0) 363 | chi2 = _random_chi2( 364 | key3, 365 | jnp.where(cond, self.df - 1.0, self.df + 2.0 * i), 366 | shape=shape, 367 | ) 368 | return jnp.where(cond, chi2 + n * n, chi2) 369 | 370 | @validate_sample 371 | def log_prob(self, value): 372 | try: 373 | import tensorflow_probability.substrates.jax as tfp 374 | except ImportError as e: 375 | raise ImportError( 376 | "tensorflow-probability is must be installed to use the " 377 | "NoncentralChi2 distribution." 378 | ) from e 379 | 380 | # Ref: https://github.com/scipy/scipy/blob/ 381 | # 500878e88eacddc7edba93dda7d9ee5f784e50e6/scipy/ 382 | # stats/_distn_infrastructure.py#L597-L610 383 | df2 = self.df / 2.0 - 1.0 384 | xs, ns = jnp.sqrt(value), jnp.sqrt(self.nc) 385 | res = jsp.special.xlogy(df2 / 2.0, value / self.nc) - 0.5 * (xs - ns) ** 2 386 | corr = tfp.math.bessel_ive(df2, xs * ns) / 2.0 387 | return jnp.where( 388 | jnp.greater(corr, 0.0), 389 | res + jnp.log(corr), 390 | -jnp.inf, 391 | ) 392 | 393 | @property 394 | def mean(self): 395 | return self.df + self.nc 396 | 397 | @property 398 | def variance(self): 399 | return 2.0 * (self.df + 2.0 * self.nc) 400 | 401 | 402 | class MarginalizedLinear(dist.Distribution): 403 | arg_constraints = {"design_matrix": constraints.real} 404 | support = constraints.real_vector 405 | reparametrized_params = ["design_matrix"] 406 | 407 | def __init__( 408 | self, 409 | design_matrix, 410 | prior_distribution, 411 | data_distribution, 412 | *, 413 | validate_args=None, 414 | ): 415 | # We treat the trailing dimensions of the design matrix as "ground 416 | # truth" for the dimensions of the problem. 417 | if jnp.ndim(design_matrix) < 2: 418 | raise ValueError("The design matrix must have at least 2 dimensions") 419 | data_size, latent_size = jnp.shape(design_matrix)[-2:] 420 | 421 | # We don't really care about the batch vs. event shapes of the input 422 | # distributions, so instead we just check that the trailing dimensions 423 | # are correct. 424 | prior_dist_shape = tuple(prior_distribution.batch_shape) + tuple( 425 | prior_distribution.event_shape 426 | ) 427 | if len(prior_dist_shape) != 0 and prior_dist_shape[-1] != latent_size: 428 | raise ValueError( 429 | "The trailing dimensions of the prior distribution must match " 430 | "the latent dimension defined by the design matrix; expected " 431 | f"(..., {latent_size}), got {prior_dist_shape}" 432 | ) 433 | 434 | data_dist_shape = tuple(data_distribution.batch_shape) + tuple( 435 | data_distribution.event_shape 436 | ) 437 | if len(data_dist_shape) != 0 and data_dist_shape[-1] != data_size: 438 | raise ValueError( 439 | "The trailing dimensions of the data distribution must match " 440 | "the data dimension defined by the design matrix; expected " 441 | f"(..., {data_size}), got {data_dist_shape}" 442 | ) 443 | 444 | # We broadcast the relevant batch shapes to find the batch shape of this 445 | # distribution, and expand or reshape all the members to match. 446 | batch_shape = lax.broadcast_shapes( 447 | design_matrix.shape[:-2], 448 | prior_dist_shape[:-1], 449 | data_dist_shape[:-1], 450 | ) 451 | event_shape = (data_size,) 452 | self.design_matrix = jnp.broadcast_to( 453 | design_matrix, batch_shape + (data_size, latent_size) 454 | ) 455 | 456 | if prior_distribution.event_shape == (): 457 | self.prior_distribution = prior_distribution.expand( 458 | batch_shape + (latent_size,) 459 | ) 460 | else: 461 | self.prior_distribution = prior_distribution.expand(batch_shape) 462 | 463 | if data_distribution.event_shape == (): 464 | self.data_distribution = data_distribution.expand( 465 | batch_shape + (data_size,) 466 | ) 467 | else: 468 | self.data_distribution = data_distribution.expand(batch_shape) 469 | 470 | super().__init__( 471 | batch_shape=batch_shape, 472 | event_shape=event_shape, 473 | validate_args=validate_args, 474 | ) 475 | 476 | # Convert the distributions to linear ops 477 | self.prior_linear_op = to_linear_op(self.prior_distribution) 478 | self.data_linear_op = to_linear_op(self.data_distribution) 479 | 480 | # This inner matrix is used for both the matrix determinant and inverse 481 | self.projected_design_matrix = self.data_linear_op.solve_tril( 482 | self.design_matrix, False 483 | ) 484 | self.conditional_precision_matrix = ( 485 | self.prior_linear_op.inverse() 486 | + _inner_product(self.projected_design_matrix, self.projected_design_matrix) 487 | ) 488 | self.conditional_inv_tril, _ = cho_factor( 489 | self.conditional_precision_matrix, lower=True 490 | ) 491 | 492 | def sample_with_intermediates(self, key, sample_shape=()): 493 | assert is_prng_key(key) 494 | prior_key, data_key = jax.random.split(key) 495 | prior_sample = self.prior_distribution.sample( 496 | prior_key, sample_shape=sample_shape 497 | ) 498 | data_sample = self.data_distribution.sample(data_key, sample_shape=sample_shape) 499 | delta = jnp.einsum("...ij,...j->...i", self.design_matrix, prior_sample) 500 | return data_sample + delta, [prior_sample] 501 | 502 | def sample(self, key, sample_shape=()): 503 | return self.sample_with_intermediates(key, sample_shape)[0] 504 | 505 | def _get_alpha(self, value): 506 | data_size = jnp.shape(self.design_matrix)[-2] 507 | assert jnp.shape(value)[-1] == data_size 508 | return self.data_linear_op.solve_tril((value - self.mean)[..., None], False) 509 | 510 | @validate_sample 511 | def log_prob(self, value): 512 | data_size = jnp.shape(self.design_matrix)[-2] 513 | 514 | # Use the matrix determinant lemma to compute the full determinant 515 | hld = jnp.sum( 516 | jnp.log(jnp.diagonal(self.conditional_inv_tril, axis1=-2, axis2=-1)), -1 517 | ) 518 | norm = ( 519 | hld 520 | + self.prior_linear_op.half_log_det() 521 | + self.data_linear_op.half_log_det() 522 | ) 523 | 524 | # Use the Woodbury matrix identity to solve the linear system 525 | alpha = self._get_alpha(value) 526 | result = _inner_product(alpha, alpha) 527 | alpha = _inner_product(self.projected_design_matrix, alpha) 528 | result -= _inner_product( 529 | alpha, cho_solve((self.conditional_inv_tril, True), alpha) 530 | ) 531 | log_prob = ( 532 | -0.5 * result[..., 0, 0] - norm - 0.5 * data_size * jnp.log(2 * jnp.pi) 533 | ) 534 | 535 | return log_prob 536 | 537 | def conditional(self, value=None): 538 | if value is None: 539 | return self.prior_distribution 540 | 541 | # TODO(dfm): The following two lines are also used in `log_prob`. 542 | # They're probably not the bottleneck, but it is interesting to think 543 | # about how we could avoid the duplication, since we typically want both 544 | # distributions for the same `value``. 545 | alpha = self._get_alpha(value) 546 | alpha = _inner_product(self.projected_design_matrix, alpha) 547 | a = self.prior_linear_op.loc()[..., None] 548 | a = self.prior_linear_op.solve_tril(a, False) 549 | a = self.prior_linear_op.solve_tril(a, True) 550 | a = cho_solve((self.conditional_inv_tril, True), (a + alpha)[..., 0]) 551 | 552 | return dist.MultivariateNormal( 553 | loc=a, precision_matrix=self.conditional_precision_matrix 554 | ) 555 | 556 | def tree_flatten(self): 557 | prior_flat, prior_aux = self.prior_distribution.tree_flatten() 558 | data_flat, data_aux = self.data_distribution.tree_flatten() 559 | return (self.design_matrix, prior_flat, data_flat), ( 560 | type(self.prior_distribution), 561 | prior_aux, 562 | type(self.data_distribution), 563 | data_aux, 564 | ) 565 | 566 | @classmethod 567 | def tree_unflatten(cls, aux_data, params): 568 | design_matrix, prior_flat, data_flat = params 569 | prior_dist = aux_data[0].tree_unflatten(aux_data[1], prior_flat) 570 | data_dist = aux_data[2].tree_unflatten(aux_data[3], data_flat) 571 | return cls( 572 | design_matrix=design_matrix, 573 | prior_distribution=prior_dist, 574 | data_distribution=data_dist, 575 | ) 576 | 577 | @property 578 | def mean(self): 579 | mu = self.prior_distribution.mean[..., None] 580 | mu = jnp.broadcast_to(mu, self.batch_shape + mu.shape[-2:]) 581 | return self.data_distribution.mean + (self.design_matrix @ mu)[..., 0] 582 | 583 | @property 584 | def covariance_matrix(self): 585 | prior = to_linear_op(self.prior_distribution) 586 | data = to_linear_op(self.data_distribution) 587 | return data.cov() + self.design_matrix @ prior.cov() @ jnp.swapaxes( 588 | self.design_matrix, -2, -1 589 | ) 590 | 591 | 592 | def _inner_product(a, b): 593 | aT = jnp.swapaxes(a, -2, -1) 594 | return aT @ b 595 | -------------------------------------------------------------------------------- /src/numpyro_ext/infer.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from numpyro import handlers, infer 3 | 4 | 5 | def prior_sample(model, num_samples): 6 | pred = infer.Predictive(model, num_samples=num_samples) 7 | 8 | def sample(rng_key, *args, **kwargs): 9 | # Generate samples from the prior 10 | samples = pred(rng_key, *args, **kwargs) 11 | 12 | # The log likelihood function for a single sample, which we will vmap 13 | # over the prior samples. Note that we could potentially also use 14 | # numpyro.infer.util.log_likelihood, but that seems to fail when the 15 | # model includes "factor" nodes so here we just implement it ourselves. 16 | def log_like_fn(sample): 17 | trace = handlers.trace(handlers.substitute(model, sample)).get_trace( 18 | *args, **kwargs 19 | ) 20 | result = 0.0 21 | for site in trace.values(): 22 | if site["type"] == "sample" and site["is_observed"]: 23 | result += site["fn"].log_prob(site["value"]).sum() 24 | return result 25 | 26 | log_like = jax.vmap(log_like_fn)(samples) 27 | return samples, log_like 28 | 29 | return sample 30 | -------------------------------------------------------------------------------- /src/numpyro_ext/info.py: -------------------------------------------------------------------------------- 1 | __all__ = ["information"] 2 | 3 | from functools import partial 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | from jax.flatten_util import ravel_pytree 8 | from numpyro import handlers, infer 9 | 10 | from numpyro_ext.linear_op import to_linear_op 11 | 12 | 13 | def standardize(d): 14 | op = to_linear_op(d) 15 | return op.solve_tril(op.loc(), False) 16 | 17 | 18 | def _is_conditioned(site): 19 | return site["is_observed"] and not site["infer"].get("is_auxiliary", False) 20 | 21 | 22 | def _information_and_log_prior_hessian( 23 | model, 24 | params, 25 | model_args=(), 26 | model_kwargs=None, 27 | invert=False, 28 | include_prior=False, 29 | unconstrained=False, 30 | ): 31 | model_kwargs = {} if model_kwargs is None else model_kwargs 32 | 33 | # Determine which parameters are sampled 34 | trace = handlers.trace(handlers.substitute(model, data=params)).get_trace( 35 | *model_args, **model_kwargs 36 | ) 37 | base_params = {} 38 | for site in trace.values(): 39 | if site["type"] != "sample" or site["is_observed"]: 40 | continue 41 | if site["name"] not in params: 42 | raise KeyError(f"Input params is missing the site called '{site['name']}'") 43 | base_params[site["name"]] = params[site["name"]] 44 | 45 | # This function computes the terms of the likelihood that we will 46 | # differentiate to get the information matrix, and the log prior 47 | def impl(params, unravel, model, model_args, model_kwargs): 48 | params = unravel(params) 49 | if unconstrained: 50 | substituted_model = handlers.substitute( 51 | model, 52 | substitute_fn=partial(infer.util._unconstrain_reparam, params), 53 | ) 54 | else: 55 | substituted_model = handlers.substitute(model, data=params) 56 | 57 | trace = handlers.trace(substituted_model).get_trace(*model_args, **model_kwargs) 58 | 59 | info_terms = [] 60 | log_prior = jnp.zeros(()) 61 | for site in trace.values(): 62 | if site["type"] != "sample": 63 | continue 64 | 65 | # If a site is observed, we need to include it in the information 66 | # computation, but some sites will be labeled as observed when 67 | # they're actually the Jacobians of transforms. In these cases, we 68 | # want to include them in the prior instead. 69 | if _is_conditioned(site): 70 | info_terms.append(standardize(site["fn"])) 71 | else: 72 | log_prior += jnp.sum(site["fn"].log_prob(site["value"])) 73 | 74 | return tuple(info_terms) or (0.0,), log_prior 75 | 76 | flat_params, unravel = ravel_pytree(base_params) 77 | 78 | # Compute the Jacobian of the model to evaluate the information matrix 79 | Js = jax.jacobian(lambda *args: impl(*args)[0])( 80 | flat_params, unravel, model, model_args, model_kwargs 81 | ) 82 | F = jnp.zeros((flat_params.shape[0], flat_params.shape[0])) 83 | for J in Js: 84 | F += jnp.einsum("...n,...m->nm", J, J) 85 | 86 | # Compute the Hessian of the log prior function 87 | if include_prior: 88 | H = jax.hessian(lambda *args: impl(*args)[1])( 89 | flat_params, unravel, model, model_args, model_kwargs 90 | ) 91 | 92 | # Combine the two 93 | F = F - H 94 | 95 | if invert: 96 | F = jnp.linalg.inv(F) 97 | 98 | def unravel_batched(row): 99 | if jnp.ndim(row) == 1: 100 | return unravel(row) 101 | func = unravel 102 | for n in range(1, jnp.ndim(row)): 103 | func = jax.vmap(func, in_axes=(n,)) 104 | return func(row) 105 | 106 | return jax.tree_util.tree_map(unravel_batched, jax.vmap(unravel)(F)) 107 | 108 | 109 | def information(model, invert=False, include_prior=False, unconstrained=False): 110 | """Compute the Fisher information matrix for a NumPyro model 111 | 112 | Note that this only supports a limited set of observation sites. By default, 113 | this requires either ``Normal`` or ``MultivariateNormal`` distributions for 114 | observed sites, but custom distributions can be supported by registering a 115 | custom ``standardize`` transformation. Take a look at the source code for 116 | ``numpyro_ext.info.standardize`` for some examples. 117 | 118 | Args: 119 | model: The NumPyro model definition. 120 | invert: If ``True``, the inverse information matrix will be returned. 121 | include_prior: If ``True``, the Hessian of the log prior will be 122 | subtracted from the information matrix. 123 | unconstrained: If ``True``, the parameters are assumed to be in the 124 | unconstrained space and the information is computed in that space. 125 | 126 | Returns: 127 | A callable with the signature ``def info(params, *args, **kwargs)`` to 128 | compute the information matrix, where ``params`` is a dictionary of the 129 | parameters where the information will be computed, and the other 130 | arguments are the static arguments for ``model``. 131 | 132 | """ 133 | 134 | return lambda params, *args, **kwargs: _information_and_log_prior_hessian( 135 | model, 136 | params, 137 | model_args=args, 138 | model_kwargs=kwargs, 139 | invert=invert, 140 | include_prior=include_prior, 141 | unconstrained=unconstrained, 142 | ) 143 | -------------------------------------------------------------------------------- /src/numpyro_ext/linear_op.py: -------------------------------------------------------------------------------- 1 | from functools import singledispatch 2 | from typing import Any, Callable, NamedTuple 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from jax import lax 7 | from jax.scipy.linalg import cho_solve, solve_triangular 8 | from numpyro.distributions import ExpandedDistribution, MultivariateNormal, Normal 9 | 10 | if hasattr(jax, "Array"): 11 | Array = jax.Array 12 | else: 13 | Array = Any 14 | 15 | 16 | class LinearOp(NamedTuple): 17 | loc: Callable[[], Array] 18 | covariance: Callable[[], Array] 19 | inverse: Callable[[], Array] 20 | solve_tril: Callable[[Array, bool], Array] 21 | half_log_det: Callable[[], Array] 22 | 23 | 24 | @singledispatch 25 | def to_linear_op(dist) -> LinearOp: 26 | raise ValueError(f"{type(dist)} doesn't support the 'to_linear_op' interface") 27 | 28 | 29 | @to_linear_op.register(Normal) 30 | def _(dist): 31 | scale = jnp.broadcast_to(dist.scale, dist.shape()) 32 | 33 | def loc(): 34 | return jnp.broadcast_to(dist.loc, dist.shape()) 35 | 36 | def covariance(): 37 | return jnp.vectorize(jnp.diag, signature="(n)->(n,n)")( 38 | jnp.square(jnp.atleast_1d(scale)) 39 | ) 40 | 41 | def inverse(): 42 | return jnp.vectorize(jnp.diag, signature="(n)->(n,n)")( 43 | 1.0 / jnp.square(jnp.atleast_1d(scale)) 44 | ) 45 | 46 | def solve_tril(y, transpose): 47 | del transpose 48 | return y / jnp.atleast_1d(scale)[..., None] 49 | 50 | def half_log_det(): 51 | return jnp.sum(jnp.log(jnp.atleast_1d(scale)), axis=-1) 52 | 53 | return LinearOp(loc, covariance, inverse, solve_tril, half_log_det) 54 | 55 | 56 | @to_linear_op.register(MultivariateNormal) 57 | def _(dist): 58 | def loc(): 59 | return jnp.broadcast_to(dist.loc, dist.shape()) 60 | 61 | def covariance(): 62 | return dist.covariance_matrix 63 | 64 | def inverse(): 65 | y = jnp.broadcast_to( 66 | jnp.eye(dist.scale_tril.shape[-1]), dist.covariance_matrix.shape 67 | ) 68 | return cho_solve((dist.scale_tril, True), y) 69 | 70 | def solve_tril(y, transpose): 71 | return solve_triangular(dist.scale_tril, y, trans=transpose, lower=True) 72 | 73 | def half_log_det(): 74 | return jnp.sum(jnp.log(jnp.diagonal(dist.scale_tril, axis1=-2, axis2=-1)), -1) 75 | 76 | return LinearOp(loc, covariance, inverse, solve_tril, half_log_det) 77 | 78 | 79 | @to_linear_op.register(ExpandedDistribution) 80 | def _(dist): 81 | ( 82 | base_loc, 83 | base_covariance, 84 | base_inverse, 85 | base_solve_tril, 86 | base_half_log_det, 87 | ) = to_linear_op(dist.base_dist) 88 | shape = dist.batch_shape + dist.event_shape 89 | batch_shape = shape[:-1] 90 | event_shape = shape[-1:] 91 | 92 | def loc(): 93 | mu = base_loc() 94 | return jnp.broadcast_to(mu, shape) 95 | 96 | def covariance(): 97 | cov = base_covariance() 98 | if cov.shape[-1:] != event_shape: 99 | assert cov.shape[-1] == 1 100 | cov = jnp.eye(event_shape[0]) * cov 101 | return jnp.broadcast_to(cov, batch_shape + event_shape + event_shape) 102 | 103 | def inverse(): 104 | inv = base_inverse() 105 | if inv.shape[-1:] != event_shape: 106 | assert inv.shape[-1] == 1 107 | inv = jnp.eye(event_shape[0]) * inv 108 | return jnp.broadcast_to(inv, batch_shape + event_shape + event_shape) 109 | 110 | def solve_tril(y, transpose): 111 | if jnp.ndim(y) < 2: 112 | raise ValueError( 113 | "An expanded linear operator's inverse is only defined for matrices" 114 | ) 115 | 116 | shape = lax.broadcast_shapes( 117 | batch_shape, 118 | jnp.shape(y)[: max(jnp.ndim(y) - 2, 0)], 119 | ) 120 | alpha = jnp.vectorize( 121 | lambda x: base_solve_tril(x, transpose), signature="(m,k)->(m,k)" 122 | )(y) 123 | return jnp.broadcast_to(alpha, shape + y.shape[-2:]) 124 | 125 | def half_log_det(): 126 | hld = base_half_log_det() 127 | # Special case for scalar base distribution 128 | if dist.base_dist.shape() == (): 129 | hld *= event_shape[0] 130 | return jnp.broadcast_to(hld, batch_shape) 131 | 132 | return LinearOp(loc, covariance, inverse, solve_tril, half_log_det) 133 | -------------------------------------------------------------------------------- /src/numpyro_ext/optim.py: -------------------------------------------------------------------------------- 1 | __all__ = ["optimize", "JAXOptMinimize"] 2 | 3 | from contextlib import ExitStack 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpyro 8 | from jax.tree_util import tree_map 9 | from numpyro import distributions as dist 10 | from numpyro import infer 11 | from numpyro.infer.initialization import init_to_median, init_to_value 12 | from numpyro.optim import _NumPyroOptim 13 | 14 | 15 | def optimize( 16 | model, 17 | sites=None, 18 | start=None, 19 | *, 20 | init_strategy=None, 21 | optimizer=None, 22 | num_steps=1, 23 | include_deterministics=True, 24 | return_info=False, 25 | ): 26 | """Numerically maximize the log probability of a NumPyro model 27 | 28 | The main feature that this interface supports is that it enables optimizing 29 | a subset of the parameters in an automated fashion, something which can be 30 | tricky with the built in NumPyro functions. 31 | 32 | Example: 33 | 34 | .. code-block:: python 35 | 36 | def model(x, yerr, y=None): 37 | A = jnp.vander(x, 2) 38 | w = numpyro.sample("w", dist.Normal(0.0, 1.0).expand([2])) 39 | mu = numpyro.deterministic("mu", A @ w) 40 | numpyro.sample("y", dist.Normal(mu, yerr), obs=y) 41 | 42 | run_optim = optim.optimize(model) 43 | param = run_optim(jax.random.PRNGKey(0), x, yerr, y=y) 44 | 45 | Args: 46 | model: The NumPyro model definition. 47 | sites: A list of the site names to vary, keeping the others fixed. By 48 | default, all parameters are varied. 49 | start: A dictionary of initial site values keyed by site name. For 50 | sites not included in ``sites``, this will be the fixed value used 51 | for that site. 52 | init_strategy: If ``start`` is provided, this will be ignored. 53 | Otherwise, this specifies the initial values for the sites in the 54 | optimization. By default, this take the value ``init_to_median``. 55 | optimizer: A NumPyro optimizer object to use as the optimization engine. 56 | By default this uses a ``JAXOptMinimize`` optimizer. 57 | num_steps: The number of optimization steps to run. The default 58 | ``JAXOptMinimize`` optimizer only requires one step, so this is the 59 | default. 60 | include_deterministics: If ``True``, return the values of the 61 | deterministics computed at the optimized parameters, in addition to 62 | the parameter values. 63 | return_info: If ``True``, the returned function will return a tuple with 64 | the parameters as the first element, and scipy's minimization status 65 | as the second element. 66 | 67 | Returns: 68 | A callable that will execute the optimization routine, with the 69 | signature ``run(random_key, *args, **kwargs)`` where ``random_key`` is a 70 | ``jax.random.PRNGKey``, and ``*args`` and ``**kwargs`` are the static 71 | arguments for ``model``. 72 | """ 73 | if start is not None: 74 | init_strategy = init_to_value(values=start) 75 | elif init_strategy is None: 76 | init_strategy = init_to_median() 77 | 78 | optimizer = JAXOptMinimize() if optimizer is None else optimizer 79 | guide = AutoDelta(model, sites=sites, init_loc_fn=init_strategy) 80 | svi = infer.SVI(model, guide, optimizer, loss=infer.Trace_ELBO()) 81 | 82 | def run(rng_key, *args, **kwargs): 83 | init_key, sample_key, pred_key = jax.random.split(rng_key, 3) 84 | state = svi.init(init_key, *args, **kwargs) 85 | for _ in range(num_steps): 86 | state, _ = svi.update(state, *args, **kwargs) 87 | info = getattr(state.optim_state[1], "state", None) 88 | params = svi.get_params(state) 89 | sample = guide.sample_posterior(sample_key, params) 90 | if include_deterministics: 91 | pred = tree_map( 92 | lambda x: x[0], 93 | infer.Predictive(model, tree_map(lambda x: x[None], sample))( 94 | pred_key, *args, **kwargs 95 | ), 96 | ) 97 | sample = dict(sample, **pred) 98 | if return_info: 99 | return sample, info 100 | return sample 101 | 102 | return run 103 | 104 | 105 | def _jaxopt_wrapper(): 106 | def init_fn(params): 107 | from jaxopt._src.scipy_wrappers import ScipyMinimizeInfo 108 | from jaxopt.base import OptStep 109 | 110 | return OptStep( 111 | params=params, 112 | state=ScipyMinimizeInfo( 113 | fun_val=jnp.zeros(()), 114 | success=False, 115 | status=0, 116 | iter_num=0, 117 | hess_inv=None, 118 | ), 119 | ) 120 | 121 | def update_fn(i, grad_tree, opt_state): 122 | return opt_state 123 | 124 | def get_params_fn(opt_state): 125 | return opt_state.params 126 | 127 | return init_fn, update_fn, get_params_fn 128 | 129 | 130 | class JAXOptMinimize(_NumPyroOptim): 131 | """A NumPyro-compatible optimizer built using jaxopt.ScipyMinimize 132 | 133 | This exposes the ``ScipyMinimize`` optimizer from ``jaxopt`` to NumPyro. All 134 | keyword arguments are passed directly to ``jaxopt.ScipyMinimize``. 135 | """ 136 | 137 | def __init__(self, **kwargs): 138 | try: 139 | pass 140 | except ImportError as e: 141 | raise ImportError("jaxopt must be installed to use JAXOptMinimize") from e 142 | 143 | super().__init__(_jaxopt_wrapper) 144 | self.solver_kwargs = {} if kwargs is None else kwargs 145 | 146 | def eval_and_update(self, fn, in_state, forward_mode_differentiation=False): 147 | import scipy.optimize # noqa 148 | from jaxopt import ScipyMinimize 149 | 150 | def loss(p): 151 | out, aux = fn(p) 152 | if aux is not None: 153 | raise ValueError( 154 | "JAXOptMinimize does not support models with mutable states." 155 | ) 156 | return out 157 | 158 | if forward_mode_differentiation: 159 | raise ValueError( 160 | "Forward mode differentiation is not implemented for JaxOptMinimze" 161 | ) 162 | 163 | solver = ScipyMinimize(fun=loss, **self.solver_kwargs) 164 | out_state = solver.run(self.get_params(in_state)) 165 | return (out_state.state.fun_val, None), (in_state[0] + 1, out_state) 166 | 167 | 168 | class AutoDelta(infer.autoguide.AutoDelta): 169 | """A MAP autoguide with support for keeping some sites fixed 170 | 171 | This is an extension of ``numpyro.infer.autoguide.AutoDelta`` that adds 172 | support for only varying a subset of sites. All arguments except for 173 | ``sites`` are passed directly to the upstream implementation. 174 | 175 | Args: 176 | sites: A list of the site names to vary, keeping the others fixed. By 177 | default, all parameters are varied. 178 | """ 179 | 180 | def __init__( 181 | self, 182 | model, 183 | sites=None, 184 | *, 185 | prefix="auto", 186 | init_loc_fn=infer.init_to_median, 187 | create_plates=None, 188 | ): 189 | self._sites = sites 190 | super().__init__( 191 | model, 192 | prefix=prefix, 193 | init_loc_fn=init_loc_fn, 194 | create_plates=create_plates, 195 | ) 196 | 197 | def __call__(self, *args, **kwargs): 198 | if self.prototype_trace is None: 199 | self._setup_prototype(*args, **kwargs) 200 | 201 | plates = self._create_plates(*args, **kwargs) 202 | result = {} 203 | for name, site in self.prototype_trace.items(): 204 | if site["type"] != "sample" or site["is_observed"]: 205 | continue 206 | 207 | event_dim = self._event_dims[name] 208 | init_loc = self._init_locs[name] 209 | with ExitStack() as stack: 210 | for frame in site["cond_indep_stack"]: 211 | stack.enter_context(plates[frame.name]) 212 | 213 | if self._sites is None or name in self._sites: 214 | site_loc = numpyro.param( 215 | "{}_{}_loc".format(name, self.prefix), 216 | init_loc, 217 | constraint=site["fn"].support, 218 | event_dim=event_dim, 219 | ) 220 | 221 | site_fn = dist.Delta(site_loc).to_event(event_dim) 222 | 223 | else: 224 | site_fn = dist.Delta(self._init_locs[name]).to_event(event_dim) 225 | 226 | result[name] = numpyro.sample(name, site_fn) 227 | 228 | return result 229 | 230 | def sample_posterior(self, rng_key, params, sample_shape=()): 231 | del rng_key 232 | locs = { 233 | k: params.get("{}_{}_loc".format(k, self.prefix), v) 234 | for k, v in self._init_locs.items() 235 | } 236 | latent_samples = { 237 | k: jnp.broadcast_to(v, sample_shape + jnp.shape(v)) for k, v in locs.items() 238 | } 239 | return latent_samples 240 | -------------------------------------------------------------------------------- /tests/test_distributions.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import numpyro 5 | from numpyro import distributions as dist 6 | from numpyro import infer 7 | from scipy.stats import ks_2samp 8 | 9 | from numpyro_ext import distributions as distx 10 | 11 | 12 | def test_log_prob(): 13 | x = jnp.linspace(-1, 1, 100) 14 | y = jnp.sin(x) 15 | yerr = 0.1 16 | design = jnp.vander(x, 2) 17 | 18 | prior = dist.Normal(0.0, 1.0).expand([2]) 19 | data = dist.Normal(jnp.zeros_like(x), yerr) 20 | full = distx.MarginalizedLinear(jnp.vander(x, 2), prior, data) 21 | 22 | b = jnp.zeros_like(y) 23 | B = yerr**2 * jnp.eye(len(x)) + design @ jnp.eye(2) @ design.T 24 | expect = dist.MultivariateNormal(b, B) 25 | 26 | np.testing.assert_allclose(full.log_prob(y), expect.log_prob(y), rtol=1e-5) 27 | 28 | 29 | def test_numerical_posterior(): 30 | # First sample the full model 31 | x = jnp.linspace(-1, 1, 100) 32 | yerr = 0.1 33 | 34 | def model(x, yerr, y=None): 35 | logs = numpyro.sample("logs", dist.Normal(0.0, 1.0)) 36 | w = numpyro.sample("w", dist.Normal(0.0, 1.0).expand([2])) 37 | numpyro.sample( 38 | "y", 39 | dist.Normal( 40 | jnp.dot(jnp.vander(x, 2), w), jnp.sqrt(yerr**2 + jnp.exp(2 * logs)) 41 | ), 42 | obs=y, 43 | ) 44 | 45 | y = infer.Predictive(model, num_samples=1)(jax.random.PRNGKey(0), x, yerr)["y"][0] 46 | mcmc = infer.MCMC( 47 | infer.NUTS(model), 48 | num_warmup=1000, 49 | num_samples=1000, 50 | num_chains=1, 51 | progress_bar=False, 52 | ) 53 | mcmc.run(jax.random.PRNGKey(1), x, yerr, y=y) 54 | 55 | # Then sample the marginalized model 56 | def marg_model(x, yerr, y=None): 57 | logs = numpyro.sample("logs", dist.Normal(0.0, 1.0)) 58 | prior = dist.Normal(0.0, 1.0).expand([2]) 59 | data = dist.Normal(jnp.zeros_like(x), jnp.sqrt(yerr**2 + jnp.exp(2 * logs))) 60 | marg = distx.MarginalizedLinear(jnp.vander(x, 2), prior, data) 61 | numpyro.sample("y", marg, obs=y) 62 | if y is not None: 63 | numpyro.sample("w", marg.conditional(y)) 64 | 65 | marg_mcmc = infer.MCMC( 66 | infer.NUTS(marg_model), 67 | num_warmup=1000, 68 | num_samples=1000, 69 | num_chains=1, 70 | progress_bar=False, 71 | ) 72 | marg_mcmc.run(jax.random.PRNGKey(2), x, yerr, y=y) 73 | 74 | # Check that the results are K-S consistent 75 | a = mcmc.get_samples()["logs"] 76 | b = marg_mcmc.get_samples()["logs"] 77 | np.testing.assert_allclose(jnp.mean(a), jnp.mean(b), rtol=0.01) 78 | kstest = ks_2samp(a, b) 79 | assert kstest.pvalue > 0.01 80 | 81 | # Check the conditional distribution 82 | a = mcmc.get_samples()["w"] 83 | b = marg_mcmc.get_samples()["w"] 84 | kstest = ks_2samp(a[:, 0], b[:, 0]) 85 | assert kstest.pvalue > 0.01 86 | kstest = ks_2samp(a[:, 1], b[:, 1]) 87 | assert kstest.pvalue > 0.01 88 | -------------------------------------------------------------------------------- /tests/test_infer.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import numpyro 5 | import pytest 6 | from numpyro import distributions as dist 7 | 8 | from numpyro_ext import infer as inferx 9 | 10 | 11 | @pytest.mark.parametrize("prior", [dist.Normal(0.0, 1.0), dist.Uniform(0.0, 1.0)]) 12 | def test_prior_sample(prior): 13 | def model(y=None): 14 | x = numpyro.sample("x", prior) 15 | numpyro.sample("y", dist.Normal(x, 2.0), obs=y) 16 | 17 | samples, log_like = inferx.prior_sample(model, 100)(jax.random.PRNGKey(0), y=1.5) 18 | expect = -0.5 * ((samples["x"] - samples["y"]) / 2.0) ** 2 - 0.5 * jnp.log( 19 | 2 * jnp.pi * 2.0**2 20 | ) 21 | np.testing.assert_allclose(log_like, expect) 22 | -------------------------------------------------------------------------------- /tests/test_info.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import numpyro 5 | import pytest 6 | from numpyro import distributions as dist 7 | 8 | from numpyro_ext import info 9 | 10 | 11 | def assert_allclose(a, b, **kwargs): 12 | kwargs["rtol"] = kwargs.get("rtol", 2e-5) 13 | return np.testing.assert_allclose(a, b, **kwargs) 14 | 15 | 16 | @pytest.fixture 17 | def linear_data(): 18 | x = jnp.linspace(0, 10, 100) 19 | y = 3 * x + 2 20 | K = 1.5 * jnp.exp( 21 | -0.5 * (x[:, None] - x[None, :]) ** 2 / 0.5**2 22 | ) + 1e-3 * jnp.eye(len(x)) 23 | A = jnp.vander(x, 2) 24 | expect = A.T @ jnp.linalg.solve(K, A) 25 | return x, y, K, expect 26 | 27 | 28 | def test_linear(linear_data): 29 | x, y, K, expect = linear_data 30 | 31 | def model(x, y=None): 32 | A = jnp.vander(x, 2) 33 | w = numpyro.sample("w", dist.Normal(0.0, 1.0).expand([2])) 34 | numpyro.sample( 35 | "y", dist.MultivariateNormal(loc=A @ w, covariance_matrix=K), obs=y 36 | ) 37 | 38 | params = {"w": jnp.zeros(2)} 39 | calc = info.information(model)(params, x, y=y) 40 | assert_allclose(calc["w"]["w"], expect) 41 | 42 | calc = info.information(model, invert=True)(params, x, y=y) 43 | assert_allclose(calc["w"]["w"], jnp.linalg.inv(expect)) 44 | 45 | 46 | def test_linear_multi_in(linear_data): 47 | x, y, K, expect = linear_data 48 | 49 | def model(x, y=None): 50 | m = numpyro.sample("m", dist.Normal(0.0, 1.0)) 51 | b = numpyro.sample("b", dist.Normal(0.0, 1.0)) 52 | numpyro.sample( 53 | "y", 54 | dist.MultivariateNormal(loc=m * x + b, covariance_matrix=K), 55 | obs=y, 56 | ) 57 | 58 | params = {"m": 0.0, "b": 0.0} 59 | calc = info.information(model)(params, x, y=y) 60 | calc = jnp.array( 61 | [ 62 | [calc["m"]["m"], calc["m"]["b"]], 63 | [calc["b"]["m"], calc["b"]["b"]], 64 | ] 65 | ) 66 | assert_allclose(calc, expect) 67 | 68 | calc = info.information(model, invert=True)(params, x, y=y) 69 | calc = jnp.array( 70 | [ 71 | [calc["m"]["m"], calc["m"]["b"]], 72 | [calc["b"]["m"], calc["b"]["b"]], 73 | ] 74 | ) 75 | assert_allclose(calc, jnp.linalg.inv(expect)) 76 | 77 | 78 | def test_linear_multi_out(linear_data): 79 | x, y, _, _ = linear_data 80 | yerr = 1.0 81 | A = jnp.vander(x, 2) 82 | expect = A.T @ (A / yerr**2) 83 | 84 | def model(x, y=None): 85 | A = jnp.vander(x, 2) 86 | w = numpyro.sample("w", dist.Normal(0.0, 1.0).expand([2])) 87 | mu = A @ w 88 | for n in range(len(x)): 89 | numpyro.sample(f"y{n}", dist.Normal(mu[n], yerr), obs=y[n]) 90 | 91 | params = {"w": jnp.zeros(2)} 92 | calc = info.information(model)(params, x, y=y) 93 | assert_allclose(calc["w"]["w"], expect) 94 | 95 | calc = info.information(model, invert=True)(params, x, y=y) 96 | assert_allclose(calc["w"]["w"], jnp.linalg.inv(expect)) 97 | 98 | 99 | def test_factor(): 100 | def rosenbrock(x, y, a=1.0, b=100.0): 101 | return jnp.log(jnp.square(a - x) + b * jnp.square(y - x**2)) 102 | 103 | def model(): 104 | x = numpyro.sample("x", dist.Uniform(-2, 2)) 105 | y = numpyro.sample("y", dist.Uniform(-1, 3)) 106 | numpyro.factor("prior", rosenbrock(x, y)) 107 | 108 | params = {"x": 0.0, "y": 2.0} 109 | expect = jax.hessian(rosenbrock, argnums=(0, 1))(params["x"], params["y"]) 110 | calc = info.information(model, include_prior=True)(params) 111 | assert_allclose(calc["x"]["x"], -expect[0][0]) 112 | assert_allclose(calc["x"]["y"], -expect[0][1]) 113 | assert_allclose(calc["y"]["x"], -expect[1][0]) 114 | assert_allclose(calc["y"]["y"], -expect[1][1]) 115 | 116 | 117 | def test_unconstrained(): 118 | def model1(y=None): 119 | x_ = numpyro.sample( 120 | "x", dist.ImproperUniform(dist.constraints.real, (), event_shape=(1,)) 121 | ) 122 | x = dist.transforms.SigmoidTransform()(x_) 123 | numpyro.sample("y", dist.Normal(x, 1.0), obs=y) 124 | 125 | def model2(y=None): 126 | x = numpyro.sample("x", dist.Uniform(0.0, 1.0)) 127 | numpyro.sample("y", dist.Normal(x, 1.0), obs=y) 128 | 129 | y = 0.1 130 | info1 = info.information(model1)({"x": 0.1}, y=y) 131 | info2 = info.information(model2, unconstrained=True)({"x": 0.1}, y=y) 132 | assert_allclose(info1["x"]["x"], info2["x"]["x"]) 133 | -------------------------------------------------------------------------------- /tests/test_linear_op.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import numpy as np 3 | import pytest 4 | from numpyro import distributions 5 | 6 | from numpyro_ext.linear_op import to_linear_op 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "dist, expected_cov", 11 | [ 12 | (distributions.Normal(0.1, 0.5), 0.5**2 * jnp.ones((1, 1))), 13 | ( 14 | distributions.Normal(0.1 * jnp.ones(10), jnp.linspace(0.1, 0.5, 10)), 15 | jnp.diag(jnp.linspace(0.1, 0.5, 10) ** 2), 16 | ), 17 | ( 18 | distributions.Normal(0.1, jnp.linspace(0.1, 0.5, 10)), 19 | jnp.diag(jnp.linspace(0.1, 0.5, 10) ** 2), 20 | ), 21 | (distributions.Normal(0.1 * jnp.ones(10), 0.5), 0.5**2 * jnp.eye(10)), 22 | ( 23 | distributions.MultivariateNormal( 24 | 0.1 * jnp.ones(10), jnp.diag(jnp.linspace(0.1, 0.5, 10)) 25 | ), 26 | jnp.diag(jnp.linspace(0.1, 0.5, 10)), 27 | ), 28 | ( 29 | distributions.MultivariateNormal(0.1, jnp.diag(jnp.linspace(0.1, 0.5, 10))), 30 | jnp.diag(jnp.linspace(0.1, 0.5, 10)), 31 | ), 32 | ], 33 | ) 34 | def test_linear_op(dist, expected_cov): 35 | event_shape = expected_cov.shape[:1] 36 | mu = np.broadcast_to(dist.loc, event_shape) 37 | alpha = np.linalg.solve(np.linalg.cholesky(expected_cov), mu[:, None])[:, 0] 38 | expected_inv = np.linalg.inv(expected_cov) 39 | op = to_linear_op(dist) 40 | np.testing.assert_allclose(op.loc(), mu) 41 | np.testing.assert_allclose(op.covariance(), expected_cov) 42 | np.testing.assert_allclose(op.inverse(), expected_inv, rtol=5e-6) 43 | np.testing.assert_allclose( 44 | op.solve_tril(mu[:, None], False)[:, 0], alpha, rtol=5e-6 45 | ) 46 | np.testing.assert_allclose( 47 | op.half_log_det(), 0.5 * np.linalg.slogdet(expected_cov)[1] 48 | ) 49 | 50 | shape = [2, 3, 4] + list(event_shape) 51 | if dist.event_shape: 52 | exp = dist.expand(shape[:-1]) 53 | else: 54 | exp = dist.expand(shape) 55 | mu = np.broadcast_to(mu, shape) 56 | op = to_linear_op(exp) 57 | np.testing.assert_allclose(op.loc(), mu) 58 | np.testing.assert_allclose( 59 | op.covariance(), np.broadcast_to(expected_cov, shape + list(event_shape)) 60 | ) 61 | np.testing.assert_allclose( 62 | op.inverse(), 63 | np.broadcast_to(expected_inv, shape + list(event_shape)), 64 | rtol=5e-6, 65 | ) 66 | np.testing.assert_allclose( 67 | op.solve_tril(mu[..., None], False)[..., 0], 68 | np.broadcast_to(alpha, shape), 69 | rtol=5e-6, 70 | ) 71 | np.testing.assert_allclose( 72 | op.half_log_det(), 73 | jnp.broadcast_to(0.5 * np.linalg.slogdet(expected_cov)[1], shape[:-1]), 74 | ) 75 | -------------------------------------------------------------------------------- /tests/test_optim.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | import numpyro 5 | import pytest 6 | from numpyro import distributions 7 | 8 | from numpyro_ext import optim 9 | 10 | 11 | @pytest.mark.parametrize("mu", [0.0, jnp.linspace(-1.0, 1.0, 5)]) 12 | def test_optimize(mu): 13 | def model(): 14 | x = numpyro.sample("x", distributions.Normal(mu, 1.0)) 15 | y = numpyro.sample("y", distributions.Normal(mu, 2.0)) 16 | numpyro.deterministic("sm", x + y) 17 | 18 | soln = optim.optimize(model)(jax.random.PRNGKey(0)) 19 | np.testing.assert_allclose(soln["x"], mu, atol=1e-3) 20 | np.testing.assert_allclose(soln["y"], mu, atol=1e-3) 21 | np.testing.assert_allclose(soln["sm"], mu + mu, atol=1e-3) 22 | 23 | soln = optim.optimize(model, ["x"], include_deterministics=False)( 24 | jax.random.PRNGKey(0) 25 | ) 26 | np.testing.assert_allclose(soln["x"], mu, atol=1e-3) 27 | assert not np.allclose(soln["y"], mu, atol=1e-3) 28 | assert "sm" not in soln 29 | 30 | soln, info = optim.optimize(model, ["y"], return_info=True)(jax.random.PRNGKey(0)) 31 | assert not np.allclose(soln["x"], mu, atol=1e-3) 32 | np.testing.assert_allclose(soln["y"], mu, atol=1e-3) 33 | assert info.success 34 | --------------------------------------------------------------------------------