├── .github └── workflows │ └── ci.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE.md ├── Makefile ├── README.md ├── docs ├── Makefile ├── make.bat └── source │ ├── conf.py │ ├── dispatch.rst │ ├── index.rst │ └── testing.rst ├── pyroapi ├── __init__.py ├── dispatch.py ├── testing.py ├── tests │ ├── __init__.py │ ├── test_mcmc.py │ └── test_svi.py └── version.py ├── scripts └── update_headers.py ├── setup.cfg ├── setup.py └── test ├── conftest.py ├── test_dispatch.py └── test_tests.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | unit: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: [3.6] 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install .[test] 25 | pip freeze 26 | - name: Lint with flake8 27 | run: | 28 | make lint 29 | - name: Run unit tests 30 | run: | 31 | pytest -vs --tb=short test 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | pyro_api.egg-info 2 | __pycache__/ 3 | .ipynb_checkpoints/ 4 | build 5 | 6 | # built / compiled 7 | *.pyc 8 | *.pyo 9 | /build 10 | /dist 11 | 12 | # IDE 13 | .idea 14 | .vscode 15 | *~ 16 | 17 | # data files 18 | *.pdf 19 | pyro-api/.DS_Store 20 | 21 | # test related 22 | .pytest_cache 23 | 24 | # tmp files 25 | .*.swp 26 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | * Using welcoming and inclusive language 12 | * Being respectful of differing viewpoints and experiences 13 | * Gracefully accepting constructive criticism 14 | * Focusing on what is best for the community 15 | * Showing empathy towards other community members 16 | 17 | Examples of unacceptable behavior by participants include: 18 | 19 | * The use of sexualized language or imagery and unwelcome sexual attention or advances 20 | * Trolling, insulting/derogatory comments, and personal or political attacks 21 | * Public or private harassment 22 | * Publishing others' private information, such as a physical or electronic address, without explicit permission 23 | * Other conduct which could reasonably be considered inappropriate in a professional setting 24 | 25 | ## Our Responsibilities 26 | 27 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 28 | 29 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 30 | 31 | ## Scope 32 | 33 | This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 34 | 35 | ## Enforcement 36 | 37 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at fritzo@uber.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 38 | 39 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. 40 | 41 | ## Attribution 42 | 43 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] 44 | 45 | [homepage]: http://contributor-covenant.org 46 | [version]: http://contributor-covenant.org/version/1/4/ 47 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all install docs lint test clean FORCE 2 | 3 | all: docs test 4 | 5 | install: 6 | pip install -e .[dev,test] 7 | 8 | docs: FORCE 9 | make -C docs html 10 | 11 | lint: FORCE 12 | flake8 13 | 14 | test: lint FORCE 15 | pytest -vx test 16 | 17 | clean: FORCE 18 | git clean -dfx -e pyroapi-egg.info 19 | 20 | FORCE: 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.com/pyro-ppl/pyro-api.svg?branch=master)](https://travis-ci.com/pyro-ppl/pyro-api) 2 | [![Latest Version](https://badge.fury.io/py/pyro-api.svg)](https://pypi.python.org/pypi/pyro-api) 3 | [![Documentation Status](https://readthedocs.org/projects/pyro-api/badge/?version=latest)](http://pyro-api.readthedocs.io/en/latest/?badge=master) 4 | 5 | # Pyro API 6 | 7 | Generic API for modeling and inference for dispatch to different Pyro backends. 8 | 9 | ---------------------------------------------------------------------------------------------------- 10 | 11 | ## Testing 12 | 13 | For testing API compatibility on different backends, install pytest and other test dependencies that includes backends like [funsor](https://github.com/pyro-ppl/funsor) and [numpyro](https://github.com/pyro-ppl/numpyro) and run the test suite: 14 | 15 | ``` 16 | pip install -e .[test] 17 | pytest -vs 18 | ``` 19 | 20 | This library has no dependencies and can easily be installed for testing your particular Pyro backend 21 | implementation. You can use the following pattern and test your backend on models in the `pyroapi.testing` 22 | module. 23 | 24 | ```python 25 | from pyro_api.dispatch import pyro_backend 26 | from pyro_api.testing import MODELS 27 | 28 | 29 | # Register backend 30 | with pyro_backend(handlers='my_backend.handlers', 31 | distributions='my_backend.distributions', 32 | ...): 33 | 34 | # Test on models in pyro_api.testing 35 | for model in MODELS: 36 | f = MODELS[model]() 37 | model, model_args = f['model'], f.get('model_args', ()) 38 | model(*model_args) 39 | ... # further testing 40 | ``` 41 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import sys 6 | 7 | import sphinx_rtd_theme 8 | # Configuration file for the Sphinx documentation builder. 9 | # 10 | # This file only contains a selection of the most common options. For a full 11 | # list see the documentation: 12 | # http://www.sphinx-doc.org/en/master/config 13 | 14 | # -- Path setup -------------------------------------------------------------- 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | sys.path.insert(0, os.path.abspath('../..')) 21 | 22 | # -- Project information ----------------------------------------------------- 23 | 24 | project = u'Pyro API' 25 | copyright = u'2019, Uber Technologies, Inc' 26 | author = u'Uber AI Labs' 27 | 28 | # The full version, including alpha/beta/rc tags 29 | release = u'0.0' 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # Add any Sphinx extension module names here, as strings. They can be 35 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 36 | # ones. 37 | extensions = [ 38 | 'sphinx.ext.autodoc', 39 | 'sphinx.ext.doctest', 40 | 'sphinx.ext.intersphinx', 41 | 'sphinx.ext.mathjax', 42 | 'sphinx.ext.viewcode', 43 | ] 44 | 45 | # Disable documentation inheritance so as to avoid inheriting 46 | # docstrings in a different format, e.g. when the parent class 47 | # is a PyTorch class. 48 | 49 | autodoc_inherit_docstrings = False 50 | 51 | # Add any paths that contain templates here, relative to this directory. 52 | templates_path = ['_templates'] 53 | 54 | # The suffix(es) of source filenames. 55 | # You can specify multiple suffix as a list of string: 56 | # 57 | # source_suffix = ['.rst', '.md'] 58 | source_suffix = '.rst' 59 | 60 | # The master toctree document. 61 | master_doc = 'index' 62 | 63 | # The language for content autogenerated by Sphinx. Refer to documentation 64 | # for a list of supported languages. 65 | # 66 | # This is also used if you do content translation via gettext catalogs. 67 | # Usually you set "language" from the command line for these cases. 68 | language = None 69 | 70 | # List of patterns, relative to source directory, that match files and 71 | # directories to ignore when looking for source files. 72 | # This pattern also affects html_static_path and html_extra_path. 73 | exclude_patterns = [] 74 | 75 | # The name of the Pygments (syntax highlighting) style to use. 76 | pygments_style = 'sphinx' 77 | 78 | 79 | # do not prepend module name to functions 80 | add_module_names = False 81 | 82 | # -- Options for HTML output ------------------------------------------------- 83 | 84 | # The theme to use for HTML and HTML Help pages. See the documentation for 85 | # a list of builtin themes. 86 | # 87 | html_theme = "sphinx_rtd_theme" 88 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 89 | 90 | # Add any paths that contain custom static files (such as style sheets) here, 91 | # relative to this directory. They are copied after the builtin static files, 92 | # so a file named "default.css" will overwrite the builtin "default.css". 93 | html_static_path = ['_static'] 94 | -------------------------------------------------------------------------------- /docs/source/dispatch.rst: -------------------------------------------------------------------------------- 1 | Dispatch 2 | ======== 3 | 4 | .. automodule:: pyroapi.dispatch 5 | .. autofunction:: pyroapi.dispatch.pyro_backend 6 | .. autofunction:: pyroapi.dispatch.register_backend 7 | 8 | Generic Modules 9 | --------------- 10 | - pyro - The main pyro module. 11 | - distributions - Includes distributions.transforms and distributions.constraints. 12 | - handlers - Generalizing the original pyro.poutine. 13 | - infer - Inference algorithms. 14 | - optim - Optimization utilities. 15 | - ops - Basic tensor operations (like numpy or torch). 16 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. pyroapi documentation master file, created by 2 | sphinx-quickstart on Fri Oct 18 13:54:39 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Pyro API 7 | ======== 8 | 9 | The ``pyroapi`` package dynamically dispatches among multiple Pyro backends, including standard Pyro_, NumPyro_, Funsor_, and custom user-defined backends. 10 | This package includes both **dispatch** mechanisms for use in model and inference code, and **testing** utilities to help develop and test new Pyro backends. 11 | 12 | .. _Pyro: https://pyro.ai 13 | .. _NumPyro: https://num.pyro.ai 14 | .. _Funsor: https://funsor.pyro.ai 15 | 16 | .. toctree:: 17 | :maxdepth: 2 18 | :caption: Contents: 19 | 20 | dispatch 21 | testing 22 | 23 | 24 | Indices and tables 25 | ================== 26 | 27 | * :ref:`genindex` 28 | * :ref:`modindex` 29 | * :ref:`search` 30 | -------------------------------------------------------------------------------- /docs/source/testing.rst: -------------------------------------------------------------------------------- 1 | Testing 2 | ======= 3 | 4 | The pyroapi package includes tests to ensure new backends conform to the standard API, indeed these tests serve as the formal API description. 5 | To add tests to your new backend say in ``project/test/`` follow these steps (or see the example_ in funsor): 6 | 7 | .. _example: https://github.com/pyro-ppl/funsor/tree/master/test/pyroapi 8 | 9 | 1. Create a new directory ``project/test/pyroapi/``. 10 | 11 | 2. Create a file ``project/test/pyroapi/conftest.py`` and a hook to treat missing features as xfail: 12 | 13 | .. code-block:: python 14 | 15 | import pytest 16 | 17 | 18 | def pytest_runtest_call(item): 19 | try: 20 | item.runtest() 21 | except NotImplementedError as e: 22 | pytest.xfail(str(e)) 23 | 24 | 3. Create a file ``project/test/pyroapi/test_pyroapi.py`` and define a ``backend`` fixture: 25 | 26 | .. code-block:: python 27 | 28 | import pytest 29 | from pyroapi import pyro_backend 30 | from pyroapi.tests import * # noqa F401 31 | 32 | @pytest.yield_fixture 33 | def backend(): 34 | with pyro_backend("my_backend"): 35 | yield 36 | 37 | 4. Test your backend with pytest 38 | 39 | .. code-block:: bash 40 | 41 | pytest -vx test/pyroapi 42 | -------------------------------------------------------------------------------- /pyroapi/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from pyroapi.dispatch import distributions, handlers, infer, ops, optim, pyro, pyro_backend, register_backend 5 | 6 | __all__ = [ 7 | 'distributions', 8 | 'handlers', 9 | 'infer', 10 | 'ops', 11 | 'optim', 12 | 'pyro', 13 | 'pyro_backend', 14 | 'register_backend', 15 | ] 16 | -------------------------------------------------------------------------------- /pyroapi/dispatch.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | """ 5 | Dispatching allows you to dynamically set a backend using :func:`pyro_backend` 6 | and to register new backends using :func:`register_backend` . It's easiest to 7 | see how to use these by example: 8 | 9 | .. code-block:: python 10 | 11 | from pyroapi import distributions as dist 12 | from pyroapi import infer, ops, optim, pyro, pyro_backend 13 | 14 | # These model and guide are backend-agnostic. 15 | def model(): 16 | locs = pyro.param("locs", ops.tensor([0.2, 0.3, 0.5])) 17 | p = ops.tensor([0.2, 0.3, 0.5]) 18 | with pyro.plate("plate", len(data), dim=-1): 19 | x = pyro.sample("x", dist.Categorical(p)) 20 | pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data) 21 | 22 | def guide(): 23 | p = pyro.param("p", ops.tensor([0.5, 0.3, 0.2])) 24 | with pyro.plate("plate", len(data), dim=-1): 25 | pyro.sample("x", dist.Categorical(p)) 26 | 27 | # We can now set a backend at inference time. 28 | with pyro_backend("numpyro"): 29 | elbo = infer.Trace_ELBO(ignore_jit_warnings=True) 30 | adam = optim.Adam({"lr": 1e-6}) 31 | inference = infer.SVI(model, guide, adam, elbo) 32 | for step in range(10): 33 | loss = inference.step(*args, **kwargs) 34 | print("step {} loss = {}".format(step, loss)) 35 | 36 | """ 37 | import importlib 38 | from contextlib import contextmanager 39 | 40 | DEFAULT_RNG_SEED = 1 41 | _ALIASES = {} 42 | 43 | 44 | class GenericModule(object): 45 | """ 46 | Wrapper for a module that can be dynamically routed to a custom backend. 47 | """ 48 | current_backend = {} 49 | _modules = {} 50 | 51 | def __init__(self, name, default_backend): 52 | assert isinstance(name, str) 53 | assert isinstance(default_backend, str) 54 | self._name = name 55 | GenericModule.current_backend[name] = default_backend 56 | 57 | def __getattribute__(self, name): 58 | module_name = super(GenericModule, self).__getattribute__('_name') 59 | backend = GenericModule.current_backend[module_name] 60 | try: 61 | module = GenericModule._modules[backend] 62 | except KeyError: 63 | module = importlib.import_module(backend) 64 | GenericModule._modules[backend] = module 65 | if name.startswith('__'): 66 | return getattr(module, name) # allow magic attributes to return AttributeError 67 | try: 68 | return getattr(module, name) 69 | except AttributeError: 70 | raise NotImplementedError('This Pyro backend does not implement {}.{}' 71 | .format(module_name, name)) 72 | 73 | 74 | @contextmanager 75 | def pyro_backend(*aliases, **new_backends): 76 | """ 77 | Context manager to set a custom backend for Pyro models. 78 | 79 | Backends can be specified either by name (for standard backends or backends 80 | registered through :func:`register_backend` ) or by providing kwargs 81 | mapping module name to backend module name. Standard backends include: 82 | pyro, minipyro, funsor, and numpy. 83 | """ 84 | if aliases: 85 | assert len(aliases) == 1 86 | assert not new_backends 87 | new_backends = _ALIASES[aliases[0]] 88 | 89 | old_backends = {} 90 | for name, new_backend in new_backends.items(): 91 | old_backends[name] = GenericModule.current_backend[name] 92 | GenericModule.current_backend[name] = new_backend 93 | try: 94 | with handlers.seed(rng_seed=DEFAULT_RNG_SEED): 95 | yield 96 | finally: 97 | for name, old_backend in old_backends.items(): 98 | GenericModule.current_backend[name] = old_backend 99 | 100 | 101 | def register_backend(alias, new_backends): 102 | """ 103 | Register a new backend alias. For example:: 104 | 105 | register_backend("minipyro", { 106 | "infer": "pyro.contrib.minipyro", 107 | "optim": "pyro.contrib.minipyro", 108 | "pyro": "pyro.contrib.minipyro", 109 | }) 110 | 111 | :param str alias: The name of the new backend. 112 | :param dict new_backends: A dict mapping standard module name (str) to new 113 | module name (str). This needs to include only nonstandard backends 114 | (e.g. if your backend uses torch ops, you need not override ``ops``) 115 | """ 116 | assert isinstance(new_backends, dict) 117 | assert all(isinstance(key, str) for key in new_backends.keys()) 118 | assert all(isinstance(value, str) for value in new_backends.values()) 119 | _ALIASES[alias] = new_backends.copy() 120 | 121 | 122 | # These modules can be overridden. 123 | pyro = GenericModule('pyro', 'pyro') 124 | distributions = GenericModule('distributions', 'pyro.distributions') 125 | handlers = GenericModule('handlers', 'pyro.poutine') 126 | infer = GenericModule('infer', 'pyro.infer') 127 | optim = GenericModule('optim', 'pyro.optim') 128 | ops = GenericModule('ops', 'torch') 129 | 130 | 131 | # These are standard backends. 132 | register_backend('pyro', { 133 | 'distributions': 'pyro.distributions', 134 | 'handlers': 'pyro.poutine', 135 | 'infer': 'pyro.infer', 136 | 'ops': 'torch', 137 | 'optim': 'pyro.optim', 138 | 'pyro': 'pyro', 139 | }) 140 | register_backend('minipyro', { 141 | 'distributions': 'pyro.distributions', 142 | 'handlers': 'pyro.poutine', 143 | 'infer': 'pyro.contrib.minipyro', 144 | 'ops': 'torch', 145 | 'optim': 'pyro.contrib.minipyro', 146 | 'pyro': 'pyro.contrib.minipyro', 147 | }) 148 | register_backend('funsor', { 149 | 'distributions': 'funsor.torch.distributions', 150 | 'handlers': 'funsor.minipyro', 151 | 'infer': 'funsor.minipyro', 152 | 'ops': 'funsor.compat.ops', 153 | 'optim': 'funsor.minipyro', 154 | 'pyro': 'funsor.minipyro', 155 | }) 156 | register_backend('numpy', { 157 | 'distributions': 'numpyro.compat.distributions', 158 | 'handlers': 'numpyro.compat.handlers', 159 | 'infer': 'numpyro.compat.infer', 160 | 'ops': 'numpyro.compat.ops', 161 | 'optim': 'numpyro.compat.optim', 162 | 'pyro': 'numpyro.compat.pyro', 163 | }) 164 | -------------------------------------------------------------------------------- /pyroapi/testing.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | """ 5 | Models for testing the generic interface. 6 | 7 | For specifying the arguments to model functions, the convention followed is 8 | that positional arguments are inputs to the model and keyword arguments denote 9 | observed data. 10 | """ 11 | 12 | from collections import OrderedDict 13 | 14 | from pyroapi.dispatch import distributions as dist, handlers, ops, pyro 15 | 16 | MODELS = OrderedDict() 17 | 18 | 19 | def register(rng_seed=None): 20 | def _register_fn(fn): 21 | MODELS[fn.__name__] = handlers.seed(fn, rng_seed) 22 | 23 | return _register_fn 24 | 25 | 26 | @register(rng_seed=1) 27 | def logistic_regression(): 28 | N, dim = 3000, 3 29 | # generic way to sample from distributions 30 | data = pyro.sample('data', dist.Normal(0., 1.), sample_shape=(N, dim)) 31 | true_coefs = ops.arange(1., dim + 1.) 32 | logits = ops.sum(true_coefs * data, axis=-1) 33 | labels = pyro.sample('labels', dist.Bernoulli(logits=logits)) 34 | 35 | def model(x, y=None): 36 | coefs = pyro.sample('coefs', dist.Normal(ops.zeros(dim), ops.ones(dim))) 37 | intercept = pyro.sample('intercept', dist.Normal(0., 1.)) 38 | logits = ops.sum(coefs * x, axis=-1) + intercept 39 | return pyro.sample('obs', dist.Bernoulli(logits=logits), obs=y) 40 | 41 | return {'model': model, 'model_args': (data,), 'model_kwargs': {'y': labels}} 42 | 43 | 44 | @register(rng_seed=1) 45 | def neals_funnel(): 46 | def model(dim): 47 | y = pyro.sample('y', dist.Normal(0, 3)) 48 | pyro.sample('x', dist.TransformedDistribution( 49 | dist.Normal(ops.zeros(dim - 1), 1), dist.transforms.AffineTransform(0, ops.exp(y / 2)))) 50 | 51 | return {'model': model, 'model_args': (10,)} 52 | 53 | 54 | @register(rng_seed=1) 55 | def eight_schools(): 56 | J = 8 57 | y = ops.tensor([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) 58 | sigma = ops.tensor([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) 59 | 60 | def model(J, sigma, y=None): 61 | mu = pyro.sample('mu', dist.Normal(0, 5)) 62 | tau = pyro.sample('tau', dist.HalfCauchy(5)) 63 | with pyro.plate('J', J): 64 | theta = pyro.sample('theta', dist.Normal(mu, tau)) 65 | pyro.sample('obs', dist.Normal(theta, sigma), obs=y) 66 | 67 | return {'model': model, 'model_args': (J, sigma), 'model_kwargs': {'y': y}} 68 | 69 | 70 | @register(rng_seed=1) 71 | def beta_binomial(): 72 | N, D1, D2 = 10, 2, 2 73 | true_probs = ops.tensor([[0.7, 0.4], [0.6, 0.4]]) 74 | total_count = ops.tensor([[1000, 600], [400, 800]]) 75 | 76 | data = pyro.sample('data', dist.Binomial(total_count=total_count, probs=true_probs), 77 | sample_shape=(N,)) 78 | 79 | def model(N, D1, D2, data=None): 80 | with pyro.plate("plate_0", D1): 81 | alpha = pyro.sample("alpha", dist.HalfCauchy(1.)) 82 | beta = pyro.sample("beta", dist.HalfCauchy(1.)) 83 | with pyro.plate("plate_1", D2): 84 | probs = pyro.sample("probs", dist.Beta(alpha, beta)) 85 | with pyro.plate("data", N): 86 | pyro.sample("binomial", dist.Binomial(probs=probs, total_count=total_count), obs=data) 87 | 88 | return {'model': model, 'model_args': (N, D1, D2), 'model_kwargs': {'data': data}} 89 | -------------------------------------------------------------------------------- /pyroapi/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from .test_mcmc import * # noqa F401 5 | from .test_svi import * # noqa F401 6 | -------------------------------------------------------------------------------- /pyroapi/tests/test_mcmc.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from pyroapi.dispatch import distributions as dist 5 | from pyroapi.dispatch import infer, pyro 6 | 7 | # Note that the backend arg to these tests must be provided as a 8 | # user-defined fixture that sets the pyro_backend. For demonstration, 9 | # see test/conftest.py. 10 | 11 | 12 | def assert_ok(model, *args, **kwargs): 13 | """ 14 | Assert that inference works without warnings or errors. 15 | """ 16 | pyro.get_param_store().clear() 17 | kernel = infer.NUTS(model) 18 | mcmc = infer.MCMC(kernel, num_samples=2, warmup_steps=2) 19 | mcmc.run(*args, **kwargs) 20 | 21 | 22 | def test_mcmc_run_ok(backend): 23 | if backend not in ["pyro", "numpy"]: 24 | return 25 | 26 | def model(): 27 | pyro.sample("x", dist.Normal(0, 1)) 28 | 29 | assert_ok(model) 30 | -------------------------------------------------------------------------------- /pyroapi/tests/test_svi.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pytest 5 | 6 | from pyroapi.dispatch import distributions as dist 7 | from pyroapi.dispatch import infer, ops, optim, pyro 8 | 9 | # This file tests a variety of model,guide pairs with valid and invalid structure. 10 | # See https://github.com/pyro-ppl/pyro/blob/0.3.1/tests/infer/test_valid_models.py 11 | # 12 | # Note that the backend arg to these tests must be provided as a 13 | # user-defined fixture that sets the pyro_backend. For demonstration, 14 | # see test/conftest.py. 15 | 16 | 17 | def assert_ok(model, guide, elbo, *args, **kwargs): 18 | """ 19 | Assert that inference works without warnings or errors. 20 | """ 21 | pyro.get_param_store().clear() 22 | adam = optim.Adam({"lr": 1e-6}) 23 | inference = infer.SVI(model, guide, adam, elbo) 24 | for i in range(2): 25 | inference.step(*args, **kwargs) 26 | 27 | 28 | def test_generate_data(backend): 29 | 30 | def model(data=None): 31 | loc = pyro.param("loc", ops.tensor(2.0)) 32 | scale = pyro.param("scale", ops.tensor(1.0)) 33 | x = pyro.sample("x", dist.Normal(loc, scale), obs=data) 34 | return x 35 | 36 | data = model() 37 | assert data.shape == () 38 | 39 | 40 | def test_generate_data_plate(backend): 41 | num_points = 1000 42 | 43 | def model(data=None): 44 | loc = pyro.param("loc", ops.tensor(2.0)) 45 | scale = pyro.param("scale", ops.tensor(1.0)) 46 | with pyro.plate("data", 1000, dim=-1): 47 | x = pyro.sample("x", dist.Normal(loc, scale), obs=data) 48 | return x 49 | 50 | data = model() 51 | if type(data).__module__.startswith('funsor'): 52 | pytest.xfail(reason='plate is an input, and does not appear in .shape') 53 | assert data.shape == (num_points,) 54 | mean = data.sum().item() / num_points 55 | assert 1.9 <= mean <= 2.1 56 | 57 | 58 | @pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) 59 | @pytest.mark.parametrize("optim_name, optim_kwargs", [ 60 | ("Adam", {"lr": 1e-6}), 61 | ("ClippedAdam", {"lr": 1e-6, "lrd": 0.999}), 62 | ]) 63 | def test_optimizer(backend, optim_name, optim_kwargs, jit): 64 | 65 | def model(data): 66 | p = pyro.param("p", ops.tensor(0.5)) 67 | pyro.sample("x", dist.Bernoulli(p), obs=data) 68 | 69 | def guide(data): 70 | pass 71 | 72 | data = ops.tensor(0.) 73 | pyro.get_param_store().clear() 74 | Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO 75 | elbo = Elbo(ignore_jit_warnings=True) 76 | optimizer = getattr(optim, optim_name)(optim_kwargs.copy()) 77 | inference = infer.SVI(model, guide, optimizer, elbo) 78 | for i in range(2): 79 | inference.step(data) 80 | 81 | 82 | @pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) 83 | def test_nonempty_model_empty_guide_ok(backend, jit): 84 | 85 | def model(data): 86 | loc = pyro.param("loc", ops.tensor(0.0)) 87 | pyro.sample("x", dist.Normal(loc, 1.), obs=data) 88 | 89 | def guide(data): 90 | pass 91 | 92 | data = ops.tensor(2.) 93 | Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO 94 | elbo = Elbo(ignore_jit_warnings=True) 95 | assert_ok(model, guide, elbo, data) 96 | 97 | 98 | @pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) 99 | def test_plate_ok(backend, jit): 100 | data = ops.randn(10) 101 | 102 | def model(): 103 | locs = pyro.param("locs", ops.tensor([0.2, 0.3, 0.5])) 104 | p = ops.tensor([0.2, 0.3, 0.5]) 105 | with pyro.plate("plate", len(data), dim=-1): 106 | x = pyro.sample("x", dist.Categorical(p)) 107 | pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data) 108 | 109 | def guide(): 110 | p = pyro.param("p", ops.tensor([0.5, 0.3, 0.2])) 111 | with pyro.plate("plate", len(data), dim=-1): 112 | pyro.sample("x", dist.Categorical(p)) 113 | 114 | Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO 115 | elbo = Elbo(ignore_jit_warnings=True) 116 | assert_ok(model, guide, elbo) 117 | 118 | 119 | @pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) 120 | def test_nested_plate_plate_ok(backend, jit): 121 | data = ops.randn(2, 3) 122 | 123 | def model(): 124 | loc = ops.tensor(3.0) 125 | with pyro.plate("plate_outer", data.shape[-1], dim=-1): 126 | x = pyro.sample("x", dist.Normal(loc, 1.)) 127 | with pyro.plate("plate_inner", data.shape[-2], dim=-2): 128 | pyro.sample("y", dist.Normal(x, 1.), obs=data) 129 | 130 | def guide(): 131 | loc = pyro.param("loc", ops.tensor(0.)) 132 | scale = pyro.param("scale", ops.tensor(1.)) 133 | with pyro.plate("plate_outer", data.shape[-1], dim=-1): 134 | pyro.sample("x", dist.Normal(loc, scale)) 135 | 136 | Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO 137 | elbo = Elbo(ignore_jit_warnings=True) 138 | assert_ok(model, guide, elbo) 139 | 140 | 141 | @pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) 142 | def test_local_param_ok(backend, jit): 143 | data = ops.randn(10) 144 | 145 | def model(): 146 | locs = pyro.param("locs", ops.tensor([-1., 0., 1.])) 147 | with pyro.plate("plate", len(data), dim=-1): 148 | x = pyro.sample("x", dist.Categorical(ops.ones(3) / 3)) 149 | pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data) 150 | 151 | def guide(): 152 | with pyro.plate("plate", len(data), dim=-1): 153 | p = pyro.param("p", ops.ones(len(data), 3) / 3, event_dim=1) 154 | pyro.sample("x", dist.Categorical(p)) 155 | return p 156 | 157 | Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO 158 | elbo = Elbo(ignore_jit_warnings=True) 159 | assert_ok(model, guide, elbo) 160 | 161 | # Check that pyro.param() can be called without init_value. 162 | expected = guide() 163 | actual = pyro.param("p") 164 | assert ops.allclose(actual, expected) 165 | 166 | 167 | @pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) 168 | def test_constraints(backend, jit): 169 | data = ops.tensor(0.5) 170 | 171 | def model(): 172 | locs = pyro.param("locs", ops.randn(3), 173 | constraint=dist.constraints.real) 174 | scales = pyro.param("scales", ops.exp(ops.randn(3)), 175 | constraint=dist.constraints.positive) 176 | p = ops.tensor([0.5, 0.3, 0.2]) 177 | x = pyro.sample("x", dist.Categorical(p)) 178 | pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data) 179 | 180 | def guide(): 181 | q = pyro.param("q", ops.exp(ops.randn(3)), 182 | constraint=dist.constraints.simplex) 183 | pyro.sample("x", dist.Categorical(q)) 184 | 185 | Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO 186 | elbo = Elbo(ignore_jit_warnings=True) 187 | assert_ok(model, guide, elbo) 188 | 189 | 190 | def test_mean_field_ok(backend): 191 | 192 | def model(): 193 | x = pyro.sample("x", dist.Normal(0., 1.)) 194 | pyro.sample("y", dist.Normal(x, 1.)) 195 | 196 | def guide(): 197 | loc = pyro.param("loc", ops.tensor(0.)) 198 | x = pyro.sample("x", dist.Normal(loc, 1.)) 199 | pyro.sample("y", dist.Normal(x, 1.)) 200 | 201 | elbo = infer.TraceMeanField_ELBO() 202 | assert_ok(model, guide, elbo) 203 | -------------------------------------------------------------------------------- /pyroapi/version.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | __version__ = '0.1.2' 5 | -------------------------------------------------------------------------------- /scripts/update_headers.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import glob 6 | 7 | root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 8 | blacklist = ["/build/", "/dist/", "/pyro_api.egg"] 9 | file_types = [ 10 | ("*.py", "# {}"), 11 | ("*.cpp", "// {}"), 12 | ] 13 | 14 | for basename, comment in file_types: 15 | copyright_line = comment.format("Copyright Contributors to the Pyro project.\n") 16 | # See https://spdx.org/ids-how 17 | spdx_line = comment.format("SPDX-License-Identifier: Apache-2.0\n") 18 | 19 | filenames = glob.glob(os.path.join(root, "**", basename), recursive=True) 20 | filenames.sort() 21 | filenames = [ 22 | filename 23 | for filename in filenames 24 | if not any(word in filename for word in blacklist) 25 | ] 26 | for filename in filenames: 27 | with open(filename) as f: 28 | lines = f.readlines() 29 | 30 | # Ignore empty files like __init__.py 31 | if all(line.isspace() for line in lines): 32 | continue 33 | 34 | # Ensure first few line are copyright notices. 35 | lineno = 0 36 | if not lines[lineno].startswith(comment.format("Copyright")): 37 | lines.insert(lineno, copyright_line) 38 | else: 39 | lines[lineno] = copyright_line 40 | lineno += 1 41 | while lines[lineno].startswith(comment.format("Copyright")): 42 | lineno += 1 43 | 44 | # Ensure next line is an SPDX short identifier. 45 | if not lines[lineno].startswith(comment.format("SPDX-License-Identifier")): 46 | lines.insert(lineno, spdx_line) 47 | else: 48 | lines[lineno] = spdx_line 49 | lineno += 1 50 | 51 | # Ensure next line is blank. 52 | if not lines[lineno].isspace(): 53 | lines.insert(lineno, "\n") 54 | 55 | with open(filename, "w") as f: 56 | f.write("".join(lines)) 57 | 58 | print("updated {}".format(filename[len(root) + 1:])) 59 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | exclude = docs/src, build, dist 4 | 5 | [isort] 6 | line_length = 120 7 | multi_line_output=3 8 | not_skip = __init__.py 9 | known_first_party = pyroapi 10 | 11 | [tool:pytest] 12 | filterwarnings = error 13 | ignore::DeprecationWarning 14 | once::DeprecationWarning 15 | 16 | doctest_optionflags = ELLIPSIS NORMALIZE_WHITESPACE IGNORE_EXCEPTION_DETAIL 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import sys 6 | 7 | from setuptools import find_packages, setup 8 | 9 | PROJECT_PATH = os.path.dirname(os.path.abspath(__file__)) 10 | 11 | # Find version 12 | for line in open(os.path.join(PROJECT_PATH, 'pyroapi', 'version.py')): 13 | if line.startswith('__version__ = '): 14 | version = line.strip().split()[2][1:-1] 15 | 16 | # READ README.md for long description on PyPi. 17 | try: 18 | long_description = open('README.md', encoding='utf-8').read() 19 | except Exception as e: 20 | sys.stderr.write('Failed to convert README.md to rst:\n {}\n'.format(e)) 21 | sys.stderr.flush() 22 | long_description = '' 23 | 24 | 25 | setup( 26 | name='pyro-api', 27 | version=version, 28 | description='Generic API for dispatch to Pyro backends.', 29 | packages=find_packages(include=['pyroapi', 'pyroapi.*']), 30 | url='https://github.com/pyro-ppl/pyro-api', 31 | author='Uber AI Labs', 32 | author_email='npradhan@uber.com', 33 | install_requires=[], 34 | extras_require={ 35 | # PyPi does not like @ versions, 36 | # so please comment out the 'test' section when uploading to pypi. 37 | 'test': [ 38 | 'flake8', 39 | 'pytest>=5.0', 40 | 'pyro-ppl@https://api.github.com/repos/pyro-ppl/pyro/tarball/dev', 41 | 'numpyro@https://api.github.com/repos/pyro-ppl/numpyro/tarball/master', 42 | 'funsor@https://api.github.com/repos/pyro-ppl/funsor/tarball/master', 43 | ], 44 | 'dev': [ 45 | 'sphinx>=2.0', 46 | 'sphinx_rtd_theme', 47 | 'ipython', 48 | ], 49 | }, 50 | long_description=long_description, 51 | long_description_content_type='text/markdown', 52 | tests_require=['flake8', 'pytest>=4.1'], 53 | keywords='probabilistic machine learning bayesian statistics', 54 | license='Apache License 2.0', 55 | classifiers=[ 56 | 'Intended Audience :: Developers', 57 | 'Intended Audience :: Education', 58 | 'Intended Audience :: Science/Research', 59 | 'License :: OSI Approved :: Apache Software License', 60 | 'Operating System :: POSIX :: Linux', 61 | 'Operating System :: MacOS :: MacOS X', 62 | 'Programming Language :: Python :: 3.6', 63 | ], 64 | ) 65 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pytest 5 | 6 | 7 | def pytest_configure(config): 8 | try: 9 | import funsor 10 | except ImportError: 11 | pass 12 | else: 13 | funsor.set_backend("torch") 14 | 15 | 16 | def pytest_runtest_call(item): 17 | try: 18 | item.runtest() 19 | except NotImplementedError as e: 20 | pytest.xfail(str(e)) 21 | -------------------------------------------------------------------------------- /test/test_dispatch.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pytest 5 | 6 | from pyroapi import handlers, infer, pyro, pyro_backend, register_backend 7 | from pyroapi.testing import MODELS 8 | 9 | PACKAGE_NAME = { 10 | "pyro": "pyro", 11 | "minipyro": "pyro", 12 | "numpy": "numpyro", 13 | "funsor": "funsor", 14 | } 15 | 16 | 17 | @pytest.mark.filterwarnings("ignore", category=UserWarning) 18 | @pytest.mark.parametrize('model', MODELS) 19 | @pytest.mark.parametrize('backend', [ 20 | "pyro", 21 | pytest.param('numpy', marks=[pytest.mark.xfail( 22 | reason="Signature of numpyro MCMC does not match, numpyro/issues/1321")])]) 23 | def test_mcmc_interface(model, backend): 24 | pytest.importorskip(PACKAGE_NAME[backend]) 25 | with pyro_backend(backend), handlers.seed(rng_seed=20): 26 | f = MODELS[model]() 27 | model, args, kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {}) 28 | nuts_kernel = infer.NUTS(model=model) 29 | mcmc = infer.MCMC(nuts_kernel, num_samples=10, warmup_steps=10) 30 | mcmc.run(*args, **kwargs) 31 | mcmc.summary() 32 | 33 | 34 | @pytest.mark.parametrize('backend', ['funsor', 'minipyro', 'numpy', 'pyro']) 35 | def test_not_implemented(backend): 36 | pytest.importorskip(PACKAGE_NAME[backend]) 37 | with pyro_backend(backend): 38 | pyro.sample # should be implemented 39 | pyro.param # should be implemented 40 | with pytest.raises(NotImplementedError): 41 | pyro.nonexistent_primitive 42 | 43 | 44 | @pytest.mark.parametrize('model', MODELS) 45 | @pytest.mark.parametrize("backend", ["pyro", "minipyro", "numpy", "funsor"]) 46 | @pytest.mark.xfail(reason='Not supported by backend.') 47 | def test_model_sample(model, backend): 48 | pytest.importorskip(PACKAGE_NAME[backend]) 49 | with pyro_backend(backend), handlers.seed(rng_seed=2): 50 | f = MODELS[model]() 51 | model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {}) 52 | model(*model_args, **model_kwargs) 53 | 54 | 55 | @pytest.mark.parametrize('model', MODELS) 56 | @pytest.mark.parametrize('backend', [ 57 | pytest.param("funsor", marks=[pytest.mark.xfail(reason="not implemented")]), 58 | 'minipyro', 59 | 'numpy', 60 | 'pyro', 61 | ]) 62 | def test_trace_handler(model, backend): 63 | pytest.importorskip(PACKAGE_NAME[backend]) 64 | with pyro_backend(backend), handlers.seed(rng_seed=2): 65 | f = MODELS[model]() 66 | model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {}) 67 | # should be implemented 68 | handlers.trace(model).get_trace(*model_args, **model_kwargs) 69 | 70 | 71 | @pytest.mark.parametrize('model', MODELS) 72 | def test_register_backend(model): 73 | pytest.importorskip("pyro") 74 | register_backend("foo", { 75 | "infer": "pyro.contrib.minipyro", 76 | "optim": "pyro.contrib.minipyro", 77 | "pyro": "pyro.contrib.minipyro", 78 | }) 79 | with pyro_backend("foo"): 80 | f = MODELS[model]() 81 | model, model_args, model_kwargs = f['model'], f.get('model_args', ()), f.get('model_kwargs', {}) 82 | handlers.trace(model).get_trace(*model_args, **model_kwargs) 83 | -------------------------------------------------------------------------------- /test/test_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pytest 5 | 6 | from pyroapi import pyro_backend 7 | from pyroapi.tests import * # noqa F401 8 | 9 | pytestmark = pytest.mark.filterwarnings( 10 | "ignore::numpyro.compat.util.UnsupportedAPIWarning", 11 | "ignore:.*loss does not support models with discrete latent variables:UserWarning", 12 | # The behavior of using // with negative numbers is changed in PyTorch. 13 | # But we don't need to worry about it. This UserWarning will be removed in 14 | # a future version of PyTorch. 15 | "ignore:.*floordiv.* is deprecated, and its behavior will change in a future version:UserWarning", 16 | ) 17 | 18 | PACKAGE_NAME = { 19 | "pyro": "pyro", 20 | "minipyro": "pyro", 21 | "numpy": "numpyro", 22 | "funsor": "funsor", 23 | } 24 | 25 | 26 | @pytest.fixture(params=["pyro", "minipyro", "numpy", "funsor"]) 27 | def backend(request): 28 | pytest.importorskip(PACKAGE_NAME[request.param]) 29 | with pyro_backend(request.param): 30 | yield 31 | 32 | 33 | # TODO(fehiepsi): Remove the following when the test passes in numpyro. 34 | _test_mcmc_run_ok = test_mcmc_run_ok # noqa F405 35 | 36 | 37 | @pytest.mark.parametrize("backend", [ 38 | "pyro", 39 | pytest.param("numpy", marks=[ 40 | pytest.mark.xfail(reason="numpyro signature for MCMC is not consistent.")])]) 41 | def test_mcmc_run_ok(backend): 42 | pytest.importorskip(PACKAGE_NAME[backend]) 43 | with pyro_backend(backend): 44 | _test_mcmc_run_ok(backend) 45 | --------------------------------------------------------------------------------